Source code for cellmil.utils.train.metrics

import torch
import numpy as np
from torchmetrics import Metric
from typing import Any, cast
from sksurv.metrics import concordance_index_censored, integrated_brier_score  # type: ignore

[docs]class ConcordanceIndex(Metric): """ Concordance Index (C-index) for survival analysis. The C-index measures the model's ability to correctly order pairs of samples by their survival times. A C-index of 1.0 indicates perfect concordance, while 0.5 indicates random predictions. Args: compute_on_cpu (bool): Whether to compute on CPU. Default: False. **kwargs: Additional keyword arguments passed to the parent Metric class. """ is_differentiable: bool | None = False higher_is_better: bool | None = True full_state_update: bool | None = False
[docs] def __init__(self, compute_on_cpu: bool = False, **kwargs: Any): super().__init__(**kwargs) self.compute_on_cpu = compute_on_cpu # Store predictions and targets for batch computation self.add_state("preds_list", default=[], dist_reduce_fx="cat") # type: ignore self.add_state("durations_list", default=[], dist_reduce_fx="cat") # type: ignore self.add_state("events_list", default=[], dist_reduce_fx="cat") # type: ignore
[docs] def update( self, preds: torch.Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int] ) -> None: """ Update state with predictions and targets. Args: preds (torch.Tensor): Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score. target: Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample """ # Handle both tensor and int inputs if isinstance(target[0], int): durations = torch.tensor([target[0]], dtype=torch.float32) events = torch.tensor([target[1]], dtype=torch.float32) else: durations = target[0] events = target[1] assert isinstance(durations, torch.Tensor) assert isinstance(events, torch.Tensor) # Convert logits to risk scores that capture expected event timing if preds.dim() == 2: hazards = torch.sigmoid(preds) batch_size, num_bins = hazards.shape device = hazards.device dtype = hazards.dtype # Survival probability at the start of each bin survival_prefix = torch.cumprod( torch.cat( [torch.ones(batch_size, 1, device=device, dtype=dtype), 1.0 - hazards], dim=1, ), dim=1, ) survival_prev = survival_prefix[:, :-1] # Event probability mass per bin event_probs = survival_prev * hazards time_positions = torch.arange(num_bins, device=device, dtype=dtype) expected_time = torch.sum(event_probs * time_positions, dim=1) tail_mass = torch.clamp(1.0 - event_probs.sum(dim=1), min=0.0) expected_time = expected_time + tail_mass * num_bins # Store expected event time; conversion to risk happens in compute() risk_scores = expected_time elif preds.dim() == 1: # Single score per sample. # NOTE: This is treated as a survival score (higher = longer survival) # because it is negated in compute(). If passing risk (hazard), negate it first. risk_scores = preds else: raise ValueError(f"Unexpected prediction shape: {preds.shape}") # Convert to CPU if requested if self.compute_on_cpu: risk_scores = risk_scores.cpu() durations = durations.cpu() events = events.cpu() # Convert boolean events to float if events.dtype == torch.bool: events = events.float() # Store for later computation self.preds_list.append(risk_scores) # type: ignore self.durations_list.append(durations) # type: ignore self.events_list.append(events) # type: ignore
[docs] def compute(self) -> torch.Tensor: """Compute the final C-index using scikit-survival.""" if len(self.preds_list) == 0: # type: ignore return torch.tensor(0.5) # Return 0.5 if no samples # Concatenate all stored values all_preds = torch.cat(self.preds_list).detach().cpu().numpy() # type: ignore all_durations = torch.cat(self.durations_list).detach().cpu().numpy() # type: ignore all_events = torch.cat(self.events_list).detach().cpu().numpy().astype(bool) # type: ignore # Use scikit-survival's concordance_index_censored # Note: Higher risk scores should predict shorter survival times try: result = cast(tuple[float | int, ...], concordance_index_censored( all_events, all_durations, -all_preds # Negate because higher risk = shorter survival )) c_index = result[0] # Returns (c_index, concordant, discordant, tied_risk, tied_time) return torch.tensor(c_index, dtype=torch.float32) except Exception: # Fallback if computation fails return torch.tensor(0.5, dtype=torch.float32)
[docs]class BrierScore(Metric): """ Integrated Brier Score for survival analysis. The Brier score measures the accuracy of probabilistic predictions. For survival analysis, we compute the time-integrated Brier score. Lower values indicate better predictions (0 is perfect, 1 is worst). Note: This is a simplified implementation that approximates the integrated Brier score by computing the mean squared error between predicted risk and actual outcomes at observed time points. Args: compute_on_cpu (bool): Whether to compute on CPU. Default: False. **kwargs: Additional keyword arguments passed to the parent Metric class. """ is_differentiable: bool | None = False higher_is_better: bool | None = False full_state_update: bool | None = False
[docs] def __init__(self, compute_on_cpu: bool = False, **kwargs: Any): super().__init__(**kwargs) self.compute_on_cpu = compute_on_cpu # Aggregate Brier statistics instead of storing entire tensors self.add_state("brier_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("weight_sum", default=torch.tensor(0.0), dist_reduce_fx="sum")
[docs] def update( self, preds: torch.Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int] ) -> None: """ Update state with predictions and targets. Args: preds (torch.Tensor): Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score. target: Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample """ # Handle both tensor and int inputs if isinstance(target[0], int): durations = torch.tensor([target[0]], dtype=torch.float32) events = torch.tensor([target[1]], dtype=torch.float32) else: durations = target[0] events = target[1] assert isinstance(durations, torch.Tensor) assert isinstance(events, torch.Tensor) if preds.dim() != 2: raise ValueError(f"Unexpected prediction shape: {preds.shape}") hazards = torch.sigmoid(preds) batch_size, num_bins = hazards.shape device = hazards.device dtype = hazards.dtype survival_full = torch.cumprod( torch.cat( [torch.ones(batch_size, 1, device=device, dtype=dtype), 1.0 - hazards], dim=1, ), dim=1, ) survival_after_bin = survival_full[:, 1:] cumulative_event = 1.0 - survival_after_bin time_positions = torch.arange(num_bins, device=device, dtype=dtype).unsqueeze(0) durations = durations.to(device=device, dtype=time_positions.dtype) events = events.to(device=device, dtype=cumulative_event.dtype) durations = torch.clamp(durations, min=0, max=num_bins - 1) duration_expanded = durations.unsqueeze(1) event_expanded = events.unsqueeze(1) # Event indicator becomes 1 once the event has occurred target_event = (time_positions >= duration_expanded).float() * event_expanded # Only evaluate bins up to the observed time for censored samples # For event samples, we know the status for all future bins (dead) valid_mask = (event_expanded.bool() | (time_positions <= duration_expanded)).float() squared_error = (cumulative_event - target_event) ** 2 * valid_mask batch_error = squared_error.sum() batch_weight = valid_mask.sum().clamp_min(1.0) if self.compute_on_cpu: batch_error = batch_error.cpu() batch_weight = batch_weight.cpu() self.brier_sum += batch_error self.weight_sum += batch_weight
[docs] def compute(self) -> torch.Tensor: """Compute the final Brier score using scikit-survival.""" if self.weight_sum <= 0: return torch.tensor(0.25, dtype=torch.float32) return (self.brier_sum / self.weight_sum).to(dtype=torch.float32)