import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Any
from topk.svm import SmoothTop1SVM # type: ignore
from torch_geometric.nn import global_mean_pool # type: ignore
from ..clam import CLAM_SB, CLAM_MB
from ..standard import MIL_fc, MIL_fc_mc
from ..attentiondeepmil import AttentionDeepMIL
from abc import ABC, abstractmethod
from cellmil.utils import logger
[docs]class GlobalPooling_Classifier(nn.Module, ABC):
"""Abstract base class for global pooling classifiers in GraphMIL.
This class defines the interface for pooling classifiers that aggregate
node features into graph-level predictions. Each subclass handles its
own specific arguments and validates them appropriately.
"""
[docs] def __init__(
self,
input_dim: int,
dropout: float,
n_classes: int,
size_arg: list[int] | str,
**kwargs: Any,
):
super().__init__() # type: ignore
self.input_dim = input_dim
self.dropout = dropout
self.kwargs = kwargs
self.n_classes = n_classes
self.size_arg = size_arg
[docs] def get_hyperparameters(self) -> dict[str, Any]:
"""Get all hyperparameters for this pooling classifier."""
return {
"type": self.__class__.__name__,
"input_dim": self.input_dim,
"dropout": self.dropout,
"n_classes": self.n_classes,
"size_arg": self.size_arg,
**self.kwargs,
**self._get_specific_hyperparameters()
}
[docs] def _get_specific_hyperparameters(self) -> dict[str, Any]:
"""Override in subclasses to add specific hyperparameters."""
return {}
[docs] @abstractmethod
def forward(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Forward pass for the pooling classifier.
Args:
x: Node features tensor of shape (num_nodes, input_dim)
batch: Batch assignment tensor for global pooling (if applicable)
**kwargs: Additional arguments specific to each pooling classifier
Returns:
tuple containing:
- logits: Raw model outputs of shape (1, n_classes)
- output_dict: Dictionary containing instance-level information
"""
pass
[docs] def get_attention_weights(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None
) -> torch.Tensor | None:
"""
Extract attention weights from the pooling classifier.
Args:
x: Node features tensor
batch: Batch assignment tensor (if applicable)
Returns:
Attention weights tensor or None if not available
"""
return None
[docs]class Mean_MLP(GlobalPooling_Classifier):
"""Mean pooling followed by MLP classifier."""
[docs] def __init__(
self,
input_dim: int,
dropout: float,
n_classes: int,
size_arg: list[int],
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
dropout=dropout,
n_classes=n_classes,
size_arg=size_arg,
**kwargs,
)
if isinstance(self.size_arg, str):
raise ValueError("Mean_MLP requires size_arg to be a list of integers")
layers: list[nn.Module] = []
prev_dim = input_dim
if len(self.size_arg) > 0:
for hidden_dim in self.size_arg:
layers.append(nn.Linear(prev_dim, hidden_dim))
layers.append(nn.ReLU())
prev_dim = hidden_dim
if dropout > 0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(prev_dim, n_classes))
self.classifier = nn.Sequential(*layers)
[docs] def forward(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Forward pass for Mean_MLP pooling classifier.
Args:
x: Node features tensor of shape (num_nodes, input_dim)
batch: Batch assignment tensor for global pooling
**kwargs: Additional arguments (ignored by Mean_MLP)
Returns:
tuple containing:
- logits: Raw model outputs of shape (1, n_classes)
- output_dict: Dictionary containing instance-level information
"""
# Use global_mean_pool with batch assignment if provided, otherwise assume single graph
if batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
x = global_mean_pool(x, batch)
logits = self.classifier(x)
Y_prob = (
torch.sigmoid(logits) if self.n_classes == 1 else F.softmax(logits, dim=1)
)
Y_hat = Y_prob.argmax(dim=1, keepdim=True)
output_dict: dict[str, Any] = {
"y_hat": Y_hat,
"y_prob": Y_prob,
}
return logits, output_dict
[docs]class CLAM(GlobalPooling_Classifier):
"""CLAM pooling classifier with attention-based multiple instance learning."""
[docs] def __init__(
self,
input_dim: int,
dropout: float,
n_classes: int,
size_arg: list[int] | str,
gate: bool = True,
k_sample: int = 8,
instance_loss_fn: nn.Module = SmoothTop1SVM(n_classes=2).cuda() if torch.cuda.is_available() else SmoothTop1SVM(n_classes=2),
subtyping: bool = False,
clam_type: str = "SB",
temperature: float = 1.0,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
dropout=dropout,
n_classes=n_classes,
size_arg=size_arg,
**kwargs
)
self.gate = gate
self.k_sample = k_sample
self.instance_loss_fn = instance_loss_fn
self.subtyping = subtyping
self.clam_type = clam_type
self.temperature = temperature
# Validate size_arg
if isinstance(self.size_arg, str):
if self.size_arg not in ["small", "big"]:
raise ValueError("size_arg must be 'small' or 'big' or a list")
else:
if len(self.size_arg) == 0 or len(self.size_arg) > 2:
raise ValueError(
"size_arg list must not be empty and must have at most 2 elements"
)
# Validate CLAM type
if self.clam_type not in ["SB", "MB"]:
raise ValueError(f"Unknown CLAM type: {self.clam_type}")
# Create the appropriate CLAM network
if self.clam_type == "SB":
self.net = CLAM_SB(
self.gate,
self.size_arg, # type: ignore
self.dropout > 0,
self.k_sample,
self.n_classes,
self.instance_loss_fn,
self.subtyping,
self.input_dim,
self.temperature,
)
elif self.clam_type == "MB":
self.net = CLAM_MB(
self.gate,
self.size_arg, # type: ignore
self.dropout > 0,
self.k_sample,
self.n_classes,
self.instance_loss_fn,
self.subtyping,
self.input_dim,
self.temperature,
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {
"gate": self.gate,
"k_sample": self.k_sample,
"instance_loss_fn": self.instance_loss_fn.__class__.__name__,
"subtyping": self.subtyping,
"clam_type": self.clam_type,
"temperature": self.temperature
}
[docs] def forward(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None,
label: torch.Tensor | None = None,
instance_eval: bool = False,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Forward pass for CLAM pooling classifier.
Args:
x: Node features tensor of shape (num_nodes, input_dim)
batch: Batch assignment tensor (not used by CLAM, must be single graph)
label: Ground truth labels for instance evaluation
instance_eval: Whether to perform instance-level evaluation
**kwargs: Additional arguments (ignored by CLAM)
Returns:
tuple containing:
- logits: Raw model outputs of shape (1, n_classes)
- output_dict: Dictionary containing instance-level information
"""
# CLAM doesn't use batch assignment since it handles MIL internally
# For batch_size > 1, we'd need to handle this differently, but we assert batch_size=1
if batch is not None and batch.max().item() > 0:
raise ValueError(
"CLAM pooling classifier requires batch_size=1. Found multiple graphs in batch."
)
logits, Y_prob, Y_hat, a, instance_dict = self.net(x, label, instance_eval)
output_dict: dict[str, Any] = {
"y_prob": Y_prob,
"y_hat": Y_hat,
"attention": a,
**instance_dict,
}
return logits, output_dict
[docs] def get_attention_weights(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None
) -> torch.Tensor | None:
"""
Extract attention weights from CLAM.
Args:
x: Node features tensor
batch: Batch assignment tensor (should be single graph)
Returns:
Attention weights tensor of shape [1, num_nodes] or [num_classes, num_nodes]
"""
try:
result = self.net.forward(x, attention_only=True, instance_eval=False) # type: ignore
if isinstance(result, torch.Tensor):
return result
except Exception:
pass
return None
[docs]class Standard(GlobalPooling_Classifier):
"""Standard MIL pooling classifier."""
[docs] def __init__(
self,
input_dim: int,
dropout: float,
n_classes: int,
size_arg: list[int] | str = "small",
standard_type: str = "fc",
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
dropout=dropout,
n_classes=n_classes,
size_arg=size_arg,
**kwargs
)
self.standard_type = standard_type
# Validate size_arg
if isinstance(self.size_arg, str):
if self.size_arg not in ["small"]:
raise ValueError("size_arg must be 'small' or 'big' or a list")
else:
if len(self.size_arg) == 0 or len(self.size_arg) > 2:
raise ValueError(
"size_arg list must not be empty and must have at most 2 elements"
)
# Validate standard_type
if self.standard_type not in ["fc", "fc_mc"]:
raise ValueError(f"Unknown Standard type: {self.standard_type}")
# Create the appropriate Standard network
if self.standard_type == "fc":
self.net = MIL_fc(
size_arg=self.size_arg, # type: ignore
dropout=self.dropout,
n_classes=self.n_classes,
embed_dim=self.input_dim,
)
elif self.standard_type == "fc_mc":
self.net = MIL_fc_mc(
size_arg=self.size_arg, # type: ignore
dropout=self.dropout,
n_classes=self.n_classes,
embed_dim=self.input_dim,
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {"standard_type": self.standard_type}
[docs] def forward(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Forward pass for Standard MIL pooling classifier.
Args:
x: Node features tensor of shape (num_nodes, input_dim)
batch: Batch assignment tensor (not used by Standard, must be single graph)
**kwargs: Additional arguments (ignored by Standard)
Returns:
tuple containing:
- logits: Raw model outputs of shape (1, n_classes)
- output_dict: Dictionary containing instance-level information
"""
# For batch_size > 1, we'd need to handle this differently, but we assert batch_size=1
if batch is not None and batch.max().item() > 0:
raise ValueError(
"Standard pooling classifier requires batch_size=1. Found multiple graphs in batch."
)
top_instance, Y_prob, Y_hat, y_probs, results_dict = self.net(x)
# Add consistent attention fields (Standard doesn't use attention)
output_dict: dict[str, Any] = {
"y_prob": Y_prob,
"y_hat": Y_hat,
"y_probs": y_probs,
**results_dict,
}
return top_instance, output_dict
[docs]class Attention(GlobalPooling_Classifier):
"""AttentionDeepMIL pooling classifier."""
[docs] def __init__(
self,
input_dim: int,
dropout: float,
n_classes: int,
size_arg: list[int],
attention_branches: int = 1,
temperature: float = 1.0,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
dropout=dropout,
n_classes=n_classes,
size_arg=size_arg,
**kwargs
)
if isinstance(self.size_arg, str):
raise ValueError("AttentionDeepMIL requires size_arg to be a list of integers")
if len(size_arg) != 2:
raise ValueError("size_arg must be a list with exactly 2 elements for AttentionDeepMIL")
if attention_branches < 1:
raise ValueError("attention_branches must be a positive integer")
self.attention_branches = attention_branches
self.temperature = temperature
self.net = AttentionDeepMIL(
size_arg=self.size_arg,
n_classes=self.n_classes,
embed_dim=self.input_dim,
attention_branches=self.attention_branches,
temperature=self.temperature,
dropout=self.dropout
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {
"attention_branches": self.attention_branches,
"temperature": self.temperature
}
[docs] def forward(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None,
**kwargs: Any,
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
"""Forward pass for AttentionDeepMIL pooling classifier.
Args:
x: Node features tensor of shape (num_nodes, input_dim)
batch: Batch assignment tensor (not used by AttentionDeepMIL, must be single graph)
**kwargs: Additional arguments (ignored by AttentionDeepMIL)
Returns:
tuple containing:
- logits: Raw model outputs of shape (1, n_classes)
- output_dict: Dictionary containing instance-level information and attention weights
"""
# For batch_size > 1, we'd need to handle this differently, but we assert batch_size=1
if batch is not None and batch.max().item() > 0:
raise ValueError(
"AttentionDeepMIL pooling classifier requires batch_size=1. Found multiple graphs in batch."
)
logits, output_dict = self.net(x)
return logits, output_dict
[docs] def get_attention_weights(
self,
x: torch.Tensor,
batch: torch.Tensor | None = None
) -> torch.Tensor | None:
"""
Extract attention weights from AttentionDeepMIL.
Args:
x: Node features tensor
batch: Batch assignment tensor (must be single graph)
Returns:
Attention weights tensor of shape [attention_branches, num_nodes]
"""
if batch is not None and batch.max().item() > 0:
raise ValueError(
"AttentionDeepMIL requires batch_size=1. Found multiple graphs in batch."
)
try:
logits, output_dict = self.net(x)
logger.info(f"Extracted logits: {logits}")
return output_dict.get('attention')
except Exception:
return None