cellmil.models.mil.graphmil.gnn¶
Classes
|
|
|
|
|
|
|
|
|
|
|
Abstract base class for Graph Neural Networks. |
|
|
|
|
|
Implements the logic from the SGFormer paper's official repository (function full_attention_conv), adapted for non-batched input and numerical stability. |
|
SmallWorld GNN that creates additional connections between high-attention nodes. |
- class cellmil.models.mil.graphmil.gnn.GNN(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]¶
-
Abstract base class for Graph Neural Networks.
This class defines the interface for GNN models but cannot be instantiated directly. Subclasses must implement the layer creation logic in their __init__ method.
- __init__(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _get_specific_hyperparameters() dict[str, Any][source]¶
Override in subclasses to add specific hyperparameters.
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.GAT(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, heads: int = 1, **kwargs: Any)[source]¶
Bases:
GNN
- class cellmil.models.mil.graphmil.gnn.GATv2(input_dim: int, hidden_dim: int | list[int], n_layers: int, dropout: float, **kwargs: Any)[source]¶
Bases:
GNN
- class cellmil.models.mil.graphmil.gnn.SAGE(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]¶
Bases:
GNN
- class cellmil.models.mil.graphmil.gnn.EGNN(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]¶
Bases:
GNN- __init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.CHIMERA(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]¶
Bases:
GNN- __init__(input_dim: int, dropout: float, heads: int = 1, residual: bool = True, n_layers: int = 3, hidden_dim: list[int] = [128, 256, 512], **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.CHIMERA_block(in_channels: int, out_channels: int, heads: int, dropout: float, residual: bool = True)[source]¶
Bases:
Module- __init__(in_channels: int, out_channels: int, heads: int, dropout: float, residual: bool = True)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.SmallWorld(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]¶
Bases:
GNNSmallWorld GNN that creates additional connections between high-attention nodes.
After each layer, uses SAGPooling to identify important nodes (top 1% by attention score) and creates additional edges between them, forming a small-world-like topology where important nodes are more densely connected.
- __init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float, top_k_ratio: float = 0.005, heads: int = 1, layer_type: Literal['GCN', 'GAT', 'SAGE'] = 'SAGE', **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- _add_small_world_edges(x: Tensor, edge_index: Tensor, batch: Tensor, attention_pool: SAGPooling) Tensor[source]¶
Add edges between top-k nodes based on attention scores.
- Parameters:
x – Node features
edge_index – Current edge index
batch – Batch assignment for nodes
attention_pool – SAGPooling layer to compute attention scores
- Returns:
Updated edge index with additional small-world connections
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.SGFormer(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]¶
Bases:
GNN- __init__(input_dim: int, hidden_dim: int, n_layers: int, dropout: float = 0.25, heads: int = 1, alpha: float = 0.5, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(data: Data) Data[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.models.mil.graphmil.gnn.SGFormerAttention(channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False)[source]¶
Bases:
ModuleImplements the logic from the SGFormer paper’s official repository (function full_attention_conv), adapted for non-batched input and numerical stability.
Note: This logic is different from Equations (2) & (3) in the paper text. This implementation corresponds to: Z = (Q(K^T V) + N*V) / (Q(K^T * 1) + N)
- Parameters:
- __init__(channels: int, heads: int = 1, head_channels: int = 64, qkv_bias: bool = False) None[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, mask: Optional[Tensor] = None) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.