cellmil.datamodels.datasets.cell_mil_dataset

Classes

CellMILDataset(root, label, folder, data, ...)

An PyTorch Dataset for MIL (Multiple Instance Learning) tasks.

class cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, roi_folder: Optional[Path] = None, max_workers: int = 8, force_reload: bool = False)[source]

Bases: Dataset[Union[Tuple[Tensor, Union[int, Tuple[float, int]]], Tuple[Tensor, Tensor, Union[int, Tuple[float, int]]]]]

An PyTorch Dataset for MIL (Multiple Instance Learning) tasks.

Returns:

When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, int] (features, label)

When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, int] (features, cell_types, label)

For survival prediction tasks:

When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event)) When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, Tuple[float, int]] (features, cell_types, (duration, event))

Return type:

For classification tasks

__init__(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, roi_folder: Optional[Path] = None, max_workers: int = 8, force_reload: bool = False)[source]

Initialize the optimized MIL dataset.

Parameters:
  • root – Root directory where the processed dataset will be cached

  • label – Label for the dataset. Either: - A single string (e.g., “dcr”) for classification tasks - A tuple of two strings (e.g., (“duration”, “event”)) for survival prediction tasks

  • folder – Path to the dataset folder

  • data – DataFrame containing metadata. If roi_folder is provided, must contain ‘ID’, ‘I3LUNG_ID’, and ‘CENTER’ columns.

  • extractor – Feature extractor type or list of types to use for feature extraction.

  • graph_creator – Optional graph creator type, needed for some extractors

  • segmentation_model – Optional Segmentation model type, needed for some extractors

  • split – Dataset split (train/val/test/all). Use “all” to include all data regardless of split

  • cell_type – Whether to add cell types as one-hot encoded columns to the feature tensor. Only available for ‘cellvit’ and ‘hovernet’ segmentation models.

  • cell_types_to_keep – Optional list of cell type names to keep (e.g., [“Neoplastic”, “Connective”]). Valid names: “Neoplastic”, “Inflammatory”, “Connective”, “Dead”, “Epithelial” (case-insensitive). If provided, only cells of these types will be included. Requires cell_type=True.

  • roi_folder – Optional path to the directory containing ROI CSV files organized by center folders. If provided, cells will be filtered to only include those within ROI boundaries. Requires ‘ID’, ‘I3LUNG_ID’, and ‘CENTER’ columns in the data DataFrame. Slides with no cells matching these types will be excluded from the dataset.

  • return_cell_types – Whether to return cell types in __getitem__. If False, only returns (features, label). If True and cell_type=True, returns (features, cell_types, label). Default is True for backward compatibility.

  • max_workers – Maximum number of threads for parallel processing

  • force_reload – Whether to force reprocessing even if processed files exist

  • transforms – Optional TransformPipeline to apply to features at getitem time

  • label_transform – Optional transform to apply to labels (e.g., TimeDiscretizerTransform for binning survival times)

_get_data_path() Path[source]

Get the path for the processed dataset file.

_process_dataset() None[source]

Process the entire dataset and cache results.

_get_labels() Dict[str, Union[int, Tuple[float, int]]][source]

Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.

Returns:

Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)

_load_slide_features(slide_name: str) Optional[Tensor][source]

Load raw features for a single slide without applying transforms.

Parameters:

slide_name – Name of the slide to process

Returns:

Raw features tensor or None if loading fails

_get_cell_types_tensor(slide_name: str, cell_indices: Dict[int, int]) Tensor[source]

Get cell types for a given slide and convert to one-hot encoding.

Parameters:
  • slide_name – Name of the slide

  • cell_indices – Dictionary mapping cell_id to tensor index

Returns:

Tensor of shape (n_cells, n_cell_types) containing one-hot encoded cell types

_save_data(path: Path) None[source]

Save data to disk (all slides with features, label-independent).

_load_data(path: Path) None[source]

Load data from disk and extract labels for current task.

_apply_roi_filtering() None[source]

Apply ROI filtering to the loaded dataset. This filters cells in each slide to keep only those within ROI boundaries. Creates roi_filtered versions of features, cell_types_tensors, and cell_indices.

get_num_labels() int[source]

Get the number of unique labels in the dataset.

Note: For survival prediction tasks, this returns 0 as there are no discrete classes.

get_weights_for_sampler() Tensor[source]

Get weights for WeightedRandomSampler to handle class imbalance.

Note: Only applicable for classification tasks. For survival prediction, returns uniform weights.

get_config() dict[str, Any][source]

Get dataset configuration as a dictionary.

__len__() int[source]

Return the number of samples in the dataset.

__getitem__(idx: int) Union[Tuple[Tensor, Union[int, Tuple[float, int]]], Tuple[Tensor, Tensor, Union[int, Tuple[float, int]]]][source]

Get a sample from the dataset.

Parameters:

idx – Index of the sample to retrieve

Returns:

If cell_type=False or return_cell_types=False:

Tuple of (features, label) where features is a tensor of shape (n_instances, n_features)

If cell_type=True and return_cell_types=True:

Tuple of (features, cell_types, label) where: - features is a tensor of shape (n_instances, n_features) - cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types - label is the sample label (int)

For survival prediction tasks:
If cell_type=False or return_cell_types=False:

Tuple of (features, (duration, event)) where features is a tensor and (duration, event) is survival data

If cell_type=True and return_cell_types=True:

Tuple of (features, cell_types, (duration, event)) where: - features is a tensor of shape (n_instances, n_features) - cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types - (duration, event) is the survival data tuple

Return type:

For classification tasks

create_subset(indices: List[int]) CellMILDataset[source]

Create a subset of the dataset using the specified indices.

This is useful for creating train/val/test splits when using split=”all”. The subset will share the same cached features but only include the specified samples.

Parameters:

indices – List of indices to include in the subset

Returns:

New CellMILDataset instance containing only the specified samples

Raises:

ValueError – If any index is out of range

create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[CellMILDataset, CellMILDataset][source]

Create train and validation datasets with transforms fitted only on training data.

This method prevents data leakage by ensuring that any fittable transforms (like normalization or feature selection) are fitted only on the training set and then applied to both train and validation sets.

Parameters:
  • train_indices – List of indices for training set

  • val_indices – List of indices for validation set

  • transforms – Optional transforms to apply. If provided, any FittableTransform will be fitted on training data only

  • label_transforms – Optional label transform (e.g., TimeDiscretizerTransform) to apply. Will be fitted on training labels only

Returns:

Tuple of (train_dataset, val_dataset) with properly fitted transforms

Raises:

ValueError – If indices lists are empty or contain invalid indices