cellmil.models.mil.graphmil.pool¶
Classes
|
AttentionDeepMIL pooling classifier. |
|
CLAM pooling classifier with attention-based multiple instance learning. |
|
Abstract base class for global pooling classifiers in GraphMIL. |
|
Mean pooling followed by MLP classifier. |
|
Standard MIL pooling classifier. |
- class cellmil.models.mil.graphmil.pool.GlobalPooling_Classifier(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]¶
-
Abstract base class for global pooling classifiers in GraphMIL.
This class defines the interface for pooling classifiers that aggregate node features into graph-level predictions. Each subclass handles its own specific arguments and validates them appropriately.
- __init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _get_specific_hyperparameters() dict[str, Any][source]¶
Override in subclasses to add specific hyperparameters.
- abstract forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass for the pooling classifier.
- Parameters:
x – Node features tensor of shape (num_nodes, input_dim)
batch – Batch assignment tensor for global pooling (if applicable)
**kwargs – Additional arguments specific to each pooling classifier
- Returns:
logits: Raw model outputs of shape (1, n_classes)
output_dict: Dictionary containing instance-level information
- Return type:
tuple containing
- class cellmil.models.mil.graphmil.pool.Mean_MLP(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]¶
Bases:
GlobalPooling_ClassifierMean pooling followed by MLP classifier.
- __init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass for Mean_MLP pooling classifier.
- Parameters:
x – Node features tensor of shape (num_nodes, input_dim)
batch – Batch assignment tensor for global pooling
**kwargs – Additional arguments (ignored by Mean_MLP)
- Returns:
logits: Raw model outputs of shape (1, n_classes)
output_dict: Dictionary containing instance-level information
- Return type:
tuple containing
- class cellmil.models.mil.graphmil.pool.CLAM(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]¶
Bases:
GlobalPooling_ClassifierCLAM pooling classifier with attention-based multiple instance learning.
- __init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, batch: Optional[Tensor] = None, label: Optional[Tensor] = None, instance_eval: bool = False, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass for CLAM pooling classifier.
- Parameters:
x – Node features tensor of shape (num_nodes, input_dim)
batch – Batch assignment tensor (not used by CLAM, must be single graph)
label – Ground truth labels for instance evaluation
instance_eval – Whether to perform instance-level evaluation
**kwargs – Additional arguments (ignored by CLAM)
- Returns:
logits: Raw model outputs of shape (1, n_classes)
output_dict: Dictionary containing instance-level information
- Return type:
tuple containing
- get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]¶
Extract attention weights from CLAM.
- Parameters:
x – Node features tensor
batch – Batch assignment tensor (should be single graph)
- Returns:
Attention weights tensor of shape [1, num_nodes] or [num_classes, num_nodes]
- class cellmil.models.mil.graphmil.pool.Standard(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]¶
Bases:
GlobalPooling_ClassifierStandard MIL pooling classifier.
- __init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass for Standard MIL pooling classifier.
- Parameters:
x – Node features tensor of shape (num_nodes, input_dim)
batch – Batch assignment tensor (not used by Standard, must be single graph)
**kwargs – Additional arguments (ignored by Standard)
- Returns:
logits: Raw model outputs of shape (1, n_classes)
output_dict: Dictionary containing instance-level information
- Return type:
tuple containing
- class cellmil.models.mil.graphmil.pool.Attention(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]¶
Bases:
GlobalPooling_ClassifierAttentionDeepMIL pooling classifier.
- __init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]¶
Forward pass for AttentionDeepMIL pooling classifier.
- Parameters:
x – Node features tensor of shape (num_nodes, input_dim)
batch – Batch assignment tensor (not used by AttentionDeepMIL, must be single graph)
**kwargs – Additional arguments (ignored by AttentionDeepMIL)
- Returns:
logits: Raw model outputs of shape (1, n_classes)
output_dict: Dictionary containing instance-level information and attention weights
- Return type:
tuple containing
- get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]¶
Extract attention weights from AttentionDeepMIL.
- Parameters:
x – Node features tensor
batch – Batch assignment tensor (must be single graph)
- Returns:
Attention weights tensor of shape [attention_branches, num_nodes]