Source code for cellmil.datamodels.datasets.patch_mil_dataset

import pandas as pd
import torch
import json
import hashlib
from tqdm import tqdm
from pathlib import Path
from typing import List, Literal, Tuple, Union, Optional, Any, Dict
from torch.utils.data import Dataset
from cellmil.interfaces.FeatureExtractorConfig import (
    ExtractorType,
    FeatureExtractionType,
)
from cellmil.utils import logger
from ..transforms import Transform, TransformPipeline, LabelTransform, LabelTransformPipeline, FittableLabelTransform
from .utils import (
    wsl_preprocess,
    column_sanity_check,
    filter_split,
    preprocess_row,
    get_feature_path,
    weights_for_sampler,
)


[docs]class PatchMILDataset(Dataset[Tuple[torch.Tensor, int | Tuple[float, int]]]): """ An optimized PyTorch Dataset for Patch-based MIL (Multiple Instance Learning) tasks. This dataset follows PyTorch best practices by preprocessing all data once during initialization and storing it efficiently. This provides significant speed improvements over the previous implementation by avoiding repeated feature loading in __getitem__. Returns: For classification tasks: Tuple[torch.Tensor, int] (features, label) For survival prediction tasks: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event)) """
[docs] def __init__( self, root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: pd.DataFrame, extractor: ExtractorType, split: Literal["train", "val", "test", "all"] = "all", force_reload: bool = False, transforms: Optional[Union[Transform, TransformPipeline]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, ): """ Initialize the optimized Patch MIL dataset. Args: root: Root directory where the processed dataset will be cached label: Label for the dataset. Either: - A single string (e.g., "dcr") for classification tasks - A tuple of two strings (e.g., ("duration", "event")) for survival prediction tasks folder: Path to the dataset folder data: DataFrame containing metadata extractor: Feature extractor type (must be embedding type) split: Dataset split (train/val/test/all). Use "all" to include all data regardless of split force_reload: Whether to force reprocessing even if processed files exist transforms: Optional Transform or TransformPipeline to apply to features before returning them label_transforms: Optional LabelTransform or LabelTransformPipeline to apply to labels (e.g., TimeDiscretizerTransform for survival analysis) """ self.root = Path(root) self.label = label self.folder = folder self.raw_data = data self.extractor = extractor self.split = split self.force_reload = force_reload self.transforms = transforms self.label_transforms = label_transforms if self.extractor not in FeatureExtractionType.Embedding: raise ValueError( f"Extractor {self.extractor} not supported for PatchMILDataset. Use embedding extractors." ) # TODO: Make this configurable self.wsl = False # TODO --- # Data structures self.all_slides: List[str] = [] # All slides with features self.slides: List[str] = [] # Only slides with labels for this task self.labels: Dict[str, Union[int, Tuple[float, int]]] = {} # Preprocessed data storage self.features: Dict[str, torch.Tensor] = {} # Create root directory self.root.mkdir(parents=True, exist_ok=True) # Check if we need to process or can load from cache processed_path = self._get_processed_path() if self.force_reload or not processed_path.exists(): logger.info("Processing patch dataset from scratch...") self._process_dataset() self._save_processed_data(processed_path) else: logger.info(f"Loading preprocessed patch dataset from {processed_path}") self._load_processed_data(processed_path)
[docs] def get_num_labels(self) -> int: """ Get the number of unique labels in the dataset. Note: For survival prediction tasks, this returns 0 as there are no discrete classes. """ if self.labels: # Check if we have survival data by looking at the first label first_label = next(iter(self.labels.values())) if isinstance(first_label, tuple): # Survival data - no discrete classes return 0 # Classification data return len(set(self.labels.values()))
[docs] def _get_processed_path(self) -> Path: """Get the path for the processed dataset file.""" config_dict: Dict[str, Any] = { "extractor": str(self.extractor), "split": self.split, } config_str = json.dumps(config_dict, sort_keys=True) config_hash = hashlib.md5(config_str.encode("utf-8")).hexdigest()[:8] logger.info(f"Dataset configuration:{config_dict}") logger.info(f"Patch dataset configuration hash: {config_hash}") return self.root / f"data_{self.split}_{config_hash}.pt"
[docs] def _process_dataset(self) -> None: """Process the entire dataset and cache results.""" try: if self.wsl: processed_data = wsl_preprocess(self.raw_data) else: processed_data = self.raw_data.copy() # Perform sanity check column_sanity_check(processed_data, self.label) # Filter by split type (unless split is "all") if self.split != "all": print("-"*50) print(f"Filtering data for split: {self.split}") processed_data = filter_split(processed_data, self.split) else: logger.info(f"Using all data: {len(processed_data)} slides") # Extract valid slides and labels logger.info("Validating patch slides...") for _, row in tqdm( processed_data.iterrows(), total=len(processed_data), desc="Validating slides", ): try: slide_name, _ = preprocess_row( row, self.label, # Still need to validate that label exists self.folder, self.extractor, ) if slide_name is not None: self.all_slides.append(slide_name) except Exception as e: raise ValueError(f"Failed to validate row: {e}") logger.info( f"Found {len(self.all_slides)} valid patch slides out of {len(processed_data)} total slides" ) if not self.all_slides: raise ValueError("No valid slides found") # Preprocess all features logger.info("Preprocessing patch features for all slides...") valid_slides: list[str] = [] for slide_name in tqdm(self.all_slides, desc="Processing patch features"): try: features = self._preprocess_slide_features(slide_name) if features is not None: self.features[slide_name] = features valid_slides.append(slide_name) except Exception as e: logger.error(f"Failed to preprocess slide {slide_name}: {e}") # Update all_slides list to only include those with successfully processed features self.all_slides = valid_slides logger.info(f"Successfully preprocessed {len(self.features)} patch slides") # Now get labels and filter slides to only those with labels self.labels = self._get_labels() self.slides = [slide for slide in self.all_slides if slide in self.labels] logger.info( f"Extracted {len(self.labels)} labels for {len(self.slides)} slides with labels" ) if not self.slides: raise ValueError("No slides found with labels for this task") except Exception as e: logger.error(f"Failed to process patch dataset: {e}") raise ValueError(f"Failed to process patch dataset: {e}")
[docs] def _get_labels(self) -> Dict[str, Union[int, Tuple[float, int]]]: """ Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features. Returns: Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival) """ try: # Process the DataFrame same way as in process() if self.wsl: processed_data = wsl_preprocess(self.raw_data) else: processed_data = self.raw_data.copy() # Filter by split type (unless split is "all") if self.split != "all": processed_data = filter_split(processed_data, self.split) # Extract labels for valid slides as a dictionary labels: Dict[str, Union[int, Tuple[float, int]]] = {} for _, row in tqdm(processed_data.iterrows(), total=len(processed_data), desc="Extracting labels"): try: slide_name, label_value = preprocess_row( row, self.label, self.folder, self.extractor, do_validate_features=False ) if slide_name in self.all_slides and label_value is not None: labels[slide_name] = label_value except Exception as e: raise ValueError(f"Failed to extract label from row: {e}") return labels except Exception as e: logger.error(f"Failed to extract labels: {e}") raise
[docs] def _preprocess_slide_features(self, slide_name: str) -> Optional[torch.Tensor]: """ Preprocess features for a single slide. Args: slide_name: Name of the slide to process Returns: Preprocessed features tensor or None if processing fails """ try: features = torch.load( get_feature_path(self.folder, slide_name, self.extractor), map_location="cpu", weights_only=False, )["features"] if features is None or features.numel() == 0: raise ValueError( f"No features found for slide {slide_name} with extractor {self.extractor}" ) return None return features except Exception as e: logger.error(f"Error preprocessing slide {slide_name}: {e}") return None
[docs] def _save_processed_data(self, path: Path) -> None: """Save preprocessed data to disk (all slides with features, label-independent).""" data_dict: dict[str, Any] = { "all_slides": self.all_slides, "features": self.features, } torch.save(data_dict, path) logger.info(f"Saved preprocessed patch dataset to {path}")
[docs] def _load_processed_data(self, path: Path) -> None: """Load preprocessed data from disk and extract labels for current task.""" data_dict = torch.load(path, map_location="cpu", weights_only=False) self.all_slides = data_dict["all_slides"] self.features = data_dict.get("features", {}) # Extract labels for current task and filter slides self.labels = self._get_labels() self.slides = [slide for slide in self.all_slides if slide in self.labels] logger.info(f"Loaded {len(self.all_slides)} total slides, {len(self.slides)} with labels for current task")
[docs] def get_weights_for_sampler(self) -> torch.Tensor: """ Get weights for WeightedRandomSampler to handle class imbalance. Note: Only applicable for classification tasks. For survival prediction, returns uniform weights. """ # Convert dictionary values to list in the same order as slides labels_list = [self.labels[slide] for slide in self.slides] # Check if we have survival data if labels_list and isinstance(labels_list[0], tuple): # Survival data - return uniform weights logger.warning("Uniform weights used for survival prediction tasks") return torch.ones(len(labels_list), dtype=torch.float32) # Classification data return weights_for_sampler(labels_list) # type: ignore
[docs] def get_config(self) -> dict[str, Any]: """Get dataset configuration as a dictionary.""" return { "dataset_type": self.__class__.__name__, "label": str(self.label), "extractor": str(self.extractor), "split": self.split, }
[docs] def create_subset(self, indices: List[int]) -> "PatchMILDataset": """ Create a subset of the dataset using the specified indices. This is useful for creating train/val/test splits when using split="all". The subset will share the same cached features but only include the specified samples. Args: indices: List of indices to include in the subset Returns: New PatchMILDataset instance containing only the specified samples Raises: ValueError: If any index is out of range """ # Allow empty indices for validation-less training (e.g., final model on all data) if not indices: # Create empty subset subset = PatchMILDataset.__new__(PatchMILDataset) subset.root = self.root subset.label = self.label subset.folder = self.folder subset.raw_data = self.raw_data subset.extractor = self.extractor subset.split = self.split subset.force_reload = self.force_reload subset.transforms = self.transforms subset.label_transforms = self.label_transforms subset.wsl = self.wsl subset.slides = [] subset.all_slides = self.all_slides subset.features = self.features subset.labels = {} logger.info("Created empty subset (0 samples)") return subset max_idx = len(self.slides) - 1 invalid_indices = [idx for idx in indices if idx < 0 or idx > max_idx] if invalid_indices: raise ValueError( f"Invalid indices {invalid_indices}. Valid range is 0-{max_idx}" ) # Create a new instance with the same configuration subset = PatchMILDataset.__new__(PatchMILDataset) # Copy configuration subset.root = self.root subset.label = self.label subset.folder = self.folder subset.raw_data = self.raw_data subset.extractor = self.extractor subset.split = self.split subset.force_reload = self.force_reload subset.transforms = self.transforms subset.label_transforms = self.label_transforms subset.wsl = self.wsl # Create subset data subset.slides = [self.slides[i] for i in indices] subset.all_slides = self.all_slides # Share the same all_slides # Share the same features cache subset.features = self.features # Extract labels fresh for the subset subset.labels = subset._get_labels() logger.info( f"Created subset with {len(subset.slides)} samples from {len(self.slides)} total samples" ) return subset
def __len__(self) -> int: return len(self.slides)
[docs] def __getitem__(self, index: int) -> Tuple[torch.Tensor, Union[int, Tuple[float, int]]]: """ Get a sample from the dataset. Args: index: Index of the sample to retrieve Returns: For classification tasks: Tuple of (features, label) where features is a tensor of shape (n_patches, n_features) and label is an int. For survival prediction tasks: Tuple of (features, (duration, event)) where features is a tensor of shape (n_patches, n_features), duration is a float, and event is an int. """ slide_name = self.slides[index] label = self.labels[slide_name] try: # Get features from cache features = self.features[ slide_name ].clone() # Clone to avoid modifying cached data # Apply transforms if provided if self.transforms is not None: features = self.transforms.transform(features) # Apply label transforms if provided if self.label_transforms is not None: labels_dict = {slide_name: label} transformed = self.label_transforms.transform_labels(labels_dict) label = transformed[slide_name] return features, label except KeyError: logger.error(f"No features found for slide {slide_name}") raise ValueError(f"No features found for slide {slide_name}") except Exception as e: logger.error(f"Failed to get item for index {index}: {e}") raise ValueError(f"Failed to get item for index {index}: {e}")
[docs] def get_normalization_params(self) -> None: """Return None for compatibility - patch datasets don't use normalization.""" return None
[docs] def get_correlation_mask(self) -> None: """Return None for compatibility - patch datasets don't use correlation filtering.""" return None
[docs] def create_train_val_datasets( self, train_indices: List[int], val_indices: List[int], transforms: Optional[Union[Transform, TransformPipeline]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None ) -> Tuple["PatchMILDataset", "PatchMILDataset"]: """ Create train and validation datasets with transforms fitted on training data only. This function prevents data leakage by ensuring transforms are fitted only on the training set before being applied to both training and validation sets. Args: train_indices: List of indices for training data val_indices: List of indices for validation data transforms: Optional Transform or TransformPipeline for features (fitted on training data) label_transforms: Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis) Returns: Tuple of (train_dataset, val_dataset) with properly fitted transforms Raises: ValueError: If indices are invalid or transforms cannot be fitted """ if transforms is not None: logger.warning("Transforms are ignored...") # Create train dataset train_dataset = self.create_subset(train_indices) # Fit label transforms on training labels only if label_transforms is not None: # Get training labels train_labels = {train_dataset.slides[i]: train_dataset.labels[train_dataset.slides[i]] for i in range(len(train_dataset.slides))} # Fit label transforms on training data if isinstance(label_transforms, FittableLabelTransform): label_transforms.fit(train_labels) logger.info(f"Fitted label transform on {len(train_labels)} training labels") elif isinstance(label_transforms, LabelTransformPipeline): label_transforms.fit(train_labels) logger.info(f"Fitted label transform pipeline on {len(train_labels)} training labels") # Apply to both datasets train_dataset.label_transforms = label_transforms # Create validation dataset with the same fitted transforms val_dataset = self.create_subset(val_indices) # Share the fitted label transforms with validation dataset if label_transforms is not None: val_dataset.label_transforms = label_transforms logger.info(f"Created train dataset with {len(train_dataset)} samples and val dataset with {len(val_dataset)} samples") return train_dataset, val_dataset