Source code for cellmil.datamodels.transforms.normalization

"""
Robust scaler transform for feature normalization.
"""

from typing import Any, Dict, Optional, List, Tuple, cast
import torch

from cellmil.utils import logger
from .base_transform import FittableTransform


[docs]class RobustScalerTransform(FittableTransform): """ Transform that applies robust scaling to features using median and IQR. Robust scaling is less sensitive to outliers than standard scaling. Formula: (x - median) / IQR, where IQR = Q3 - Q1 Features are first log-transformed to handle skewed distributions and outliers, then robust scaling is applied. """
[docs] def __init__( self, apply_log_transform: bool = True, quantile_range: Tuple[float, float] = (0.25, 0.75), clip_quantiles: Tuple[float, float] = (0.005, 0.995), constant_threshold: float = 1e-8, ): """ Initialize the robust scaler transform. Args: apply_log_transform: Whether to apply log transformation before scaling quantile_range: Tuple of (lower_quantile, upper_quantile) for IQR computation clip_quantiles: Tuple of (lower_clip, upper_clip) for outlier clipping constant_threshold: Threshold below which IQR is considered too small """ super().__init__("robust_scaler") self.apply_log_transform = apply_log_transform self.quantile_range = quantile_range self.clip_quantiles = clip_quantiles self.constant_threshold = constant_threshold # Fitted parameters self.median_values_: Optional[torch.Tensor] = None self.iqr_values_: Optional[torch.Tensor] = None self.clip_min_values_: Optional[torch.Tensor] = None self.clip_max_values_: Optional[torch.Tensor] = None self.num_features_: Optional[int] = None self.constant_features_mask_: Optional[torch.Tensor] = None
[docs] def fit( self, features: torch.Tensor, feature_names: Optional[List[str]] = None ) -> "RobustScalerTransform": """ Fit the robust scaler on training data. Args: features: Training features tensor of shape (n_instances, n_features) feature_names: Optional list of feature names for logging Returns: Self for method chaining """ self.num_features_ = features.shape[1] logger.info( f"Computing robust scaling parameters for {features.shape[1]} features " f"using {features.shape[0]} instances..." ) # Apply log transformation if requested if self.apply_log_transform: # First clip extreme outliers self.clip_min_values_ = torch.quantile( features, self.clip_quantiles[0], dim=0 ) self.clip_max_values_ = torch.quantile( features, self.clip_quantiles[1], dim=0 ) features_processed = torch.clamp( features, min=self.clip_min_values_, max=self.clip_max_values_ ) # Apply log transformation: sign(x) * log(1 + |x| + epsilon) epsilon = 1e-8 features_processed = torch.sign(features_processed) * torch.log1p( torch.abs(features_processed) + epsilon ) else: features_processed = features.clone() self.clip_min_values_ = None self.clip_max_values_ = None # Compute median and IQR self.median_values_ = torch.median(features_processed, dim=0)[ 0 ] # Shape: (n_features,) q1 = torch.quantile(features_processed, self.quantile_range[0], dim=0) q3 = torch.quantile(features_processed, self.quantile_range[1], dim=0) self.iqr_values_ = q3 - q1 # Handle features with zero/small IQR (near-constant features) self.constant_features_mask_ = self.iqr_values_ <= self.constant_threshold if self.constant_features_mask_.sum() > 0: num_constant = self.constant_features_mask_.sum().item() logger.info( f"Found {num_constant} near-constant features with very small IQR" ) # For constant features, set IQR to 1 to avoid division by zero self.iqr_values_[self.constant_features_mask_] = 1.0 # Log constant feature names if available if feature_names is not None: constant_indices = cast( list[int], torch.where(self.constant_features_mask_)[0].tolist() # type: ignore ) constant_names = [ feature_names[i] for i in constant_indices if i < len(feature_names) ] if constant_names: logger.info(f"Near-constant features: {constant_names}") logger.info( f"Computed robust scaling parameters: " f"median range = {self.median_values_.min():.6f} to {self.median_values_.max():.6f}, " f"IQR range = {self.iqr_values_.min():.6f} to {self.iqr_values_.max():.6f}" ) self.is_fitted = True return self
[docs] def _transform_impl(self, features: torch.Tensor) -> torch.Tensor: """ Apply the robust scaling to features. Args: features: Input features tensor Returns: Normalized features tensor """ if features.size(1) != self.num_features_: raise ValueError( f"Feature dimension mismatch: expected {self.num_features_} features, " f"got {features.size(1)}" ) # Apply the same preprocessing as during fitting if self.apply_log_transform: # Clip outliers using fitted values features_processed = torch.clamp( features, min=self.clip_min_values_, max=self.clip_max_values_ ) # Apply log transformation epsilon = 1e-8 features_processed = torch.sign(features_processed) * torch.log1p( torch.abs(features_processed) + epsilon ) else: features_processed = features.clone() if self.median_values_ is None or self.iqr_values_ is None: raise RuntimeError("Transform has not been fitted yet") # Apply robust scaling: (x - median) / IQR scaled_features = (features_processed - self.median_values_) / self.iqr_values_ return scaled_features
[docs] def get_config(self) -> Dict[str, Any]: """Get the configuration dictionary for this transform.""" config: dict[str, Any] = { "name": self.name, "apply_log_transform": self.apply_log_transform, "quantile_range": self.quantile_range, "clip_quantiles": self.clip_quantiles, "constant_threshold": self.constant_threshold, } if self.is_fitted: config.update( { "median_values": self.median_values_.tolist() # type: ignore if self.median_values_ is not None else None, "iqr_values": self.iqr_values_.tolist() # type: ignore if self.iqr_values_ is not None else None, "clip_min_values": self.clip_min_values_.tolist() # type: ignore if self.clip_min_values_ is not None else None, "clip_max_values": self.clip_max_values_.tolist() # type: ignore if self.clip_max_values_ is not None else None, "num_features": self.num_features_, "constant_features_mask": self.constant_features_mask_.tolist() # type: ignore if self.constant_features_mask_ is not None else None, "is_fitted": True, } ) else: config["is_fitted"] = False return config
[docs] @classmethod def from_config(cls, config: Dict[str, Any]) -> "RobustScalerTransform": """Create transform instance from configuration dictionary.""" # Create instance with hyperparameters transform = cls( apply_log_transform=config["apply_log_transform"], quantile_range=tuple(config["quantile_range"]), clip_quantiles=tuple(config["clip_quantiles"]), constant_threshold=config.get("constant_threshold", 1e-8), ) # Restore fitted state if available if config.get("is_fitted", False): transform.median_values_ = ( torch.tensor(config["median_values"], dtype=torch.float32) if config["median_values"] is not None else None ) transform.iqr_values_ = ( torch.tensor(config["iqr_values"], dtype=torch.float32) if config["iqr_values"] is not None else None ) transform.clip_min_values_ = ( torch.tensor(config["clip_min_values"], dtype=torch.float32) if config["clip_min_values"] is not None else None ) transform.clip_max_values_ = ( torch.tensor(config["clip_max_values"], dtype=torch.float32) if config["clip_max_values"] is not None else None ) transform.num_features_ = config["num_features"] transform.constant_features_mask_ = ( torch.tensor(config["constant_features_mask"], dtype=torch.bool) if config["constant_features_mask"] is not None else None ) transform.is_fitted = True return transform
[docs] def get_scaling_parameters(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: """ Get the scaling parameters (median, IQR) computed from training data. Returns: Tuple of (median_values, iqr_values) if fitted, None otherwise """ if not self.is_fitted: return None if self.median_values_ is None or self.iqr_values_ is None: return None return self.median_values_, self.iqr_values_
[docs] def get_constant_features_mask(self) -> Optional[torch.Tensor]: """ Get a boolean mask indicating which features were considered constant. Returns: Boolean tensor of shape (n_features,) where True indicates the feature was considered constant during fitting. """ if not self.is_fitted: return None return self.constant_features_mask_