cellmil.datamodels.datasets.cell_mil_dataset¶
Classes
|
An PyTorch Dataset for MIL (Multiple Instance Learning) tasks. |
- class cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, roi_folder: Optional[Path] = None, max_workers: int = 8, force_reload: bool = False)[source]¶
Bases:
Dataset[Union[Tuple[Tensor,Union[int,Tuple[float,int]]],Tuple[Tensor,Tensor,Union[int,Tuple[float,int]]]]]An PyTorch Dataset for MIL (Multiple Instance Learning) tasks.
- Returns:
- When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, int] (features, label)
When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, int] (features, cell_types, label)
- For survival prediction tasks:
When cell_type=False or return_cell_types=False: Tuple[torch.Tensor, Tuple[float, int]] (features, (duration, event)) When cell_type=True and return_cell_types=True: Tuple[torch.Tensor, torch.Tensor, Tuple[float, int]] (features, cell_types, (duration, event))
- Return type:
For classification tasks
- __init__(root: Union[str, Path], label: Union[str, Tuple[str, str]], folder: Path, data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, split: Literal['train', 'val', 'test', 'all'] = 'all', transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, roi_folder: Optional[Path] = None, max_workers: int = 8, force_reload: bool = False)[source]¶
Initialize the optimized MIL dataset.
- Parameters:
root – Root directory where the processed dataset will be cached
label – Label for the dataset. Either: - A single string (e.g., “dcr”) for classification tasks - A tuple of two strings (e.g., (“duration”, “event”)) for survival prediction tasks
folder – Path to the dataset folder
data – DataFrame containing metadata. If roi_folder is provided, must contain ‘ID’, ‘I3LUNG_ID’, and ‘CENTER’ columns.
extractor – Feature extractor type or list of types to use for feature extraction.
graph_creator – Optional graph creator type, needed for some extractors
segmentation_model – Optional Segmentation model type, needed for some extractors
split – Dataset split (train/val/test/all). Use “all” to include all data regardless of split
cell_type – Whether to add cell types as one-hot encoded columns to the feature tensor. Only available for ‘cellvit’ and ‘hovernet’ segmentation models.
cell_types_to_keep – Optional list of cell type names to keep (e.g., [“Neoplastic”, “Connective”]). Valid names: “Neoplastic”, “Inflammatory”, “Connective”, “Dead”, “Epithelial” (case-insensitive). If provided, only cells of these types will be included. Requires cell_type=True.
roi_folder – Optional path to the directory containing ROI CSV files organized by center folders. If provided, cells will be filtered to only include those within ROI boundaries. Requires ‘ID’, ‘I3LUNG_ID’, and ‘CENTER’ columns in the data DataFrame. Slides with no cells matching these types will be excluded from the dataset.
return_cell_types – Whether to return cell types in __getitem__. If False, only returns (features, label). If True and cell_type=True, returns (features, cell_types, label). Default is True for backward compatibility.
max_workers – Maximum number of threads for parallel processing
force_reload – Whether to force reprocessing even if processed files exist
transforms – Optional TransformPipeline to apply to features at getitem time
label_transform – Optional transform to apply to labels (e.g., TimeDiscretizerTransform for binning survival times)
- _get_labels() Dict[str, Union[int, Tuple[float, int]]][source]¶
Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.
- Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
- _load_slide_features(slide_name: str) Optional[Tensor][source]¶
Load raw features for a single slide without applying transforms.
- Parameters:
slide_name – Name of the slide to process
- Returns:
Raw features tensor or None if loading fails
- _get_cell_types_tensor(slide_name: str, cell_indices: Dict[int, int]) Tensor[source]¶
Get cell types for a given slide and convert to one-hot encoding.
- Parameters:
slide_name – Name of the slide
cell_indices – Dictionary mapping cell_id to tensor index
- Returns:
Tensor of shape (n_cells, n_cell_types) containing one-hot encoded cell types
- _save_data(path: Path) None[source]¶
Save data to disk (all slides with features, label-independent).
- _apply_roi_filtering() None[source]¶
Apply ROI filtering to the loaded dataset. This filters cells in each slide to keep only those within ROI boundaries. Creates roi_filtered versions of features, cell_types_tensors, and cell_indices.
- get_num_labels() int[source]¶
Get the number of unique labels in the dataset.
Note: For survival prediction tasks, this returns 0 as there are no discrete classes.
- get_weights_for_sampler() Tensor[source]¶
Get weights for WeightedRandomSampler to handle class imbalance.
Note: Only applicable for classification tasks. For survival prediction, returns uniform weights.
- __getitem__(idx: int) Union[Tuple[Tensor, Union[int, Tuple[float, int]]], Tuple[Tensor, Tensor, Union[int, Tuple[float, int]]]][source]¶
Get a sample from the dataset.
- Parameters:
idx – Index of the sample to retrieve
- Returns:
- If cell_type=False or return_cell_types=False:
Tuple of (features, label) where features is a tensor of shape (n_instances, n_features)
- If cell_type=True and return_cell_types=True:
Tuple of (features, cell_types, label) where: - features is a tensor of shape (n_instances, n_features) - cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types - label is the sample label (int)
- For survival prediction tasks:
- If cell_type=False or return_cell_types=False:
Tuple of (features, (duration, event)) where features is a tensor and (duration, event) is survival data
- If cell_type=True and return_cell_types=True:
Tuple of (features, cell_types, (duration, event)) where: - features is a tensor of shape (n_instances, n_features) - cell_types is a tensor of shape (n_instances, n_cell_types) with one-hot encoded cell types - (duration, event) is the survival data tuple
- Return type:
For classification tasks
- create_subset(indices: List[int]) CellMILDataset[source]¶
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split=”all”. The subset will share the same cached features but only include the specified samples.
- Parameters:
indices – List of indices to include in the subset
- Returns:
New CellMILDataset instance containing only the specified samples
- Raises:
ValueError – If any index is out of range
- create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[CellMILDataset, CellMILDataset][source]¶
Create train and validation datasets with transforms fitted only on training data.
This method prevents data leakage by ensuring that any fittable transforms (like normalization or feature selection) are fitted only on the training set and then applied to both train and validation sets.
- Parameters:
train_indices – List of indices for training set
val_indices – List of indices for validation set
transforms – Optional transforms to apply. If provided, any FittableTransform will be fitted on training data only
label_transforms – Optional label transform (e.g., TimeDiscretizerTransform) to apply. Will be fitted on training labels only
- Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
- Raises:
ValueError – If indices lists are empty or contain invalid indices