cellmil.models.mil.graphmil.pool

Classes

Attention(input_dim, dropout, n_classes, ...)

AttentionDeepMIL pooling classifier.

CLAM(input_dim, dropout, n_classes, size_arg)

CLAM pooling classifier with attention-based multiple instance learning.

GlobalPooling_Classifier(input_dim, dropout, ...)

Abstract base class for global pooling classifiers in GraphMIL.

Mean_MLP(input_dim, dropout, n_classes, ...)

Mean pooling followed by MLP classifier.

Standard(input_dim, dropout, n_classes[, ...])

Standard MIL pooling classifier.

class cellmil.models.mil.graphmil.pool.GlobalPooling_Classifier(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]

Bases: Module, ABC

Abstract base class for global pooling classifiers in GraphMIL.

This class defines the interface for pooling classifiers that aggregate node features into graph-level predictions. Each subclass handles its own specific arguments and validates them appropriately.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_hyperparameters() dict[str, Any][source]

Get all hyperparameters for this pooling classifier.

_get_specific_hyperparameters() dict[str, Any][source]

Override in subclasses to add specific hyperparameters.

abstract forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for the pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor for global pooling (if applicable)

  • **kwargs – Additional arguments specific to each pooling classifier

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from the pooling classifier.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (if applicable)

Returns:

Attention weights tensor or None if not available

class cellmil.models.mil.graphmil.pool.Mean_MLP(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

Mean pooling followed by MLP classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for Mean_MLP pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor for global pooling

  • **kwargs – Additional arguments (ignored by Mean_MLP)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

class cellmil.models.mil.graphmil.pool.CLAM(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

CLAM pooling classifier with attention-based multiple instance learning.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str, gate: bool = True, k_sample: int = 8, instance_loss_fn: Module = SmoothTop1SVM(), subtyping: bool = False, clam_type: str = 'SB', temperature: float = 1.0, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, label: Optional[Tensor] = None, instance_eval: bool = False, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for CLAM pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by CLAM, must be single graph)

  • label – Ground truth labels for instance evaluation

  • instance_eval – Whether to perform instance-level evaluation

  • **kwargs – Additional arguments (ignored by CLAM)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from CLAM.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (should be single graph)

Returns:

Attention weights tensor of shape [1, num_nodes] or [num_classes, num_nodes]

class cellmil.models.mil.graphmil.pool.Standard(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

Standard MIL pooling classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int] | str = 'small', standard_type: str = 'fc', **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for Standard MIL pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by Standard, must be single graph)

  • **kwargs – Additional arguments (ignored by Standard)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information

Return type:

tuple containing

class cellmil.models.mil.graphmil.pool.Attention(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]

Bases: GlobalPooling_Classifier

AttentionDeepMIL pooling classifier.

__init__(input_dim: int, dropout: float, n_classes: int, size_arg: list[int], attention_branches: int = 1, temperature: float = 1.0, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, batch: Optional[Tensor] = None, **kwargs: Any) tuple[torch.Tensor, dict[str, torch.Tensor]][source]

Forward pass for AttentionDeepMIL pooling classifier.

Parameters:
  • x – Node features tensor of shape (num_nodes, input_dim)

  • batch – Batch assignment tensor (not used by AttentionDeepMIL, must be single graph)

  • **kwargs – Additional arguments (ignored by AttentionDeepMIL)

Returns:

  • logits: Raw model outputs of shape (1, n_classes)

  • output_dict: Dictionary containing instance-level information and attention weights

Return type:

tuple containing

get_attention_weights(x: Tensor, batch: Optional[Tensor] = None) torch.Tensor | None[source]

Extract attention weights from AttentionDeepMIL.

Parameters:
  • x – Node features tensor

  • batch – Batch assignment tensor (must be single graph)

Returns:

Attention weights tensor of shape [attention_branches, num_nodes]