Explainability¶
This package provides two complementary explainability methods for attention-based Multiple Instance Learning (MIL) models: Attention Heatmaps for local explanations and SHAP Analysis for global feature importance.
Overview¶
Attention Heatmaps¶
Attention heatmaps provide local, instance-level explanations by visualizing which cells (instances) the model focuses on when making a prediction for a specific slide. This shows where the model is looking.
Supported Models:
CLAM (Clustering-constrained Attention MIL)
AttentionDeepMIL (ABMIL)
HEAD4TYPE (Cell type-aware attention)
GraphMIL models with attention pooling
SHAP Analysis¶
SHAP (SHapley Additive exPlanations) values provide global feature importance explanations across the entire dataset by identifying which cell features (morphological, textural, topological) drive the attention mechanism. This reveals what features the model considers important.
Key Insight: SHAP analyzes the attention module itself, explaining which features lead to high or low attention scores rather than the final prediction.
Attention Heatmaps¶
Attention heatmaps visualize attention weights spatially on whole slide images, highlighting regions the model deems important.
Configuration¶
Create an AttentionExplainerConfig to customize the visualization:
from cellmil.interfaces.AttentionExplainerConfig import (
AttentionExplainerConfig,
VisualizationMode,
Normalization
)
from pathlib import Path
config = AttentionExplainerConfig(
output_path=Path("./explanations/attention"),
visualization_mode=VisualizationMode.graph,
normalization=Normalization.min_max,
normalize_attention=True,
visualize_cell_types=False,
)
Configuration Parameters:
output_path (
Path, required): Directory where visualizations will be savedvisualization_mode (
VisualizationMode, default:geojson): Output formatgeojson: Compatible with pathology viewers (QuPath, etc.)graph: Interactive graph visualization with attention edgesall: Generate both formats
normalization (
Normalization, default:min_max): Attention weight normalizationmin_max: Scale to [0, 1]z_score: Standardize (mean=0, std=1)robust: Median and IQR scalingsoftmax: Normalize to sum to 1sigmoid: Sigmoid transformationnone: No normalization
normalize_attention (
bool, default:True): Whether to apply normalizationcolor_scheme (
tuple[list[int], list[int]]): RGB color gradient (start, end)visualize_cell_types (
bool, default:False): Create cell type visualizationsattention_aggregation (
Aggregation): For GraphMIL models onlypooling_only: Only pooling layer attentiongnn_only: Only GNN layer attentiongnn_layer: Specific GNN layer (setgnn_layer_index)random_walk: Aggregate via random walk
class_index (
Optional[int]): For multi-class models, specify class to explainattention_head (
Optional[int]): For multi-head attention, specify head
Usage¶
Basic Example:
from cellmil.explainability.attention import AttentionExplainer
from cellmil.datamodels.model import ModelStorage
from pathlib import Path
# Load trained model
model_storage = ModelStorage.from_directory(
"./results/TASK+FEATURES+MODEL+REG+STRA"
)
# Initialize explainer
explainer = AttentionExplainer(config)
# Generate explanation for a specific slide
results = explainer.generate_explanation(
model_storage=model_storage,
slide_path=Path("./dataset/SLIDE_1"),
)
Using a Specific Fold:
# Use fold 2 instead of the final model
results = explainer.generate_explanation(
model_storage=model_storage,
slide_path=Path("./dataset/SLIDE_1"),
fold_idx=2, # Use specific fold
)
GraphMIL Models:
# For GraphMIL models, aggregate GNN attention
config = AttentionExplainerConfig(
output_path=Path("./explanations/graphmil"),
visualization_mode=VisualizationMode.graph,
attention_aggregation=Aggregation.gnn_only,
)
Model Storage¶
The ModelStorage class manages trained models from k-fold cross-validation experiments. It automatically loads:
Model checkpoint (weights)
Feature transforms (normalization, scaling)
Dataset configuration (extractors, segmentation model, graph creator)
Model configuration (architecture, hyperparameters)
Loading from Directory:
from cellmil.datamodels.model import ModelStorage
# Load from k-fold results directory
model_storage = ModelStorage.from_directory(
"./results/TASK+FEATURES+MODEL+REG+STRA"
)
# Check available folds
available_folds = model_storage.list_folds() # [0, 1, 2, 3, 4]
# Check if final model exists
has_final = model_storage.has_final_model() # True/False
Model directories are typically created by training scripts and follow this structure:
results/
└── TASK+FEATURES+MODEL+REG+STRA/
├── experiment_metadata.json
├── fold_0/
│ ├── best_model.ckpt
│ └── transforms/
├── fold_1/
│ ├── best_model.ckpt
│ └── transforms/
└── final_model/
├── best_model.ckpt
└── transforms/
Output¶
The explainer generates several output files:
Visualizations:
GeoJSON (
*.geojson): Cell polygons colored by attention weightImport into QuPath or other pathology viewers
Each cell has
attention_weightproperty
Graph (
*_graph.html): Interactive Plotly visualizationNodes colored by attention weight
Edge opacity reflects spatial relationships
Hover to see cell details
Metadata:
attention_weights.json: Raw attention weights for each cellattention_metadata.json: Statistics (mean, std, entropy, sparsity)
SHAP Analysis¶
SHAP analysis identifies which cell features drive the attention mechanism globally across the entire dataset.
Methodology¶
Cell-Level Dataset Creation: Extract features from all cells across all slides in the dataset
Attention Computation: Compute attention weights for every cell using the trained model
Stratified Sampling: Sample cells from different attention quantile bins to ensure diverse representation
SHAP Computation: Use SHAP to explain which features contribute to high/low attention scores
Visualization: Generate feature importance plots and summary statistics
Configuration¶
from cellmil.interfaces.SHAPExplainerConfig import (
SHAPExplainerConfig,
SHAPExplainerType
)
from pathlib import Path
config = SHAPExplainerConfig(
output_path=Path("./explanations/shap"),
# Sampling parameters
num_bins=5,
samples_per_bin=10000,
# SHAP computation
explainer_type=SHAPExplainerType.gradient,
background_percentage=0.2,
nsamples=500, # Only for kernel explainer
# Analysis options
explain_per_head=True,
explain_mean_head=True,
top_features=20,
# Output options
save_raw_shap_values=True,
create_summary_plots=True,
)
Configuration Parameters:
Sampling Configuration:
num_bins (
int, default: 5): Number of attention quantile bins (e.g., 5 = quantiles)samples_per_bin (
int, default: 10000): Cells to sample per binmax_total_samples (
Optional[int]): Cap total samples (overridessamples_per_bin)
SHAP Computation:
explainer_type (
SHAPExplainerType, default:gradient):gradient: Fast, uses gradients (recommended for neural networks)deep: DeepExplainer for deep learningkernel: Model-agnostic but very slow
background_percentage (
float, default: 0.2): Fraction of sampled cells to use as backgroundnsamples (
int, default: 500): Coalitions per cell for kernel explainer onlyexplain_top_cells (
Optional[int]): Only explain top N cells (None = all sampled)explain_per_head (
bool, default:True): Separate SHAP for each attention headexplain_mean_head (
bool, default:True): SHAP for mean attention across heads
Analysis Options:
top_features (
int, default: 20): Number of top features to highlightrandom_seed (
int, default: 42): For reproducibilitybatch_size (
int, default: 1024): Batch size for attention computation
Output Options:
save_raw_shap_values (
bool, default:True): Save raw SHAP values (can be large)create_summary_plots (
bool, default:True): Generate visualizations
Usage¶
Basic Example:
from cellmil.explainability.shap import SHAPExplainer
from cellmil.datamodels.model import ModelStorage
import pandas as pd
from pathlib import Path
# Load trained model
model_storage = ModelStorage.from_directory(
"./results/TASK+FEATURES+MODEL+REG+STRA"
)
# Load metadata (must have 'FULL_PATH' column with slide names)
metadata = pd.read_excel("./data/metadata.xlsx")
# Initialize explainer
config = SHAPExplainerConfig(output_path=Path("./explanations/shap"))
explainer = SHAPExplainer(config)
# Generate explanation
results = explainer.generate_explanation(
model_storage=model_storage,
dataset_folder=Path("./dataset"),
data=metadata,
)
Using Specific Fold:
# Explain fold 3
results = explainer.generate_explanation(
model_storage=model_storage,
dataset_folder=Path("./dataset"),
data=metadata,
fold_idx=3,
)
Custom Sampling:
# More aggressive sampling for larger dataset
config = SHAPExplainerConfig(
output_path=Path("./explanations/shap_large"),
num_bins=10, # Decile bins
samples_per_bin=5000,
max_total_samples=40000, # Cap at 40k total
top_features=30,
)
Kernel Explainer (Slower but Model-Agnostic):
config = SHAPExplainerConfig(
output_path=Path("./explanations/shap_kernel"),
explainer_type=SHAPExplainerType.kernel,
nsamples=1000, # More samples for better approximation
samples_per_bin=5000, # Fewer cells due to computation cost
)
Output¶
The SHAP explainer generates comprehensive analysis files:
Visualizations (per head):
feature_importance_bar.png: Top features ranked by importancebeeswarm_plot.png: SHAP values distribution (feature impact on attention)feature_distribution_top20.html: Interactive Plotly distribution plots
Interpretation¶
Feature Importance Bar Chart:
Shows the top N features ranked by mean absolute SHAP value. Higher values indicate features that strongly influence attention weights (either positively or negatively).
Beeswarm Plot:
Each point represents a cell, with:
X-axis: SHAP value (impact on attention)
Positive: Increases attention
Negative: Decreases attention
Color: Feature value (red = high, blue = low)
Y-axis: Features ranked by importance