cellmil.models.mil.multifocus¶
Classes
|
|
|
- class cellmil.models.mil.multifocus.MultiFocus(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [32], temperature: float = 1.0, dropout: float = 0.0)[source]¶
Bases:
Module- __init__(embed_dim: int, n_classes: int = 2, size_arg: list[int] = [32], temperature: float = 1.0, dropout: float = 0.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.multifocus.LitMultiFocus(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = True, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25)[source]¶
Bases:
LitGeneral- __init__(model: Module, optimizer: Optimizer, loss: Module = CrossEntropyLoss(), lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 0.8, use_aem: bool = True, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25) None[source]¶
- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]¶
Load a model from a checkpoint.
- Parameters:
checkpoint_path (str | Path | IO[bytes]) – Path to the checkpoint file or a file-like object.
map_location (optional) – Device mapping for loading the model.
hparams_file (optional) – Path to a YAML file containing hyperparameters.
strict (optional) – Whether to strictly enforce that the keys in state_dict match the keys returned by the model’s state_dict function.
**kwargs – Additional keyword arguments passed to the model’s constructor.
- Returns:
An instance of LitAttentionDeepMIL.
- forward(x: Tensor) Tensor[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- get_attention_weights(x: Tensor) Tensor[source]¶
Get attention weights for the input instances.
- Parameters:
x (torch.Tensor) – Input tensor of shape [n_instances, feat_dim].
- Returns:
Attention weights of shape [embed_dim, n_instances].
- Return type: