Source code for cellmil.datamodels.transforms.label_pipeline

"""
Pipeline for composing multiple label transforms.
"""

from typing import List, Union, Dict, Any, Tuple
from pathlib import Path
import json
from .base_label_transform import LabelTransform, FittableLabelTransform


[docs]class LabelTransformPipeline: """ Pipeline for applying multiple label transforms in sequence. This class manages a sequence of label transforms, ensuring that fittable transforms are properly fitted on training data before being applied. """
[docs] def __init__(self, transforms: List[LabelTransform]): """ Initialize the pipeline with a list of transforms. Args: transforms: List of LabelTransform instances to apply in sequence """ self.transforms = transforms self.name = "label_pipeline"
[docs] def fit( self, labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any ) -> "LabelTransformPipeline": """ Fit all fittable transforms in the pipeline on training labels. Args: labels: Training labels dictionary **kwargs: Additional keyword arguments passed to each transform's fit method Returns: Self for method chaining """ for transform in self.transforms: if isinstance(transform, FittableLabelTransform): transform.fit(labels, **kwargs) # Apply the transform to get updated labels for next transform labels = transform.transform_labels(labels) return self
[docs] def transform_labels( self, labels: Dict[str, Union[int, Tuple[float, int]]] ) -> Dict[str, Union[int, Tuple[float, int]]]: """ Apply all transforms in the pipeline sequentially. Args: labels: Labels dictionary to transform Returns: Transformed labels dictionary after applying all transforms """ result = labels for transform in self.transforms: result = transform.transform_labels(result) # type: ignore return result # type: ignore
[docs] def fit_transform( self, labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any ) -> Dict[str, Union[int, Tuple[float, int]]]: """ Fit the pipeline and apply it to labels. Args: labels: Labels to fit and transform **kwargs: Additional keyword arguments for fitting Returns: Transformed labels """ self.fit(labels, **kwargs) return self.transform_labels(labels)
[docs] def get_config(self) -> Dict[str, Any]: """ Get configuration for all transforms in the pipeline. Returns: Dictionary containing pipeline configuration """ return { "name": self.name, "transforms": [ {"transform_class": t.__class__.__name__, "config": t.get_config()} for t in self.transforms ], }
[docs] def save(self, directory: Path) -> None: """ Save the pipeline and all its transforms to disk. Args: directory: Directory to save the pipeline configuration and transforms """ directory.mkdir(parents=True, exist_ok=True) # Save pipeline config pipeline_config = self.get_config() pipeline_path = directory / "pipeline.json" with open(pipeline_path, "w") as f: json.dump(pipeline_config, f, indent=2) # Save individual transforms for idx, transform in enumerate(self.transforms): transform_path = directory / f"transform_{idx}.json" transform.save(transform_path)
[docs] @classmethod def load(cls, directory: Path) -> "LabelTransformPipeline": """ Load a pipeline from disk. Args: directory: Directory containing the saved pipeline Returns: LabelTransformPipeline instance """ pipeline_path = directory / "pipeline.json" with open(pipeline_path, "r") as f: pipeline_config = json.load(f) # Import transform classes from . import TimeDiscretizerTransform transform_classes = { "TimeDiscretizerTransform": TimeDiscretizerTransform, } # Load transforms transforms: List[LabelTransform] = [] for idx, transform_info in enumerate(pipeline_config["transforms"]): transform_path = directory / f"transform_{idx}.json" transform_class_name = transform_info.get("transform_class") if transform_class_name not in transform_classes: raise ValueError( f"Unknown label transform class: {transform_class_name}" ) transform_class = transform_classes[transform_class_name] transform = transform_class.load(transform_path) transforms.append(transform) return cls(transforms)
[docs] def __len__(self) -> int: """Return the number of transforms in the pipeline.""" return len(self.transforms)
[docs] def __getitem__(self, idx: int) -> LabelTransform: """Get a transform by index.""" return self.transforms[idx]