# -*- coding: utf-8 -*-
# Standard Model Implementation
#
# References:
# Data-efficient and weakly supervised computational pathology on whole-slide images
# Lu, Ming Y et al., Nature Biomedical Engineering, 2021
# DOI: https://doi.org/10.1038/s41551-021-00707-9
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any, Literal
from .utils import LitGeneral
[docs]class MIL_fc(nn.Module):
"""
Multiple Instance Learning model with fully connected layers for binary classification.
This model processes a bag of instances, applies a feature extractor (FC layers),
and performs binary classification. It selects the top k instances based on their
probability scores for the positive class.
Args:
size_arg: Size configuration for the network architecture ('small' is the only option currently).
dropout: Dropout rate for regularization.
n_classes: Number of classes (must be 2 for binary classification).
top_k: Number of top instances to select based on positive class probability.
embed_dim: Dimension of the input feature embeddings.
"""
[docs] def __init__(
self,
size_arg: Literal["small"] | list[int] = "small",
dropout: float = 0.0,
n_classes: int = 2,
top_k: int = 1,
embed_dim: int = 1024,
):
super().__init__() # type: ignore
assert n_classes == 2
self.n_classes = n_classes
self.size_dict = {"small": [embed_dim, 512]}
if isinstance(size_arg, list):
if len(size_arg) != 1:
raise ValueError("size_arg must be a list of length 1")
size = [embed_dim, *size_arg]
else:
size = self.size_dict[size_arg]
fc: list[nn.Module] = [
nn.Linear(size[0], size[1]),
nn.ReLU(),
nn.Dropout(dropout),
]
self.fc = nn.Sequential(*fc)
self.classifier = nn.Linear(size[1], n_classes)
self.top_k = top_k
[docs] def forward(
self, h: torch.Tensor, return_features: bool = False
) -> tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, torch.Tensor]
]:
"""
Forward pass of the MIL_fc model.
Args:
h: Input tensor of shape [n_instances, embed_dim] containing instance embeddings.
return_features: If True, returns the feature representations of top instances.
Returns:
A tuple containing:
- top_instance: Logits of the top instance(s).
- Y_prob: Softmax probabilities for the top instance(s).
- Y_hat: Predicted class labels for the top instance(s).
- y_probs: Softmax probabilities for all instances.
- results_dict: Additional results, may contain feature representations if return_features is True.
"""
h = self.fc(h)
logits = self.classifier(h) # K x 2
y_probs = F.softmax(logits, dim=1)
top_instance_idx = torch.topk(y_probs[:, 1], self.top_k, dim=0)[1].view(
1,
)
top_instance = torch.index_select(logits, dim=0, index=top_instance_idx)
Y_hat = torch.topk(top_instance, 1, dim=1)[1]
Y_prob = F.softmax(top_instance, dim=1)
results_dict: dict[str, torch.Tensor] = {}
if return_features:
top_features = torch.index_select(h, dim=0, index=top_instance_idx)
results_dict.update({"features": top_features})
return top_instance, Y_prob, Y_hat, y_probs, results_dict
[docs]class MIL_fc_mc(nn.Module):
"""
Multiple Instance Learning model with fully connected layers for multi-class classification.
This model processes a bag of instances, applies a feature extractor (FC layers),
and performs multi-class classification. It selects the top instance based on
the highest probability across all classes.
Args:
size_arg: Size configuration for the network architecture ('small' is the only option currently).
dropout: Dropout rate for regularization.
n_classes: Number of classes (must be > 2 for multi-class classification).
top_k: Number of top instances to select (must be 1 for this implementation).
embed_dim: Dimension of the input feature embeddings.
"""
[docs] def __init__(
self,
size_arg: Literal["small"] | list[int] = "small",
dropout: float = 0.0,
n_classes: int = 2,
top_k: int = 1,
embed_dim: int = 1024,
):
super().__init__() # type: ignore
assert n_classes > 2
self.size_dict = {"small": [embed_dim, 512]}
if isinstance(size_arg, list):
if len(size_arg) != 1:
raise ValueError("size_arg must be a list of length 1")
size = [embed_dim, *size_arg]
else:
size = self.size_dict[size_arg]
fc: list[nn.Module] = [
nn.Linear(size[0], size[1]),
nn.ReLU(),
nn.Dropout(dropout),
]
self.fc = nn.Sequential(*fc)
self.classifiers = nn.Linear(size[1], n_classes)
self.top_k = top_k
self.n_classes = n_classes
assert self.top_k == 1
[docs] def forward(
self, h: torch.Tensor, return_features: bool = False
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, dict[str, Any]]:
"""
Forward pass of the MIL_fc_mc model for multi-class classification.
Args:
h: Input tensor of shape [n_instances, embed_dim] containing instance embeddings.
return_features: If True, returns the feature representations of top instances.
Returns:
A tuple containing:
- top_instance: Logits of the top instance.
- Y_prob: Softmax probabilities for the top instance.
- Y_hat: Predicted class label for the top instance.
- y_probs: Softmax probabilities for all instances.
- results_dict: Additional results, may contain feature representations if return_features is True.
"""
h = self.fc(h)
logits = self.classifiers(h)
y_probs = F.softmax(logits, dim=1)
m = y_probs.view(1, -1).argmax(1)
top_indices = torch.cat(
((m // self.n_classes).view(-1, 1), (m % self.n_classes).view(-1, 1)), dim=1
).view(-1, 1)
top_instance = logits[top_indices[0]]
Y_hat = top_indices[1]
Y_prob = y_probs[top_indices[0]]
results_dict: dict[str, torch.Tensor] = {}
if return_features:
top_features = torch.index_select(h, dim=0, index=top_indices[0])
results_dict.update({"features": top_features})
return top_instance, Y_prob, Y_hat, y_probs, results_dict
[docs]class LitStandard(LitGeneral):
[docs] def forward(self, x: torch.Tensor):
return self.model(x)[0]
def _shared_step(
self,
batch: tuple[torch.Tensor, torch.Tensor],
stage: str,
log: bool = True
):
x, y = batch
# Ensure MIL batch size is 1
assert x.size(0) == 1, "Batch size must be 1 for MIL"
x = x.squeeze(0) # [n_instances, feat_dim]
logits = self(x)
loss = self.loss(logits, y)
if log:
self.log(f"{stage}/loss", loss, prog_bar=(stage != "train"), on_step=(stage=="train"), on_epoch=True)
return loss, logits, y