cellmil.models.mil.transmil

Classes

PPEG([dim])

Pyramid Position Encoding Generator.

TransLayer(norm_layer, dim)

Transformer Layer with Nystrom Attention.

TransMIL(n_classes[, embed_dim])

Transformer-based Multiple Instance Learning model.

class cellmil.models.mil.transmil.TransLayer(norm_layer: type[torch.nn.modules.normalization.LayerNorm] = <class 'torch.nn.modules.normalization.LayerNorm'>, dim: int = 512)[source]

Bases: Module

Transformer Layer with Nystrom Attention.

This layer implements a transformer block using Nystrom Attention, which is an efficient approximation of the standard self-attention mechanism. It’s particularly useful for processing long sequences as it reduces the computational complexity from O(n²) to O(n).

Parameters:
  • norm_layer (type[nn.LayerNorm], optional) – Normalization layer class. Defaults to nn.LayerNorm.

  • dim (int, optional) – Feature dimension. Defaults to 512.

__init__(norm_layer: type[torch.nn.modules.normalization.LayerNorm] = <class 'torch.nn.modules.normalization.LayerNorm'>, dim: int = 512)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, return_attention: bool = False) torch.Tensor | tuple[torch.Tensor, torch.Tensor][source]

Forward pass for the transformer layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape [B, N, C] where B is batch size, N is sequence length, and C is feature dimension.

  • return_attention (bool, optional) – Whether to return attention maps. Defaults to False.

Returns:

If return_attention is False,

returns the output tensor of the same shape as input. If True, returns a tuple containing the output tensor and the attention map.

Return type:

torch.Tensor | tuple[torch.Tensor, torch.Tensor]

class cellmil.models.mil.transmil.PPEG(dim: int = 512)[source]

Bases: Module

Pyramid Position Encoding Generator.

PPEG is a positional encoding module that uses convolutional layers with different kernel sizes to capture positional information at multiple scales. It transforms tokens into a 2D spatial grid, applies convolutional operations, and reshapes back to the sequence format.

This module helps the transformer model to be aware of the spatial relationships between tokens, which is crucial for vision tasks.

Parameters:

dim (int, optional) – Feature dimension. Defaults to 512.

__init__(dim: int = 512)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, H: int, W: int) Tensor[source]

Forward pass for the position encoding generator.

Parameters:
  • x (torch.Tensor) – Input tensor of shape [B, N, C] where B is batch size, N is sequence length (including class token), and C is feature dimension.

  • H (int) – Height of the feature map when arranged in a 2D grid.

  • W (int) – Width of the feature map when arranged in a 2D grid.

Returns:

Output tensor with positional encoding information added,

same shape as input [B, N, C].

Return type:

torch.Tensor

class cellmil.models.mil.transmil.TransMIL(n_classes: int, embed_dim: int = 1024)[source]

Bases: Module

Transformer-based Multiple Instance Learning model.

TransMIL is a model for MIL tasks using transformers, as presented in the paper “TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification”. It processes a bag of instances (e.g., patches from a whole slide image) using transformer architecture with Nystrom attention to efficiently handle large sets of instances.

The model includes: - A learnable class token similar to ViT - Positional encoding through a convolutional approach (PPEG) - Multiple transformer layers with Nystrom attention

Parameters:

n_classes (int) – Number of output classes for classification.

__init__(n_classes: int, embed_dim: int = 1024)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Tensor) tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for the TransMIL model.

The model processes a bag of instances (features) using a transformer-based architecture. It first projects the features to a lower dimension, adds a class token, applies transformer layers with positional encoding, and finally produces classification outputs.

Parameters:

data (torch.Tensor) – Input tensor of shape [B, n, D] where B is batch size, n is the number of instances in each bag, and D is the input feature dimension.

Returns:

  • logits (torch.Tensor): Raw classification scores [B, n_classes]

  • Y_prob (torch.Tensor): Probability distribution over classes [B, n_classes]

  • Y_hat (torch.Tensor): Predicted class indices [B]

  • cls_attn (torch.Tensor): Class token’s attention to each instance [B, n+1]

  • results_dict (dict[str, torch.Tensor]): Additional results/metrics

Return type:

tuple containing