cellmil.utils.train.metrics¶
Classes
|
Integrated Brier Score for survival analysis. |
|
Concordance Index (C-index) for survival analysis. |
- class cellmil.utils.train.metrics.ConcordanceIndex(compute_on_cpu: bool = False, **kwargs: Any)[source]¶
Bases:
MetricConcordance Index (C-index) for survival analysis.
The C-index measures the model’s ability to correctly order pairs of samples by their survival times. A C-index of 1.0 indicates perfect concordance, while 0.5 indicates random predictions.
- Parameters:
compute_on_cpu (bool) – Whether to compute on CPU. Default: False.
**kwargs – Additional keyword arguments passed to the parent Metric class.
- __init__(compute_on_cpu: bool = False, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- update(preds: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) None[source]¶
Update state with predictions and targets.
- Parameters:
preds (torch.Tensor) – Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score.
target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample
- class cellmil.utils.train.metrics.BrierScore(compute_on_cpu: bool = False, **kwargs: Any)[source]¶
Bases:
MetricIntegrated Brier Score for survival analysis.
The Brier score measures the accuracy of probabilistic predictions. For survival analysis, we compute the time-integrated Brier score. Lower values indicate better predictions (0 is perfect, 1 is worst).
Note: This is a simplified implementation that approximates the integrated Brier score by computing the mean squared error between predicted risk and actual outcomes at observed time points.
- Parameters:
compute_on_cpu (bool) – Whether to compute on CPU. Default: False.
**kwargs – Additional keyword arguments passed to the parent Metric class.
- __init__(compute_on_cpu: bool = False, **kwargs: Any)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- update(preds: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) None[source]¶
Update state with predictions and targets.
- Parameters:
preds (torch.Tensor) – Logits for hazard at each time bin. Shape: [batch_size, num_bins] for discrete-time hazard model or [batch_size] for single risk score.
target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample