# -*- coding: utf-8 -*-
# TransMIL Model Implementation
#
# References:
# Transmil: Transformer based correlated multiple instance learning for whole slide image classification
# Shao, Zhuchen et al., Advances in Neural Information Processing Systems, 2021
# DOI: https://proceedings.neurips.cc/paper/2021/hash/10c272d06794d3e5785d5e7c5356e9ff-Abstract.html
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from nystrom_attention import NystromAttention # type: ignore
[docs]class TransLayer(nn.Module):
"""Transformer Layer with Nystrom Attention.
This layer implements a transformer block using Nystrom Attention, which is an efficient
approximation of the standard self-attention mechanism. It's particularly useful for
processing long sequences as it reduces the computational complexity from O(n²) to O(n).
Args:
norm_layer (type[nn.LayerNorm], optional): Normalization layer class. Defaults to nn.LayerNorm.
dim (int, optional): Feature dimension. Defaults to 512.
"""
[docs] def __init__(self, norm_layer: type[nn.LayerNorm] = nn.LayerNorm, dim: int = 512):
super().__init__() # type: ignore
self.norm = norm_layer(dim)
self.attn = NystromAttention(
dim=dim,
dim_head=dim // 8,
heads=8,
num_landmarks=dim // 2, # number of landmarks
pinv_iterations=6, # number of moore-penrose iterations for approximating pinverse. 6 was recommended by the paper
residual=True, # whether to do an extra residual with the value or not. supposedly faster convergence if turned on
dropout=0.1,
)
[docs] def forward(
self, x: torch.Tensor, return_attention: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""Forward pass for the transformer layer.
Args:
x (torch.Tensor): Input tensor of shape [B, N, C] where B is batch size,
N is sequence length, and C is feature dimension.
return_attention (bool, optional): Whether to return attention maps. Defaults to False.
Returns:
torch.Tensor | tuple[torch.Tensor, torch.Tensor]: If return_attention is False,
returns the output tensor of the same shape as input. If True, returns a tuple
containing the output tensor and the attention map.
"""
norm_x = self.norm(x)
output, attention = self.attn(norm_x, return_attn=True)
x = x + output
if return_attention:
return x, attention
return x
[docs]class PPEG(nn.Module):
"""Pyramid Position Encoding Generator.
PPEG is a positional encoding module that uses convolutional layers with different
kernel sizes to capture positional information at multiple scales. It transforms
tokens into a 2D spatial grid, applies convolutional operations, and reshapes
back to the sequence format.
This module helps the transformer model to be aware of the spatial relationships
between tokens, which is crucial for vision tasks.
Args:
dim (int, optional): Feature dimension. Defaults to 512.
"""
[docs] def __init__(self, dim: int = 512):
super(PPEG, self).__init__() # type: ignore
self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)
[docs] def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
"""Forward pass for the position encoding generator.
Args:
x (torch.Tensor): Input tensor of shape [B, N, C] where B is batch size,
N is sequence length (including class token), and C is feature dimension.
H (int): Height of the feature map when arranged in a 2D grid.
W (int): Width of the feature map when arranged in a 2D grid.
Returns:
torch.Tensor: Output tensor with positional encoding information added,
same shape as input [B, N, C].
"""
B, _, C = x.shape
cls_token, feat_token = x[:, 0], x[:, 1:]
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
x = x.flatten(2).transpose(1, 2)
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
return x
[docs]class TransMIL(nn.Module):
"""Transformer-based Multiple Instance Learning model.
TransMIL is a model for MIL tasks using transformers, as presented in the paper
"TransMIL: Transformer based Correlated Multiple Instance Learning for Whole Slide Image Classification".
It processes a bag of instances (e.g., patches from a whole slide image) using
transformer architecture with Nystrom attention to efficiently handle large sets of instances.
The model includes:
- A learnable class token similar to ViT
- Positional encoding through a convolutional approach (PPEG)
- Multiple transformer layers with Nystrom attention
Args:
n_classes (int): Number of output classes for classification.
"""
[docs] def __init__(self, n_classes: int, embed_dim: int = 1024):
super(TransMIL, self).__init__() # type: ignore
d = 512
self.pos_layer = PPEG(dim=d)
self._fc1 = nn.Sequential(nn.Linear(embed_dim, d), nn.ReLU())
self.cls_token = nn.Parameter(torch.randn(1, 1, d))
self.n_classes = n_classes
self.layer1 = TransLayer(dim=d)
self.layer2 = TransLayer(dim=d)
self.norm = nn.LayerNorm(d)
self._fc2 = nn.Linear(d, self.n_classes)
# TODO: Review this, embeding is much smaller
[docs] def forward(
self, data: torch.Tensor
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]
]:
"""Forward pass for the TransMIL model.
The model processes a bag of instances (features) using a transformer-based architecture.
It first projects the features to a lower dimension, adds a class token, applies
transformer layers with positional encoding, and finally produces classification outputs.
Args:
data (torch.Tensor): Input tensor of shape [B, n, D] where B is batch size,
n is the number of instances in each bag, and D is the
input feature dimension.
Returns:
tuple containing:
- logits (torch.Tensor): Raw classification scores [B, n_classes]
- Y_prob (torch.Tensor): Probability distribution over classes [B, n_classes]
- Y_hat (torch.Tensor): Predicted class indices [B]
- cls_attn (torch.Tensor): Class token's attention to each instance [B, n+1]
- results_dict (dict[str, torch.Tensor]): Additional results/metrics
"""
h = data.float() # [B, n, D]
h = self._fc1(h) # [B, n, d]
# ---->pad
H = h.shape[1]
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
add_length = _H * _W - H
h = torch.cat([h, h[:, :add_length, :]], dim=1) # [B, N, d]
# ---->cls_token
B = h.shape[0]
cls_tokens = self.cls_token.expand(B, -1, -1).cuda()
h = torch.cat((cls_tokens, h), dim=1)
# ---->Translayer x1
h = self.layer1(h) # [B, N, d]
# ---->PPEG
h = self.pos_layer(h, _H, _W) # [B, N, d]
# ---->Translayer x2
h, attention = self.layer2(h, return_attention=True) # [B, N, d]
# Extract CLS token attention (first token's attention to all others)
# Shape of attention is [B, h, N, N] where h is number of heads
# We want the first row for each head (CLS token's attention to all tokens)
cls_attention = attention[:, :, 0, :] # [B, h, N]
# Average across heads to get a single attention vector per batch item
# Or keep all heads separate if you prefer
cls_attention_avg = cls_attention.mean(dim=1) # [B, N]
# Remove padding if needed - get only attention to real tokens
orig_seq_len = H + 1 # +1 for cls token
cls_attn = cls_attention_avg[:, :orig_seq_len]
# ---->cls_token
h = self.norm(h)[:, 0]
# ---->predict
logits = self._fc2(h) # [B, n_classes]
Y_hat = torch.argmax(logits, dim=1)
Y_prob = F.softmax(logits, dim=1)
results_dict: dict[str, torch.Tensor] = {}
return logits, Y_prob, Y_hat, cls_attn, results_dict