Source code for cellmil.utils.train.evals.utils

import lightning as Pl
from lightning import Trainer
import numpy as np
import torch
from torch.utils.data import DataLoader as DataLoaderTorch
from lightning.pytorch.callbacks import ModelCheckpoint
from sklearn.metrics import classification_report  # type: ignore
from torch_geometric.loader import DataLoader as DataLoaderPyG  # type: ignore
from cellmil.datamodels.datasets.cell_mil_dataset import CellMILDataset
from cellmil.utils.train.metrics import ConcordanceIndex
from cellmil.datamodels.datasets.cell_gnn_mil_dataset import (
    CellGNNMILDataset
)
from cellmil.datamodels.datasets.patch_gnn_mil_dataset import (
    PatchGNNMILDataset
)
from cellmil.datamodels.datasets.patch_mil_dataset import PatchMILDataset
from typing import Any, Union, cast
from cellmil.utils import logger

[docs]def is_survival_model(lit_model: Pl.LightningModule) -> bool: """Check if the model is a survival analysis model.""" # Check if it's a LitSurv or similar survival model model_class_name = lit_model.__class__.__name__ return "Surv" in model_class_name
[docs]def compute_slide_cell_counts( dataset: Union[ CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset ], slides: np.ndarray[Any, Any], ): """Estimate per-slide cell counts for balancing folds.""" counts: list[int] = [] missing: list[str] = [] feature_cache = getattr(dataset, "features", None) index_lookup = cast(dict[int, str] | None, getattr(dataset, "cell_indices", None)) slide_sequence = list(slides) for idx, slide_id in enumerate(slide_sequence): count: int | None = None if isinstance(feature_cache, dict) and slide_id in feature_cache: tensor = cast(torch.Tensor, feature_cache[slide_id]) count = int(tensor.shape[0]) if ( count is None and isinstance(index_lookup, dict) and slide_id in index_lookup ): count = int(len(index_lookup[slide_id])) if count is None: try: sample = dataset[idx] except Exception: sample = None if sample is not None: if isinstance(sample, tuple) and len(sample) > 0: feature_tensor = sample[0] feature_shape = getattr(feature_tensor, "shape", None) if feature_shape is not None and len(feature_shape) > 0: count = int(feature_shape[0]) else: data_obj = sample num_nodes = getattr(data_obj, "num_nodes", None) if isinstance(num_nodes, int): count = num_nodes else: node_features = getattr(data_obj, "x", None) node_shape = getattr(node_features, "shape", None) if node_shape is not None and len(node_shape) > 0: count = int(node_shape[0]) if count is None: counts.append(0) missing.append(str(slide_id)) else: counts.append(count) if missing: preview = ", ".join(missing[:5]) suffix = "..." if len(missing) > 5 else "" logger.warning( f"Unable to determine cell counts for {len(missing)} slides; assigned 0 for balancing: {preview}{suffix}" ) return np.asarray(counts, dtype=np.int64)
[docs]class Report: """Base class for evaluation reports."""
[docs] def __init__( self, trainer: Trainer, lit_model: Pl.LightningModule, ) -> None: self.trainer = trainer self.lit_model = lit_model self._load_best_model()
[docs] def _load_best_model( self ) -> None: """Load the best model checkpoint if available.""" try: # Try to find ModelCheckpoint callback and load best checkpoint for callback in self.trainer.callbacks: # type: ignore if isinstance(callback, ModelCheckpoint) and callback.best_model_path: # type: ignore logger.info(f"Loading best model from: {callback.best_model_path}") # type: ignore # Load the state dict from the best checkpoint into the current model # Set weights_only=False to handle optimizer state and other objects checkpoint = torch.load( callback.best_model_path, # type: ignore map_location=self.lit_model.device, # type: ignore weights_only=False, ) self.lit_model.load_state_dict(checkpoint["state_dict"]) self.lit_model.eval() break except Exception as e: logger.warning( f"Could not load best checkpoint: {e}. Using current model state." )
[docs] def generate(self, dataloader: Union[DataLoaderTorch[Any], DataLoaderPyG]) -> dict[str, Any]: """Generate the report.""" # Check if this is a survival model is_survival = is_survival_model(self.lit_model) if is_survival: # For survival models, compute C-index and Brier score return self._get_survival_report(dataloader) else: # For classification models, use the existing logic return self._get_classification_report(dataloader)
[docs] def _get_classification_report( self, dataloader: Union[DataLoaderTorch[Any], DataLoaderPyG], ) -> dict[str, Any]: """Get classification report with precision, recall, f1-score.""" y_pred = self.trainer.predict(self.lit_model, dataloader) if isinstance(dataloader, DataLoaderPyG): y_true = [data.y for data in dataloader] else: # Handle both 2-element (x, y) and 3-element (x, cell_types, y) batches # Extract the last element which should always be the label y_true = [batch[-1] for batch in dataloader] if y_pred is not None and isinstance(y_pred[0], torch.Tensor): y_pred_flat = [pred.cpu().numpy().flatten()[0] for pred in y_pred] # type: ignore else: y_pred_flat = [ # type: ignore pred.flatten()[0] if hasattr(pred, "flatten") else pred # type: ignore for pred in y_pred # type: ignore ] if y_true and isinstance(y_true[0], torch.Tensor): y_true_flat = [true.cpu().numpy().flatten()[0] for true in y_true] else: y_true_flat = [ true.flatten()[0] if hasattr(true, "flatten") else true for true in y_true ] report = cast( dict[str, Any], classification_report(y_true_flat, y_pred_flat, output_dict=True), ) # type: ignore return report
[docs] def _get_survival_report( self, dataloader: Union[DataLoaderTorch[Any], DataLoaderPyG], ) -> dict[str, Any]: """Get survival analysis report with C-index""" def _extract_survival_tensors(target: Any) -> tuple[torch.Tensor, torch.Tensor] | None: """Normalize different target formats to (duration, event) tensors.""" def _to_tensor(value: Any) -> torch.Tensor: tensor = torch.as_tensor(value) if tensor.ndim == 0: tensor = tensor.unsqueeze(0) return tensor if isinstance(target, dict): key_duration = next((k for k in target if k.lower() in {"duration", "durations", "time"}), None) # type: ignore key_event = next((k for k in target if k.lower() in {"event", "events", "status"}), None) # type: ignore if key_duration is not None and key_event is not None: return _to_tensor(target[key_duration]), _to_tensor(target[key_event]) if isinstance(target, (list, tuple)) and len(target) == 2: # type: ignore return _to_tensor(target[0]), _to_tensor(target[1]) if torch.is_tensor(target): tensor = target if tensor.ndim == 1 and tensor.numel() == 2: return tensor[0].view(1), tensor[1].view(1) if tensor.ndim >= 1 and tensor.shape[-1] == 2: durations = tensor[..., 0].reshape(-1) events = tensor[..., 1].reshape(-1) return durations, events return None def _append_target(target: Any, source: str) -> None: parsed = _extract_survival_tensors(target) if parsed is None: logger.warning("Unexpected %s format for survival data", source) return dur_tensor, evt_tensor = parsed durations_list.append(dur_tensor) events_list.append(evt_tensor) # Get predictions (logits from discrete-time hazard model) y_pred = self.trainer.predict(self.lit_model, dataloader) # Extract true survival data (durations, events) from dataloader durations_list: list[torch.Tensor] = [] events_list: list[torch.Tensor] = [] if isinstance(dataloader, DataLoaderPyG): for data in dataloader: _append_target(getattr(data, "y", None), "PyG batch.y") else: for batch in dataloader: # batch is (x, y) where y is (duration, event) y = batch[-1] _append_target(y, "batch label") # Check if we have any data if not durations_list or not events_list: raise ValueError("No survival data found in dataloader") # Convert predictions to tensor if y_pred is not None and len(y_pred) > 0: if isinstance(y_pred[0], torch.Tensor): # Predictions are logits with shape [1, num_bins] per sample logits = torch.cat([pred.cpu() for pred in y_pred], dim=0) # type: ignore [batch_size, num_bins] else: logger.error("Unexpected prediction format") logits = torch.zeros((len(durations_list), 1)) # type: ignore else: logger.error("No predictions returned") logits = torch.zeros((len(durations_list), 1)) # type: ignore # Convert durations and events to tensors durations = torch.cat([d.cpu().flatten() for d in durations_list]) # type: ignore events = torch.cat([e.cpu().flatten() for e in events_list]) # type: ignore # Initialize metrics c_index_metric = ConcordanceIndex() # Update metrics with logits (not hazards) c_index_metric.update(logits, (durations, events)) report: dict[str, float] = { "c_index": float(c_index_metric.compute()), "n_samples": len(durations), "n_events": int(events.sum()), } return report