import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from abc import ABC
from typing import Any, cast, Literal, Optional
from torch_geometric.nn import SAGEConv, GATConv, GATv2Conv, SAGPooling, GCNConv # type: ignore
# from torch_geometric.nn.attention import SGFormerAttention # type: ignore
from torch_geometric.data import Data # type: ignore
from egnn_pytorch import EGNN_Sparse # type: ignore
from cellmil.utils import logger
[docs]class GNN(nn.Module, ABC):
"""Abstract base class for Graph Neural Networks.
This class defines the interface for GNN models but cannot be instantiated directly.
Subclasses must implement the layer creation logic in their __init__ method.
"""
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int | list[int],
n_layers: int,
dropout: float,
**kwargs: Any,
):
super().__init__() # type: ignore
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.n_layers = n_layers
self.dropout = dropout
self.kwargs = kwargs
self.convs = nn.ModuleList()
[docs] def get_hyperparameters(self) -> dict[str, Any]:
"""Get all hyperparameters for this GNN."""
return {
"type": self.__class__.__name__,
"input_dim": self.input_dim,
"hidden_dim": self.hidden_dim,
"n_layers": self.n_layers,
"dropout": self.dropout,
**self.kwargs,
**self._get_specific_hyperparameters(),
}
[docs] def _get_specific_hyperparameters(self) -> dict[str, Any]:
"""Override in subclasses to add specific hyperparameters."""
return {}
[docs] def forward(self, data: Data) -> Data:
x, edge_index = data.x, data.edge_index
for conv in self.convs:
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
data.x = x
return data
[docs] def get_attention_weights(self, data: Data) -> dict[str, torch.Tensor]:
"""
Extract attention weights from GNN layers.
Args:
data: Input graph data
Returns:
Dictionary with attention weights from each layer (empty for non-attention GNNs)
"""
return {}
[docs]class GAT(GNN):
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int,
n_layers: int,
dropout: float,
heads: int = 1,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
self.heads = heads
# For multi-layer GAT, adjust_hidden_dim to be divisible by heads if needed
if n_layers > 1 and hidden_dim % self.heads != 0:
# Round up to next multiple of heads to ensure proper concatenation
adjusted_hidden_dim = (
(hidden_dim + self.heads - 1) // self.heads
) * self.heads
logger.warning(
f"GAT: hidden_dim ({hidden_dim}) adjusted to {adjusted_hidden_dim} "
f"to be divisible by heads ({self.heads}) for proper concatenation"
)
self.hidden_dim = adjusted_hidden_dim
else:
adjusted_hidden_dim = hidden_dim
for i in range(n_layers):
in_dim = input_dim if i == 0 else adjusted_hidden_dim
if i < n_layers - 1:
out_dim = adjusted_hidden_dim // self.heads
concat = True
else:
out_dim = adjusted_hidden_dim
concat = False
self.convs.append(
GATConv(
in_dim, out_dim, heads=self.heads, dropout=dropout, concat=concat
)
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {"heads": self.heads}
[docs] def get_attention_weights(self, data: Data) -> dict[str, torch.Tensor]:
"""
Extract attention weights from each GAT layer.
Args:
data: Input graph data
Returns:
Dictionary with attention weights: {'gnn_attention_layer_{i}': weights}
"""
attention_weights: dict[str, torch.Tensor] = {}
if data.x is None:
return attention_weights
x = data.x
edge_index = data.edge_index
for i, conv in enumerate(self.convs):
try:
# Extract attention weights from GAT layer
result = conv(x, edge_index, return_attention_weights=True)
if isinstance(result, tuple) and len(result) == 2: # type: ignore
x_new, attention_info = result # type: ignore
if isinstance(attention_info, tuple) and len(attention_info) == 2: # type: ignore
_, att_weights = attention_info # type: ignore
if isinstance(att_weights, torch.Tensor):
attention_weights[f"gnn_attention_layer_{i}"] = att_weights
elif isinstance(attention_info, torch.Tensor):
attention_weights[f"gnn_attention_layer_{i}"] = attention_info
x = x_new # type: ignore
else:
# Fallback: normal forward pass
x = conv(x, edge_index)
except Exception:
# Fallback: normal forward pass
x = conv(x, edge_index)
# Apply activation and dropout
x = F.relu(x) # type: ignore
x = F.dropout(x, p=self.dropout, training=self.training)
return attention_weights
[docs]class GATv2(GNN):
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int | list[int],
n_layers: int,
dropout: float,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
_heads = cast(int, kwargs.get("heads", 1))
# Handle both int and list[int] for hidden_dim
if isinstance(hidden_dim, list):
if len(hidden_dim) != n_layers:
raise ValueError("GATv2: hidden_dim list length must match n_layers")
hidden_dims = hidden_dim
else:
# If int, use the same dimension for all layers
hidden_dims = [hidden_dim] * n_layers
# Adjust hidden dimensions to be divisible by heads if needed
adjusted_hidden_dims: list[int] = []
for i, h_dim in enumerate(hidden_dims):
if h_dim % _heads != 0:
# Round up to next multiple of heads to ensure proper concatenation
adjusted_h_dim = ((h_dim + _heads - 1) // _heads) * _heads
logger.warning(
f"GATv2: hidden_dim[{i}] ({h_dim}) adjusted to {adjusted_h_dim} "
f"to be divisible by heads ({_heads}) for proper concatenation"
)
adjusted_hidden_dims.append(adjusted_h_dim)
else:
adjusted_hidden_dims.append(h_dim)
# Store the adjusted dimensions
self.hidden_dim = (
adjusted_hidden_dims
if isinstance(hidden_dim, list)
else adjusted_hidden_dims[0]
)
self.heads = _heads # Store for hyperparameters
for i in range(n_layers):
in_dim = input_dim if i == 0 else adjusted_hidden_dims[i - 1]
current_hidden_dim = adjusted_hidden_dims[i]
if i < n_layers - 1:
out_dim = current_hidden_dim // _heads
concat = True
else:
out_dim = current_hidden_dim
concat = False
self.convs.append(
GATv2Conv(in_dim, out_dim, heads=_heads, dropout=dropout, concat=concat)
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {"heads": self.heads}
[docs] def get_attention_weights(self, data: Data) -> dict[str, torch.Tensor]:
"""
Extract attention weights from each GATv2 layer.
Args:
data: Input graph data
Returns:
Dictionary with attention weights: {'gnn_attention_layer_{i}': weights}
"""
attention_weights: dict[str, torch.Tensor] = {}
if data.x is None:
return attention_weights
x = data.x
edge_index = data.edge_index
for i, conv in enumerate(self.convs):
try:
# Extract attention weights from GATv2 layer
result = conv(x, edge_index, return_attention_weights=True)
if isinstance(result, tuple) and len(result) == 2: # type: ignore
x_new, attention_info = result # type: ignore
if isinstance(attention_info, tuple) and len(attention_info) == 2: # type: ignore
_, att_weights = attention_info # type: ignore
if isinstance(att_weights, torch.Tensor):
attention_weights[f"gnn_attention_layer_{i}"] = att_weights
elif isinstance(attention_info, torch.Tensor):
attention_weights[f"gnn_attention_layer_{i}"] = attention_info
x = x_new # type: ignore
else:
# Fallback: normal forward pass
x = conv(x, edge_index)
except Exception:
# Fallback: normal forward pass
x = conv(x, edge_index)
# Apply activation and dropout
x = F.relu(x) # type: ignore
x = F.dropout(x, p=self.dropout, training=self.training)
return attention_weights
[docs]class SAGE(GNN):
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int,
n_layers: int,
dropout: float,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
for i in range(n_layers):
in_dim = input_dim if i == 0 else hidden_dim
out_dim = hidden_dim
self.convs.append(SAGEConv(in_channels=in_dim, out_channels=out_dim))
[docs]class EGNN(GNN):
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int,
n_layers: int,
dropout: float,
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
self.pos_dim = int(cast(int, kwargs.get("pos_dim", 2)))
self.proj = nn.Linear(input_dim, hidden_dim)
for _ in range(n_layers):
self.convs.append(
EGNN_Sparse(
feats_dim=hidden_dim,
pos_dim=self.pos_dim,
dropout=float(dropout),
)
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {"pos_dim": self.pos_dim}
[docs] def forward(self, data: Data) -> Data:
x = self.proj(data.x)
x_all = torch.cat([cast(torch.Tensor, data.pos), x], dim=-1)
edge_index = data.edge_index
for conv in self.convs:
x_all = conv(x_all, edge_index)
data.pos = x_all[:, : self.pos_dim]
data.x = x_all[:, self.pos_dim :]
return data
[docs]class CHIMERA(GNN):
[docs] def __init__(
self,
input_dim: int,
dropout: float,
heads: int = 1,
residual: bool = True,
n_layers: int = 3,
hidden_dim: list[int] = [128, 256, 512],
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim[0],
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
self.heads = heads
self.residual = residual
self.hidden_dim = hidden_dim
if len(hidden_dim) != self.n_layers:
raise ValueError("CHIMERA: hidden_dim list length must match n_layers")
self.blocks = nn.ModuleList()
for i in range(n_layers):
in_channels = input_dim if i == 0 else hidden_dim[i - 1] * self.heads
out_channels = hidden_dim[i]
self.blocks.append(
CHIMERA_block(
in_channels,
out_channels,
heads=self.heads,
dropout=dropout,
residual=residual,
)
)
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {"heads": self.heads, "residual": self.residual}
[docs] def forward(self, data: Data) -> Data:
for block in self.blocks:
data = block(data)
return data
[docs]class CHIMERA_block(nn.Module):
[docs] def __init__(
self,
in_channels: int,
out_channels: int,
heads: int,
dropout: float,
residual: bool = True,
):
super().__init__() # type: ignore
self.convs = nn.ModuleList()
_in_channels = in_channels
_out_channels = out_channels
for _ in range(2):
self.convs.append(
GATv2Conv(
_in_channels,
_out_channels,
heads=heads,
concat=True,
dropout=dropout,
residual=residual,
)
)
_in_channels = _out_channels
self.norms = nn.ModuleList()
for _ in range(2):
self.norms.append(nn.BatchNorm1d(_out_channels * heads))
self.pool = SAGPooling(_out_channels * heads, ratio=0.5)
[docs] def forward(self, data: Data) -> Data:
x, edge_index, batch = data.x, data.edge_index, data.batch
for i in range(2):
x = self.convs[i](x, edge_index)
x = F.elu(x)
x = self.norms[i](x)
x, edge_index, _, batch, _, _ = self.pool(x, edge_index, batch=batch)
data.x = x
data.edge_index = edge_index
data.batch = batch
return data
[docs]class SmallWorld(GNN):
"""
SmallWorld GNN that creates additional connections between high-attention nodes.
After each layer, uses SAGPooling to identify important nodes (top 1% by attention score)
and creates additional edges between them, forming a small-world-like topology where
important nodes are more densely connected.
"""
[docs] def __init__(
self,
input_dim: int,
hidden_dim: int,
n_layers: int,
dropout: float,
top_k_ratio: float = 0.005,
heads: int = 1,
layer_type: Literal["GCN", "GAT", "SAGE"] = "SAGE",
**kwargs: Any,
):
super().__init__(
input_dim=input_dim,
hidden_dim=hidden_dim,
n_layers=n_layers,
dropout=dropout,
**kwargs,
)
self.top_k_ratio = top_k_ratio
self.heads = heads
self.layer_type = layer_type
# Adjust hidden_dim to be divisible by heads if needed (only for GAT)
if layer_type == "GAT" and hidden_dim % heads != 0:
adjusted_hidden_dim = ((hidden_dim + heads - 1) // heads) * heads
logger.warning(
f"SmallWorld: hidden_dim ({hidden_dim}) adjusted to {adjusted_hidden_dim} "
f"to be divisible by heads ({heads}) for proper concatenation"
)
self.hidden_dim = adjusted_hidden_dim
else:
adjusted_hidden_dim = hidden_dim
# Create convolution layers based on layer_type
for i in range(n_layers):
in_dim = input_dim if i == 0 else adjusted_hidden_dim
if layer_type == "GAT":
if i < n_layers - 1:
out_dim = adjusted_hidden_dim // heads
concat = True
else:
out_dim = adjusted_hidden_dim
concat = False
self.convs.append(
GATConv(
in_dim, out_dim, heads=heads, dropout=dropout, concat=concat
)
)
elif layer_type == "SAGE":
out_dim = adjusted_hidden_dim
self.convs.append(SAGEConv(in_channels=in_dim, out_channels=out_dim))
elif layer_type == "GCN":
out_dim = adjusted_hidden_dim
self.convs.append(GCNConv(in_channels=in_dim, out_channels=out_dim))
# Create SAGPooling layers for attention scoring (not for actual pooling)
self.attention_pools = nn.ModuleList()
for i in range(n_layers - 1): # No pooling after last layer
# The attention pool receives features AFTER the conv layer has been applied
# For GAT with concat=True, output is hidden_dim (heads * out_dim)
# For other layers, output is adjusted_hidden_dim
pool_in_dim = adjusted_hidden_dim
self.attention_pools.append(SAGPooling(pool_in_dim, ratio=1.0))
def _get_specific_hyperparameters(self) -> dict[str, Any]:
return {
"top_k_ratio": self.top_k_ratio,
"heads": self.heads,
"layer_type": self.layer_type,
}
[docs] def _add_small_world_edges(
self,
x: torch.Tensor,
edge_index: torch.Tensor,
batch: torch.Tensor,
attention_pool: SAGPooling,
) -> torch.Tensor:
"""
Add edges between top-k nodes based on attention scores.
Args:
x: Node features
edge_index: Current edge index
batch: Batch assignment for nodes
attention_pool: SAGPooling layer to compute attention scores
Returns:
Updated edge index with additional small-world connections
"""
# Get attention scores from SAGPooling
score = attention_pool.gnn(x, edge_index).view(-1)
# Process each graph in the batch separately
new_edges: list[torch.Tensor] = []
for batch_id in cast(torch.Tensor, torch.unique(batch)): # type: ignore
mask = batch == batch_id
batch_scores = score[mask]
batch_indices = torch.where(mask)[0]
# Get top-k nodes
k = max(1, int(len(batch_scores) * self.top_k_ratio))
_, top_k_local_indices = torch.topk(batch_scores, k)
top_k_indices = batch_indices[top_k_local_indices]
# Create fully connected edges among top-k nodes
if len(top_k_indices) > 1:
# Create all pairs of connections
src = top_k_indices.repeat_interleave(len(top_k_indices))
dst = top_k_indices.repeat(len(top_k_indices))
# Remove self-loops
mask_no_self_loop = src != dst
src = src[mask_no_self_loop]
dst = dst[mask_no_self_loop]
# Add bidirectional edges
batch_new_edges = torch.stack(
[torch.cat([src, dst]), torch.cat([dst, src])], dim=0
)
new_edges.append(batch_new_edges)
if new_edges:
# Concatenate all new edges
all_new_edges = torch.cat(new_edges, dim=1)
# Combine with existing edges and remove duplicates
combined_edges = torch.cat([edge_index, all_new_edges], dim=1)
# Remove duplicate edges
combined_edges = cast(torch.Tensor, torch.unique(combined_edges, dim=1)) # type: ignore
return combined_edges
return edge_index
[docs] def forward(self, data: Data) -> Data:
x, edge_index = data.x, data.edge_index
if x is None or edge_index is None:
raise ValueError(
"SmallWorld GNN requires node features (data.x) and edge index (data.edge_index) to be present."
)
if not hasattr(data, "batch") or data.batch is None:
batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)
else:
batch = data.batch
for i, conv in enumerate(self.convs):
# Apply GNN layer
x = conv(x, edge_index)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Add small-world edges for next layer (except after last layer)
if i < len(self.convs) - 1:
edge_index = self._add_small_world_edges(
x,
edge_index,
batch,
self.attention_pools[i], # type: ignore
)
data.x = x
data.edge_index = edge_index
return data