cellmil.utils.train.metrics

Classes

BrierScore([compute_on_cpu])

Integrated Brier Score for survival analysis.

ConcordanceIndex([compute_on_cpu])

Concordance Index (C-index) for survival analysis.

class cellmil.utils.train.metrics.ConcordanceIndex(compute_on_cpu: bool = False, **kwargs: Any)[source]

Bases: Metric

Concordance Index (C-index) for survival analysis.

The C-index measures the model’s ability to correctly order pairs of samples by their survival times. A C-index of 1.0 indicates perfect concordance, while 0.5 indicates random predictions.

Parameters:
  • compute_on_cpu (bool) – Whether to compute on CPU. Default: False.

  • **kwargs – Additional keyword arguments passed to the parent Metric class.

is_differentiable: bool | None = False
higher_is_better: bool | None = True
full_state_update: bool | None = False
__init__(compute_on_cpu: bool = False, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

update(preds: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) None[source]

Update state with predictions and targets.

Parameters:
  • preds (torch.Tensor) – Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score.

  • target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample

compute() Tensor[source]

Compute the final C-index using scikit-survival.

class cellmil.utils.train.metrics.BrierScore(compute_on_cpu: bool = False, **kwargs: Any)[source]

Bases: Metric

Integrated Brier Score for survival analysis.

The Brier score measures the accuracy of probabilistic predictions. For survival analysis, we compute the time-integrated Brier score. Lower values indicate better predictions (0 is perfect, 1 is worst).

Note: This is a simplified implementation that approximates the integrated Brier score by computing the mean squared error between predicted risk and actual outcomes at observed time points.

Parameters:
  • compute_on_cpu (bool) – Whether to compute on CPU. Default: False.

  • **kwargs – Additional keyword arguments passed to the parent Metric class.

is_differentiable: bool | None = False
higher_is_better: bool | None = False
full_state_update: bool | None = False
__init__(compute_on_cpu: bool = False, **kwargs: Any)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

update(preds: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) None[source]

Update state with predictions and targets.

Parameters:
  • preds (torch.Tensor) – Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score.

  • target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample

compute() Tensor[source]

Compute the final Brier score using scikit-survival.