cellmil.datamodels.datasets.patch_mil_dataset¶
Classes
|
An optimized PyTorch Dataset for Patch-based MIL (Multiple Instance Learning) tasks. |
- class cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]¶
Bases:
Dataset[Tuple[Tensor,Union[int,Tuple[float,int]]]]An optimized PyTorch Dataset for Patch-based MIL (Multiple Instance Learning) tasks.
This dataset follows PyTorch best practices by preprocessing all data once during initialization and storing it efficiently. This provides significant speed improvements over the previous implementation by avoiding repeated feature loading in __getitem__.
- Returns:
Tuple[torch.Tensor, int] (features, label) For survival prediction tasks: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event))
- Return type:
For classification tasks
- __init__(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]¶
Initialize the optimized Patch MIL dataset.
- Parameters:
root – Root directory where the processed dataset will be cached
label – Label for the dataset. Either: - A single string (e.g., “dcr”) for classification tasks - A tuple of two strings (e.g., (“duration”, “event”)) for survival prediction tasks
folder – Path to the dataset folder
data – DataFrame containing metadata
extractor – Feature extractor type (must be embedding type)
split – Dataset split (train/val/test/all). Use “all” to include all data regardless of split
force_reload – Whether to force reprocessing even if processed files exist
transforms – Optional Transform or TransformPipeline to apply to features before returning them
label_transforms – Optional LabelTransform or LabelTransformPipeline to apply to labels (e.g., TimeDiscretizerTransform for survival analysis)
- get_num_labels() int[source]¶
Get the number of unique labels in the dataset.
Note: For survival prediction tasks, this returns 0 as there are no discrete classes.
- _get_labels() Dict[str, Union[int, Tuple[float, int]]][source]¶
Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.
- Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
- _preprocess_slide_features(slide_name: str) Optional[Tensor][source]¶
Preprocess features for a single slide.
- Parameters:
slide_name – Name of the slide to process
- Returns:
Preprocessed features tensor or None if processing fails
- _save_processed_data(path: Path) None[source]¶
Save preprocessed data to disk (all slides with features, label-independent).
- _load_processed_data(path: Path) None[source]¶
Load preprocessed data from disk and extract labels for current task.
- get_weights_for_sampler() Tensor[source]¶
Get weights for WeightedRandomSampler to handle class imbalance.
Note: Only applicable for classification tasks. For survival prediction, returns uniform weights.
- create_subset(indices: List[int]) PatchMILDataset[source]¶
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split=”all”. The subset will share the same cached features but only include the specified samples.
- Parameters:
indices – List of indices to include in the subset
- Returns:
New PatchMILDataset instance containing only the specified samples
- Raises:
ValueError – If any index is out of range
- __getitem__(index: int) Tuple[Tensor, Union[int, Tuple[float, int]]][source]¶
Get a sample from the dataset.
- Parameters:
index – Index of the sample to retrieve
- Returns:
- Tuple of (features, label) where features is a tensor of shape (n_patches, n_features)
and label is an int.
- For survival prediction tasks:
Tuple of (features, (duration, event)) where features is a tensor of shape (n_patches, n_features), duration is a float, and event is an int.
- Return type:
For classification tasks
- get_normalization_params() None[source]¶
Return None for compatibility - patch datasets don’t use normalization.
- get_correlation_mask() None[source]¶
Return None for compatibility - patch datasets don’t use correlation filtering.
- create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[PatchMILDataset, PatchMILDataset][source]¶
Create train and validation datasets with transforms fitted on training data only.
This function prevents data leakage by ensuring transforms are fitted only on the training set before being applied to both training and validation sets.
- Parameters:
train_indices – List of indices for training data
val_indices – List of indices for validation data
transforms – Optional Transform or TransformPipeline for features (fitted on training data)
label_transforms – Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)
- Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
- Raises:
ValueError – If indices are invalid or transforms cannot be fitted