Source code for cellmil.datamodels.transforms.time_discretizer

"""
Survival discretization transform for converting continuous survival times to discrete bins.
"""
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Tuple, Any, cast, Union, Optional
from .base_label_transform import FittableLabelTransform


[docs]class TimeDiscretizerTransform(FittableLabelTransform): """ Transform for discretizing continuous survival times into bins. This transform bins survival times using quantiles computed on uncensored patients only. It's designed for discrete-time survival analysis where we convert the regression problem into a classification problem. Args: n_bins (int): Number of bins to create (e.g., 4 for quartiles) eps (float): Small epsilon value to adjust bin boundaries """
[docs] def __init__(self, n_bins: int = 4, eps: float = 1e-8): super().__init__(name="survival_discretizer") self.n_bins = n_bins self.eps = eps self.bins: Optional[np.ndarray[Any, Any]] = None
[docs] def fit( self, labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any ) -> "TimeDiscretizerTransform": """ Fit the discretizer on survival data. Args: labels: Dictionary mapping slide IDs to (duration, event) tuples **kwargs: Additional keyword arguments (not used, for API compatibility) Returns: Self for method chaining """ # Extract uncensored patients for computing quantiles uncensored_durations: List[float] = [] all_durations: List[float] = [] for duration, event in labels.values(): # type: ignore all_durations.append(duration) # type: ignore if event == 1: # Uncensored (event occurred) uncensored_durations.append(duration) # type: ignore if len(uncensored_durations) < self.n_bins: raise ValueError( f"Not enough uncensored samples ({len(uncensored_durations)}) " f"to create {self.n_bins} bins" ) # Compute quantile bins on uncensored data _, bins = pd.qcut(uncensored_durations, q=self.n_bins, retbins=True, labels=False) # type: ignore # Adjust bin boundaries to ensure all data is included bins = np.array(bins) bins[-1] = max(all_durations) + self.eps bins[0] = min(all_durations) - self.eps self.bins = bins self.is_fitted = True return self
[docs] def _transform_labels_impl( self, labels: Dict[str, Union[int, Tuple[float, int]]] ) -> Dict[str, Union[int, Tuple[float, int]]]: """ Implementation of the label transform operation. Args: labels: Dictionary mapping slide IDs to (duration, event) tuples Returns: Dictionary mapping slide IDs to (bin_index, event) tuples """ if self.bins is None: raise RuntimeError("Transform must be fitted before use") discretized_labels: Dict[str, Tuple[int, int]] = {} for slide_id, label in labels.items(): # Extract duration and event from label if isinstance(label, tuple): duration, event = label else: raise ValueError(f"Expected tuple label for survival analysis, got {type(label)}") # Use pd.cut to assign to bins (right=False means left-inclusive) bin_index = cast(int, pd.cut( # type: ignore [duration], bins=self.bins, # type: ignore labels=False, right=False, include_lowest=True )[0]) # Convert to int (pd.cut may return float with NaN) if pd.isna(bin_index): # type: ignore # Handle edge case: assign to closest bin if duration <= self.bins[0]: bin_index = 0 elif duration >= self.bins[-1]: bin_index = self.n_bins - 1 else: # Find the appropriate bin bin_index = np.searchsorted(self.bins, duration, side='right') - 1 bin_index = max(0, min(bin_index, self.n_bins - 1)) discretized_labels[slide_id] = (int(bin_index), int(event)) return discretized_labels # type: ignore
[docs] def get_config(self) -> Dict[str, Any]: """Get configuration for saving.""" config: dict[str, Any] = { "name": self.name, "n_bins": self.n_bins, "eps": self.eps, } if self.bins is not None: config["bins"] = self.bins.tolist() return config
[docs] @classmethod def from_config(cls, config: dict[str, Any]): """Create instance from configuration.""" transform = cls( n_bins=config["n_bins"], eps=config["eps"] ) if "bins" in config: transform.bins = np.array(config["bins"]) transform.is_fitted = True return transform
[docs] def save(self, path: Path) -> None: """Save transform to disk.""" config = self.get_config() config["transform_class"] = self.__class__.__name__ with open(path, "w") as f: json.dump(config, f, indent=2)
[docs] @classmethod def load(cls, path: Path): """Load transform from disk.""" with open(path, "r") as f: config = json.load(f) config.pop("transform_class", None) return cls.from_config(config)