cellmil.datamodels.datasets.patch_gnn_mil_dataset

Classes

PatchGNNMILDataset(root, folder, label, ...)

SubsetPatchGNNMILDataset(parent_dataset, indices)

A lightweight wrapper for creating subsets of PatchGNNMILDataset.

class cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]

Bases: InMemoryDataset

__init__(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: ExtractorType, split: Literal['train', 'val', 'test', 'all'] = 'all', transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None)[source]
property raw_file_names: List[str]

Required by InMemoryDataset. Since we’re working with pre-existing data, we return an empty list.

property processed_file_names: List[str]

Return the name of the processed file that will contain our data. Create a stable hash based on the configuration.

download() None[source]

Required by InMemoryDataset. Since we’re working with pre-existing data, we don’t need to download anything.

get_config() dict[str, Any][source]

Get dataset configuration as a dictionary.

get_num_labels() int[source]

Get the number of unique labels in the dataset.

Note: For survival prediction tasks, this returns 0 as there are no discrete classes.

process() None[source]

Required by InMemoryDataset. Process all the raw data and save it as a list of Data objects. This is where the real work happens and it’s only done once.

_get_slides() List[str][source]

Extract slide names from the cached graph data.

Returns:

List of slide names corresponding to all processed graphs

_get_labels() Dict[str, Union[int, Tuple[float, int]]][source]

Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.

Returns:

Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)

_process_single_graph(slide_name: str) Optional[Data][source]

Process a single slide into a graph Data object without labels. Labels are handled separately to enable caching across different classification tasks.

Parameters:

slide_name – Name of the slide

Returns:

Processed Data object without labels or None if processing fails

_get_positions_tensor(coordinates: list[str]) Tensor[source]

Get positions tensor for a slide.

_merge_graph_with_features(features: Tensor, positions: Tensor) Data[source]

Create graph with spatial connectivity based on patch positions.

__len__() int[source]

Return the number of slides with labels for this task.

get(idx: int) Data[source]

Override get method to apply feature transforms and attach labels dynamically.

Parameters:

idx – Index of the sample to retrieve (based on slides with labels)

Returns:

Data object with transforms and labels applied

create_subset(indices: List[int]) SubsetPatchGNNMILDataset[source]

Create a subset of the dataset using the specified indices.

This is useful for creating train/val/test splits when using split=”all”. Note: This creates a lightweight wrapper that references the original data.

Parameters:

indices – List of indices to include in the subset

Returns:

New PatchGNNMILDataset instance containing only the specified samples

Raises:

ValueError – If any index is out of range

create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[SubsetPatchGNNMILDataset, SubsetPatchGNNMILDataset][source]

Create train and validation datasets with transforms fitted on training data only.

This function prevents data leakage by ensuring transforms are fitted only on the training set before being applied to both training and validation sets.

Parameters:
  • train_indices – List of indices for training data

  • val_indices – List of indices for validation data

  • transforms – Optional Transform or TransformPipeline for features (fitted on training data)

  • label_transforms – Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)

Returns:

Tuple of (train_dataset, val_dataset) with properly fitted transforms

Raises:

ValueError – If indices are invalid or transforms cannot be fitted

get_normalization_params() None[source]
get_correlation_mask() None[source]
class cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset(parent_dataset: PatchGNNMILDataset, indices: List[int])[source]

Bases: object

A lightweight wrapper for creating subsets of PatchGNNMILDataset. This avoids the complexity of properly initializing InMemoryDataset subsets.

__init__(parent_dataset: PatchGNNMILDataset, indices: List[int])[source]
get(idx: int) Data[source]

Alias for __getitem__ to match PyTorch Geometric interface.

property num_classes: int

Get number of classes from parent dataset.

get_num_labels() int[source]

Get number of labels from parent dataset.