from pydantic import BaseModel, Field, field_validator
from enum import Enum
from pathlib import Path
from typing import Optional
[docs]class Aggregation(str, Enum):
"""Attention aggregation methods for GraphMIL."""
pooling_only = "pooling_only" # Only pooling attention
gnn_only = "gnn_only" # Only GNN attention
gnn_layer = "gnn_layer" # Specific GNN layer
random_walk = "random_walk" # a = p.T @ P_(l) @ p_(l-1) @ ... @ p_(1)
[docs]class Normalization(str, Enum):
"""Attention normalization methods."""
min_max = "min_max" # Min-Max normalization to [0, 1]
z_score = "z_score" # Z-score standardization (mean=0, std=1)
robust = "robust" # Robust scaling using median and IQR
softmax = "softmax" # Softmax normalization (sum to 1)
sigmoid = "sigmoid" # Sigmoid normalization to [0, 1]
none = "none" # No normalization
[docs]class VisualizationMode(str, Enum):
"""Visualization output modes."""
geojson = "geojson" # GeoJSON for pathology viewers (QuPath, etc.)
graph = "graph" # Interactive graph with attention edges
all = "all" # Generate all visualization types
[docs]class AttentionExplainerConfig(BaseModel):
"""Configuration for Attention explainability method."""
output_path: Path = Field(
..., description="Path where the explanations will be saved"
)
# Attention-specific configurations
attention_aggregation: Aggregation = Field(
default=Aggregation.pooling_only,
description="How to aggregate attention for GraphMIL models",
)
gnn_layer_index: Optional[int] = Field(
default=None,
description="Specific GNN layer index when using gnn_layer aggregation",
)
class_index: Optional[int] = Field(
default=None,
description="Specific class index for multi-class attention (None for all classes)",
)
attention_head: Optional[int] = Field(
default=None,
description="Specific attention head for multi-head attention (None for mean)",
)
# Visualization configurations
visualization_mode: VisualizationMode = Field(
default=VisualizationMode.geojson,
description="Type of visualization to generate",
)
color_scheme: tuple[list[int], list[int]] = Field(
default=([35, 92, 236], [255, 0, 0]), # Blue to red
description="Color gradient for attention visualization (start_rgb, end_rgb)",
)
# Advanced options
normalize_attention: bool = Field(
default=True, description="Whether to normalize attention weights"
)
normalization: Normalization = Field(
default=Normalization.min_max,
description="Type of normalization to apply to attention weights",
)
visualize_cell_types: bool = Field(
default=False,
description="Whether to create cell type visualization (graph colored by cell types)",
)
[docs] @field_validator("gnn_layer_index")
def validate_gnn_layer_index(cls, v: Optional[int]) -> Optional[int]:
if v is not None and v < 0:
raise ValueError("GNN layer index must be non-negative")
return v