cellmil.utils.train.losses

Classes

FocalLoss([alpha, gamma, label_smoothing])

NegativeLogLikelihoodSurvLoss([alpha, ...])

Negative Log-Likelihood Loss for Discrete-Time Survival Analysis using Logistic Hazard.

class cellmil.utils.train.losses.FocalLoss(alpha: Optional[Union[float, List[float], Tensor]] = None, gamma: float = 2.0, label_smoothing: float = 0.0)[source]

Bases: Module

__init__(alpha: Optional[Union[float, List[float], Tensor]] = None, gamma: float = 2.0, label_smoothing: float = 0.0)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(inputs: Tensor, targets: Tensor) Tensor[source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class cellmil.utils.train.losses.NegativeLogLikelihoodSurvLoss(alpha: Optional[float] = None, epsilon: float = 1e-08, reduction: Literal['sum', 'mean'] = 'sum')[source]

Bases: Module

Negative Log-Likelihood Loss for Discrete-Time Survival Analysis using Logistic Hazard.

This loss function is designed for survival analysis tasks where the model predicts the hazard probabilities for discrete time intervals. It computes the negative log-likelihood based on the predicted hazards and the observed survival times and event indicators.

Parameters:
  • alpha (Optional[float]) – Weighting factor for the loss. Default is None.

  • epsilon (float) – Small value to avoid log(0). Default is 1e-8.

  • reduction (str) – Specifies the reduction to apply to the output: ‘mean’ | ‘sum’. Default is ‘sum’.

__init__(alpha: Optional[float] = None, epsilon: float = 1e-08, reduction: Literal['sum', 'mean'] = 'sum')[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

__call__(inputs: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) Tensor[source]

Compute the Negative Log-Likelihood Loss for Discrete-Time Survival Analysis.

Parameters:
  • inputs (torch.Tensor) – Predicted hazard probabilities for each time bin. Shape: [batch_size, n_bins]

  • target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample (will be converted to tensors)

Returns:

Computed loss value.

Return type:

torch.Tensor