cellmil.models.mil.graphmil.gnn

Classes

CHIMERA(input_dim, dropout[, heads, ...])

CHIMERA_block(in_channels, out_channels, ...)

EGNN(input_dim, hidden_dim, n_layers, ...)

GAT(input_dim, hidden_dim, n_layers, dropout)

GATv2(input_dim, hidden_dim, n_layers, ...)

GNN(input_dim, hidden_dim, n_layers, ...)

Abstract base class for Graph Neural Networks.

SAGE(input_dim, hidden_dim, n_layers, ...)

SGFormer(input_dim, hidden_dim, n_layers[, ...])

SGFormerAttention(channels[, heads, ...])

Implements the logic from the SGFormer paper's official repository (function full_attention_conv), adapted for non-batched input and numerical stability.

SmallWorld(input_dim, hidden_dim, n_layers, ...)

SmallWorld GNN that creates additional connections between high-attention nodes.

class cellmil.models.mil.graphmil.gnn.GNN(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: Module, ABC

Abstract base class for Graph Neural Networks.

This class defines the interface for GNN models but cannot be instantiated directly. Subclasses must implement the layer creation logic in their __init__ method.

__init__(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_hyperparameters() dict[str, Any][source]

Get all hyperparameters for this GNN.

_get_specific_hyperparameters() dict[str, Any][source]

Override in subclasses to add specific hyperparameters.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from GNN layers.

Parameters:

data – Input graph data

Returns:

Dictionary with attention weights from each layer (empty for non-attention GNNs)

class cellmil.models.mil.graphmil.gnn.GAT(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, heads: int = 1, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, heads: int = 1, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from each GAT layer.

Parameters:

data – Input graph data

Returns:

{‘gnn_attention_layer_{i}’: weights}

Return type:

Dictionary with attention weights

class cellmil.models.mil.graphmil.gnn.GATv2(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

get_attention_weights(data: Data) dict[str, torch.Tensor][source]

Extract attention weights from each GATv2 layer.

Parameters:

data – Input graph data

Returns:

{‘gnn_attention_layer_{i}’: weights}

Return type:

Dictionary with attention weights

class cellmil.models.mil.graphmil.gnn.SAGE(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

class cellmil.models.mil.graphmil.gnn.EGNN(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.gnn.CHIMERA(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.gnn.CHIMERA_block(in_channels: int, out_channels: int, heads: int, dropout: float, residual: bool = True)[source]

Bases: Module

__init__(in_channels: int, out_channels: int, heads: int, dropout: float, residual: bool = True)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.gnn.SmallWorld(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]

Bases: GNN

SmallWorld GNN that creates additional connections between high-attention nodes.

After each layer, uses SAGPooling to identify important nodes (top 1% by attention score) and creates additional edges between them, forming a small-world-like topology where important nodes are more densely connected.

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

_add_small_world_edges(x: Tensor, edge_index: Tensor, batch: Tensor, attention_pool: SAGPooling) Tensor[source]

Add edges between top-k nodes based on attention scores.

Parameters:
  • x – Node features

  • edge_index – Current edge index

  • batch – Batch assignment for nodes

  • attention_pool – SAGPooling layer to compute attention scores

Returns:

Updated edge index with additional small-world connections

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.gnn.SGFormer(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]

Bases: GNN

__init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(data: Data) Data[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.models.mil.graphmil.gnn.SGFormerAttention(channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False)[source]

Bases: Module

Implements the logic from the SGFormer paper’s official repository (function full_attention_conv), adapted for non-batched input and numerical stability.

Note: This logic is different from Equations (2) & (3) in the paper text. This implementation corresponds to: Z = (Q(K^T V) + N*V) / (Q(K^T * 1) + N)

Parameters:
  • channels (int) – Size of each input sample (hidden_dim).

  • heads (int, optional) – Number of parallel attention heads.

  • head_channels (int, optional) – Size of each attention head.

  • qkv_bias (bool, optional) – If specified, add bias to query, key and value.

__init__(channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False) None[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, mask: Optional[Tensor] = None) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()[source]