cellmil.models.mil.cellconv¶
Classes
|
|
|
|
|
- class cellmil.models.mil.cellconv.CellConv(embed_dim: int, n_classes: int = 2, convolution_depth: int = 3, size_arg: list[int] = [512, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0, kernel_size: int = 3)[source]¶
Bases:
Module- __init__(embed_dim: int, n_classes: int = 2, convolution_depth: int = 3, size_arg: list[int] = [512, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0, kernel_size: int = 3) None[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.cellconv.LitCellConv(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitGeneral- __init__(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25) None[source]¶
- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]¶
Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to
__init__in the checkpoint under"hyper_parameters".Any arguments specified through **kwargs will override args stored in
"hyper_parameters".- Parameters:
checkpoint_path – Path to checkpoint. This can also be a URL, or file-like object
map_location – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in
torch.load().hparams_file –
Optional path to a
.yamlor.csvfile with hierarchical structure as in this example:drop_prob: 0.2 dataloader: batch_size: 32
You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a
.yamlfile with the hparams you’d like to use. These will be converted into adictand passed into yourLightningModulefor use.If your model’s
hparamsargument isNamespaceand.yamlfile has hierarchical structure, you need to refactor your model to treathparamsasdict.strict – Whether to strictly enforce that the keys in
checkpoint_pathmatch the keys returned by this module’s state dict. Defaults toTrueunlessLightningModule.strict_loadingis set, in which case it defaults to the value ofLightningModule.strict_loading.weights_only – If
True, restricts loading tostate_dictsof plaintorch.Tensorand other primitive types. If loading a checkpoint from a trusted source that contains annn.Module, useweights_only=False. If loading checkpoint from an untrusted source, we recommend usingweights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.**kwargs – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.
- Returns:
LightningModuleinstance with loaded weights and hyperparameters (if available).
Note
load_from_checkpointis a class method. You should use yourLightningModuleclass to call it instead of theLightningModuleinstance, or aTypeErrorwill be raised.Note
To ensure all layers can be loaded from the checkpoint, this function will call
configure_model()directly after instantiating the model if this hook is overridden in your LightningModule. However, note thatload_from_checkpointdoes not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via.fit(ckpt_path=...).Example:
# load weights without mapping ... model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt') # or load weights mapping all weights from GPU 1 to GPU 0 ... map_location = {'cuda:1':'cuda:0'} model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', map_location=map_location ) # or load weights and hyperparameters from separate files. model = MyLightningModule.load_from_checkpoint( 'path/to/checkpoint.ckpt', hparams_file='/path/to/hparams_file.yaml' ) # override some of the params with new values model = MyLightningModule.load_from_checkpoint( PATH, num_layers=128, pretrained_ckpt_path=NEW_PATH, ) # predict pretrained_model.eval() pretrained_model.freeze() y_hat = pretrained_model(x)
- forward(x: Tensor) Tensor[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- transfer_batch_to_device(batch: tuple[torch.Tensor, torch.Tensor], device: device, dataloader_idx: int) tuple[torch.Tensor, torch.Tensor][source]¶
Override this hook if your
DataLoaderreturns tensors wrapped in a custom data structure.The data types listed below (and any arbitrary nesting of them) are supported out of the box:
torch.Tensoror anything that implements .to(…)
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).
Note
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predictingso that you can add different logic as per your requirement.- Parameters:
batch – A batch of data that needs to be transferred to a new device.
device – The target device as defined in PyTorch.
dataloader_idx – The index of the dataloader to which the batch belongs.
- Returns:
A reference to the data on the new device.
Example:
def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch
See also
move_data_to_device()apply_to_collection()
- class cellmil.models.mil.cellconv.LitSurvCellConv(model: CellConv, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitCellConv- __init__(model: CellConv, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
- predict_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Prediction step returns logits for discrete-time hazard intervals.