cellmil.models.mil.head4type¶
Classes
|
|
|
Lightning wrapper for Head4Type model. |
|
- class cellmil.models.mil.head4type.Head4Type(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [512, 128], temperature: float = 1.0, cell_types: int = 5, heads_aggregation: Literal['weighted_mean', 'attention', 'mean', 'concatenation', 'custom'] = 'custom', dropout: float = 0.0, custom_aggregation_weights: list[float] | None = [3.0, 2.0, 1.0, 0.0, 0.0])[source]¶
Bases:
Module- __init__(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [512, 128], temperature: float = 1.0, cell_types: int = 5, heads_aggregation: Literal['weighted_mean', 'attention', 'mean', 'concatenation', 'custom'] = 'custom', dropout: float = 0.0, custom_aggregation_weights: list[float] | None = [3.0, 2.0, 1.0, 0.0, 0.0])[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, cell_types: Tensor) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.head4type.LitHead4Type(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitGeneralLightning wrapper for Head4Type model.
This class extends the base LitGeneral class to provide Lightning-specific functionality for the Ours model.
- Parameters:
model (nn.Module) – The Ours model instance.
optimizer (torch.optim.Optimizer) – Optimizer for training.
loss (nn.Module, optional) – Loss function. Defaults to nn.CrossEntropyLoss().
lr_scheduler (LRScheduler | None, optional) – Learning rate scheduler. Defaults to None.
use_aem (bool, optional) – Whether to use AEM regularization. Defaults to False.
aem_weight_initial (float, optional) – Initial weight for AEM loss. Defaults to 0.001.
aem_weight_final (float, optional) – Final weight for AEM loss after annealing. Defaults to 0.0.
aem_annealing_epochs (int, optional) – Number of epochs to anneal AEM weight. Defaults to 50.
- __init__(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25) None[source]¶
- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]¶
Load a model from a checkpoint.
- Parameters:
checkpoint_path (str | Path | IO[bytes]) – Path to the checkpoint file or a file-like object.
map_location (optional) – Device mapping for loading the model.
hparams_file (optional) – Path to a YAML file containing hyperparameters.
strict (optional) – Whether to strictly enforce that the keys in state_dict match the keys returned by the model’s state_dict function.
**kwargs – Additional keyword arguments passed to the model’s constructor.
- Returns:
An instance of LitAttentionDeepMIL.
- forward(x: Tensor, cell_types: Tensor) Tensor[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- get_attention_weights(x: Tensor, cell_types: Tensor) Tensor[source]¶
Get attention weights for the input instances.
- Parameters:
x (torch.Tensor) – Input tensor of shape [n_instances, feat_dim].
cell_types (torch.Tensor) – Cell type tensor of shape [n_instances, n_cell_types].
- Returns:
Attention weights of shape [cell_types, n_instances].
- Return type:
- class cellmil.models.mil.head4type.LitSurvHead4Type(model: Head4Type, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = True, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitHead4Type- __init__(model: Head4Type, optimizer: Optimizer, loss: Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = True, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
- predict_step(batch: tuple[torch.Tensor, torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Prediction step returns logits for discrete-time hazard intervals.