cellmil.datamodels.transforms

Feature and label transforms for preprocessing.

class cellmil.datamodels.transforms.Transform(name: str)[source]

Bases: ABC

Base class for all feature transforms.

__init__(name: str)[source]

Initialize the transform.

Parameters:

name – Name of the transform for identification

abstract transform(features: Tensor) Tensor[source]

Apply the transform to features.

Parameters:

features – Input features tensor of shape (n_instances, n_features)

Returns:

Transformed features tensor

abstract get_config() Dict[str, Any][source]

Get the configuration dictionary for this transform.

Returns:

Dictionary containing transform configuration

abstract classmethod from_config(config: Dict[str, Any]) Transform[source]

Create transform instance from configuration dictionary.

Parameters:

config – Configuration dictionary

Returns:

Transform instance

save(path: Path) None[source]

Save the transform to disk.

Parameters:

path – Path to save the transform

classmethod load(path: Path) Transform[source]

Load transform from disk.

Parameters:

path – Path to load the transform from

Returns:

Transform instance

class cellmil.datamodels.transforms.FittableTransform(name: str)[source]

Bases: Transform

Base class for transforms that need to be fitted on training data.

__init__(name: str)[source]

Initialize the fittable transform.

Parameters:

name – Name of the transform for identification

abstract fit(features: Tensor, feature_names: Optional[List[str]] = None) FittableTransform[source]

Fit the transform on training data.

Parameters:
  • features – Training features tensor of shape (n_instances, n_features)

  • feature_names – Optional list of feature names for logging

Returns:

Self for method chaining

fit_transform(features: Tensor, feature_names: Optional[List[str]] = None) Tensor[source]

Fit the transform and apply it to the features.

Parameters:
  • features – Features to fit and transform

  • feature_names – Optional list of feature names for logging

Returns:

Transformed features

transform(features: Tensor) Tensor[source]

Apply the transform to features.

Parameters:

features – Input features tensor

Returns:

Transformed features tensor

Raises:

RuntimeError – If transform hasn’t been fitted yet

abstract _transform_impl(features: Tensor) Tensor[source]

Implementation of the transform operation.

Parameters:

features – Input features tensor

Returns:

Transformed features tensor

class cellmil.datamodels.transforms.TransformPipeline(transforms: List[Transform])[source]

Bases: object

Pipeline for chaining multiple feature transforms.

Supports both fittable and non-fittable transforms, automatic fitting on training data, and serialization for reuse on new samples.

__init__(transforms: List[Transform])[source]

Initialize the transform pipeline.

Parameters:

transforms – List of transforms to apply in order

fit(features: Tensor, feature_names: Optional[List[str]] = None) TransformPipeline[source]

Fit all fittable transforms in the pipeline on training data.

Parameters:
  • features – Training features tensor of shape (n_instances, n_features)

  • feature_names – Optional list of feature names for logging

Returns:

Self for method chaining

transform(features: Tensor) Tensor[source]

Apply all transforms in the pipeline to features.

Parameters:

features – Input features tensor

Returns:

Transformed features tensor

Raises:

RuntimeError – If pipeline contains unfitted transforms

fit_transform(features: Tensor, feature_names: Optional[List[str]] = None) Tensor[source]

Fit the pipeline and apply it to the features.

Parameters:
  • features – Features to fit and transform

  • feature_names – Optional list of feature names for logging

Returns:

Transformed features

save(path: Union[str, Path]) None[source]

Save the entire pipeline to disk.

Parameters:

path – Path to save the pipeline (should be a directory)

classmethod load(path: Union[str, Path]) TransformPipeline[source]

Load a transform pipeline from disk.

Parameters:

path – Path to load the pipeline from (should be a directory)

Returns:

Loaded transform pipeline

get_config() Dict[str, Any][source]

Get the configuration dictionary for the entire pipeline.

Returns:

Dictionary containing pipeline and all transform configurations

get_config_for_hashing() Dict[str, Any][source]

Get stable configuration for hashing that excludes fitted state.

Returns:

Dictionary containing pipeline configuration without fitted state

__len__() int[source]

Return the number of transforms in the pipeline.

__getitem__(index: int) Transform[source]

Get a transform by index.

__iter__()[source]

Iterate over transforms.

add_transform(transform: Transform) None[source]

Add a transform to the end of the pipeline.

Parameters:

transform – Transform to add

Note

If the pipeline is already fitted, you’ll need to refit it.

remove_transform(index: int) Transform[source]

Remove and return a transform by index.

Parameters:

index – Index of transform to remove

Returns:

Removed transform

Note

If the pipeline is already fitted, you’ll need to refit it.

class cellmil.datamodels.transforms.CorrelationFilterTransform(correlation_threshold: float = 0.9, plot_correlation_matrix: bool = False, constant_threshold: float = 1e-08)[source]

Bases: FittableTransform

Transform that removes highly correlated features based on a correlation threshold.

Features with correlation above the threshold will have one feature removed. Also removes constant features (features with very low standard deviation).

__init__(correlation_threshold: float = 0.9, plot_correlation_matrix: bool = False, constant_threshold: float = 1e-08)[source]

Initialize the correlation filter transform.

Parameters:
  • correlation_threshold – Correlation threshold above which features will be removed

  • plot_correlation_matrix – Whether to plot the correlation matrix during fitting

  • constant_threshold – Threshold below which features are considered constant

fit(features: Tensor, feature_names: Optional[List[str]] = None) CorrelationFilterTransform[source]

Fit the correlation filter on training data.

Parameters:
  • features – Training features tensor of shape (n_instances, n_features)

  • feature_names – Optional list of feature names for logging

Returns:

Self for method chaining

_transform_impl(features: Tensor) Tensor[source]

Apply the correlation filter to features.

Parameters:

features – Input features tensor

Returns:

Filtered features tensor

_compute_correlation_matrix(features: Tensor) Tensor[source]

Compute correlation matrix efficiently.

Parameters:

features – Feature tensor

Returns:

Correlation matrix

_find_features_to_remove(corr_matrix: Tensor) set[int][source]

Find features to remove based on correlation threshold using iterative approach.

Parameters:

corr_matrix – Correlation matrix

Returns:

Set of feature indices to remove

_plot_correlation_matrix(corr_matrix: Tensor) None[source]

Plot the correlation matrix.

Parameters:

corr_matrix – Correlation matrix to plot

get_config() Dict[str, Any][source]

Get the configuration dictionary for this transform.

classmethod from_config(config: Dict[str, Any]) CorrelationFilterTransform[source]

Create transform instance from configuration dictionary.

get_feature_importance_mask() Optional[Tensor][source]

Get a boolean mask indicating which original features are kept.

Returns:

Boolean tensor of shape (n_original_features,) where True indicates the feature is kept after correlation filtering.

get_removed_feature_indices() Optional[List[int]][source]

Get the indices of features that were removed.

Returns:

List of feature indices that were removed, or None if not fitted.

class cellmil.datamodels.transforms.RobustScalerTransform(apply_log_transform: bool = True, quantile_range: Tuple[float, float] = (0.25, 0.75), clip_quantiles: Tuple[float, float] = (0.005, 0.995), constant_threshold: float = 1e-08)[source]

Bases: FittableTransform

Transform that applies robust scaling to features using median and IQR.

Robust scaling is less sensitive to outliers than standard scaling. Formula: (x - median) / IQR, where IQR = Q3 - Q1

Features are first log-transformed to handle skewed distributions and outliers, then robust scaling is applied.

__init__(apply_log_transform: bool = True, quantile_range: Tuple[float, float] = (0.25, 0.75), clip_quantiles: Tuple[float, float] = (0.005, 0.995), constant_threshold: float = 1e-08)[source]

Initialize the robust scaler transform.

Parameters:
  • apply_log_transform – Whether to apply log transformation before scaling

  • quantile_range – Tuple of (lower_quantile, upper_quantile) for IQR computation

  • clip_quantiles – Tuple of (lower_clip, upper_clip) for outlier clipping

  • constant_threshold – Threshold below which IQR is considered too small

fit(features: Tensor, feature_names: Optional[List[str]] = None) RobustScalerTransform[source]

Fit the robust scaler on training data.

Parameters:
  • features – Training features tensor of shape (n_instances, n_features)

  • feature_names – Optional list of feature names for logging

Returns:

Self for method chaining

_transform_impl(features: Tensor) Tensor[source]

Apply the robust scaling to features.

Parameters:

features – Input features tensor

Returns:

Normalized features tensor

get_config() Dict[str, Any][source]

Get the configuration dictionary for this transform.

classmethod from_config(config: Dict[str, Any]) RobustScalerTransform[source]

Create transform instance from configuration dictionary.

get_scaling_parameters() Optional[Tuple[Tensor, Tensor]][source]

Get the scaling parameters (median, IQR) computed from training data.

Returns:

Tuple of (median_values, iqr_values) if fitted, None otherwise

get_constant_features_mask() Optional[Tensor][source]

Get a boolean mask indicating which features were considered constant.

Returns:

Boolean tensor of shape (n_features,) where True indicates the feature was considered constant during fitting.

class cellmil.datamodels.transforms.LabelTransform(name: str)[source]

Bases: ABC

Base class for all label transforms.

__init__(name: str)[source]

Initialize the transform.

Parameters:

name – Name of the transform for identification

abstract transform_labels(labels: Dict[str, Union[int, Tuple[float, int]]]) Dict[str, Union[int, Tuple[float, int]]][source]

Apply the transform to labels.

Parameters:

labels – Dictionary mapping slide IDs to labels For classification: int labels For survival: (duration, event) tuples

Returns:

Transformed labels dictionary

abstract get_config() Dict[str, Any][source]

Get the configuration dictionary for this transform.

Returns:

Dictionary containing transform configuration

abstract classmethod from_config(config: Dict[str, Any]) LabelTransform[source]

Create transform instance from configuration dictionary.

Parameters:

config – Configuration dictionary

Returns:

LabelTransform instance

save(path: Path) None[source]

Save the transform to disk.

Parameters:

path – Path to save the transform

classmethod load(path: Path) LabelTransform[source]

Load transform from disk.

Parameters:

path – Path to load the transform from

Returns:

LabelTransform instance

class cellmil.datamodels.transforms.FittableLabelTransform(name: str)[source]

Bases: LabelTransform

Base class for label transforms that need to be fitted on training data.

__init__(name: str)[source]

Initialize the fittable label transform.

Parameters:

name – Name of the transform for identification

abstract fit(labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any) FittableLabelTransform[source]

Fit the transform on training labels.

Parameters:
  • labels – Training labels dictionary mapping slide IDs to labels

  • **kwargs – Additional keyword arguments for fitting

Returns:

Self for method chaining

fit_transform(labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any) Dict[str, Union[int, Tuple[float, int]]][source]

Fit the transform and apply it to the labels.

Parameters:
  • labels – Labels to fit and transform

  • **kwargs – Additional keyword arguments for fitting

Returns:

Transformed labels

transform_labels(labels: Dict[str, Union[int, Tuple[float, int]]]) Dict[str, Union[int, Tuple[float, int]]][source]

Apply the transform to labels.

Parameters:

labels – Labels dictionary to transform

Returns:

Transformed labels dictionary

Raises:

RuntimeError – If transform hasn’t been fitted yet

abstract _transform_labels_impl(labels: Dict[str, Union[int, Tuple[float, int]]]) Dict[str, Union[int, Tuple[float, int]]][source]

Implementation of the label transform operation.

Parameters:

labels – Input labels dictionary

Returns:

Transformed labels dictionary

class cellmil.datamodels.transforms.LabelTransformPipeline(transforms: List[LabelTransform])[source]

Bases: object

Pipeline for applying multiple label transforms in sequence.

This class manages a sequence of label transforms, ensuring that fittable transforms are properly fitted on training data before being applied.

__init__(transforms: List[LabelTransform])[source]

Initialize the pipeline with a list of transforms.

Parameters:

transforms – List of LabelTransform instances to apply in sequence

fit(labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any) LabelTransformPipeline[source]

Fit all fittable transforms in the pipeline on training labels.

Parameters:
  • labels – Training labels dictionary

  • **kwargs – Additional keyword arguments passed to each transform’s fit method

Returns:

Self for method chaining

transform_labels(labels: Dict[str, Union[int, Tuple[float, int]]]) Dict[str, Union[int, Tuple[float, int]]][source]

Apply all transforms in the pipeline sequentially.

Parameters:

labels – Labels dictionary to transform

Returns:

Transformed labels dictionary after applying all transforms

fit_transform(labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any) Dict[str, Union[int, Tuple[float, int]]][source]

Fit the pipeline and apply it to labels.

Parameters:
  • labels – Labels to fit and transform

  • **kwargs – Additional keyword arguments for fitting

Returns:

Transformed labels

get_config() Dict[str, Any][source]

Get configuration for all transforms in the pipeline.

Returns:

Dictionary containing pipeline configuration

save(directory: Path) None[source]

Save the pipeline and all its transforms to disk.

Parameters:

directory – Directory to save the pipeline configuration and transforms

classmethod load(directory: Path) LabelTransformPipeline[source]

Load a pipeline from disk.

Parameters:

directory – Directory containing the saved pipeline

Returns:

LabelTransformPipeline instance

__len__() int[source]

Return the number of transforms in the pipeline.

__getitem__(idx: int) LabelTransform[source]

Get a transform by index.

class cellmil.datamodels.transforms.TimeDiscretizerTransform(n_bins: int = 4, eps: float = 1e-08)[source]

Bases: FittableLabelTransform

Transform for discretizing continuous survival times into bins.

This transform bins survival times using quantiles computed on uncensored patients only. It’s designed for discrete-time survival analysis where we convert the regression problem into a classification problem.

Parameters:
  • n_bins (int) – Number of bins to create (e.g., 4 for quartiles)

  • eps (float) – Small epsilon value to adjust bin boundaries

__init__(n_bins: int = 4, eps: float = 1e-08)[source]

Initialize the fittable label transform.

Parameters:

name – Name of the transform for identification

fit(labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any) TimeDiscretizerTransform[source]

Fit the discretizer on survival data.

Parameters:
  • labels – Dictionary mapping slide IDs to (duration, event) tuples

  • **kwargs – Additional keyword arguments (not used, for API compatibility)

Returns:

Self for method chaining

_transform_labels_impl(labels: Dict[str, Union[int, Tuple[float, int]]]) Dict[str, Union[int, Tuple[float, int]]][source]

Implementation of the label transform operation.

Parameters:

labels – Dictionary mapping slide IDs to (duration, event) tuples

Returns:

Dictionary mapping slide IDs to (bin_index, event) tuples

get_config() Dict[str, Any][source]

Get configuration for saving.

classmethod from_config(config: dict[str, Any])[source]

Create instance from configuration.

save(path: Path) None[source]

Save transform to disk.

classmethod load(path: Path)[source]

Load transform from disk.

Modules

cellmil.datamodels.transforms.base_label_transform

Base classes for label transforms.

cellmil.datamodels.transforms.base_transform

Base classes for feature transforms.

cellmil.datamodels.transforms.correlation_filter

Correlation filter transform for removing highly correlated features.

cellmil.datamodels.transforms.label_pipeline

Pipeline for composing multiple label transforms.

cellmil.datamodels.transforms.normalization

Robust scaler transform for feature normalization.

cellmil.datamodels.transforms.pipeline

Transform pipeline for chaining multiple feature transforms.

cellmil.datamodels.transforms.time_discretizer

Survival discretization transform for converting continuous survival times to discrete bins.