cellmil.datamodels.datasets.cell_gnn_mil_dataset¶
Classes
|
An optimized PyTorch Geometric InMemoryDataset for GNN+MIL (Graph Neural Network + Multiple Instance Learning) tasks. |
|
A lightweight wrapper for creating subsets of CellGNNMILDataset. |
- class cellmil.datamodels.datasets.cell_gnn_mil_dataset.CellGNNMILDataset(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: GraphCreatorType, segmentation_model: ModelType, split: Literal['train', 'val', 'test', 'all'] = 'all', cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, centroid: bool = False, roi_folder: Optional[Path] = None, max_workers: int = 8, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False)[source]¶
Bases:
InMemoryDatasetAn optimized PyTorch Geometric InMemoryDataset for GNN+MIL (Graph Neural Network + Multiple Instance Learning) tasks.
This dataset follows the official PyTorch Geometric pattern by processing all data once in the process() method and storing it efficiently. This provides significant speed improvements over the previous implementation.
- __init__(root: Union[str, Path], folder: Union[str, Path], label: Union[str, Tuple[str, str]], data: DataFrame, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: GraphCreatorType, segmentation_model: ModelType, split: Literal['train', 'val', 'test', 'all'] = 'all', cell_type: bool = False, cell_types_to_keep: Optional[List[str]] = None, return_cell_types: bool = True, centroid: bool = False, roi_folder: Optional[Path] = None, max_workers: int = 8, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, transform: Optional[Callable[[Data], Data]] = None, pre_transform: Optional[Callable[[Data], Data]] = None, pre_filter: Optional[Callable[[Data], bool]] = None, force_reload: bool = False)[source]¶
Initialize the optimized GNN+MIL dataset.
- Parameters:
root – Root directory where the processed dataset will be cached
folder – Path to the original dataset folder containing the raw data
label – Label for the dataset. Either: - A single string (e.g., “dcr”) for classification tasks - A tuple of two strings (e.g., (“duration”, “event”)) for survival prediction tasks
data – DataFrame containing metadata
extractor – Feature extractor type(s)
graph_creator – Graph creator type - required for locating pre-computed graphs
segmentation_model – Segmentation model used for cell detection
split – Dataset split (train/val/test/all). Use “all” to include all data regardless of split
cell_type – Whether to include cell type features
cell_types_to_keep – Optional list of cell type names to keep (e.g., [“Neoplastic”, “Connective”]). Valid names: “Neoplastic”, “Inflammatory”, “Connective”, “Dead”, “Epithelial” (case-insensitive). If provided, only cells of these types will be included. Requires cell_type=True. Slides with no cells matching these types will be excluded from the dataset.
return_cell_types – Whether to store cell types in the graph data. If False, cell_types attribute is not added. If True and cell_type=True, graph data will have a cell_types attribute. Default is True for backward compatibility.
centroid – Whether to include centroid features
roi_folder – Optional path to the directory containing ROI CSV files organized by center folders. If provided, cells will be filtered to only include those within ROI boundaries. Requires ‘ID’, ‘I3LUNG_ID’, and ‘CENTER’ columns in the data DataFrame. Slides with no cells in ROI will be excluded from the dataset.
max_workers – Number of worker threads
transform – A function/transform that takes in a Data object and returns a transformed version
pre_transform – A function/transform applied before caching
pre_filter – A function that filters data objects
force_reload – Whether to force reprocessing even if processed files exist
transforms – Optional TransformPipeline to apply to node features at getitem time
label_transforms – Optional LabelTransform or LabelTransformPipeline to apply to labels (e.g., TimeDiscretizerTransform for survival analysis)
- property raw_file_names: List[str]¶
Required by InMemoryDataset. Since we’re working with pre-existing data, we return an empty list.
- property processed_file_names: List[str]¶
Return the name of the processed files that will contain our data. Create a stable hash based on the configuration.
- download() None[source]¶
Required by InMemoryDataset. Since we’re working with pre-existing data, we don’t need to download anything.
- process() None[source]¶
Required by InMemoryDataset. Process all the raw data and save it as a list of Data objects. This is where the real work happens and it’s only done once.
- _get_slides() List[str][source]¶
Extract slide names from the cached graph data.
- Returns:
List of slide names corresponding to all processed graphs
- _get_labels() Dict[str, Union[int, Tuple[float, int]]][source]¶
Extract labels from the DataFrame based on current configuration. This allows labels to be extracted fresh without being cached with features.
- Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
- _apply_roi_filtering() None[source]¶
Apply ROI filtering to the loaded dataset in-memory only. This filters cells within ROI boundaries and keeps filtered graphs in memory. Also removes slides that have no cells within ROI from self.slides. Does NOT modify the cached files on disk.
- _process_single_graph(slide_name: str) Optional[Data][source]¶
Process a single graph by merging pre-computed graph structure with features. Labels are handled separately to enable caching across different classification tasks.
- Parameters:
slide_name – Name of the slide
- Returns:
Processed Data object without labels or None if processing fails
- _get_cell_types_tensor(slide_name: str, cell_indices: Dict[int, int]) Tensor[source]¶
Get cell types tensor for a slide.
- _get_centroids_tensor(slide_name: str, cell_indices: Dict[int, int]) Tensor[source]¶
Get centroids tensor for a slide.
- get(idx: int) Data[source]¶
Override get method to apply feature transforms and attach labels dynamically.
- Parameters:
idx – Index of the sample to retrieve (based on slides with labels)
- Returns:
Data object with transforms and labels applied
- create_subset(indices: List[int]) SubsetCellGNNMILDataset[source]¶
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split=”all”. Note: This creates a lightweight wrapper that references the original data.
- Parameters:
indices – List of indices to include in the subset
- Returns:
New SubsetCellGNNMILDataset instance containing only the specified samples
- Raises:
ValueError – If any index is out of range
- create_train_val_datasets(train_indices: List[int], val_indices: List[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) Tuple[SubsetCellGNNMILDataset, SubsetCellGNNMILDataset][source]¶
Create train and validation datasets with transforms fitted only on training data.
This method prevents data leakage by ensuring that any fittable transforms (like normalization or feature selection) are fitted only on the training set and then applied to both train and validation sets.
- Parameters:
train_indices – List of indices for training set
val_indices – List of indices for validation set
transforms – Optional transforms to apply. If provided, any FittableTransform will be fitted on training data only
label_transforms – Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)
- Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
- Raises:
ValueError – If indices lists are empty or contain invalid indices
- class cellmil.datamodels.datasets.cell_gnn_mil_dataset.SubsetCellGNNMILDataset(parent_dataset: CellGNNMILDataset, indices: List[int])[source]¶
Bases:
objectA lightweight wrapper for creating subsets of CellGNNMILDataset. This avoids the complexity of properly initializing InMemoryDataset subsets.
- __init__(parent_dataset: CellGNNMILDataset, indices: List[int])[source]¶