import pandas as pd
import torch
import json
import random
import hashlib
from tqdm import tqdm
from pathlib import Path
from typing import List, Literal, Tuple, Union, Optional, Dict, Any
from torch.utils.data import Dataset
from cellmil.interfaces.FeatureExtractorConfig import (
ExtractorType,
FeatureExtractionType,
)
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorType
from cellmil.interfaces.CellSegmenterConfig import ModelType, TYPE_NUCLEI_DICT
from cellmil.utils import logger
from ..transforms import (
TransformPipeline,
Transform,
FittableTransform,
LabelTransform,
LabelTransformPipeline,
)
from .utils import (
wsl_preprocess,
column_sanity_check,
filter_split,
preprocess_row,
get_cell_types,
get_centroids,
cell_types_to_tensor,
get_cell_features,
weights_for_sampler,
cell_type_name_to_index,
load_roi_for_slide,
filter_cells_by_roi,
)
[docs]class CellMILDataset(
Dataset[
Tuple[torch.Tensor, int | Tuple[float, int]]
| Tuple[torch.Tensor, torch.Tensor, int | Tuple[float, int]]
]
):
"""
An PyTorch Dataset for MIL (Multiple Instance Learning) tasks.
Returns:
For classification tasks:
When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, int] (features, label)
When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, int] (features, cell_types, label)
For survival prediction tasks:
When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event))
When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, Tuple[float, int]] (features, cell_types, (duration, event))
"""
[docs] def __init__(
self,
root: Union[str, Path],
label: Union[str, Tuple[str, str]],
folder: Path,
data: pd.DataFrame,
extractor: Union[ExtractorType, List[ExtractorType]],
graph_creator: GraphCreatorType | None = None,
segmentation_model: ModelType | None = None,
split: Literal["train", "val", "test", "all"] = "all",
transforms: Optional[TransformPipeline | Transform] = None,
label_transforms: Optional[LabelTransform | LabelTransformPipeline] = None,
cell_type: bool = False,
cell_types_to_keep: Optional[List[str]] = None,
return_cell_types: bool = True,
roi_folder: Optional[Path] = None,
max_workers: int = 8,
force_reload: bool = False,
):
"""
Initialize the optimized MIL dataset.
Args:
root: Root directory where the processed dataset will be cached
label: Label for the dataset. Either:
- A single string (e.g., "dcr") for classification tasks
- A tuple of two strings (e.g., ("duration", "event")) for survival prediction tasks
folder: Path to the dataset folder
data: DataFrame containing metadata. If roi_folder is provided, must contain 'ID', 'I3LUNG_ID', and 'CENTER' columns.
extractor: Feature extractor type or list of types to use for feature extraction.
graph_creator: Optional graph creator type, needed for some extractors
segmentation_model: Optional Segmentation model type, needed for some extractors
split: Dataset split (train/val/test/all). Use "all" to include all data regardless of split
cell_type: Whether to add cell types as one-hot encoded columns to the feature tensor.
Only available for 'cellvit' and 'hovernet' segmentation models.
cell_types_to_keep: Optional list of cell type names to keep (e.g., ["Neoplastic", "Connective"]).
Valid names: "Neoplastic", "Inflammatory", "Connective", "Dead", "Epithelial" (case-insensitive).
If provided, only cells of these types will be included. Requires cell_type=True.
roi_folder: Optional path to the directory containing ROI CSV files organized by center folders.
If provided, cells will be filtered to only include those within ROI boundaries.
Requires 'ID', 'I3LUNG_ID', and 'CENTER' columns in the data DataFrame.
Slides with no cells matching these types will be excluded from the dataset.
return_cell_types: Whether to return cell types in __getitem__. If False, only returns (features, label).
If True and cell_type=True, returns (features, cell_types, label). Default is True for backward compatibility.
max_workers: Maximum number of threads for parallel processing
force_reload: Whether to force reprocessing even if processed files exist
transforms: Optional TransformPipeline to apply to features at getitem time
label_transform: Optional transform to apply to labels (e.g., TimeDiscretizerTransform for binning survival times)
"""
self.root = Path(root)
self.label = label
self.folder = folder
self.raw_data = data
self.extractor = extractor
self.graph_creator = graph_creator
self.segmentation_model = segmentation_model
self.split = split
self.cell_type = cell_type
self.return_cell_types = return_cell_types
self.label_transforms = label_transforms
# Convert cell type names to indices
if cell_types_to_keep is not None:
self.cell_types_to_keep_indices = cell_type_name_to_index(
cell_types_to_keep
)
else:
self.cell_types_to_keep_indices = None
# ROI filtering setup
self.roi_folder = Path(roi_folder) if roi_folder is not None else None
# Validate ROI parameters
if self.roi_folder is not None:
if not self.roi_folder.exists():
raise ValueError(f"ROI folder does not exist: {self.roi_folder}")
# Check if data DataFrame has required columns for ROI filtering
required_roi_columns = ["ID", "I3LUNG_ID", "CENTER"]
missing_columns = [
col for col in required_roi_columns if col not in data.columns
]
if missing_columns:
raise ValueError(
f"ROI filtering requires the following columns in data DataFrame: {required_roi_columns}. "
f"Missing columns: {missing_columns}"
)
logger.info(f"ROI filtering enabled. ROI folder: {self.roi_folder}")
self.max_workers = max_workers
self.force_reload = force_reload
self.transforms = transforms
if (
isinstance(self.extractor, ExtractorType)
and self.extractor in FeatureExtractionType.Embedding
) or (
isinstance(self.extractor, list)
and any(ext in FeatureExtractionType.Embedding for ext in self.extractor)
):
raise ValueError("Embedding extractor is not supported for this dataset.")
# Validate that cell types are only used with compatible segmentation models
if self.cell_type and self.segmentation_model not in ["cellvit", "hovernet"]:
raise ValueError(
f"Cell types can only be used with 'cellvit' or 'hovernet' segmentation models. "
f"Got '{self.segmentation_model}'"
)
# Validate that cell_types_to_keep requires cell_type=True
if self.cell_types_to_keep_indices is not None and not self.cell_type:
raise ValueError(
"cell_types_to_keep requires cell_type=True. "
"Please set cell_type=True to enable cell type filtering."
)
# TODO: Make this configurable
self.wsl = False
# TODO ---
# Data structures
self.all_slides: List[str] = [] # All slides with features
self.slides: List[str] = [] # Only slides with labels for this task
self.labels: Dict[str, Union[int, Tuple[float, int]]] = {}
self.cell_types: Dict[str, Dict[int, int]] = {}
# Raw data storage (transforms applied at getitem time)
self.features: Dict[str, torch.Tensor] = {}
self.cell_types_tensors: Dict[
str, torch.Tensor
] = {} # Store cell types separately
self.cell_indices: Dict[str, Dict[int, int]] = {}
# Create root directory
self.root.mkdir(parents=True, exist_ok=True)
# Check if we need to process or can load from cache
data_path = self._get_data_path()
if self.force_reload or not data_path.exists():
logger.info("Processing dataset from scratch...")
self._process_dataset()
self._save_data(data_path)
else:
logger.info(f"Loading preprocessed dataset from {data_path}")
self._load_data(data_path)
# Apply ROI filtering if enabled (after loading/processing complete dataset)
if self.roi_folder is not None:
self._apply_roi_filtering()
[docs] def _get_data_path(self) -> Path:
"""Get the path for the processed dataset file."""
# Create a hash of the configuration to ensure cache invalidation
# when parameters change
extractor_str = (
json.dumps(self.extractor, sort_keys=True)
if isinstance(self.extractor, list)
else str(self.extractor)
)
config_dict: Dict[str, Any] = {
"extractor": extractor_str,
"graph_creator": self.graph_creator,
"segmentation_model": self.segmentation_model,
"split": self.split,
"cell_type": self.cell_type,
}
config_str = json.dumps(config_dict, sort_keys=True)
config_hash = hashlib.md5(config_str.encode("utf-8")).hexdigest()[:8]
logger.info(f"Dataset configuration:{config_dict}")
logger.info(f"Dataset configuration hash: {config_hash}")
return self.root / f"data_{self.split}_{config_hash}.pt"
[docs] def _process_dataset(self) -> None:
"""Process the entire dataset and cache results."""
try:
if self.wsl:
processed_data = wsl_preprocess(self.raw_data)
else:
processed_data = self.raw_data.copy()
# Perform sanity check
column_sanity_check(processed_data, self.label)
# Filter by split type (unless split is "all")
if self.split != "all":
processed_data = filter_split(processed_data, self.split)
else:
logger.info(f"Using all data: {len(processed_data)} slides")
# Extract valid slides and labels
logger.info("Validating slides...")
for _, row in tqdm(
processed_data.iterrows(),
total=len(processed_data),
desc="Validating slides",
):
try:
slide_name, _ = preprocess_row(
row,
self.label,
self.folder,
self.extractor,
self.graph_creator,
self.segmentation_model,
)
if slide_name is not None:
self.all_slides.append(slide_name)
except Exception as e:
raise ValueError(f"Error processing row: {e}")
logger.info(
f"Found {len(self.all_slides)} valid slides out of {len(processed_data)} total slides"
)
if not self.all_slides:
raise ValueError("No valid slides found")
# Precompute cell types if enabled
if self.cell_type and self.segmentation_model:
logger.info("Precomputing cell types for all slides...")
for slide_name in tqdm(self.all_slides, desc="Loading cell types"):
_cell_types = get_cell_types(
self.folder, slide_name, self.segmentation_model
)
if _cell_types is None:
logger.warning(f"No cell types found for slide {slide_name}")
self.cell_types[slide_name] = {}
else:
self.cell_types[slide_name] = _cell_types
# Load all features (without applying transforms)
logger.info("Loading raw features for all slides...")
valid_slides: List[str] = []
for slide_name in tqdm(self.all_slides, desc="Loading features"):
try:
features = self._load_slide_features(slide_name)
if features is not None:
# Store raw features
self.features[slide_name] = features
valid_slides.append(slide_name)
except Exception as e:
logger.error(f"Failed to load slide {slide_name}: {e}")
# Update all_slides to only include those with successfully loaded features
self.all_slides = valid_slides
# Now get labels and filter slides to only those with labels
self.labels = self._get_labels()
self.slides = [slide for slide in self.all_slides if slide in self.labels]
# Filter out slides with no matching cell types if cell_types_to_keep is specified
if self.cell_types_to_keep_indices is not None:
logger.info("Filtering slides with no matching cell types...")
slides_with_matching_cells: list[str] = []
for slide_name in tqdm(self.slides, desc="Checking cell types"):
cell_types_tensor = self.cell_types_tensors.get(slide_name)
if cell_types_tensor is None or cell_types_tensor.shape[0] == 0:
raise ValueError(
f"Missing or empty cell types tensor for slide {slide_name}"
)
# Get the cell type for each cell (argmax of one-hot encoding)
cell_type_tensor_indices = torch.argmax(cell_types_tensor, dim=1)
# Convert 1-based TYPE_NUCLEI_DICT indices to 0-based tensor indices
tensor_indices_to_keep = [
idx - 1 for idx in self.cell_types_to_keep_indices
]
# Check if any cells match the requested types
has_matching_cells = any(
(cell_type_tensor_indices == tensor_idx).any().item()
for tensor_idx in tensor_indices_to_keep
)
if has_matching_cells:
slides_with_matching_cells.append(slide_name)
else:
logger.debug(
f"Skipping slide {slide_name}: no cells of requested types"
)
self.slides = slides_with_matching_cells
logger.info(f"Kept {len(self.slides)} slides with matching cell types")
logger.info(f"Successfully loaded {len(self.features)} slides")
logger.info(f"Found {len(self.slides)} slides with labels for current task")
except Exception as e:
logger.error(f"Failed to process dataset: {e}")
raise ValueError(f"Failed to process dataset: {e}")
[docs] def _get_labels(self) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Extract labels from the DataFrame based on current configuration.
This allows labels to be extracted fresh without being cached with features.
Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
"""
try:
# Process the DataFrame same way as in process()
if self.wsl:
processed_data = wsl_preprocess(self.raw_data)
else:
processed_data = self.raw_data.copy()
# Filter by split type (unless split is "all")
if self.split != "all":
processed_data = filter_split(processed_data, self.split)
# Extract labels for valid slides as a dictionary
labels: Dict[str, Union[int, Tuple[float, int]]] = {}
for _, row in processed_data.iterrows():
try:
slide_name, label_value = preprocess_row(
row,
self.label,
self.folder,
self.extractor,
self.graph_creator,
self.segmentation_model,
do_validate_features=False,
)
if slide_name in self.all_slides and label_value is not None:
labels[slide_name] = label_value
except Exception as e:
logger.warning(f"Error extracting label: {e}")
return labels
except Exception as e:
logger.error(f"Failed to extract labels: {e}")
raise
[docs] def _load_slide_features(self, slide_name: str) -> Optional[torch.Tensor]:
"""
Load raw features for a single slide without applying transforms.
Args:
slide_name: Name of the slide to process
Returns:
Raw features tensor or None if loading fails
"""
try:
# Load features from one or many extractors and concatenate column-wise
features, cell_indices, _ = get_cell_features(
self.folder,
slide_name,
self.extractor,
self.graph_creator,
self.segmentation_model,
)
if cell_indices is None or features is None:
raise ValueError("Missing features or cell_indices")
# Store cell indices for potential future use
self.cell_indices[slide_name] = cell_indices
# Store cell types as a separate tensor if requested
if self.cell_type:
if len(cell_indices) == 0:
logger.warning(
f"Missing cell_indices for slide {slide_name}; adding zero cell types."
)
n_cell_types = len(TYPE_NUCLEI_DICT)
cell_types_onehot = torch.zeros(
features.size(0), n_cell_types, dtype=torch.float32
)
else:
cell_types_onehot = self._get_cell_types_tensor(
slide_name, cell_indices
)
# Store cell types separately instead of concatenating
self.cell_types_tensors[slide_name] = cell_types_onehot
return features
except Exception as e:
logger.error(f"Error loading slide {slide_name}: {e}")
return None
[docs] def _get_cell_types_tensor(
self, slide_name: str, cell_indices: Dict[int, int]
) -> torch.Tensor:
"""
Get cell types for a given slide and convert to one-hot encoding.
Args:
slide_name: Name of the slide
cell_indices: Dictionary mapping cell_id to tensor index
Returns:
Tensor of shape (n_cells, n_cell_types) containing one-hot encoded cell types
"""
n_cells = len(cell_indices)
n_cell_types = len(TYPE_NUCLEI_DICT)
# Get cached data from the dictionary cache
cell_types = self.cell_types.get(slide_name)
if cell_types is None:
logger.warning(f"No cached cell types found for slide {slide_name}")
cell_types_onehot = torch.zeros(n_cells, n_cell_types, dtype=torch.float32)
elif len(cell_types) == 0:
# Empty dict means no cell types available
cell_types_onehot = torch.zeros(n_cells, n_cell_types, dtype=torch.float32)
else:
# Convert cached dict to one-hot tensor
cell_types_onehot = cell_types_to_tensor(cell_types, cell_indices)
return cell_types_onehot
[docs] def _save_data(self, path: Path) -> None:
"""Save data to disk (all slides with features, label-independent)."""
data_dict: dict[str, Any] = {
"all_slides": self.all_slides,
"cell_types": self.cell_types,
"features": self.features,
"cell_types_tensors": self.cell_types_tensors,
"cell_indices": self.cell_indices,
}
torch.save(data_dict, path)
logger.info(f"Saved dataset to {path}")
[docs] def _load_data(self, path: Path) -> None:
"""Load data from disk and extract labels for current task."""
data_dict = torch.load(path, map_location="cpu", weights_only=False)
self.all_slides = data_dict["all_slides"]
self.cell_types = data_dict["cell_types"]
self.features = data_dict.get("features", {})
self.cell_types_tensors = data_dict.get("cell_types_tensors", {})
self.cell_indices = data_dict.get("cell_indices", {})
# Extract labels for current task and filter slides
self.labels = self._get_labels()
# Apply label transform if provided (e.g., discretize survival times)
if self.label_transforms is not None and isinstance(self.label, tuple):
# For survival tasks, apply discretization
self.labels = self.label_transforms.transform_labels(self.labels)
self.slides = [slide for slide in self.all_slides if slide in self.labels]
logger.info(
f"Loaded {len(self.all_slides)} total slides, {len(self.slides)} with labels for current task"
)
[docs] def _apply_roi_filtering(self) -> None:
"""
Apply ROI filtering to the loaded dataset.
This filters cells in each slide to keep only those within ROI boundaries.
Creates roi_filtered versions of features, cell_types_tensors, and cell_indices.
"""
logger.info("Applying ROI filtering to loaded dataset...")
if self.segmentation_model is None:
raise ValueError("segmentation_model must be specified for ROI filtering")
if self.roi_folder is None:
raise ValueError("roi_folder must be specified for ROI filtering")
slides_to_remove: list[str] = []
total_cells_before = 0
total_cells_after = 0
for slide_name in tqdm(self.slides, desc="Applying ROI filtering"):
# Load ROI for this slide
roi_df = load_roi_for_slide(slide_name, self.roi_folder, self.raw_data)
if roi_df is None:
logger.warning(
f"No ROI found for slide {slide_name}, removing slide from dataset"
)
slides_to_remove.append(slide_name)
continue
# Get cell centroids
centroids = get_centroids(self.folder, slide_name, self.segmentation_model)
if centroids is None:
logger.warning(
f"Could not load centroids for slide {slide_name}, keeping all cells"
)
continue
# Filter cells by ROI
cells_to_keep = filter_cells_by_roi(centroids, roi_df)
total_cells_before += len(centroids)
total_cells_after += len(cells_to_keep)
if len(cells_to_keep) == 0:
logger.warning(
f"No cells within ROI for slide {slide_name}, removing slide from dataset"
)
slides_to_remove.append(slide_name)
continue
# Get original cell_indices for this slide
cell_indices = self.cell_indices[slide_name]
# Filter cell_indices to keep only cells within ROI
filtered_cell_indices = {
cell_id: idx
for cell_id, idx in cell_indices.items()
if cell_id in cells_to_keep
}
# Create a mapping from old indices to new indices
old_to_new_idx = {
old_idx: new_idx
for new_idx, old_idx in enumerate(
sorted(filtered_cell_indices.values())
)
}
# Filter features tensor
indices_to_keep = sorted(filtered_cell_indices.values())
self.features[slide_name] = self.features[slide_name][indices_to_keep]
# Update cell_indices with new sequential indices
self.cell_indices[slide_name] = {
cell_id: old_to_new_idx[old_idx]
for cell_id, old_idx in filtered_cell_indices.items()
}
# Filter cell types if they exist
if slide_name in self.cell_types_tensors:
self.cell_types_tensors[slide_name] = self.cell_types_tensors[
slide_name
][indices_to_keep]
logger.debug(
f"ROI filtering for {slide_name}: kept {len(cells_to_keep)}/{len(centroids)} cells "
f"({len(cells_to_keep) / len(centroids) * 100:.1f}%)"
)
# Remove slides with no cells in ROI
for slide_name in slides_to_remove:
self.slides.remove(slide_name)
if slide_name in self.labels:
del self.labels[slide_name]
logger.info(
f"ROI filtering complete: kept {total_cells_after}/{total_cells_before} cells "
f"({total_cells_after / total_cells_before * 100:.1f}%) across {len(self.slides)} slides"
)
if slides_to_remove:
logger.info(f"Removed {len(slides_to_remove)} slides with no cells in ROI")
[docs] def get_num_labels(self) -> int:
"""
Get the number of unique labels in the dataset.
Note: For survival prediction tasks, this returns 0 as there are no discrete classes.
"""
# Check if we have survival data by looking at the first label
if self.labels:
first_label = next(iter(self.labels.values()))
if isinstance(first_label, tuple):
# Survival data - no discrete classes
return 0
# Classification data
return len(set(self.labels.values()))
[docs] def get_weights_for_sampler(self) -> torch.Tensor:
"""
Get weights for WeightedRandomSampler to handle class imbalance.
Note: Only applicable for classification tasks. For survival prediction,
returns uniform weights.
"""
# Convert dictionary values to list in the same order as slides
labels_list = [self.labels[slide] for slide in self.slides]
# Check if we have survival data
if labels_list and isinstance(labels_list[0], tuple):
# Survival data - return uniform weights
logger.warning("Uniform weights used for survival prediction tasks")
return torch.ones(len(labels_list), dtype=torch.float32)
# Classification data
return weights_for_sampler(labels_list) # type: ignore
[docs] def get_config(self) -> dict[str, Any]:
"""Get dataset configuration as a dictionary."""
config: dict[str, Any] = {
"dataset_type": self.__class__.__name__,
"label": str(self.label),
"extractor": str(self.extractor),
"split": self.split,
"cell_type": self.cell_type,
"return_cell_types": self.return_cell_types,
}
if self.graph_creator is not None:
config["graph_creator"] = str(self.graph_creator)
if self.segmentation_model is not None:
config["segmentation_model"] = str(self.segmentation_model)
if self.cell_types_to_keep_indices is not None:
config["cell_types_to_keep_indices"] = self.cell_types_to_keep_indices
if self.roi_folder is not None:
config["roi_folder"] = str(self.roi_folder)
return config
[docs] def __len__(self) -> int:
"""
Return the number of samples in the dataset.
"""
return len(self.slides)
[docs] def __getitem__(
self, idx: int
) -> Union[
Tuple[torch.Tensor, int | Tuple[float, int]],
Tuple[torch.Tensor, torch.Tensor, int | Tuple[float, int]],
]:
"""
Get a sample from the dataset.
Args:
idx: Index of the sample to retrieve
Returns:
For classification tasks:
If cell_type=False or return_cell_types=False:
Tuple of (features, label) where features is a tensor of shape (n_instances, n_features)
If cell_type=True and return_cell_types=True:
Tuple of (features, cell_types, label) where:
- features is a tensor of shape (n_instances, n_features)
- cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types
- label is the sample label (int)
For survival prediction tasks:
If cell_type=False or return_cell_types=False:
Tuple of (features, (duration, event)) where features is a tensor and (duration, event) is survival data
If cell_type=True and return_cell_types=True:
Tuple of (features, cell_types, (duration, event)) where:
- features is a tensor of shape (n_instances, n_features)
- cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types
- (duration, event) is the survival data tuple
"""
slide_name = self.slides[idx]
# Get raw features from cache
features = self.features[
slide_name
].clone() # Clone to avoid modifying cached data
label = self.labels[slide_name]
# Handle cell type filtering if enabled
if self.cell_type:
# Get cell types tensor
if slide_name in self.cell_types_tensors:
cell_types = self.cell_types_tensors[slide_name].clone()
else:
raise ValueError(f"Missing cell types tensor for slide {slide_name}")
# Filter by cell types if requested
if self.cell_types_to_keep_indices is not None:
# Get the cell type for each cell (argmax of one-hot encoding)
# Note: cell_types tensor uses 0-based indexing (TYPE_NUCLEI_DICT keys are 1-based)
cell_type_tensor_indices = torch.argmax(
cell_types, dim=1
) # (n_cells,) values 0-4
# Convert 1-based TYPE_NUCLEI_DICT indices to 0-based tensor indices
tensor_indices_to_keep = [
idx - 1 for idx in self.cell_types_to_keep_indices
]
# Create mask for cells to keep
mask = torch.zeros(cell_type_tensor_indices.shape[0], dtype=torch.bool)
for tensor_idx in tensor_indices_to_keep:
mask |= cell_type_tensor_indices == tensor_idx
# Apply mask to filter both features and cell_types
features = features[mask]
cell_types = cell_types[mask]
# Apply transforms after filtering
if self.transforms is not None:
features = self.transforms.transform(features)
# Return based on return_cell_types parameter
if self.return_cell_types:
return features, cell_types, label
else:
return features, label
else:
# Apply transforms if provided
if self.transforms is not None:
features = self.transforms.transform(features)
return features, label
[docs] def create_subset(self, indices: List[int]) -> "CellMILDataset":
"""
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split="all".
The subset will share the same cached features but only include the specified samples.
Args:
indices: List of indices to include in the subset
Returns:
New CellMILDataset instance containing only the specified samples
Raises:
ValueError: If any index is out of range
"""
# Allow empty indices for validation-less training (e.g., final model on all data)
if not indices:
# Create empty subset
subset = CellMILDataset.__new__(CellMILDataset)
subset.root = self.root
subset.label = self.label
subset.folder = self.folder
subset.raw_data = self.raw_data
subset.extractor = self.extractor
subset.graph_creator = self.graph_creator
subset.segmentation_model = self.segmentation_model
subset.split = self.split
subset.cell_type = self.cell_type
subset.return_cell_types = self.return_cell_types
subset.cell_types_to_keep_indices = self.cell_types_to_keep_indices
subset.max_workers = self.max_workers
subset.force_reload = self.force_reload
subset.transforms = self.transforms
subset.wsl = self.wsl
subset.slides = []
subset.all_slides = self.all_slides
subset.cell_types = self.cell_types
subset.features = self.features
subset.cell_types_tensors = self.cell_types_tensors
subset.cell_indices = self.cell_indices
subset.labels = {}
logger.info("Created empty subset (0 samples)")
return subset
max_idx = len(self.slides) - 1
invalid_indices = [idx for idx in indices if idx < 0 or idx > max_idx]
if invalid_indices:
raise ValueError(
f"Invalid indices {invalid_indices}. Valid range is 0-{max_idx}"
)
# Create a new instance with the same configuration
subset = CellMILDataset.__new__(CellMILDataset)
# Copy configuration
subset.root = self.root
subset.label = self.label
subset.folder = self.folder
subset.raw_data = self.raw_data
subset.extractor = self.extractor
subset.graph_creator = self.graph_creator
subset.segmentation_model = self.segmentation_model
subset.split = self.split
subset.cell_type = self.cell_type
subset.return_cell_types = self.return_cell_types
subset.cell_types_to_keep_indices = self.cell_types_to_keep_indices
subset.max_workers = self.max_workers
subset.force_reload = self.force_reload
subset.transforms = self.transforms
subset.wsl = self.wsl
# Create subset data
subset.slides = [self.slides[i] for i in indices]
subset.all_slides = self.all_slides # Share the same all_slides
subset.cell_types = self.cell_types # Share cell types
# Share the same features cache (no need to copy)
subset.features = self.features
subset.cell_types_tensors = self.cell_types_tensors
subset.cell_indices = self.cell_indices
# Extract labels fresh for the subset
subset.labels = subset._get_labels()
logger.info(
f"Created subset with {len(subset.slides)} samples from {len(self.slides)} total samples"
)
return subset
[docs] def create_train_val_datasets(
self,
train_indices: List[int],
val_indices: List[int],
transforms: Optional[Union[TransformPipeline, Transform]] = None,
label_transforms: Optional[LabelTransform | LabelTransformPipeline] = None,
) -> Tuple["CellMILDataset", "CellMILDataset"]:
"""
Create train and validation datasets with transforms fitted only on training data.
This method prevents data leakage by ensuring that any fittable transforms
(like normalization or feature selection) are fitted only on the training set
and then applied to both train and validation sets.
Args:
train_indices: List of indices for training set
val_indices: List of indices for validation set
transforms: Optional transforms to apply. If provided, any FittableTransform
will be fitted on training data only
label_transforms: Optional label transform (e.g., TimeDiscretizerTransform) to apply.
Will be fitted on training labels only
Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
Raises:
ValueError: If indices lists are empty or contain invalid indices
"""
if not train_indices:
raise ValueError("train_indices cannot be empty")
# Allow empty val_indices for training final model on all data
# Validate indices
max_idx = len(self.slides) - 1
invalid_train = [idx for idx in train_indices if idx < 0 or idx > max_idx]
invalid_val = [idx for idx in val_indices if idx < 0 or idx > max_idx]
if invalid_train:
raise ValueError(
f"Invalid train indices {invalid_train}. Valid range is 0-{max_idx}"
)
if invalid_val:
raise ValueError(
f"Invalid val indices {invalid_val}. Valid range is 0-{max_idx}"
)
logger.info(
f"Creating train/val datasets with {len(train_indices)} train and {len(val_indices)} val samples"
)
# Fit label transform on training labels if provided (before feature transforms)
if label_transforms is not None and isinstance(self.label, tuple):
logger.info("Fitting label transform on training survival data...")
# Get training labels
train_labels = {
self.slides[idx]: self.labels[self.slides[idx]] for idx in train_indices
}
# Fit the label transform on training labels
label_transforms.fit(train_labels) # type: ignore
logger.info(f"Label transform fitted with {label_transforms.n_bins} bins") # type: ignore
# If no transforms provided, create simple subsets
if transforms is None and label_transforms is None:
train_dataset = self.create_subset(train_indices)
val_dataset = self.create_subset(val_indices)
return train_dataset, val_dataset
# Always fit fittable transforms on training data, regardless of is_fitted
logger.info("Fitting transforms on training data to prevent data leakage...")
# Collect training features for fitting
all_train_features: List[torch.Tensor] = []
max_samples_per_slide = 100_000
sample_size = min(len(train_indices), 20)
indices = random.sample(train_indices, sample_size)
for idx in indices:
try:
result = self[idx]
features = result[0]
# Subsample if needed to prevent memory issues
if features.shape[0] > max_samples_per_slide:
perm_indices = torch.randperm(features.shape[0])[
:max_samples_per_slide
]
features = features[perm_indices]
all_train_features.append(features)
except Exception as e:
logger.warning(
f"Failed to load sample {idx} for transform fitting: {e}"
)
continue
if not all_train_features:
raise ValueError("No valid training samples found for fitting transforms")
# Concatenate all training features
combined_train_features = torch.cat(all_train_features, dim=0)
logger.info(
f"Fitting transforms on {combined_train_features.shape[0]} training instances with {combined_train_features.shape[1]} features"
)
# Fit the transforms if they are fittable
if isinstance(transforms, (TransformPipeline, FittableTransform)):
transforms.fit(combined_train_features)
else:
# If not fittable, just assign
pass
# Create datasets with fitted transforms
train_dataset = self.create_subset(train_indices)
val_dataset = self.create_subset(val_indices)
train_dataset.transforms = transforms
val_dataset.transforms = transforms
# Apply label transform to both datasets if provided
if label_transforms is not None:
train_dataset.label_transforms = label_transforms
val_dataset.label_transforms = label_transforms
# Apply label transformation to the labels
if isinstance(self.label, tuple):
train_dataset.labels = label_transforms.transform_labels(
train_dataset.labels
)
val_dataset.labels = label_transforms.transform_labels(
val_dataset.labels
)
logger.info("Successfully created train/val datasets with fitted transforms")
logger.info(f"Train dataset: {len(train_dataset)} samples")
logger.info(f"Val dataset: {len(val_dataset)} samples")
return train_dataset, val_dataset