cellmil.models.mil.histobistro

Classes

Attention([dim, heads, dim_head, dropout])

Multi-head self-attention mechanism for transformers.

FeedForward([dim, hidden_dim, dropout])

Feed-forward network used in transformer blocks.

HistoBistro(num_classes[, input_dim, dim, ...])

HistoBistro - Histopathology Bi-level Transformer for MIL classification.

PreNorm(dim, fn[, context_dim])

Layer normalization module that applies normalization before a function.

TransformerBlocks(dim, depth, heads, ...[, ...])

Stack of transformer encoder blocks.

class cellmil.models.mil.histobistro.Attention(dim: int = 512, heads: int = 8, dim_head: int = 64, dropout: float = 0.1)[source]

Bases: Module

Multi-head self-attention mechanism for transformers.

This class implements the standard multi-head self-attention as described in “Attention Is All You Need” (Vaswani et al., 2017). It includes methods for extracting and saving attention maps for interpretability.

Parameters:
  • dim – Input dimension of the features.

  • heads – Number of attention heads.

  • dim_head – Dimension of each attention head. If not provided, will be calculated as dim / heads.

  • dropout – Dropout rate applied to the output projection.

__init__(dim: int = 512, heads: int = 8, dim_head: int = 64, dropout: float = 0.1)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, register_hook: bool = False)[source]

Forward pass for self-attention.

Parameters:
  • x – Input tensor of shape [batch_size, num_tokens, dim].

  • register_hook – If True, registers a gradient hook on the attention map for interpretability.

Returns:

Output tensor of shape [batch_size, num_tokens, dim].

Return type:

torch.Tensor

save_attn_gradients(attn_gradients: Tensor)[source]

Save attention gradients during backpropagation for interpretability.

Parameters:

attn_gradients – Gradient tensor for attention maps.

get_attn_gradients()[source]

Retrieve saved attention gradients.

Returns:

The saved attention gradients.

Return type:

torch.Tensor

save_attention_map(attention_map: Tensor)[source]

Save attention map from the forward pass.

Parameters:

attention_map – Attention map tensor from the forward pass.

get_attention_map()[source]

Retrieve the saved attention map.

Returns:

The saved attention map.

Return type:

torch.Tensor

get_self_attention(x: Tensor)[source]

Calculate self-attention maps without performing the full forward pass.

This is useful for visualization and interpretability.

Parameters:

x – Input tensor of shape [batch_size, num_tokens, dim].

Returns:

Self-attention map of shape [batch_size, heads, num_tokens, num_tokens].

Return type:

torch.Tensor

class cellmil.models.mil.histobistro.FeedForward(dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1)[source]

Bases: Module

Feed-forward network used in transformer blocks.

This implements a standard MLP with GELU activation and dropout, as commonly used in transformer architectures.

Parameters:
  • dim – Input dimension.

  • hidden_dim – Hidden dimension, typically 2-4x the input dimension.

  • dropout – Dropout rate applied after each layer.

__init__(dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor)[source]

Forward pass through the feed-forward network.

Parameters:

x – Input tensor of shape [batch_size, seq_length, dim].

Returns:

Output tensor with the same shape as input.

Return type:

torch.Tensor

class cellmil.models.mil.histobistro.PreNorm(dim: int, fn: Module, context_dim: Optional[int] = None)[source]

Bases: Module

Layer normalization module that applies normalization before a function.

This implements the “Pre-LN” (Pre-LayerNorm) variant of transformers, which applies layer normalization before the attention or feed-forward operations. This variant is known to be more stable during training compared to Post-LN.

Parameters:
  • dim – Input dimension for the main tensor.

  • fn – Function to apply after normalization.

  • context_dim – Optional dimension for a context tensor that will also be normalized.

__init__(dim: int, fn: Module, context_dim: Optional[int] = None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, **kwargs: dict[str, Any])[source]

Apply normalization and then the function.

Parameters:
  • x – Input tensor.

  • **kwargs – Additional keyword arguments passed to the function. May include a ‘context’ tensor if context_dim was provided.

Returns:

Output after normalization and function application.

Return type:

torch.Tensor

class cellmil.models.mil.histobistro.TransformerBlocks(dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0)[source]

Bases: Module

Stack of transformer encoder blocks.

Each block consists of a multi-head self-attention layer followed by a feed-forward network, both with residual connections and layer normalization.

Parameters:
  • dim – Input dimension.

  • depth – Number of transformer blocks.

  • heads – Number of attention heads.

  • dim_head – Dimension of each attention head.

  • mlp_dim – Hidden dimension of the feed-forward network.

  • dropout – Dropout rate applied in both attention and feed-forward.

__init__(dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, register_hook: bool = False)[source]

Forward pass through the transformer blocks.

Parameters:
  • x – Input tensor of shape [batch_size, seq_length, dim].

  • register_hook – Whether to register gradient hooks for interpretability.

Returns:

Output tensor after passing through all transformer blocks.

Return type:

torch.Tensor

class cellmil.models.mil.histobistro.HistoBistro(num_classes: int, input_dim: int = 2048, dim: int = 512, depth: int = 2, heads: int = 8, mlp_dim: int = 512, pool: Literal['cls', 'mean'] = 'cls', dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, pos_enc: Optional[Module] = None)[source]

Bases: Module

HistoBistro - Histopathology Bi-level Transformer for MIL classification.

This is the main model class implementing the HistoBistro architecture, a transformer-based model for multiple instance learning (MIL) in histopathology images. The model processes a bag of instances (patches/cells) and makes a bag-level prediction. It includes interpretability methods to identify important instances.

Parameters:
  • num_classes – Number of output classes for classification.

  • input_dim – Input dimension of instance features.

  • dim – Internal dimension for transformer processing.

  • depth – Number of transformer blocks.

  • heads – Number of attention heads.

  • mlp_dim – Hidden dimension of the feed-forward network.

  • pool – Pooling strategy (‘cls’ for class token or ‘mean’ for mean pooling).

  • dim_head – Dimension of each attention head.

  • dropout – Dropout rate in transformer blocks.

  • emb_dropout – Dropout rate applied to the input embeddings.

  • pos_enc – Optional positional encoding module.

__init__(num_classes: int, input_dim: int = 2048, dim: int = 512, depth: int = 2, heads: int = 8, mlp_dim: int = 512, pool: Literal['cls', 'mean'] = 'cls', dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, pos_enc: Optional[Module] = None)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

save_gradient(grad: Tensor) None[source]

Save gradients during backpropagation for interpretability.

Parameters:

grad – The gradient tensor to save

The gradient is processed to obtain a single importance value per instance by taking the mean across the feature dimension.

get_instance_gradients() torch.Tensor | None[source]

Return the saved gradients for instance importance.

Returns:

The gradients for each instance, if available

get_normalized_gradients() torch.Tensor | None[source]

Return the absolute and normalized gradients for instance importance.

Returns:

Normalized gradients (values between 0 and 1) for each instance, if available

forward(x: Tensor, coords: Optional[Tensor] = None, register_hook: bool = False)[source]

Forward pass through the HistoBistro model.

Parameters:
  • x – Input tensor of shape [batch_size, num_instances, input_dim] containing instance features.

  • coords – Optional tensor of shape [batch_size, num_instances, 2] containing spatial coordinates.

  • register_hook – Whether to register gradient hooks for interpretability.

Returns:

A tuple containing:
  • logits: Raw classification scores of shape [batch_size, num_classes]

  • Y_prob: Softmax probabilities of shape [batch_size, num_classes]

  • Y_hat: Predicted class indices of shape [batch_size]

  • gradients: Instance gradients for interpretability (or None)

  • results_dict: Dictionary with additional results for analysis

Return type:

tuple