cellmil.models.mil.cellconv

Classes

CellConv(embed_dim[, n_classes, ...])

LitCellConv(model, optimizer[, loss, ...])

LitSurvCellConv(model, optimizer[, loss, ...])

class cellmil.models.mil.cellconv.CellConv(embed_dim: int, n_classes: int = 2, convolution_depth: int = 3, size_arg: list[int] = [512, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0, kernel_size: int = 3)[source]

Bases: Module

__init__(embed_dim: int, n_classes: int = 2, convolution_depth: int = 3, size_arg: list[int] = [512, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0, kernel_size: int = 3) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.cellconv.LitCellConv(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]

Bases: LitGeneral

__init__(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25) None[source]
classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]

Primary way of loading a model from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "hyper_parameters".

Any arguments specified through **kwargs will override args stored in "hyper_parameters".

Parameters:
  • checkpoint_path – Path to checkpoint. This can also be a URL, or file-like object

  • map_location – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    drop_prob: 0.2
    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningModule for use.

    If your model’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your model to treat hparams as dict.

  • strict – Whether to strictly enforce that the keys in checkpoint_path match the keys returned by this module’s state dict. Defaults to True unless LightningModule.strict_loading is set, in which case it defaults to the value of LightningModule.strict_loading.

  • weights_only – If True, restricts loading to state_dicts of plain torch.Tensor and other primitive types. If loading a checkpoint from a trusted source that contains an nn.Module, use weights_only=False. If loading checkpoint from an untrusted source, we recommend using weights_only=True. For more information, please refer to the PyTorch Developer Notes on Serialization Semantics.

  • **kwargs – Any extra keyword args needed to init the model. Can also be used to override saved hyperparameter values.

Returns:

LightningModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You should use your LightningModule class to call it instead of the LightningModule instance, or a TypeError will be raised.

Note

To ensure all layers can be loaded from the checkpoint, this function will call configure_model() directly after instantiating the model if this hook is overridden in your LightningModule. However, note that load_from_checkpoint does not support loading sharded checkpoints, and you may run out of memory if the model is too large. In this case, consider loading through the Trainer via .fit(ckpt_path=...).

Example:

# load weights without mapping ...
model = MyLightningModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights mapping all weights from GPU 1 to GPU 0 ...
map_location = {'cuda:1':'cuda:0'}
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    map_location=map_location
)

# or load weights and hyperparameters from separate files.
model = MyLightningModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
model = MyLightningModule.load_from_checkpoint(
    PATH,
    num_layers=128,
    pretrained_ckpt_path=NEW_PATH,
)

# predict
pretrained_model.eval()
pretrained_model.freeze()
y_hat = pretrained_model(x)
forward(x: Tensor) Tensor[source]

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

get_attention_weights(x: Tensor) Tensor[source]
transfer_batch_to_device(batch: tuple[torch.Tensor, torch.Tensor], device: device, dataloader_idx: int) tuple[torch.Tensor, torch.Tensor][source]

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:
  • batch – A batch of data that needs to be transferred to a new device.

  • device – The target device as defined in PyTorch.

  • dataloader_idx – The index of the dataloader to which the batch belongs.

Returns:

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch

See also

  • move_data_to_device()

  • apply_to_collection()

class cellmil.models.mil.cellconv.LitSurvCellConv(model: CellConv, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]

Bases: LitCellConv

__init__(model: CellConv, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]
_setup_metrics()[source]

Setup C-index and Brier score metrics for survival analysis.

predict_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]

Prediction step returns logits for discrete-time hazard intervals.