cellmil.datamodels.transforms.pipeline

Transform pipeline for chaining multiple feature transforms.

Classes

TransformPipeline(transforms)

Pipeline for chaining multiple feature transforms.

class cellmil.datamodels.transforms.pipeline.TransformPipeline(transforms: List[Transform])[source]

Bases: object

Pipeline for chaining multiple feature transforms.

Supports both fittable and non-fittable transforms, automatic fitting on training data, and serialization for reuse on new samples.

__init__(transforms: List[Transform])[source]

Initialize the transform pipeline.

Parameters:

transforms – List of transforms to apply in order

fit(features: Tensor, feature_names: Optional[List[str]] = None) TransformPipeline[source]

Fit all fittable transforms in the pipeline on training data.

Parameters:
  • features – Training features tensor of shape (n_instances, n_features)

  • feature_names – Optional list of feature names for logging

Returns:

Self for method chaining

transform(features: Tensor) Tensor[source]

Apply all transforms in the pipeline to features.

Parameters:

features – Input features tensor

Returns:

Transformed features tensor

Raises:

RuntimeError – If pipeline contains unfitted transforms

fit_transform(features: Tensor, feature_names: Optional[List[str]] = None) Tensor[source]

Fit the pipeline and apply it to the features.

Parameters:
  • features – Features to fit and transform

  • feature_names – Optional list of feature names for logging

Returns:

Transformed features

save(path: Union[str, Path]) None[source]

Save the entire pipeline to disk.

Parameters:

path – Path to save the pipeline (should be a directory)

classmethod load(path: Union[str, Path]) TransformPipeline[source]

Load a transform pipeline from disk.

Parameters:

path – Path to load the pipeline from (should be a directory)

Returns:

Loaded transform pipeline

get_config() Dict[str, Any][source]

Get the configuration dictionary for the entire pipeline.

Returns:

Dictionary containing pipeline and all transform configurations

get_config_for_hashing() Dict[str, Any][source]

Get stable configuration for hashing that excludes fitted state.

Returns:

Dictionary containing pipeline configuration without fitted state

__len__() int[source]

Return the number of transforms in the pipeline.

__getitem__(index: int) Transform[source]

Get a transform by index.

__iter__()[source]

Iterate over transforms.

add_transform(transform: Transform) None[source]

Add a transform to the end of the pipeline.

Parameters:

transform – Transform to add

Note

If the pipeline is already fitted, you’ll need to refit it.

remove_transform(index: int) Transform[source]

Remove and return a transform by index.

Parameters:

index – Index of transform to remove

Returns:

Removed transform

Note

If the pipeline is already fitted, you’ll need to refit it.