Source code for cellmil.models.mil.clam

# -*- coding: utf-8 -*-
# CLAM Model Implementation
#
# References:
# Data-efficient and weakly supervised computational pathology on whole-slide images
# Lu, Ming Y et al., Nature Biomedical Engineering, 2021
# DOI: https://doi.org/10.1038/s41551-021-00707-9

from typing_extensions import Self
import torch
import wandb
import torchmetrics
import numpy as np
import torch.nn as nn
import lightning as Pl
import torch.nn.functional as F
from tqdm import tqdm
from pathlib import Path
from typing import IO, Any, Callable, Literal, cast
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import LRScheduler
from sklearn.preprocessing import label_binarize  # type: ignore
from sklearn.metrics import roc_auc_score, roc_curve  # type: ignore
from sklearn.metrics import auc as calc_auc
from .utils import EarlyStopping, Accuracy_Logger, AEM
from cellmil.utils import logger
from cellmil.utils.train.losses import NegativeLogLikelihoodSurvLoss
from cellmil.utils.train.metrics import ConcordanceIndex, BrierScore
from topk.svm import SmoothTop1SVM  # type: ignore


[docs]class Attn_Net(nn.Module): """ Attention Network without Gating. This class implements a basic attention mechanism using fully connected layers followed by a tanh activation. It is used to compute attention weights for Multiple Instance Learning (MIL). Args: L (int, optional): Input feature dimension. Defaults to 1024. D (int, optional): Hidden layer dimension. Defaults to 256. dropout (bool, optional): Whether to use dropout (p = 0.25). Defaults to False. n_classes (int, optional): Number of classes (determines output dimension). Defaults to 1. """
[docs] def __init__( self, L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1 ): super(Attn_Net, self).__init__() # type: ignore _module: list[nn.Module] = [nn.Linear(L, D), nn.Tanh()] if dropout: _module.append(nn.Dropout(0.25)) _module.append(nn.Linear(D, n_classes)) self.module = nn.Sequential(*_module)
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for the Attention Network. Args: x (torch.Tensor): Input tensor of shape (N, L), where N is batch size and L is the input feature dimension. Returns: tuple[torch.Tensor, torch.Tensor]: - The attention scores after processing (shape: N x n_classes) - The original input tensor (shape: N x L) """ return self.module(x), x # N x n_classes
[docs]class Attn_Net_Gated(nn.Module): """ Attention Network with Sigmoid Gating. This class implements a gated attention mechanism using two parallel pathways: - One path with linear layer followed by tanh activation - Another path with linear layer followed by sigmoid activation These paths are combined via element-wise multiplication (gating mechanism) and passed through a final linear layer to compute attention weights. Args: L (int, optional): Input feature dimension. Defaults to 1024. D (int, optional): Hidden layer dimension. Defaults to 256. dropout (bool, optional): Whether to use dropout (p = 0.25) in both pathways. Defaults to False. n_classes (int, optional): Number of classes (determines output dimension). Defaults to 1. """
[docs] def __init__( self, L: int = 1024, D: int = 256, dropout: bool = False, n_classes: int = 1 ): super(Attn_Net_Gated, self).__init__() # type: ignore _attention_a: list[nn.Module] = [nn.Linear(L, D), nn.Tanh()] _attention_b: list[nn.Module] = [nn.Linear(L, D), nn.Sigmoid()] if dropout: _attention_a.append(nn.Dropout(0.25)) _attention_b.append(nn.Dropout(0.25)) self.attention_a = nn.Sequential(*_attention_a) self.attention_b = nn.Sequential(*_attention_b) self.attention_c = nn.Linear(D, n_classes)
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Forward pass for the Gated Attention Network. This method implements a gated attention mechanism where two parallel paths process the input: - Path A: Linear -> Tanh activation - Path B: Linear -> Sigmoid activation These paths are then combined via element-wise multiplication and passed through a final linear layer to produce attention scores. Args: x (torch.Tensor): Input tensor of shape (N, L), where N is batch size and L is the input feature dimension. Returns: tuple[torch.Tensor, torch.Tensor]: - The attention scores after the final linear layer (shape: N x n_classes) - The original input tensor (shape: N x L) """ a = self.attention_a(x) b = self.attention_b(x) c = self.attention_c(a.mul(b)) # N x n_classes return c, x
[docs]class CLAM_SB(nn.Module): """ CLAM Single Branch (SB) - Clustering-constrained Attention Multiple Instance Learning model. This model uses attention mechanisms to aggregate features from multiple instances (patches) in a bag for classification. It supports instance-level evaluation and can handle both binary and multi-class classification problems. Args: gate (bool, optional): Whether to use gated attention network. If True, uses Attn_Net_Gated, otherwise uses Attn_Net. Defaults to True. size_arg (Literal['small', 'big'], list, optional): Configuration for network size. 'small': [embed_dim, 512, 256], 'big': [embed_dim, 512, 384]. Defaults to "small". dropout (bool, optional): Whether to use dropout (p = 0.25) in attention networks and feature layers. Defaults to False. k_sample (int, optional): Number of positive/negative patches to sample for instance-level training. Used in inst_eval methods. Defaults to 8. n_classes (int, optional): Number of classes for classification. Defaults to 2. instance_loss_fn (nn.Module, optional): Loss function to supervise instance-level training. Defaults to nn.CrossEntropyLoss(). subtyping (bool, optional): Whether this is a subtyping problem. Affects instance-level evaluation for out-of-class samples. Defaults to False. embed_dim (int, optional): Input embedding dimension. Defaults to 1024. temperature (float, optional): Temperature parameter for softmax. Defaults to 1.0. """
[docs] def __init__( self, gate: bool = True, size_arg: Literal["small", "big"] | list[int] = "small", dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: nn.Module = SmoothTop1SVM(n_classes=2).cuda() if torch.cuda.is_available() else SmoothTop1SVM(n_classes=2), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0, ): super().__init__() # type: ignore self.size_dict = {"small": [embed_dim, 512, 256], "big": [embed_dim, 512, 384]} if isinstance(size_arg, list): if len(size_arg) != 2: raise ValueError("size_arg must be a list of length 2") size = [embed_dim, *size_arg] else: size = self.size_dict[size_arg] fc: list[nn.Module] = [ nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout), ] if gate: attention_net = Attn_Net_Gated( L=size[1], D=size[2], dropout=dropout, n_classes=1 ) else: attention_net = Attn_Net(L=size[1], D=size[2], dropout=dropout, n_classes=1) fc.append(attention_net) self.attention_net = nn.Sequential(*fc) self.classifiers = nn.Linear(size[1], n_classes) instance_classifiers = [nn.Linear(size[1], 2) for _ in range(n_classes)] self.instance_classifiers = nn.ModuleList(instance_classifiers) self.k_sample = k_sample self.instance_loss_fn = instance_loss_fn self.n_classes = n_classes self.subtyping = subtyping self.temperature = temperature
def __str__(self) -> str: return "<CLAM_SB>"
[docs] @staticmethod def create_positive_targets(length: int, device: torch.device) -> torch.Tensor: """ Create a tensor of positive targets (all ones). Args: length (int): The length of the tensor to create. device (torch.device): The device to create the tensor on. Returns: torch.Tensor: A tensor of ones of the specified length. """ return torch.full((length,), 1, device=device).long()
[docs] @staticmethod def create_negative_targets(length: int, device: torch.device) -> torch.Tensor: """ Create a tensor of negative targets (all zeros). Args: length (int): The length of the tensor to create. device (torch.device): The device to create the tensor on. Returns: torch.Tensor: A tensor of zeros of the specified length. """ return torch.full((length,), 0, device=device).long()
# instance-level evaluation for in-the-class attention branch
[docs] def inst_eval( self, a: torch.Tensor, h: torch.Tensor, classifier: nn.Module ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Instance-level evaluation for in-the-class attention branch. This method evaluates the model at the instance level by: 1. Selecting the top k instances with highest attention scores (positive) 2. Selecting the top k instances with lowest attention scores (negative) 3. Creating targets for these instances 4. Computing loss and predictions using the classifier Args: a (torch.Tensor): Attention scores tensor. h (torch.Tensor): Features tensor. classifier (nn.Module): Instance-level classifier. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - instance_loss: The loss for instance-level classification - all_preds: The predicted labels for all selected instances - all_targets: The target labels for all selected instances """ device = h.device if len(a.shape) == 1: a = a.view(1, -1) top_p_ids = torch.topk(a, self.k_sample)[1][-1] top_p = torch.index_select(h, dim=0, index=top_p_ids) top_n_ids = torch.topk(-a, self.k_sample, dim=1)[1][-1] top_n = torch.index_select(h, dim=0, index=top_n_ids) p_targets = self.create_positive_targets(self.k_sample, device) n_targets = self.create_negative_targets(self.k_sample, device) all_targets = torch.cat([p_targets, n_targets], dim=0) all_instances = torch.cat([top_p, top_n], dim=0) logits = classifier(all_instances) all_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1) instance_loss = self.instance_loss_fn(logits, all_targets) return instance_loss, all_preds, all_targets
# instance-level evaluation for out-of-the-class attention branch
[docs] def inst_eval_out( self, a: torch.Tensor, h: torch.Tensor, classifier: nn.Module ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Instance-level evaluation for out-of-the-class attention branch. This method evaluates the model at the instance level for out-of-class samples by: 1. Selecting the top k instances with highest attention scores 2. Creating negative targets for these instances (since they should be negative for out-of-class) 3. Computing loss and predictions using the classifier Args: a (torch.Tensor): Attention scores tensor. h (torch.Tensor): Features tensor. classifier (nn.Module): Instance-level classifier. Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - instance_loss: The loss for instance-level classification - p_preds: The predicted labels for the selected instances - p_targets: The target labels for the selected instances """ device = h.device if len(a.shape) == 1: a = a.view(1, -1) top_p_ids = torch.topk(a, self.k_sample)[1][-1] top_p = torch.index_select(h, dim=0, index=top_p_ids) p_targets = self.create_negative_targets(self.k_sample, device) logits = classifier(top_p) p_preds = torch.topk(logits, 1, dim=1)[1].squeeze(1) instance_loss = self.instance_loss_fn(logits, p_targets) return instance_loss, p_preds, p_targets
[docs] def forward( self, h: torch.Tensor, label: torch.Tensor | None = None, instance_eval: bool = False, return_features: bool = False, attention_only: bool = False, ) -> ( torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]] ): """ Forward pass of the CLAM Single Branch model. Args: h (torch.Tensor): Input feature tensor of shape (N, embed_dim), where N is the number of instances (patches) and embed_dim is the feature dimension. label (torch.Tensor | None, optional): Ground truth labels for instance-level evaluation. Required when instance_eval=True. Should be of shape (1,) for single class or (n_classes,) for multi-class. Defaults to None. instance_eval (bool, optional): Whether to perform instance-level evaluation and compute instance loss. Requires label to be provided. Defaults to False. return_features (bool, optional): Whether to return aggregated features (M) in the results dictionary. Defaults to False. attention_only (bool, optional): If True, returns only attention weights without classification. Defaults to False. Returns: torch.Tensor | tuple: - If attention_only=True: Returns attention weights tensor of shape (K, N) - Otherwise: Returns tuple of (logits, Y_prob, Y_hat, a_raw, results_dict) where: - logits (torch.Tensor): Raw classification logits of shape (1, n_classes) - Y_prob (torch.Tensor): Softmax probabilities of shape (1, n_classes) - Y_hat (torch.Tensor): Predicted class indices of shape (1, 1) - a_raw (torch.Tensor): Raw attention weights before softmax of shape (K, N) - results_dict (dict): Dictionary containing: - 'instance_loss': Instance-level loss (if instance_eval=True) - 'inst_labels': Instance-level target labels (if instance_eval=True) - 'inst_preds': Instance-level predictions (if instance_eval=True) - 'features': Aggregated features M (if return_features=True) """ a, h = self.attention_net(h) # Nx1 a = torch.transpose(a, 1, 0) # 1xN a = F.softmax(a / self.temperature, dim=1) # softmax over N if attention_only: return a if instance_eval and label is not None: total_inst_loss = 0.0 all_preds: list[np.ndarray[Any, Any]] = [] all_targets: list[np.ndarray[Any, Any]] = [] inst_labels = F.one_hot( label, num_classes=self.n_classes ).squeeze() # binarize label for i in range(len(self.instance_classifiers)): inst_label = inst_labels[i].item() classifier = self.instance_classifiers[i] if inst_label == 1: # in-the-class: instance_loss, preds, targets = self.inst_eval(a, h, classifier) all_preds.extend(preds.cpu().numpy()) # type: ignore all_targets.extend(targets.cpu().numpy()) # type: ignore else: # out-of-the-class if self.subtyping: instance_loss, preds, targets = self.inst_eval_out( a, h, classifier ) all_preds.extend(preds.cpu().numpy()) # type: ignore all_targets.extend(targets.cpu().numpy()) # type: ignore else: continue total_inst_loss += instance_loss if self.subtyping: total_inst_loss /= len(self.instance_classifiers) results_dict: dict[str, Any] = { "instance_loss": total_inst_loss, "inst_labels": np.array(all_targets), "inst_preds": np.array(all_preds), } else: results_dict = {} M = torch.mm(a, h) logits = self.classifiers(M) Y_hat = torch.topk(logits, 1, dim=1)[1] Y_prob = F.softmax(logits, dim=1) if return_features: results_dict.update({"features": M}) return logits, Y_prob, Y_hat, a, results_dict
[docs]class CLAM_MB(CLAM_SB): """ CLAM Multi-Branch (MB) - Clustering-constrained Attention Multiple Instance Learning model. This class extends CLAM_SB by using a multi-branch architecture where each class has its own attention branch and classifier. This architecture is more suitable for multi-class classification problems. Args: gate (bool, optional): Whether to use gated attention network. Defaults to True. size_arg (Literal["small", "big"], list, optional): Configuration for network size. Defaults to "small". dropout (bool, optional): Whether to use dropout. Defaults to False. k_sample (int, optional): Number of positive/negative patches to sample for instance-level training. Defaults to 8. n_classes (int, optional): Number of classes. Defaults to 2. instance_loss_fn (nn.Module, optional): Loss function for instance-level training. Defaults to nn.CrossEntropyLoss(). subtyping (bool, optional): Whether it's a subtyping problem. Defaults to False. embed_dim (int, optional): Input embedding dimension. Defaults to 1024. temperature (float, optional): Temperature parameter for softmax. Defaults to 1.0. """
[docs] def __init__( self, gate: bool = True, size_arg: Literal["small", "big"] | list[int] = "small", dropout: bool = False, k_sample: int = 8, n_classes: int = 2, instance_loss_fn: nn.Module = SmoothTop1SVM(n_classes=2), subtyping: bool = False, embed_dim: int = 1024, temperature: float = 1.0, ): nn.Module.__init__(self) # type: ignore self.size_dict = {"small": [embed_dim, 512, 256], "big": [embed_dim, 512, 384]} if isinstance(size_arg, list): if len(size_arg) != 2: raise ValueError("size_arg must be a list of length 2") size = [embed_dim, *size_arg] else: size = self.size_dict[size_arg] fc: list[nn.Module] = [ nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(dropout), ] if gate: attention_net = Attn_Net_Gated( L=size[1], D=size[2], dropout=dropout, n_classes=n_classes ) else: attention_net = Attn_Net( L=size[1], D=size[2], dropout=dropout, n_classes=n_classes ) fc.append(attention_net) self.attention_net = nn.Sequential(*fc) bag_classifiers = [ nn.Linear(size[1], 1) for _ in range(n_classes) ] # use an indepdent linear layer to predict each class self.classifiers = nn.ModuleList(bag_classifiers) instance_classifiers = [nn.Linear(size[1], 2) for _ in range(n_classes)] self.instance_classifiers = nn.ModuleList(instance_classifiers) self.k_sample = k_sample self.instance_loss_fn = instance_loss_fn self.n_classes = n_classes self.subtyping = subtyping self.temperature = temperature
def __str__(self) -> str: return "<CLAM_MB>"
[docs] def forward( self, h: torch.Tensor, label: torch.Tensor | None = None, instance_eval: bool = False, return_features: bool = False, attention_only: bool = False, ) -> ( torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]] ): """ Forward pass of the CLAM Multi-Branch model. This method extends the CLAM_SB forward pass by using multiple attention branches, one for each class. Each branch computes its own attention weights and features, which are then processed by class-specific classifiers. Args: h (torch.Tensor): Input feature tensor of shape (N, embed_dim), where N is the number of instances (patches) and embed_dim is the feature dimension. label (torch.Tensor | None, optional): Ground truth labels for instance-level evaluation. Required when instance_eval=True. Defaults to None. instance_eval (bool, optional): Whether to perform instance-level evaluation and compute instance loss. Defaults to False. return_features (bool, optional): Whether to return aggregated features in the results dictionary. Defaults to False. attention_only (bool, optional): If True, returns only attention weights without classification. Defaults to False. Returns: torch.Tensor | tuple: - If attention_only=True: Returns attention weights tensor of shape (K, N) - Otherwise: Returns tuple of (logits, Y_prob, Y_hat, a_raw, results_dict) where: - logits (torch.Tensor): Raw classification logits of shape (1, n_classes) - Y_prob (torch.Tensor): Softmax probabilities of shape (1, n_classes) - Y_hat (torch.Tensor): Predicted class indices of shape (1, 1) - a_raw (torch.Tensor): Raw attention weights before softmax of shape (K, N) - results_dict (dict): Dictionary containing: - 'instance_loss': Instance-level loss (if instance_eval=True) - 'inst_labels': Instance-level target labels (if instance_eval=True) - 'inst_preds': Instance-level predictions (if instance_eval=True) - 'features': Aggregated features M (if return_features=True) """ a, h = self.attention_net(h) # NxK a = torch.transpose(a, 1, 0) # KxN a = F.softmax(a / self.temperature, dim=1) # softmax over N if attention_only: return a if instance_eval and label is not None: total_inst_loss = 0.0 all_preds: list[np.ndarray[Any, Any]] = [] all_targets: list[np.ndarray[Any, Any]] = [] inst_labels = F.one_hot( label, num_classes=self.n_classes ).squeeze() # binarize label for i in range(len(self.instance_classifiers)): inst_label = inst_labels[i].item() classifier = self.instance_classifiers[i] if inst_label == 1: # in-the-class: instance_loss, preds, targets = self.inst_eval(a[i], h, classifier) all_preds.extend(preds.cpu().numpy()) # type: ignore all_targets.extend(targets.cpu().numpy()) # type: ignore else: # out-of-the-class if self.subtyping: instance_loss, preds, targets = self.inst_eval_out( a[i], h, classifier ) all_preds.extend(preds.cpu().numpy()) # type: ignore all_targets.extend(targets.cpu().numpy()) # type: ignore else: continue total_inst_loss += instance_loss if self.subtyping: total_inst_loss /= len(self.instance_classifiers) results_dict: dict[str, Any] = { "instance_loss": total_inst_loss, "inst_labels": np.array(all_targets), "inst_preds": np.array(all_preds), } else: results_dict = {} M = torch.mm(a, h) logits = torch.empty(1, self.n_classes).float().to(M.device) for c in range(self.n_classes): logits[0, c] = self.classifiers[c](M[c]) Y_hat = torch.topk(logits, 1, dim=1)[1] Y_prob = F.softmax(logits, dim=1) if return_features: results_dict.update({"features": M}) return logits, Y_prob, Y_hat, a, results_dict
[docs]class CLAMTrainerLegacy: """ Trainer class for CLAM models. This class handles the training loop, validation, and evaluation of CLAM models. It supports early stopping, metrics logging, and checkpointing. Args: model (CLAM_MB | CLAM_SB): The CLAM model to train. optimizer (torch.optim.Optimizer): Optimizer for model training. device (str): Device to use for training ("cuda", "cpu", etc.) ckpt_path (Path): Path to save checkpoints. weight_loss_slide (float, optional): Weight for slide-level loss. Defaults to 0.7. loss_slide (nn.Module, optional): Loss function for slide-level classification. Defaults to nn.CrossEntropyLoss(). early_stopping (EarlyStopping | None, optional): Early stopping controller. Defaults to EarlyStopping with patience=20. use_wandb (bool, optional): Whether to log metrics to Weights & Biases. Defaults to True. """
[docs] def __init__( self, model: CLAM_MB | CLAM_SB, optimizer: torch.optim.Optimizer, device: str, ckpt_path: Path, weight_loss_slide: float = 0.7, loss_slide: nn.Module = nn.CrossEntropyLoss(), early_stopping: EarlyStopping | None = EarlyStopping( patience=20, stop_epoch=50, verbose=True ), use_wandb: bool = True, scale_attention_grads_by_bag: bool = False, attn_ref_bag_size: int = 100000, attn_alpha: float = 0.5, ): self.model = model self.optimizer = optimizer self.device = torch.device(device) self.weight_loss_slide = weight_loss_slide self.loss_slide = loss_slide self.early_stopping = early_stopping self.ckpt_path = ckpt_path self.use_wandb = use_wandb self.model.to(self.device) self.scale_attention_grads_by_bag = scale_attention_grads_by_bag self.attn_ref_bag_size = max(1, int(attn_ref_bag_size)) self.attn_alpha = float(attn_alpha)
[docs] def fit( self, train_loader: DataLoader[tuple[torch.Tensor, int]], val_loader: DataLoader[tuple[torch.Tensor, int]], epochs: int, ): """ Train the CLAM model. This method runs the training loop for a specified number of epochs, with validation after each epoch. It supports early stopping and saves the best model based on validation loss. Args: train_loader (DataLoader[tuple[torch.Tensor, int]]): DataLoader for training data. val_loader (DataLoader[tuple[torch.Tensor, int]]): DataLoader for validation data. epochs (int): Number of epochs to train for. """ for epoch in range(epochs): logger.info(f"Starting epoch {epoch}") train_metrics = self._train_epoch(epoch, train_loader) val_metrics = self._val(epoch, val_loader) # Log metrics for the epoch (both to logger and wandb if enabled) self._log_epoch_metrics(epoch, train_metrics, val_metrics) if val_metrics.get("early_stop", False): logger.info("Early stopping triggered") break if self.early_stopping: logger.info(f"Loading best model from {self.ckpt_path}") self.model.load_state_dict(torch.load(self.ckpt_path)) else: logger.info(f"Saving final model to {self.ckpt_path}") torch.save(self.model.state_dict(), self.ckpt_path)
[docs] def _log_epoch_metrics( self, epoch: int, train_metrics: dict[str, Any], val_metrics: dict[str, Any] ) -> None: """ Log metrics for the current epoch. This method logs training and validation metrics to the logger and optionally to Weights & Biases if enabled. Args: epoch (int): Current epoch number. train_metrics (dict[str, Any]): Dictionary of training metrics. val_metrics (dict[str, Any]): Dictionary of validation metrics. """ # Combine metrics for logging metrics: dict[str, int | float | None] = { "epoch": epoch, "train/loss": train_metrics.get("loss", 0.0), "train/error": train_metrics.get("error", 0.0), "train/inst_loss": train_metrics.get("inst_loss", 0.0), "val/loss": val_metrics.get("loss", 0.0), "val/error": val_metrics.get("error", 0.0), "val/inst_loss": val_metrics.get("inst_loss", 0.0), "val/auc": val_metrics.get("auc", 0.0), } # Add class accuracy metrics for i in range(self.model.n_classes): if f"class_{i}_acc" in train_metrics: metrics[f"train/class_{i}_acc"] = train_metrics[f"class_{i}_acc"] if f"class_{i}_acc" in val_metrics: metrics[f"val/class_{i}_acc"] = val_metrics[f"class_{i}_acc"] # Add clustering accuracy metrics if available for i in range(min(2, self.model.n_classes)): # Usually just binary if f"cluster_{i}_acc" in train_metrics: metrics[f"train/cluster_{i}_acc"] = train_metrics[f"cluster_{i}_acc"] if f"cluster_{i}_acc" in val_metrics: metrics[f"val/cluster_{i}_acc"] = val_metrics[f"cluster_{i}_acc"] # Log to wandb if enabled if self.use_wandb: wandb.log(metrics)
[docs] def _train_epoch( self, epoch: int, train_loader: DataLoader[tuple[torch.Tensor, int]], ) -> dict[str, float | int | None]: """ Train the model for one epoch. This method processes all batches in the training data loader for one epoch. It computes loss, performs backpropagation, and collects training metrics. Args: epoch (int): Current epoch number. train_loader (DataLoader[tuple[torch.Tensor, int]]): DataLoader for training data. Returns: dict[str, float | int | None]: Dictionary of training metrics for the epoch. """ self.model.train() acc_logger = Accuracy_Logger(self.model.n_classes) inst_logger = Accuracy_Logger(self.model.n_classes) train_loss: float = 0.0 train_error: float = 0.0 train_inst_loss: float = 0.0 inst_count: int = 0 # Create progress bar for training train_pbar = tqdm( enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch} - Training", leave=False, ) for batch_idx, (data, label) in train_pbar: data, label = data.to(self.device), label.to(self.device) logits, _, Y_hat, _, instance_dict = self.model( data, label=label, instance_eval=True ) acc_logger.log(Y_hat, label) loss = self.loss_slide(logits, label) loss_value = loss.item() instance_loss = instance_dict["instance_loss"] inst_count += 1 instance_loss_value = instance_loss.item() train_inst_loss += instance_loss_value total_loss = ( self.weight_loss_slide * loss + (1 - self.weight_loss_slide) * instance_loss ) inst_preds = instance_dict["inst_preds"] inst_labels = instance_dict["inst_labels"] inst_logger.log_batch(inst_preds, inst_labels) train_loss += loss_value if (batch_idx + 1) % 20 == 0: logger.info( "Batch {}, loss: {:.4f}, instance_loss: {:.4f}, weighted_loss: {:.4f}, " "label: {}, bag_size: {}".format( batch_idx, loss_value, instance_loss_value, total_loss.item(), label.item(), data.size(0), ) ) # Log batch metrics to wandb if enabled if self.use_wandb: wandb.log( { "batch/loss": loss_value, "batch/instance_loss": instance_loss_value, "batch/total_loss": total_loss.item(), "batch/bag_size": data.size(0), } ) error = self.calculate_error(Y_hat, label) train_error += error # Update progress bar with current metrics train_pbar.set_postfix( { # type: ignore "Loss": f"{loss_value:.4f}", "Inst_Loss": f"{instance_loss_value:.4f}", "Avg_Loss": f"{train_loss / (batch_idx + 1):.4f}", "Error": f"{error:.4f}", } ) # backward pass total_loss.backward() # scale attention gradients directly based on current bag size if self.scale_attention_grads_by_bag: bag_size = max(1, int(data.size(0))) # factor = clamp((Nref / N)^alpha, [min_scale, max_scale]) factor = (self.attn_ref_bag_size / float(bag_size)) ** self.attn_alpha for p in self.model.attention_net.parameters(): if p.grad is not None: p.grad.mul_(factor) # step self.optimizer.step() self.optimizer.zero_grad() # calculate loss and error for epoch train_loss /= len(train_loader) train_error /= len(train_loader) # Collect metrics to return metrics: dict[str, float | int | None] = { "loss": train_loss, "error": train_error, } if inst_count > 0: train_inst_loss /= inst_count metrics["inst_loss"] = train_inst_loss logger.info("Instance-level clustering metrics:") for i in range(2): acc, correct, count = inst_logger.get_summary(i) metrics[f"cluster_{i}_acc"] = acc logger.info( f"Class {i} clustering accuracy: {acc:.4f}, correct: {correct}/{count}" ) logger.info( f"Epoch: {epoch}, train_loss: {train_loss:.4f}, " f"train_clustering_loss: {train_inst_loss:.4f}, train_error: {train_error:.4f}" ) # Log class-specific accuracy for i in range(self.model.n_classes): acc, correct, count = acc_logger.get_summary(i) metrics[f"class_{i}_acc"] = acc logger.info(f"Class {i}: accuracy {acc:.4f}, correct {correct}/{count}") return metrics
[docs] def _val( self, epoch: int, val_loader: DataLoader[tuple[torch.Tensor, int]], ) -> dict[str, float | int | None]: """ Validate the model on the validation set. This method evaluates the model on the validation data and computes various metrics including loss, error rate, AUC, and class-specific accuracy. Args: epoch (int): Current epoch number. val_loader (DataLoader[tuple[torch.Tensor, int]]): DataLoader for validation data. Returns: dict[str, float | int | None]: Dictionary of validation metrics for the epoch. """ self.model.eval() acc_logger = Accuracy_Logger(self.model.n_classes) inst_logger = Accuracy_Logger(self.model.n_classes) val_loss = 0.0 val_error = 0.0 val_inst_loss: float = 0.0 inst_count: int = 0 prob = np.zeros((len(val_loader), self.model.n_classes)) labels = np.zeros(len(val_loader)) # Create progress bar for validation val_pbar = tqdm( enumerate(val_loader), total=len(val_loader), desc=f"Epoch {epoch} - Validation", leave=False, ) with torch.inference_mode(): for batch_idx, (data, label) in val_pbar: data, label = data.to(self.device), label.to(self.device) logits, Y_prob, Y_hat, _, instance_dict = self.model( data, label=label, instance_eval=True ) acc_logger.log(Y_hat, label) loss = self.loss_slide(logits, label) val_loss += loss.item() instance_loss = instance_dict["instance_loss"] inst_count += 1 instance_loss_value = instance_loss.item() val_inst_loss += instance_loss_value inst_preds = instance_dict["inst_preds"] inst_labels = instance_dict["inst_labels"] inst_logger.log_batch(inst_preds, inst_labels) prob[batch_idx] = Y_prob.cpu().numpy() labels[batch_idx] = label.item() error = self.calculate_error(Y_hat, label) val_error += error # Update progress bar with current metrics val_pbar.set_postfix( { # type: ignore "Loss": f"{loss.item():.4f}", "Inst_Loss": f"{instance_loss_value:.4f}", "Avg_Loss": f"{val_loss / (batch_idx + 1):.4f}", "Error": f"{error:.4f}", } ) val_error /= len(val_loader) val_loss /= len(val_loader) # Calculate AUC if self.model.n_classes == 2: auc = roc_auc_score(labels, prob[:, 1]) aucs = [] else: aucs: list[float] = [] binary_labels = cast( np.ndarray[Any, Any], label_binarize( labels, classes=[i for i in range(self.model.n_classes)] ), ) for class_idx in range(self.model.n_classes): if class_idx in labels: fpr, tpr, _ = cast( tuple[ np.ndarray[Any, Any], np.ndarray[Any, Any], np.ndarray[Any, Any], ], roc_curve(binary_labels[:, class_idx], prob[:, class_idx]), ) aucs.append(float(calc_auc(fpr, tpr))) else: aucs.append(float("nan")) auc = np.nanmean(np.array(aucs)) # Collect metrics to return metrics: dict[str, float | int | None] = { "loss": val_loss, "error": val_error, "auc": float(auc), } logger.info( f"Validation Set, val_loss: {val_loss:.4f}, val_error: {val_error:.4f}, auc: {auc:.4f}" ) if inst_count > 0: val_inst_loss /= inst_count metrics["inst_loss"] = val_inst_loss logger.info("Instance-level clustering metrics:") for i in range(2): acc, correct, count = inst_logger.get_summary(i) metrics[f"cluster_{i}_acc"] = acc logger.info( f"Class {i} clustering accuracy: {acc:.4f}, correct: {correct}/{count}" ) # Log class-specific accuracy for i in range(self.model.n_classes): acc, correct, count = acc_logger.get_summary(i) metrics[f"class_{i}_acc"] = acc logger.info(f"Class {i}: accuracy {acc:.4f}, correct {correct}/{count}") # Handle early stopping early_stop = False if self.early_stopping: self.early_stopping(epoch, val_loss, self.model, str(self.ckpt_path)) if self.early_stopping.early_stop: logger.info("Early stopping") early_stop = True metrics["early_stop"] = early_stop return metrics
[docs] @staticmethod def calculate_error(Y_hat: torch.Tensor, Y: torch.Tensor): """ Calculate classification error. Args: Y_hat (torch.Tensor): Predicted labels. Y (torch.Tensor): Ground truth labels. Returns: float: Error rate (1 - accuracy). """ return 1.0 - Y_hat.float().eq(Y.float()).float().mean().item()
[docs] def eval( self, test_loader: DataLoader[tuple[torch.Tensor, int]], ): """ Evaluate the model on a test set. Args: test_loader (DataLoader[tuple[torch.Tensor, int]]): DataLoader for test data. Note: This is a placeholder method for model evaluation. Implementation is not provided. """ pass
[docs]class LitCLAM(Pl.LightningModule):
[docs] @staticmethod def _is_gated_attention(model: CLAM_MB | CLAM_SB) -> bool: """Check if model uses gated attention.""" if hasattr(model, "attention_net"): return any( isinstance(m, Attn_Net_Gated) for m in model.attention_net.modules() ) return True
[docs] @staticmethod def _get_size_args(model: CLAM_MB | CLAM_SB) -> list[int]: """Extract L and D parameters from attention network (size[1] and size[2]).""" if hasattr(model, "attention_net"): for module in model.attention_net.modules(): if isinstance(module, Attn_Net_Gated): # For gated attention: attention_a and attention_b are Sequential with Linear layers linear_layer = module.attention_a[0] # First layer is Linear if isinstance(linear_layer, nn.Linear): l_dim = linear_layer.in_features # size[1] d_dim = linear_layer.out_features # size[2] return [l_dim, d_dim] elif isinstance(module, Attn_Net): # For regular attention: module is Sequential with Linear layers linear_layer = module.module[0] # First layer is Linear if isinstance(linear_layer, nn.Linear): l_dim = linear_layer.in_features # size[1] d_dim = linear_layer.out_features # size[2] return [l_dim, d_dim] return [512, 256] # default
[docs] def __init__( self, model: CLAM_MB | CLAM_SB, optimizer: torch.optim.Optimizer, loss_slide: nn.Module = nn.CrossEntropyLoss(), weight_loss_slide: float = 0.7, lr_scheduler: LRScheduler | None = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50, ): super().__init__() self.model = model self.optimizer = optimizer self.loss_slide = loss_slide self.weight_loss_slide = weight_loss_slide self.lr_scheduler = lr_scheduler self.subsampling = subsampling self.use_aem = use_aem if self.use_aem: self.aem = AEM( weight_initial=aem_weight_initial, weight_final=aem_weight_final, annealing_epochs=aem_annealing_epochs, ) # Save all hyperparameters including model config model_config: dict[str, Any] = { "model_class": model.__class__.__name__, "gated": self._is_gated_attention(model), "size_arg": self._get_size_args(model), "n_classes": model.n_classes, "instance_loss_fn": model.instance_loss_fn.__class__.__name__, "k_sample": model.k_sample, "subtyping": model.subtyping, "embed_dim": model.size_dict.get("small", [1024])[0] if hasattr(model, "size_dict") else 1024, "dropout": any( isinstance(m, nn.Dropout) for m in model.attention_net.modules() ) if hasattr(model, "attention_net") else False, "temperature": model.temperature if hasattr(model, "temperature") else 1.0, } self.save_hyperparameters( { **model_config, "optimizer_class": optimizer.__class__.__name__, "optimizer_lr": optimizer.param_groups[0]["lr"], "loss_slide": loss_slide, "weight_loss_slide": weight_loss_slide, "lr_scheduler": lr_scheduler.__class__.__name__ if lr_scheduler else None, "subsampling": subsampling, "use_aem": use_aem, "aem_weight_initial": aem_weight_initial, "aem_weight_final": aem_weight_final, "aem_annealing_epochs": aem_annealing_epochs, } ) self._setup_metrics() self.bag_size: int = 0
[docs] @classmethod def load_from_checkpoint( cls, checkpoint_path: str | Path | IO[bytes], map_location: torch.device | str | int | Callable[[torch.UntypedStorage, str], torch.UntypedStorage | None] | dict[torch.device | str | int, torch.device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any, ) -> Self: """ Load a LitCLAM model from a checkpoint file. Args: checkpoint_path (str | Path | IO[bytes]): Path to the checkpoint file. map_location: Device mapping for loading the model. hparams_file (str | Path | None): Optional path to a YAML file with hyperparameters. strict (bool | None): Whether to strictly enforce that the keys in state_dict match the keys returned by the model's state_dict function. **kwargs: Additional keyword arguments. Returns: LitCLAM: The loaded LitCLAM model. """ checkpoint = torch.load( checkpoint_path, map_location=map_location, weights_only=False ) # type: ignore hparams = checkpoint.get("hyper_parameters", {}) # Reconstruct model model_class = CLAM_MB if hparams.get("model_class") == "CLAM_MB" else CLAM_SB model = model_class( gate=hparams.get("gated", True), size_arg=hparams.get("size_arg", "small"), dropout=hparams.get("dropout", False), k_sample=hparams.get("k_sample", 8), n_classes=hparams.get("n_classes", 2), instance_loss_fn=SmoothTop1SVM(n_classes=2) if hparams.get("instance_loss_fn") == "SmoothTop1SVM" else nn.CrossEntropyLoss(), subtyping=hparams.get("subtyping", False), embed_dim=hparams.get("embed_dim", 1024), temperature=hparams.get("temperature", 1.0), ) # Reconstruct optimizer optimizer_class = getattr(torch.optim, hparams.get("optimizer_class", "Adam")) optimizer = optimizer_class( model.parameters(), lr=hparams.get("optimizer_lr", 1e-4) ) # Reconstruct loss function loss_slide_param = hparams.get("loss_slide", nn.CrossEntropyLoss()) # Handle both string names (old checkpoints) and loss objects (new checkpoints) if isinstance(loss_slide_param, str): # Old checkpoint format - reconstruct from name from cellmil.utils.train.losses import FocalLoss if loss_slide_param == "FocalLoss": loss_slide = FocalLoss() elif loss_slide_param == "CrossEntropyLoss": loss_slide = nn.CrossEntropyLoss() else: # Default fallback loss_slide = nn.CrossEntropyLoss() else: # New checkpoint format - already a loss object loss_slide = loss_slide_param # Note: lr_scheduler is not reconstructed from checkpoint as it's typically # created fresh when loading a model for further training or evaluation lit_model = cls( model=model, optimizer=optimizer, loss_slide=loss_slide, weight_loss_slide=hparams.get("weight_loss_slide", 0.7), lr_scheduler=None, # Scheduler not restored from checkpoint subsampling=hparams.get("subsampling", 1.0), use_aem=hparams.get("use_aem", False), aem_weight_initial=hparams.get("aem_weight_initial", 0.001), aem_weight_final=hparams.get("aem_weight_final", 0.0), aem_annealing_epochs=hparams.get("aem_annealing_epochs", 50), ) lit_model.load_state_dict( checkpoint["state_dict"], strict=strict if strict is not None else True ) return lit_model
def _setup_metrics(self): metrics = torchmetrics.MetricCollection( { "accuracy": torchmetrics.Accuracy( task="multiclass", num_classes=self.model.n_classes, average="none" ), "f1": torchmetrics.F1Score( task="multiclass", num_classes=self.model.n_classes, average="macro" ), "precision": torchmetrics.Precision( task="multiclass", num_classes=self.model.n_classes, average="macro" ), "recall": torchmetrics.Recall( task="multiclass", num_classes=self.model.n_classes, average="macro" ), "auroc": torchmetrics.AUROC( task="multiclass", num_classes=self.model.n_classes, average="macro" ), } ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") self.test_metrics = metrics.clone(prefix="test/")
[docs] def forward( self, x: torch.Tensor, label: torch.Tensor | None = None, instance_eval: bool = True, ): return self.model(x, label, instance_eval=instance_eval)
def _shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], stage: str, log: bool = True ): data, label = batch # Ensure MIL batch size is 1 assert data.size(0) == 1, "Batch size must be 1 for MIL" data = data.squeeze(0) # [n_instances, feat_dim] # Apply subsampling during training if stage == "train" and self.subsampling != 1.0: # Calculate the number of samples to keep if 0 < self.subsampling < 1.0: # Treat as percentage num_samples = int(self.subsampling * data.shape[0]) elif self.subsampling >= 1.0: # Treat as absolute count num_samples = min(int(self.subsampling), data.shape[0]) else: raise ValueError(f"Invalid subsampling value: {self.subsampling}") # Generate random permutation of indices indices = torch.randperm(data.shape[0], device=data.device) # Select the first N samples from the permuted indices sampled_indices = indices[:num_samples] # Use the sampled indices to select instances data = data[sampled_indices] self.bag_size = data.size(0) logits, Y_prob, Y_hat, attention_weights, instance_dict = self( data, label=label, instance_eval=True ) slide_loss = self.loss_slide(logits, label) instance_loss = instance_dict["instance_loss"] # Calculate total loss starting with slide and instance losses total_loss = ( self.weight_loss_slide * slide_loss + (1 - self.weight_loss_slide) * instance_loss ) # AEM (Attention Entropy Maximization) current_epoch = self.current_epoch if hasattr(self, "current_epoch") else 0 aem: torch.Tensor | None = None if self.use_aem and stage == "train": aem = self.aem.get_aem(current_epoch, attention_weights) total_loss = total_loss + aem error = self.calculate_error(Y_hat, label) if log: self.log( f"{stage}/slide_loss", slide_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) self.log( f"{stage}/instance_loss", instance_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) self.log( f"{stage}/total_loss", total_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) self.log( f"{stage}/error", error, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) if current_epoch == 0 and stage in ["train", "val"]: self.log( f"{stage}/num_instances", self.bag_size, prog_bar=False, on_step=True, on_epoch=False, batch_size=1, ) if self.use_aem and stage == "train" and aem is not None: self.log( f"{stage}/aem", aem, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, ) return total_loss, Y_prob, label
[docs] def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): loss, Y_prob, label = self._shared_step(batch, stage="train") self.train_metrics(Y_prob, label) return loss
[docs] def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): loss, Y_prob, label = self._shared_step(batch, stage="val") self.val_metrics(Y_prob, label) return loss
[docs] def test_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): loss, Y_prob, label = self._shared_step(batch, stage="test") self.test_metrics(Y_prob, label) return loss
[docs] def on_train_epoch_end(self) -> None: computed = self.train_metrics.compute() self._flatten_and_log_metrics(computed, prefix="train") self.train_metrics.reset()
[docs] def on_validation_epoch_end(self): computed = self.val_metrics.compute() self._flatten_and_log_metrics(computed, prefix="val") self.val_metrics.reset()
[docs] def on_test_epoch_end(self): computed = self.test_metrics.compute() self._flatten_and_log_metrics(computed, prefix="test") self.test_metrics.reset()
[docs] def _flatten_and_log_metrics( self, computed: dict[str, torch.Tensor], prefix: str ) -> None: """ Convert metric dictionary produced by torchmetrics into a flat dict of scalar values and log it with `self.log_dict`. - Vector/tensor metrics (e.g. per-class accuracy) are expanded into keys like `{prefix}/class_{i}_acc`. - Scalar tensors are converted to floats. - None values are converted to NaN to satisfy loggers that expect numeric scalars. """ flat: dict[str, float] = {} for key, val in computed.items(): # Normalize key: some metrics come as 'train/accuracy' etc.; keep full key try: if val.dim() == 0: flat[key] = float(val.item()) else: vals = cast(list[torch.Tensor], val.cpu().tolist()) # type: ignore for i, v in enumerate(vals): # Special-case accuracy to use *_acc suffix if key.endswith("/accuracy"): base = key.rsplit("/", 1)[0] flat[f"{base}/class_{i}_acc"] = float(v) else: flat[f"{key}_class_{i}"] = float(v) except Exception: # Fallback: set NaN so logging doesn't fail flat[key] = float("nan") # Finally log flattened scalars self.log_dict(flat, prog_bar=True, batch_size=1)
[docs] def configure_optimizers(self): # type: ignore if self.lr_scheduler: scheduler: dict[str, Any] = { "scheduler": self.lr_scheduler, "interval": "epoch", "monitor": "val/total_loss", "frequency": 1, "strict": True, "name": "learning_rate", } return [self.optimizer], [scheduler] return [self.optimizer]
[docs] def predict_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> Any: _, Y_prob, _ = self._shared_step(batch, stage="test", log=False) return Y_prob.argmax(dim=-1)
[docs] @staticmethod def calculate_error(Y_hat: torch.Tensor, Y: torch.Tensor): """Classification error = 1 - accuracy.""" return 1.0 - Y_hat.float().eq(Y.float()).float().mean().item()
[docs] def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor: """ Get attention weights for a bag of instances. Args: x (torch.Tensor): Input tensor of shape [n_instances, feat_dim]. Returns: torch.Tensor: Attention weights of shape [n_classes, n_instances]. """ self.model.eval() with torch.inference_mode(): a = self.model(x, attention_only=True) # [n_classes, n_instances] return a
[docs]class LitSurvCLAM(LitCLAM): """ Lightning wrapper for CLAM models adapted for survival analysis. This class extends LitCLAM to support survival analysis tasks using discrete-time survival models with logistic hazard parameterization. Only overrides the metrics setup to use survival-specific metrics. Args: model (CLAM_MB | CLAM_SB): The CLAM model instance (SB or MB). optimizer (torch.optim.Optimizer): Optimizer for training. loss_slide (nn.Module, optional): Loss function for survival. Defaults to NegativeLogLikelihoodSurvLoss(). weight_loss_slide (float, optional): Weight for slide-level loss. Defaults to 0.7. lr_scheduler (LRScheduler | None, optional): Learning rate scheduler. Defaults to None. use_aem (bool, optional): Whether to use AEM regularization. Defaults to False. aem_weight_initial (float, optional): Initial weight for AEM loss. Defaults to 0.001. aem_weight_final (float, optional): Final weight for AEM loss after annealing. Defaults to 0.0. aem_annealing_epochs (int, optional): Number of epochs to anneal AEM weight. Defaults to 50. """
[docs] def __init__( self, model: CLAM_MB | CLAM_SB, optimizer: torch.optim.Optimizer, loss_slide: nn.Module = NegativeLogLikelihoodSurvLoss(), weight_loss_slide: float = 0.7, lr_scheduler: LRScheduler | None = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 50, ): super().__init__( model, optimizer, loss_slide, weight_loss_slide, lr_scheduler, subsampling, use_aem, aem_weight_initial, aem_weight_final, aem_annealing_epochs, ) # For logistic hazard, n_classes should equal num_bins self.num_bins = model.n_classes # Override with survival-specific metrics self._setup_metrics()
[docs] def _setup_metrics(self): """Setup C-index and Brier score metrics for survival analysis.""" metrics = torchmetrics.MetricCollection( { "c_index": ConcordanceIndex(), "brier_score": BrierScore(), } ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") self.test_metrics = metrics.clone(prefix="test/")
def _shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], stage: str, log: bool = True ): data, label = batch # Ensure MIL batch size is 1 assert data.size(0) == 1, "Batch size must be 1 for MIL" data = data.squeeze(0) # [n_instances, feat_dim] # Apply subsampling during training if stage == "train" and self.subsampling != 1.0: # Calculate the number of samples to keep if 0 < self.subsampling < 1.0: # Treat as percentage num_samples = int(self.subsampling * data.shape[0]) elif self.subsampling >= 1.0: # Treat as absolute count num_samples = min(int(self.subsampling), data.shape[0]) else: raise ValueError(f"Invalid subsampling value: {self.subsampling}") # Generate random permutation of indices indices = torch.randperm(data.shape[0], device=data.device) # Select the first N samples from the permuted indices sampled_indices = indices[:num_samples] # Use the sampled indices to select instances data = data[sampled_indices] self.bag_size = data.size(0) logits, Y_prob, _, attention_weights, instance_dict = self( data, label=label[0], instance_eval=True ) slide_loss = self.loss_slide(logits, label) instance_loss = instance_dict["instance_loss"] # Calculate total loss starting with slide and instance losses total_loss = ( self.weight_loss_slide * slide_loss + (1 - self.weight_loss_slide) * instance_loss ) # AEM (Attention Entropy Maximization) current_epoch = self.current_epoch if hasattr(self, "current_epoch") else 0 aem: torch.Tensor | None = None if self.use_aem and stage == "train": aem = self.aem.get_aem(current_epoch, attention_weights) total_loss = total_loss + aem # error = self.calculate_error(Y_hat, label) if log: self.log( f"{stage}/slide_loss", slide_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) self.log( f"{stage}/instance_loss", instance_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) self.log( f"{stage}/total_loss", total_loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, batch_size=1, ) # self.log(f"{stage}/error", error, prog_bar=(stage != "train"), on_step=(stage=="train"), on_epoch=True, batch_size=1) if current_epoch == 0 and stage in ["train", "val"]: self.log( f"{stage}/num_instances", self.bag_size, prog_bar=False, on_step=True, on_epoch=False, batch_size=1, ) if self.use_aem and stage == "train" and aem is not None: self.log( f"{stage}/aem", aem, prog_bar=True, on_step=False, on_epoch=True, batch_size=1, ) return total_loss, Y_prob, label
[docs] def predict_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): """Prediction step returns logits for discrete-time hazard intervals.""" x, _ = batch # Ensure MIL batch size is 1 assert x.size(0) == 1, "Batch size must be 1 for MIL" x = x.squeeze(0) # [n_instances, feat_dim] logits, _, _, _, _ = self.model(x, instance_eval=False) return logits # Return logits, not hazards