cellmil.datamodels.datasets.patch_gnn_mil_dataset¶
Classes
|
|
|
A lightweight wrapper for creating subsets of PatchGNNMILDataset. |
- class cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]¶
Bases:
InMemoryDataset- __init__(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]¶
- property raw_file_names: List[str]¶
Required by InMemoryDataset. Since we’re working with pre-existing data, we return an empty list.
- property processed_file_names: List[str]¶
Return the name of the processed file that will contain our data. Create a stable hash based on the configuration.
- download() None[source]¶
Required by InMemoryDataset. Since we’re working with pre-existing data, we don’t need to download anything.
- get_num_labels() int[source]¶
Get the number of unique labels in the dataset.
Note: For survival prediction tasks, this returns 0 as there are no discrete classes.
- process() None[source]¶
Required by InMemoryDataset. Process all the raw data and save it as a list of Data objects. This is where the real work happens and it’s only done once.
- _get_slides() List[str][source]¶
Extract slide names from the cached graph data.
- Returns:
List of slide names corresponding to all processed graphs
- _get_labels() Dict[str, Union[int, Tuple[float, int]]][source]¶
Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.
- Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
- _process_single_graph(slide_name: str) Optional[Data][source]¶
Process a single slide into a graph Data object without labels. Labels are handled separately to enable caching across different classification tasks.
- Parameters:
slide_name – Name of the slide
- Returns:
Processed Data object without labels or None if processing fails
- _merge_graph_with_features(features: Tensor, positions: Tensor) Data[source]¶
Create graph with spatial connectivity based on patch positions.
- get(idx: int) Data[source]¶
Override get method to apply feature transforms and attach labels dynamically.
- Parameters:
idx – Index of the sample to retrieve (based on slides with labels)
- Returns:
Data object with transforms and labels applied
- create_subset(indices: List[int]) SubsetPatchGNNMILDataset[source]¶
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split=”all”. Note: This creates a lightweight wrapper that references the original data.
- Parameters:
indices – List of indices to include in the subset
- Returns:
New PatchGNNMILDataset instance containing only the specified samples
- Raises:
ValueError – If any index is out of range
- create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[SubsetPatchGNNMILDataset, SubsetPatchGNNMILDataset][source]¶
Create train and validation datasets with transforms fitted on training data only.
This function prevents data leakage by ensuring transforms are fitted only on the training set before being applied to both training and validation sets.
- Parameters:
train_indices – List of indices for training data
val_indices – List of indices for validation data
transforms – Optional Transform or TransformPipeline for features (fitted on training data)
label_transforms – Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)
- Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
- Raises:
ValueError – If indices are invalid or transforms cannot be fitted
- class cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset(parent_dataset: PatchGNNMILDataset, indices: List[int])[source]¶
Bases:
objectA lightweight wrapper for creating subsets of PatchGNNMILDataset. This avoids the complexity of properly initializing InMemoryDataset subsets.
- __init__(parent_dataset: PatchGNNMILDataset, indices: List[int])[source]¶