cellmil.utils.train.losses¶
Classes
|
|
|
Negative Log-Likelihood Loss for Discrete-Time Survival Analysis using Logistic Hazard. |
- class cellmil.utils.train.losses.FocalLoss(alpha: Optional[Union[float, List[float], Tensor]] = None, gamma: float = 2.0, label_smoothing: float = 0.0)[source]¶
Bases:
Module- __init__(alpha: Optional[Union[float, List[float], Tensor]] = None, gamma: float = 2.0, label_smoothing: float = 0.0)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(inputs: Tensor, targets: Tensor) Tensor[source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class cellmil.utils.train.losses.NegativeLogLikelihoodSurvLoss(alpha: Optional[float] = None, epsilon: float = 1e-08, reduction: Literal['sum', 'mean'] = 'sum')[source]¶
Bases:
ModuleNegative Log-Likelihood Loss for Discrete-Time Survival Analysis using Logistic Hazard.
This loss function is designed for survival analysis tasks where the model predicts the hazard probabilities for discrete time intervals. It computes the negative log-likelihood based on the predicted hazards and the observed survival times and event indicators.
- Parameters:
- __init__(alpha: Optional[float] = None, epsilon: float = 1e-08, reduction: Literal['sum', 'mean'] = 'sum')[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- __call__(inputs: Tensor, target: tuple[torch.Tensor, torch.Tensor] | tuple[int, int]) Tensor[source]¶
Compute the Negative Log-Likelihood Loss for Discrete-Time Survival Analysis.
- Parameters:
inputs (torch.Tensor) – Predicted hazard probabilities for each time bin. Shape: [batch_size, n_bins]
target – Either: - tuple[torch.Tensor, torch.Tensor]: (durations, events) where both are tensors of shape [batch_size] - tuple[int, int]: (duration, event) for single sample (will be converted to tensors)
- Returns:
Computed loss value.
- Return type: