Source code for cellmil.datamodels.transforms.base_transform
"""
Base classes for feature transforms.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, List
import torch
from pathlib import Path
import json
[docs]class Transform(ABC):
"""Base class for all feature transforms."""
[docs] def __init__(self, name: str):
"""
Initialize the transform.
Args:
name: Name of the transform for identification
"""
self.name = name
[docs] @abstractmethod
def transform(self, features: torch.Tensor) -> torch.Tensor:
"""
Apply the transform to features.
Args:
features: Input features tensor of shape (n_instances, n_features)
Returns:
Transformed features tensor
"""
pass
[docs] @abstractmethod
def get_config(self) -> Dict[str, Any]:
"""
Get the configuration dictionary for this transform.
Returns:
Dictionary containing transform configuration
"""
pass
[docs] @classmethod
@abstractmethod
def from_config(cls, config: Dict[str, Any]) -> "Transform":
"""
Create transform instance from configuration dictionary.
Args:
config: Configuration dictionary
Returns:
Transform instance
"""
pass
[docs] def save(self, path: Path) -> None:
"""
Save the transform to disk.
Args:
path: Path to save the transform
"""
config = self.get_config()
config["transform_class"] = self.__class__.__name__
with open(path, "w") as f:
json.dump(config, f, indent=2)
[docs] @classmethod
def load(cls, path: Path) -> "Transform":
"""
Load transform from disk.
Args:
path: Path to load the transform from
Returns:
Transform instance
"""
with open(path, "r") as f:
config = json.load(f)
# Remove class name from config since it's not needed for instantiation
config.pop("transform_class", None)
return cls.from_config(config)
[docs]class FittableTransform(Transform):
"""Base class for transforms that need to be fitted on training data."""
[docs] def __init__(self, name: str):
"""
Initialize the fittable transform.
Args:
name: Name of the transform for identification
"""
super().__init__(name)
self.is_fitted = False
[docs] @abstractmethod
def fit(
self, features: torch.Tensor, feature_names: Optional[List[str]] = None
) -> "FittableTransform":
"""
Fit the transform on training data.
Args:
features: Training features tensor of shape (n_instances, n_features)
feature_names: Optional list of feature names for logging
Returns:
Self for method chaining
"""
pass
[docs] def fit_transform(
self, features: torch.Tensor, feature_names: Optional[List[str]] = None
) -> torch.Tensor:
"""
Fit the transform and apply it to the features.
Args:
features: Features to fit and transform
feature_names: Optional list of feature names for logging
Returns:
Transformed features
"""
self.fit(features, feature_names)
return self.transform(features)
[docs] def transform(self, features: torch.Tensor) -> torch.Tensor:
"""
Apply the transform to features.
Args:
features: Input features tensor
Returns:
Transformed features tensor
Raises:
RuntimeError: If transform hasn't been fitted yet
"""
if not self.is_fitted:
raise RuntimeError(
f"Transform '{self.name}' must be fitted before transforming data"
)
return self._transform_impl(features)
[docs] @abstractmethod
def _transform_impl(self, features: torch.Tensor) -> torch.Tensor:
"""
Implementation of the transform operation.
Args:
features: Input features tensor
Returns:
Transformed features tensor
"""
pass