cellmil.datamodels.datasets.patch_mil_dataset

Classes

PatchMILDataset(root, label, folder, data, ...)

An optimized PyTorch Dataset for Patch-based MIL (Multiple Instance Learning) tasks.

class cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]

Bases: Dataset[Tuple[Tensor, Union[int, Tuple[float, int]]]]

An optimized PyTorch Dataset for Patch-based MIL (Multiple Instance Learning) tasks.

This dataset follows PyTorch best practices by preprocessing all data once during initialization and storing it efficiently. This provides significant speed improvements over the previous implementation by avoiding repeated feature loading in __getitem__.

Returns:

Tuple[torch.Tensor, int] (features, label) For survival prediction tasks: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event))

Return type:

For classification tasks

__init__(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]

Initialize the optimized Patch MIL dataset.

Parameters:
  • root – Root directory where the processed dataset will be cached

  • label – Label for the dataset. Either: - A single string (e.g., “dcr”) for classification tasks - A tuple of two strings (e.g., (“duration”, “event”)) for survival prediction tasks

  • folder – Path to the dataset folder

  • data – DataFrame containing metadata

  • extractor – Feature extractor type (must be embedding type)

  • split – Dataset split (train/val/test/all). Use “all” to include all data regardless of split

  • force_reload – Whether to force reprocessing even if processed files exist

  • transforms – Optional Transform or TransformPipeline to apply to features before returning them

  • label_transforms – Optional LabelTransform or LabelTransformPipeline to apply to labels (e.g., TimeDiscretizerTransform for survival analysis)

get_num_labels() int[source]

Get the number of unique labels in the dataset.

Note: For survival prediction tasks, this returns 0 as there are no discrete classes.

_get_processed_path() Path[source]

Get the path for the processed dataset file.

_process_dataset() None[source]

Process the entire dataset and cache results.

_get_labels() Dict[str, Union[int, Tuple[float, int]]][source]

Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.

Returns:

Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)

_preprocess_slide_features(slide_name: str) Optional[Tensor][source]

Preprocess features for a single slide.

Parameters:

slide_name – Name of the slide to process

Returns:

Preprocessed features tensor or None if processing fails

_save_processed_data(path: Path) None[source]

Save preprocessed data to disk (all slides with features, label-independent).

_load_processed_data(path: Path) None[source]

Load preprocessed data from disk and extract labels for current task.

get_weights_for_sampler() Tensor[source]

Get weights for WeightedRandomSampler to handle class imbalance.

Note: Only applicable for classification tasks. For survival prediction, returns uniform weights.

get_config() dict[str, Any][source]

Get dataset configuration as a dictionary.

create_subset(indices: List[int]) PatchMILDataset[source]

Create a subset of the dataset using the specified indices.

This is useful for creating train/val/test splits when using split=”all”. The subset will share the same cached features but only include the specified samples.

Parameters:

indices – List of indices to include in the subset

Returns:

New PatchMILDataset instance containing only the specified samples

Raises:

ValueError – If any index is out of range

__getitem__(index: int) Tuple[Tensor, Union[int, Tuple[float, int]]][source]

Get a sample from the dataset.

Parameters:

index – Index of the sample to retrieve

Returns:

Tuple of (features, label) where features is a tensor of shape (n_patches, n_features)

and label is an int.

For survival prediction tasks:

Tuple of (features, (duration, event)) where features is a tensor of shape (n_patches, n_features), duration is a float, and event is an int.

Return type:

For classification tasks

get_normalization_params() None[source]

Return None for compatibility - patch datasets don’t use normalization.

get_correlation_mask() None[source]

Return None for compatibility - patch datasets don’t use correlation filtering.

create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[PatchMILDataset, PatchMILDataset][source]

Create train and validation datasets with transforms fitted on training data only.

This function prevents data leakage by ensuring transforms are fitted only on the training set before being applied to both training and validation sets.

Parameters:
  • train_indices – List of indices for training data

  • val_indices – List of indices for validation data

  • transforms – Optional Transform or TransformPipeline for features (fitted on training data)

  • label_transforms – Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)

Returns:

Tuple of (train_dataset, val_dataset) with properly fitted transforms

Raises:

ValueError – If indices are invalid or transforms cannot be fitted