cellmil.models.mil.clam¶
Classes
|
Attention Network without Gating. |
|
Attention Network with Sigmoid Gating. |
|
Trainer class for CLAM models. |
|
CLAM Multi-Branch (MB) - Clustering-constrained Attention Multiple Instance Learning model. |
|
CLAM Single Branch (SB) - Clustering-constrained Attention Multiple Instance Learning model. |
|
|
|
Lightning wrapper for CLAM models adapted for survival analysis. |
- class cellmil.models.mil.clam.Attn_Net(L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1)[source]¶
Bases:
ModuleAttention Network without Gating.
This class implements a basic attention mechanism using fully connected layers followed by a tanh activation. It is used to compute attention weights for Multiple Instance Learning (MIL).
- Parameters:
- __init__(L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]¶
Forward pass for the Attention Network.
- Parameters:
x (torch.Tensor) – Input tensor of shape (N, L), where N is batch size and L is the input feature dimension.
- Returns:
The attention scores after processing (shape: N x n_classes)
The original input tensor (shape: N x L)
- Return type:
- class cellmil.models.mil.clam.Attn_Net_Gated(L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1)[source]¶
Bases:
ModuleAttention Network with Sigmoid Gating.
This class implements a gated attention mechanism using two parallel pathways: - One path with linear layer followed by tanh activation - Another path with linear layer followed by sigmoid activation
These paths are combined via element-wise multiplication (gating mechanism) and passed through a final linear layer to compute attention weights.
- Parameters:
L (int, optional) – Input feature dimension. Defaults to 1024.
D (int, optional) – Hidden layer dimension. Defaults to 256.
dropout (bool, optional) – Whether to use dropout (p = 0.25) in both pathways. Defaults to False.
n_classes (int, optional) – Number of classes (determines output dimension). Defaults to 1.
- __init__(L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor) tuple[torch.Tensor, torch.Tensor][source]¶
Forward pass for the Gated Attention Network.
This method implements a gated attention mechanism where two parallel paths process the input: - Path A: Linear -> Tanh activation - Path B: Linear -> Sigmoid activation
These paths are then combined via element-wise multiplication and passed through a final linear layer to produce attention scores.
- Parameters:
x (torch.Tensor) – Input tensor of shape (N, L), where N is batch size and L is the input feature dimension.
- Returns:
The attention scores after the final linear layer (shape: N x n_classes)
The original input tensor (shape: N x L)
- Return type:
- class cellmil.models.mil.clam.CLAM_SB(gate: bool = True, size_arg: Union[Literal['small', 'big'], list[int]] = 'small', dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0)[source]¶
Bases:
ModuleCLAM Single Branch (SB) - Clustering-constrained Attention Multiple Instance Learning model.
This model uses attention mechanisms to aggregate features from multiple instances (patches) in a bag for classification. It supports instance-level evaluation and can handle both binary and multi-class classification problems.
- Parameters:
gate (bool, optional) – Whether to use gated attention network. If True, uses Attn_Net_Gated, otherwise uses Attn_Net. Defaults to True.
size_arg (Literal['small', 'big'], list, optional) – Configuration for network size. ‘small’: [embed_dim, 512, 256], ‘big’: [embed_dim, 512, 384]. Defaults to “small”.
dropout (bool, optional) – Whether to use dropout (p = 0.25) in attention networks and feature layers. Defaults to False.
k_sample (int, optional) – Number of positive/negative patches to sample for instance-level training. Used in inst_eval methods. Defaults to 8.
n_classes (int, optional) – Number of classes for classification. Defaults to 2.
instance_loss_fn (nn.Module, optional) – Loss function to supervise instance-level training. Defaults to nn.CrossEntropyLoss().
subtyping (bool, optional) – Whether this is a subtyping problem. Affects instance-level evaluation for out-of-class samples. Defaults to False.
embed_dim (int, optional) – Input embedding dimension. Defaults to 1024.
temperature (float, optional) – Temperature parameter for softmax. Defaults to 1.0.
- __init__(gate: bool = True, size_arg: Union[Literal['small', 'big'], list[int]] = 'small', dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- static create_positive_targets(length: int, device: device) Tensor[source]¶
Create a tensor of positive targets (all ones).
- Parameters:
length (int) – The length of the tensor to create.
device (torch.device) – The device to create the tensor on.
- Returns:
A tensor of ones of the specified length.
- Return type:
- static create_negative_targets(length: int, device: device) Tensor[source]¶
Create a tensor of negative targets (all zeros).
- Parameters:
length (int) – The length of the tensor to create.
device (torch.device) – The device to create the tensor on.
- Returns:
A tensor of zeros of the specified length.
- Return type:
- inst_eval(a: Tensor, h: Tensor, classifier: Module) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Instance-level evaluation for in-the-class attention branch.
This method evaluates the model at the instance level by: 1. Selecting the top k instances with highest attention scores (positive) 2. Selecting the top k instances with lowest attention scores (negative) 3. Creating targets for these instances 4. Computing loss and predictions using the classifier
- Parameters:
a (torch.Tensor) – Attention scores tensor.
h (torch.Tensor) – Features tensor.
classifier (nn.Module) – Instance-level classifier.
- Returns:
instance_loss: The loss for instance-level classification
all_preds: The predicted labels for all selected instances
all_targets: The target labels for all selected instances
- Return type:
- inst_eval_out(a: Tensor, h: Tensor, classifier: Module) tuple[torch.Tensor, torch.Tensor, torch.Tensor][source]¶
Instance-level evaluation for out-of-the-class attention branch.
This method evaluates the model at the instance level for out-of-class samples by: 1. Selecting the top k instances with highest attention scores 2. Creating negative targets for these instances (since they should be negative for out-of-class) 3. Computing loss and predictions using the classifier
- Parameters:
a (torch.Tensor) – Attention scores tensor.
h (torch.Tensor) – Features tensor.
classifier (nn.Module) – Instance-level classifier.
- Returns:
instance_loss: The loss for instance-level classification
p_preds: The predicted labels for the selected instances
p_targets: The target labels for the selected instances
- Return type:
- forward(h: Tensor, label: Optional[Tensor] = None, instance_eval: bool = False, return_features: bool = False, attention_only: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]][source]¶
Forward pass of the CLAM Single Branch model.
- Parameters:
h (torch.Tensor) – Input feature tensor of shape (N, embed_dim), where N is the number of instances (patches) and embed_dim is the feature dimension.
label (torch.Tensor | None, optional) – Ground truth labels for instance-level evaluation. Required when instance_eval=True. Should be of shape (1,) for single class or (n_classes,) for multi-class. Defaults to None.
instance_eval (bool, optional) – Whether to perform instance-level evaluation and compute instance loss. Requires label to be provided. Defaults to False.
return_features (bool, optional) – Whether to return aggregated features (M) in the results dictionary. Defaults to False.
attention_only (bool, optional) – If True, returns only attention weights without classification. Defaults to False.
- Returns:
If attention_only=True: Returns attention weights tensor of shape (K, N)
- Otherwise: Returns tuple of (logits, Y_prob, Y_hat, a_raw, results_dict) where:
logits (torch.Tensor): Raw classification logits of shape (1, n_classes)
Y_prob (torch.Tensor): Softmax probabilities of shape (1, n_classes)
Y_hat (torch.Tensor): Predicted class indices of shape (1, 1)
a_raw (torch.Tensor): Raw attention weights before softmax of shape (K, N)
- results_dict (dict): Dictionary containing:
’instance_loss’: Instance-level loss (if instance_eval=True)
’inst_labels’: Instance-level target labels (if instance_eval=True)
’inst_preds’: Instance-level predictions (if instance_eval=True)
’features’: Aggregated features M (if return_features=True)
- Return type:
- class cellmil.models.mil.clam.CLAM_MB(gate: bool = True, size_arg: Union[Literal['small', 'big'], list[int]] = 'small', dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0)[source]¶
Bases:
CLAM_SBCLAM Multi-Branch (MB) - Clustering-constrained Attention Multiple Instance Learning model.
This class extends CLAM_SB by using a multi-branch architecture where each class has its own attention branch and classifier. This architecture is more suitable for multi-class classification problems.
- Parameters:
gate (bool, optional) – Whether to use gated attention network. Defaults to True.
size_arg (Literal["small", "big"], list, optional) – Configuration for network size. Defaults to “small”.
dropout (bool, optional) – Whether to use dropout. Defaults to False.
k_sample (int, optional) – Number of positive/negative patches to sample for instance-level training. Defaults to 8.
n_classes (int, optional) – Number of classes. Defaults to 2.
instance_loss_fn (nn.Module, optional) – Loss function for instance-level training. Defaults to nn.CrossEntropyLoss().
subtyping (bool, optional) – Whether it’s a subtyping problem. Defaults to False.
embed_dim (int, optional) – Input embedding dimension. Defaults to 1024.
temperature (float, optional) – Temperature parameter for softmax. Defaults to 1.0.
- __init__(gate: bool = True, size_arg: Union[Literal['small', 'big'], list[int]] = 'small', dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(h: Tensor, label: Optional[Tensor] = None, instance_eval: bool = False, return_features: bool = False, attention_only: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]][source]¶
Forward pass of the CLAM Multi-Branch model.
This method extends the CLAM_SB forward pass by using multiple attention branches, one for each class. Each branch computes its own attention weights and features, which are then processed by class-specific classifiers.
- Parameters:
h (torch.Tensor) – Input feature tensor of shape (N, embed_dim), where N is the number of instances (patches) and embed_dim is the feature dimension.
label (torch.Tensor | None, optional) – Ground truth labels for instance-level evaluation. Required when instance_eval=True. Defaults to None.
instance_eval (bool, optional) – Whether to perform instance-level evaluation and compute instance loss. Defaults to False.
return_features (bool, optional) – Whether to return aggregated features in the results dictionary. Defaults to False.
attention_only (bool, optional) – If True, returns only attention weights without classification. Defaults to False.
- Returns:
If attention_only=True: Returns attention weights tensor of shape (K, N)
- Otherwise: Returns tuple of (logits, Y_prob, Y_hat, a_raw, results_dict) where:
logits (torch.Tensor): Raw classification logits of shape (1, n_classes)
Y_prob (torch.Tensor): Softmax probabilities of shape (1, n_classes)
Y_hat (torch.Tensor): Predicted class indices of shape (1, 1)
a_raw (torch.Tensor): Raw attention weights before softmax of shape (K, N)
- results_dict (dict): Dictionary containing:
’instance_loss’: Instance-level loss (if instance_eval=True)
’inst_labels’: Instance-level target labels (if instance_eval=True)
’inst_preds’: Instance-level predictions (if instance_eval=True)
’features’: Aggregated features M (if return_features=True)
- Return type:
- class cellmil.models.mil.clam.CLAMTrainerLegacy(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: ~torch.optim.optimizer.Optimizer, device: str, ckpt_path: ~pathlib.Path, weight_loss_slide: float = 0.7, loss_slide: ~torch.nn.modules.module.Module = CrossEntropyLoss(), early_stopping: cellmil.models.mil.utils.EarlyStopping | None = <cellmil.models.mil.utils.EarlyStopping object>, use_wandb: bool = True, scale_attention_grads_by_bag: bool = False, attn_ref_bag_size: int = 100000, attn_alpha: float = 0.5)[source]¶
Bases:
objectTrainer class for CLAM models.
This class handles the training loop, validation, and evaluation of CLAM models. It supports early stopping, metrics logging, and checkpointing.
- Parameters:
optimizer (torch.optim.Optimizer) – Optimizer for model training.
device (str) – Device to use for training (“cuda”, “cpu”, etc.)
ckpt_path (Path) – Path to save checkpoints.
weight_loss_slide (float, optional) – Weight for slide-level loss. Defaults to 0.7.
loss_slide (nn.Module, optional) – Loss function for slide-level classification. Defaults to nn.CrossEntropyLoss().
early_stopping (EarlyStopping | None, optional) – Early stopping controller. Defaults to EarlyStopping with patience=20.
use_wandb (bool, optional) – Whether to log metrics to Weights & Biases. Defaults to True.
- __init__(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: ~torch.optim.optimizer.Optimizer, device: str, ckpt_path: ~pathlib.Path, weight_loss_slide: float = 0.7, loss_slide: ~torch.nn.modules.module.Module = CrossEntropyLoss(), early_stopping: cellmil.models.mil.utils.EarlyStopping | None = <cellmil.models.mil.utils.EarlyStopping object>, use_wandb: bool = True, scale_attention_grads_by_bag: bool = False, attn_ref_bag_size: int = 100000, attn_alpha: float = 0.5)[source]¶
- fit(train_loader: DataLoader[tuple[torch.Tensor, int]], val_loader: DataLoader[tuple[torch.Tensor, int]], epochs: int)[source]¶
Train the CLAM model.
This method runs the training loop for a specified number of epochs, with validation after each epoch. It supports early stopping and saves the best model based on validation loss.
- Parameters:
train_loader (DataLoader[tuple[torch.Tensor, int]]) – DataLoader for training data.
val_loader (DataLoader[tuple[torch.Tensor, int]]) – DataLoader for validation data.
epochs (int) – Number of epochs to train for.
- _log_epoch_metrics(epoch: int, train_metrics: dict[str, Any], val_metrics: dict[str, Any]) None[source]¶
Log metrics for the current epoch.
This method logs training and validation metrics to the logger and optionally to Weights & Biases if enabled.
- _train_epoch(epoch: int, train_loader: DataLoader[tuple[torch.Tensor, int]]) dict[str, float | int | None][source]¶
Train the model for one epoch.
This method processes all batches in the training data loader for one epoch. It computes loss, performs backpropagation, and collects training metrics.
- _val(epoch: int, val_loader: DataLoader[tuple[torch.Tensor, int]]) dict[str, float | int | None][source]¶
Validate the model on the validation set.
This method evaluates the model on the validation data and computes various metrics including loss, error rate, AUC, and class-specific accuracy.
- static calculate_error(Y_hat: Tensor, Y: Tensor)[source]¶
Calculate classification error.
- Parameters:
Y_hat (torch.Tensor) – Predicted labels.
Y (torch.Tensor) – Ground truth labels.
- Returns:
Error rate (1 - accuracy).
- Return type:
- eval(test_loader: DataLoader[tuple[torch.Tensor, int]])[source]¶
Evaluate the model on a test set.
- Parameters:
test_loader (DataLoader[tuple[torch.Tensor, int]]) – DataLoader for test data.
Note
This is a placeholder method for model evaluation. Implementation is not provided.
- class cellmil.models.mil.clam.LitCLAM(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: Optimizer, loss_slide: Module = CrossEntropyLoss(), weight_loss_slide: float = 0.7, lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50)[source]¶
Bases:
LightningModule- static _is_gated_attention(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB) bool[source]¶
Check if model uses gated attention.
- static _get_size_args(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB) list[int][source]¶
Extract L and D parameters from attention network (size[1] and size[2]).
- __init__(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: Optimizer, loss_slide: Module = CrossEntropyLoss(), weight_loss_slide: float = 0.7, lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50)[source]¶
- classmethod load_from_checkpoint(checkpoint_path: Union[str, Path, IO[bytes]], map_location: Optional[Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], dict[torch.device | str | int, torch.device | str | int]]] = None, hparams_file: Optional[Union[str, Path]] = None, strict: Optional[bool] = None, **kwargs: Any) Self[source]¶
Load a LitCLAM model from a checkpoint file.
- Parameters:
checkpoint_path (str | Path | IO[bytes]) – Path to the checkpoint file.
map_location – Device mapping for loading the model.
hparams_file (str | Path | None) – Optional path to a YAML file with hyperparameters.
strict (bool | None) – Whether to strictly enforce that the keys in state_dict match the keys returned by the model’s state_dict function.
**kwargs – Additional keyword arguments.
- Returns:
The loaded LitCLAM model.
- Return type:
- forward(x: Tensor, label: Optional[Tensor] = None, instance_eval: bool = True)[source]¶
Same as
torch.nn.Module.forward().- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- training_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary which can include any keys, but must include the key'loss'in the case of automatic optimization.None- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches> 1, the loss returned here will be automatically normalized byaccumulate_grad_batchesinternally.
- validation_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Operates on a single batch of data from the validation set. In this step you’d might generate examples or calculate anything of interest like accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one val dataloader: def validation_step(self, batch, batch_idx): ... # if you have multiple val dataloaders: def validation_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single validation dataset def validation_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'val_loss': loss, 'val_acc': val_acc})
If you pass in multiple val dataloaders,
validation_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple validation dataloaders def validation_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
Note
If you don’t need to validate you don’t need to implement this method.
Note
When the
validation_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of validation, the model goes back to training mode and gradients are enabled.
- test_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Operates on a single batch of data from the test set. In this step you’d normally generate examples or calculate anything of interest such as accuracy.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor- The loss tensordict- A dictionary. Can include any keys, but must include the key'loss'.None- Skip to the next batch.
# if you have one test dataloader: def test_step(self, batch, batch_idx): ... # if you have multiple test dataloaders: def test_step(self, batch, batch_idx, dataloader_idx=0): ...
Examples:
# CASE 1: A single test dataset def test_step(self, batch, batch_idx): x, y = batch # implement your own out = self(x) loss = self.loss(out, y) # log 6 example images # or generated text... or whatever sample_imgs = x[:6] grid = torchvision.utils.make_grid(sample_imgs) self.logger.experiment.add_image('example_images', grid, 0) # calculate acc labels_hat = torch.argmax(out, dim=1) test_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs! self.log_dict({'test_loss': loss, 'test_acc': test_acc})
If you pass in multiple test dataloaders,
test_step()will have an additional argument. We recommend setting the default value of 0 so that you can quickly switch between single and multiple dataloaders.# CASE 2: multiple test dataloaders def test_step(self, batch, batch_idx, dataloader_idx=0): # dataloader_idx tells you which dataset this is. x, y = batch # implement your own out = self(x) if dataloader_idx == 0: loss = self.loss0(out, y) else: loss = self.loss1(out, y) # calculate acc labels_hat = torch.argmax(out, dim=1) acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0) # log the outputs separately for each dataloader self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
Note
If you don’t need to test you don’t need to implement this method.
Note
When the
test_step()is called, the model has been put in eval mode and PyTorch gradients have been disabled. At the end of the test epoch, the model goes back to training mode and gradients are enabled.
- on_train_epoch_end() None[source]¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModuleand access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- _flatten_and_log_metrics(computed: dict[str, torch.Tensor], prefix: str) None[source]¶
Convert metric dictionary produced by torchmetrics into a flat dict of scalar values and log it with self.log_dict.
Vector/tensor metrics (e.g. per-class accuracy) are expanded into keys like {prefix}/class_{i}_acc.
Scalar tensors are converted to floats.
None values are converted to NaN to satisfy loggers that expect numeric scalars.
- configure_optimizers()[source]¶
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.
- Returns:
Any of these 6 options.
Single optimizer.
List or Tuple of optimizers.
Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple
lr_scheduler_config).Dictionary, with an
"optimizer"key, and (optionally) a"lr_scheduler"key whose value is a single LR scheduler orlr_scheduler_config.None - Fit will run without any optimizer.
The
lr_scheduler_configis a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.lr_scheduler_config = { # REQUIRED: The scheduler instance "scheduler": lr_scheduler, # The unit of the scheduler's step size, could also be 'step'. # 'epoch' updates the scheduler on epoch end whereas 'step' # updates it after a optimizer update. "interval": "epoch", # How many epochs/steps should pass between calls to # `scheduler.step()`. 1 corresponds to updating the learning # rate after every epoch/step. "frequency": 1, # Metric to monitor for schedulers like `ReduceLROnPlateau` "monitor": "val_loss", # If set to `True`, will enforce that the value specified 'monitor' # is available when the scheduler is updated, thus stopping # training if not found. If set to `False`, it will only produce a warning "strict": True, # If using the `LearningRateMonitor` callback to monitor the # learning rate progress, this keyword can be used to specify # a custom logged name "name": None, }
When there are schedulers in which the
.step()method is conditioned on a value, such as thetorch.optim.lr_scheduler.ReduceLROnPlateauscheduler, Lightning requires that thelr_scheduler_configcontains the keyword"monitor"set to the metric name that the scheduler should be conditioned on.Metrics can be made available to monitor by simply logging it using
self.log('metric_to_track', metric_val)in yourLightningModule.Note
Some things to know:
Lightning calls
.backward()and.step()automatically in case of automatic optimization.If a learning rate scheduler is specified in
configure_optimizers()with key"interval"(default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s.step()method automatically in case of automatic optimization.If you use 16-bit precision (
precision=16), Lightning will automatically handle the optimizer.If you use
torch.optim.LBFGS, Lightning handles the closure function automatically for you.If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.
If you need to control how often the optimizer steps, override the
optimizer_step()hook.
- predict_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int) Any[source]¶
Step function called during
predict(). By default, it callsforward(). Override to add any processing logic.The
predict_step()is used to scale inference on multi-devices.To prevent an OOM error, it is possible to use
BasePredictionWritercallback to write the predictions to disk or database after each batch or on epoch end.The
BasePredictionWritershould be used while using a spawn based accelerator. This happens forTrainer(strategy="ddp_spawn")or training on 8 TPU cores withTrainer(accelerator="tpu", devices=8)as predictions won’t be returned.- Parameters:
batch – The output of your data iterable, normally a
DataLoader.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Predicted output (optional).
Example
class MyModel(LightningModule): def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) dm = ... model = MyModel() trainer = Trainer(accelerator="gpu", devices=2) predictions = trainer.predict(model, dm)
- get_attention_weights(x: Tensor) Tensor[source]¶
Get attention weights for a bag of instances.
- Parameters:
x (torch.Tensor) – Input tensor of shape [n_instances, feat_dim].
- Returns:
Attention weights of shape [n_classes, n_instances].
- Return type:
- class cellmil.models.mil.clam.LitSurvCLAM(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: Optimizer, loss_slide: Module = NegativeLogLikelihoodSurvLoss(), weight_loss_slide: float = 0.7, lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50)[source]¶
Bases:
LitCLAMLightning wrapper for CLAM models adapted for survival analysis.
This class extends LitCLAM to support survival analysis tasks using discrete-time survival models with logistic hazard parameterization. Only overrides the metrics setup to use survival-specific metrics.
- Parameters:
model (CLAM_MB | CLAM_SB) – The CLAM model instance (SB or MB).
optimizer (torch.optim.Optimizer) – Optimizer for training.
loss_slide (nn.Module, optional) – Loss function for survival. Defaults to NegativeLogLikelihoodSurvLoss().
weight_loss_slide (float, optional) – Weight for slide-level loss. Defaults to 0.7.
lr_scheduler (LRScheduler | None, optional) – Learning rate scheduler. Defaults to None.
use_aem (bool, optional) – Whether to use AEM regularization. Defaults to False.
aem_weight_initial (float, optional) – Initial weight for AEM loss. Defaults to 0.001.
aem_weight_final (float, optional) – Final weight for AEM loss after annealing. Defaults to 0.0.
aem_annealing_epochs (int, optional) – Number of epochs to anneal AEM weight. Defaults to 50.
- __init__(model: cellmil.models.mil.clam.CLAM_MB | cellmil.models.mil.clam.CLAM_SB, optimizer: Optimizer, loss_slide: Module = NegativeLogLikelihoodSurvLoss(), weight_loss_slide: float = 0.7, lr_scheduler: Optional[LRScheduler] = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50)[source]¶
- predict_step(batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int)[source]¶
Prediction step returns logits for discrete-time hazard intervals.