cellmil.models.mil.histobistro¶
Classes
|
Multi-head self-attention mechanism for transformers. |
|
Feed-forward network used in transformer blocks. |
|
HistoBistro - Histopathology Bi-level Transformer for MIL classification. |
|
Layer normalization module that applies normalization before a function. |
|
Stack of transformer encoder blocks. |
- class cellmil.models.mil.histobistro.Attention(dim: int = 512, heads: int = 8, dim_head: int = 64, dropout: float = 0.1)[source]¶
Bases:
ModuleMulti-head self-attention mechanism for transformers.
This class implements the standard multi-head self-attention as described in “Attention Is All You Need” (Vaswani et al., 2017). It includes methods for extracting and saving attention maps for interpretability.
- Parameters:
dim – Input dimension of the features.
heads – Number of attention heads.
dim_head – Dimension of each attention head. If not provided, will be calculated as dim / heads.
dropout – Dropout rate applied to the output projection.
- __init__(dim: int = 512, heads: int = 8, dim_head: int = 64, dropout: float = 0.1)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, register_hook: bool = False)[source]¶
Forward pass for self-attention.
- Parameters:
x – Input tensor of shape [batch_size, num_tokens, dim].
register_hook – If True, registers a gradient hook on the attention map for interpretability.
- Returns:
Output tensor of shape [batch_size, num_tokens, dim].
- Return type:
- save_attn_gradients(attn_gradients: Tensor)[source]¶
Save attention gradients during backpropagation for interpretability.
- Parameters:
attn_gradients – Gradient tensor for attention maps.
- get_attn_gradients()[source]¶
Retrieve saved attention gradients.
- Returns:
The saved attention gradients.
- Return type:
- save_attention_map(attention_map: Tensor)[source]¶
Save attention map from the forward pass.
- Parameters:
attention_map – Attention map tensor from the forward pass.
- get_attention_map()[source]¶
Retrieve the saved attention map.
- Returns:
The saved attention map.
- Return type:
- get_self_attention(x: Tensor)[source]¶
Calculate self-attention maps without performing the full forward pass.
This is useful for visualization and interpretability.
- Parameters:
x – Input tensor of shape [batch_size, num_tokens, dim].
- Returns:
Self-attention map of shape [batch_size, heads, num_tokens, num_tokens].
- Return type:
- class cellmil.models.mil.histobistro.FeedForward(dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1)[source]¶
Bases:
ModuleFeed-forward network used in transformer blocks.
This implements a standard MLP with GELU activation and dropout, as commonly used in transformer architectures.
- Parameters:
dim – Input dimension.
hidden_dim – Hidden dimension, typically 2-4x the input dimension.
dropout – Dropout rate applied after each layer.
- __init__(dim: int = 512, hidden_dim: int = 1024, dropout: float = 0.1)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class cellmil.models.mil.histobistro.PreNorm(dim: int, fn: Module, context_dim: Optional[int] = None)[source]¶
Bases:
ModuleLayer normalization module that applies normalization before a function.
This implements the “Pre-LN” (Pre-LayerNorm) variant of transformers, which applies layer normalization before the attention or feed-forward operations. This variant is known to be more stable during training compared to Post-LN.
- Parameters:
dim – Input dimension for the main tensor.
fn – Function to apply after normalization.
context_dim – Optional dimension for a context tensor that will also be normalized.
- __init__(dim: int, fn: Module, context_dim: Optional[int] = None)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, **kwargs: dict[str, Any])[source]¶
Apply normalization and then the function.
- Parameters:
x – Input tensor.
**kwargs – Additional keyword arguments passed to the function. May include a ‘context’ tensor if context_dim was provided.
- Returns:
Output after normalization and function application.
- Return type:
- class cellmil.models.mil.histobistro.TransformerBlocks(dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0)[source]¶
Bases:
ModuleStack of transformer encoder blocks.
Each block consists of a multi-head self-attention layer followed by a feed-forward network, both with residual connections and layer normalization.
- Parameters:
dim – Input dimension.
depth – Number of transformer blocks.
heads – Number of attention heads.
dim_head – Dimension of each attention head.
mlp_dim – Hidden dimension of the feed-forward network.
dropout – Dropout rate applied in both attention and feed-forward.
- __init__(dim: int, depth: int, heads: int, dim_head: int, mlp_dim: int, dropout: float = 0.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, register_hook: bool = False)[source]¶
Forward pass through the transformer blocks.
- Parameters:
x – Input tensor of shape [batch_size, seq_length, dim].
register_hook – Whether to register gradient hooks for interpretability.
- Returns:
Output tensor after passing through all transformer blocks.
- Return type:
- class cellmil.models.mil.histobistro.HistoBistro(num_classes: int, input_dim: int = 2048, dim: int = 512, depth: int = 2, heads: int = 8, mlp_dim: int = 512, pool: Literal['cls', 'mean'] = 'cls', dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, pos_enc: Optional[Module] = None)[source]¶
Bases:
ModuleHistoBistro - Histopathology Bi-level Transformer for MIL classification.
This is the main model class implementing the HistoBistro architecture, a transformer-based model for multiple instance learning (MIL) in histopathology images. The model processes a bag of instances (patches/cells) and makes a bag-level prediction. It includes interpretability methods to identify important instances.
- Parameters:
num_classes – Number of output classes for classification.
input_dim – Input dimension of instance features.
dim – Internal dimension for transformer processing.
depth – Number of transformer blocks.
heads – Number of attention heads.
mlp_dim – Hidden dimension of the feed-forward network.
pool – Pooling strategy (‘cls’ for class token or ‘mean’ for mean pooling).
dim_head – Dimension of each attention head.
dropout – Dropout rate in transformer blocks.
emb_dropout – Dropout rate applied to the input embeddings.
pos_enc – Optional positional encoding module.
- __init__(num_classes: int, input_dim: int = 2048, dim: int = 512, depth: int = 2, heads: int = 8, mlp_dim: int = 512, pool: Literal['cls', 'mean'] = 'cls', dim_head: int = 64, dropout: float = 0.0, emb_dropout: float = 0.0, pos_enc: Optional[Module] = None)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- save_gradient(grad: Tensor) None[source]¶
Save gradients during backpropagation for interpretability.
- Parameters:
grad – The gradient tensor to save
The gradient is processed to obtain a single importance value per instance by taking the mean across the feature dimension.
- get_instance_gradients() torch.Tensor | None[source]¶
Return the saved gradients for instance importance.
- Returns:
The gradients for each instance, if available
- get_normalized_gradients() torch.Tensor | None[source]¶
Return the absolute and normalized gradients for instance importance.
- Returns:
Normalized gradients (values between 0 and 1) for each instance, if available
- forward(x: Tensor, coords: Optional[Tensor] = None, register_hook: bool = False)[source]¶
Forward pass through the HistoBistro model.
- Parameters:
x – Input tensor of shape [batch_size, num_instances, input_dim] containing instance features.
coords – Optional tensor of shape [batch_size, num_instances, 2] containing spatial coordinates.
register_hook – Whether to register gradient hooks for interpretability.
- Returns:
- A tuple containing:
logits: Raw classification scores of shape [batch_size, num_classes]
Y_prob: Softmax probabilities of shape [batch_size, num_classes]
Y_hat: Predicted class indices of shape [batch_size]
gradients: Instance gradients for interpretability (or None)
results_dict: Dictionary with additional results for analysis
- Return type: