Source code for cellmil.interfaces.SHAPExplainerConfig

from pydantic import BaseModel, Field, field_validator
from pathlib import Path
from typing import Optional
from enum import Enum


[docs]class SHAPExplainerType(str, Enum): """SHAP explainer types.""" gradient = "gradient" # GradientExplainer (fast, requires gradients) deep = "deep" # DeepExplainer (for deep learning models) kernel = "kernel" # KernelExplainer (model-agnostic, slow)
[docs] @classmethod def values(cls): return [member.value for member in cls]
def __str__(self): return self.value
[docs]class SHAPExplainerConfig(BaseModel): """Configuration for SHAP explainability method.""" output_path: Path = Field( ..., description="Path where the SHAP explanations will be saved" ) # Sampling configuration num_bins: int = Field( default=5, description="Number of quantile bins for stratified sampling based on attention scores", ) samples_per_bin: int = Field( default=10000, description="Number of cells to sample from each attention quantile bin", ) max_total_samples: Optional[int] = Field( default=None, description="Maximum total number of samples to use (overrides samples_per_bin if specified)", ) # SHAP computation configuration explainer_type: SHAPExplainerType = Field( default=SHAPExplainerType.gradient, description="Type of SHAP explainer to use (gradient is fastest for neural networks)", ) background_percentage: float = Field( default=0.2, description="Percentage of sampled cells to use as background data (0.0 to 1.0)", ) nsamples: int = Field( default=500, description="Number of coalitions for SHAP kernel explainer per cell (only for kernel explainer)", ) explain_top_cells: Optional[int] = Field( default=None, description="Number of top cells to explain (None = explain all sampled cells)", ) explain_per_head: bool = Field( default=True, description="Whether to compute SHAP values for each attention head separately", ) explain_mean_head: bool = Field( default=True, description="Whether to also compute SHAP values for mean attention across heads", ) # Analysis configuration top_features: int = Field( default=20, description="Number of top features to highlight in summary visualizations", ) # Advanced options random_seed: int = Field( default=42, description="Random seed for reproducible sampling" ) batch_size: int = Field( default=1024, description="Batch size for computing attention weights on the full dataset", ) save_raw_shap_values: bool = Field( default=True, description="Whether to save raw SHAP values (can be large for many features)", ) create_summary_plots: bool = Field( default=True, description="Whether to create summary plots (bar plot, beeswarm plot, etc.)", )
[docs] @field_validator("num_bins") def validate_num_bins(cls, v: int) -> int: if v < 2: raise ValueError("num_bins must be at least 2") if v > 20: raise ValueError( "num_bins should not exceed 20 for meaningful stratification" ) return v
[docs] @field_validator("samples_per_bin") def validate_samples_per_bin(cls, v: int) -> int: if v < 10: raise ValueError("samples_per_bin must be at least 10") return v
[docs] @field_validator("background_percentage") def validate_background_percentage(cls, v: float) -> float: if not 0.0 < v <= 1.0: raise ValueError("background_percentage must be between 0.0 and 1.0") return v
[docs] @field_validator("nsamples") def validate_nsamples(cls, v: int) -> int: if v < 10: raise ValueError("nsamples must be at least 10 for reliable SHAP values") return v