cellmil.models.mil.attentiondeepmil¶
Classes
|
|
|
Lightning wrapper for AttentionDeepMIL model . |
|
- class cellmil.models.mil.attentiondeepmil.AttentionDeepMIL(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [500, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0)[source]¶
Bases:
Module- __init__(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [500, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.attentiondeepmil.LitAttentionDeepMIL(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitGeneralLightning wrapper for AttentionDeepMIL model .
This class extends the base LitGeneral class to provide Lightning-specific functionality for the AttentionDeepMIL model..
- Parameters:
model (nn.Module) – The AttentionDeepMIL model instance.
optimizer (torch.optim.Optimizer) – Optimizer for training.
loss (nn.Module, optional) – Loss function. Defaults to nn.CrossEntropyLoss().
lr_scheduler (LRScheduler | None, optional) – Learning rate scheduler. Defaults to None.
subsampling (float, optional) – Fraction of instances to use during training (between 0 and 1). Defaults to 1.0 (no subsampling).
use_aem (bool, optional) – Whether to use AEM regularization. Defaults to False.
aem_weight_initial (float, optional) – Initial weight for AEM loss. Defaults to 0.001.
aem_weight_final (float, optional) – Final weight for AEM loss after annealing. Defaults to 0.0.
aem_annealing_epochs (int, optional) – Number of epochs to anneal AEM weight. Defaults to 25.
- __init__(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]¶
Load a model from a checkpoint.
- Parameters:
checkpoint_path (str | Path | IO[bytes]) – Path to the checkpoint file or a file-like object.
map_location (optional) – Device mapping for loading the model.
hparams_file (optional) – Path to a YAML file containing hyperparameters.
strict (optional) – Whether to strictly enforce that the keys in state_dict match the keys returned by the model’s state_dict function.
**kwargs – Additional keyword arguments passed to the model’s constructor.
- Returns:
An instance of LitAttentionDeepMIL.
- forward(x: Tensor) Tensor[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- get_attention_weights(x: Tensor) Tensor[source]¶
Get attention weights for the input instances.
- Parameters:
x (torch.Tensor) – Input tensor of shape [n_instances, feat_dim].
- Returns:
Attention weights of shape [attention_branches, n_instances].
- Return type:
- class cellmil.models.mil.attentiondeepmil.LitSurvAttentionDeepMIL(model: AttentionDeepMIL, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitAttentionDeepMIL- __init__(model: AttentionDeepMIL, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
- predict_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Prediction step returns logits for discrete-time hazard intervals.