Source code for cellmil.datamodels.transforms.base_label_transform
"""
Base classes for label transforms.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Tuple, Union
from pathlib import Path
import json
[docs]class LabelTransform(ABC):
"""Base class for all label transforms."""
[docs] def __init__(self, name: str):
"""
Initialize the transform.
Args:
name: Name of the transform for identification
"""
self.name = name
[docs] @abstractmethod
def transform_labels(
self, labels: Dict[str, Union[int, Tuple[float, int]]]
) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Apply the transform to labels.
Args:
labels: Dictionary mapping slide IDs to labels
For classification: int labels
For survival: (duration, event) tuples
Returns:
Transformed labels dictionary
"""
pass
[docs] @abstractmethod
def get_config(self) -> Dict[str, Any]:
"""
Get the configuration dictionary for this transform.
Returns:
Dictionary containing transform configuration
"""
pass
[docs] @classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "LabelTransform":
"""
Create transform instance from configuration dictionary.
Args:
config: Configuration dictionary
Returns:
LabelTransform instance
"""
pass
[docs] def save(self, path: Path) -> None:
"""
Save the transform to disk.
Args:
path: Path to save the transform
"""
config = self.get_config()
config["transform_class"] = self.__class__.__name__
with open(path, "w") as f:
json.dump(config, f, indent=2)
[docs] @classmethod
def load(cls, path: Path) -> "LabelTransform":
"""
Load transform from disk.
Args:
path: Path to load the transform from
Returns:
LabelTransform instance
"""
with open(path, "r") as f:
config = json.load(f)
# Remove class name from config since it's not needed for instantiation
config.pop("transform_class", None)
return cls.from_config(config)
[docs]class FittableLabelTransform(LabelTransform):
"""Base class for label transforms that need to be fitted on training data."""
[docs] def __init__(self, name: str):
"""
Initialize the fittable label transform.
Args:
name: Name of the transform for identification
"""
super().__init__(name)
self.is_fitted = False
[docs] @abstractmethod
def fit(
self, labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any
) -> "FittableLabelTransform":
"""
Fit the transform on training labels.
Args:
labels: Training labels dictionary mapping slide IDs to labels
**kwargs: Additional keyword arguments for fitting
Returns:
Self for method chaining
"""
pass
[docs] def fit_transform(
self, labels: Dict[str, Union[int, Tuple[float, int]]], **kwargs: Any
) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Fit the transform and apply it to the labels.
Args:
labels: Labels to fit and transform
**kwargs: Additional keyword arguments for fitting
Returns:
Transformed labels
"""
self.fit(labels, **kwargs)
return self.transform_labels(labels)
[docs] def transform_labels(
self, labels: Dict[str, Union[int, Tuple[float, int]]]
) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Apply the transform to labels.
Args:
labels: Labels dictionary to transform
Returns:
Transformed labels dictionary
Raises:
RuntimeError: If transform hasn't been fitted yet
"""
if not self.is_fitted:
raise RuntimeError(
f"LabelTransform '{self.name}' must be fitted before transforming data"
)
return self._transform_labels_impl(labels)
[docs] @abstractmethod
def _transform_labels_impl(
self, labels: Dict[str, Union[int, Tuple[float, int]]]
) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Implementation of the label transform operation.
Args:
labels: Input labels dictionary
Returns:
Transformed labels dictionary
"""
pass