import torch
import pandas as pd
import json
import hashlib
import random
from tqdm import tqdm
from pathlib import Path
from typing import List, Literal, Tuple, Union, Optional, Dict, Callable, Any
from torch_geometric.data import InMemoryDataset, Data # type: ignore
from cellmil.interfaces.FeatureExtractorConfig import (
ExtractorType,
FeatureExtractionType,
)
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorType
from cellmil.interfaces.CellSegmenterConfig import ModelType, TYPE_NUCLEI_DICT
from cellmil.utils import logger
from ..transforms import (
Transform,
TransformPipeline,
FittableTransform,
LabelTransform,
LabelTransformPipeline,
)
from .utils import (
wsl_preprocess,
column_sanity_check,
filter_split,
preprocess_row,
get_cell_types,
cell_types_to_tensor,
get_centroids,
centroids_to_tensor,
get_cell_features,
load_precomputed_graph,
merge_graph_with_features,
cell_type_name_to_index,
load_roi_for_slide,
filter_cells_by_roi,
)
[docs]class CellGNNMILDataset(InMemoryDataset):
"""
An optimized PyTorch Geometric InMemoryDataset for GNN+MIL (Graph Neural Network + Multiple Instance Learning) tasks.
This dataset follows the official PyTorch Geometric pattern by processing all data once in the `process()` method
and storing it efficiently. This provides significant speed improvements over the previous implementation.
"""
[docs] def __init__(
self,
root: Union[str, Path],
folder: Union[str, Path],
label: Union[str, Tuple[str, str]],
data: pd.DataFrame,
extractor: Union[ExtractorType, List[ExtractorType]],
graph_creator: GraphCreatorType,
segmentation_model: ModelType,
split: Literal["train", "val", "test", "all"] = "all",
cell_type: bool = False,
cell_types_to_keep: Optional[List[str]] = None,
return_cell_types: bool = True,
centroid: bool = False,
roi_folder: Optional[Path] = None,
max_workers: int = 8,
transforms: Optional[TransformPipeline | Transform] = None,
label_transforms: Optional[
Union[LabelTransform, LabelTransformPipeline]
] = None,
transform: Optional[Callable[[Data], Data]] = None,
pre_transform: Optional[Callable[[Data], Data]] = None,
pre_filter: Optional[Callable[[Data], bool]] = None,
force_reload: bool = False,
):
"""
Initialize the optimized GNN+MIL dataset.
Args:
root: Root directory where the processed dataset will be cached
folder: Path to the original dataset folder containing the raw data
label: Label for the dataset. Either:
- A single string (e.g., "dcr") for classification tasks
- A tuple of two strings (e.g., ("duration", "event")) for survival prediction tasks
data: DataFrame containing metadata
extractor: Feature extractor type(s)
graph_creator: Graph creator type - required for locating pre-computed graphs
segmentation_model: Segmentation model used for cell detection
split: Dataset split (train/val/test/all). Use "all" to include all data regardless of split
cell_type: Whether to include cell type features
cell_types_to_keep: Optional list of cell type names to keep (e.g., ["Neoplastic", "Connective"]).
Valid names: "Neoplastic", "Inflammatory", "Connective", "Dead", "Epithelial" (case-insensitive).
If provided, only cells of these types will be included. Requires cell_type=True.
Slides with no cells matching these types will be excluded from the dataset.
return_cell_types: Whether to store cell types in the graph data. If False, cell_types attribute is not added.
If True and cell_type=True, graph data will have a cell_types attribute. Default is True for backward compatibility.
centroid: Whether to include centroid features
roi_folder: Optional path to the directory containing ROI CSV files organized by center folders.
If provided, cells will be filtered to only include those within ROI boundaries.
Requires 'ID', 'I3LUNG_ID', and 'CENTER' columns in the data DataFrame.
Slides with no cells in ROI will be excluded from the dataset.
max_workers: Number of worker threads
transform: A function/transform that takes in a Data object and returns a transformed version
pre_transform: A function/transform applied before caching
pre_filter: A function that filters data objects
force_reload: Whether to force reprocessing even if processed files exist
transforms: Optional TransformPipeline to apply to node features at getitem time
label_transforms: Optional LabelTransform or LabelTransformPipeline to apply to labels (e.g., TimeDiscretizerTransform for survival analysis)
"""
# Store parameters before calling super().__init__()
self.folder = Path(folder)
self.label = label
self.raw_data = data
self.extractor = extractor
self.graph_creator = graph_creator
self.segmentation_model = segmentation_model
self.split = split
self.cell_type = cell_type
self.return_cell_types = return_cell_types
# Convert cell type names to indices
if cell_types_to_keep is not None:
self.cell_types_to_keep_indices = cell_type_name_to_index(
cell_types_to_keep
)
else:
self.cell_types_to_keep_indices = None
self.centroid = centroid
# ROI filtering setup
self.roi_folder = Path(roi_folder) if roi_folder is not None else None
# Validate ROI parameters
if self.roi_folder is not None:
if not self.roi_folder.exists():
raise ValueError(f"ROI folder does not exist: {self.roi_folder}")
# Check if data DataFrame has required columns for ROI filtering
required_roi_columns = ["ID", "I3LUNG_ID", "CENTER"]
missing_columns = [
col for col in required_roi_columns if col not in data.columns
]
if missing_columns:
raise ValueError(
f"ROI filtering requires the following columns in data DataFrame: {required_roi_columns}. "
f"Missing columns: {missing_columns}"
)
logger.info(f"ROI filtering enabled. ROI folder: {self.roi_folder}")
self.max_workers = max_workers
self.force_reload = force_reload
self.transforms = transforms
self.label_transforms = label_transforms
# TODO: Make this configurable
self.wsl = False
# TODO: -------
if (
isinstance(self.extractor, ExtractorType)
and self.extractor in FeatureExtractionType.Embedding
) or (
isinstance(self.extractor, list)
and any(ext in FeatureExtractionType.Embedding for ext in self.extractor)
):
raise ValueError("Embedding extractor is not supported for this dataset.")
# Validate parameters
if self.cell_type and self.segmentation_model not in ["cellvit", "hovernet"]:
raise ValueError(
f"Cell types can only be used with 'cellvit' or 'hovernet' segmentation models. "
f"Got '{self.segmentation_model}'"
)
# Validate that cell_types_to_keep requires cell_type=True
if self.cell_types_to_keep_indices is not None and not self.cell_type:
raise ValueError(
"cell_types_to_keep requires cell_type=True. "
"Please set cell_type=True to enable cell type filtering."
)
# Initialize data structures that will be populated in process()
self.all_slides: List[str] = [] # All slides with features
self.slides: List[str] = [] # Only slides with labels for this task
self.labels: Dict[str, Union[int, Tuple[float, int]]] = {}
self.cell_types: Dict[str, Dict[int, int]] = {}
self.centroids: Dict[str, Dict[int, Tuple[float, float]]] = {}
self.roi_filtered_graphs: Dict[
str, Any
] = {} # ROI-filtered graphs (in-memory only)
# Call parent constructor
super().__init__( # type: ignore
root=str(root),
transform=transform,
pre_transform=pre_transform,
pre_filter=pre_filter,
force_reload=force_reload,
)
# Load the processed data
self.load(self.processed_paths[0])
# Extract slide information
if not hasattr(self, "all_slides") or not self.all_slides:
# Extract all slides from cached data
self.all_slides = self._get_slides()
# Extract labels fresh from DataFrame (not cached with features)
self.labels = self._get_labels()
# Filter slides to only those with labels for this task
self.slides = [slide for slide in self.all_slides if slide in self.labels]
# Filter out slides with no matching cell types if cell_types_to_keep is specified
if self.cell_types_to_keep_indices is not None:
logger.info("Filtering slides with no matching cell types...")
slides_with_matching_cells: list[str] = []
for slide_name in tqdm(self.slides, desc="Checking cell types"):
# Get the graph data for this slide
try:
actual_idx = self.all_slides.index(slide_name)
graph_data = super().get(actual_idx)
if (
not hasattr(graph_data, "cell_types")
or graph_data.cell_types is None
):
logger.debug(
f"Skipping slide {slide_name}: no cell types available"
)
continue
cell_types_tensor = graph_data.cell_types
if cell_types_tensor.shape[0] == 0:
logger.debug(
f"Skipping slide {slide_name}: empty cell types tensor"
)
continue
# Get the cell type for each cell (argmax of one-hot encoding)
cell_type_tensor_indices = torch.argmax(cell_types_tensor, dim=1)
# Convert 1-based TYPE_NUCLEI_DICT indices to 0-based tensor indices
tensor_indices_to_keep = [
idx - 1 for idx in self.cell_types_to_keep_indices
]
# Check if any cells match the requested types
has_matching_cells = any(
(cell_type_tensor_indices == tensor_idx).any().item()
for tensor_idx in tensor_indices_to_keep
)
if has_matching_cells:
slides_with_matching_cells.append(slide_name)
else:
logger.debug(
f"Skipping slide {slide_name}: no cells of requested types"
)
except Exception as e:
logger.warning(
f"Error checking cell types for slide {slide_name}: {e}"
)
continue
self.slides = slides_with_matching_cells
logger.info(f"Kept {len(self.slides)} slides with matching cell types")
# Apply ROI filtering if enabled (after loading/processing complete dataset)
if self.roi_folder is not None:
self._apply_roi_filtering()
logger.info(
f"Loaded optimized GNN+MIL Dataset with {len(self)} total samples, {len(self.slides)} with labels for current task"
)
@property
def raw_file_names(self) -> List[str]:
"""
Required by InMemoryDataset.
Since we're working with pre-existing data, we return an empty list.
"""
return []
@property
def processed_file_names(self) -> List[str]:
"""
Return the name of the processed files that will contain our data.
Create a stable hash based on the configuration.
"""
# Ensure that the extractor (which is a list) is represented consistently
extractor_str = json.dumps(self.extractor, sort_keys=True)
config_dict: dict[str, Any] = {
"extractor": extractor_str, # Use stringified extractor to ensure stability
"graph_creator": self.graph_creator,
"segmentation_model": self.segmentation_model,
"split": self.split,
"cell_type": self.cell_type,
"centroid": self.centroid,
}
# Convert config_dict to a JSON string, which ensures deterministic representation
config_str = json.dumps(config_dict, sort_keys=True)
# Use hashlib to generate a stable hash from the string
config_hash = hashlib.md5(config_str.encode("utf-8")).hexdigest()[
:8
] # Shorten hash to 8 characters
logger.info(f"Dataset configuration:{config_dict}")
logger.info(f"Dataset configuration hash: {config_hash}")
return [f"data_{self.split}_{config_hash}.pt"]
[docs] def download(self) -> None:
"""
Required by InMemoryDataset.
Since we're working with pre-existing data, we don't need to download anything.
"""
pass
[docs] def process(self) -> None:
"""
Required by InMemoryDataset.
Process all the raw data and save it as a list of Data objects.
This is where the real work happens and it's only done once.
"""
logger.info("Processing GNN+MIL dataset...")
# Process the DataFrame
try:
if self.wsl:
processed_data = wsl_preprocess(self.raw_data)
else:
processed_data = self.raw_data.copy()
# Perform sanity check
column_sanity_check(processed_data, self.label)
# Filter by split type (unless split is "all")
if self.split != "all":
processed_data = filter_split(processed_data, self.split)
else:
logger.info(f"Using all data: {len(processed_data)} slides")
# Extract valid slides
all_slides: list[str] = []
for _, row in tqdm(
processed_data.iterrows(),
total=len(processed_data),
desc="Validating slides",
):
try:
slide_name, _ = preprocess_row(
row,
self.label,
self.folder,
self.extractor,
self.graph_creator,
self.segmentation_model,
)
if slide_name is not None:
all_slides.append(slide_name)
except Exception as e:
raise ValueError(f"Error validating slide: {e}")
self.all_slides = all_slides
logger.info(f"Found {len(self.all_slides)} valid slides")
if not self.all_slides:
raise ValueError("No valid slides found")
# Precompute cell types if enabled
if self.cell_type:
logger.info("Precomputing cell types for all slides...")
for slide_name in tqdm(self.all_slides, desc="Loading cell types"):
cell_types = get_cell_types(
self.folder, slide_name, self.segmentation_model
)
self.cell_types[slide_name] = cell_types or {}
# Precompute centroids if enabled
if self.centroid:
logger.info("Precomputing centroids for all slides...")
for slide_name in tqdm(self.all_slides, desc="Loading centroids"):
centroids = get_centroids(
self.folder, slide_name, self.segmentation_model
)
self.centroids[slide_name] = centroids or {}
# Process all graphs
data_list: list[Data] = []
for _, slide_name in enumerate(
tqdm(
self.all_slides,
total=len(self.all_slides),
desc="Processing graphs",
)
):
try:
graph_data = self._process_single_graph(slide_name)
if graph_data is not None:
data_list.append(graph_data)
except Exception as e:
logger.error(f"Failed to process slide {slide_name}: {e}")
continue
logger.info(f"Successfully processed {len(data_list)} graphs")
if not data_list:
raise ValueError("No graphs were successfully processed")
# Apply pre_filter if provided
if self.pre_filter is not None: # type: ignore
data_list = [data for data in data_list if self.pre_filter(data)] # type: ignore
logger.info(f"After filtering: {len(data_list)} graphs remain")
# Apply pre_transform if provided
if self.pre_transform is not None: # type: ignore
data_list = [self.pre_transform(data) for data in data_list] # type: ignore
# Save the processed data
self.save(data_list, self.processed_paths[0])
logger.info(f"Saved processed dataset to {self.processed_paths[0]}")
except Exception as e:
logger.error(f"Failed to process dataset: {e}")
raise
[docs] def _get_slides(self) -> List[str]:
"""
Extract slide names from the cached graph data.
Returns:
List of slide names corresponding to all processed graphs
"""
try:
all_slides: List[str] = []
for i in range(len(self)):
# Get the data object directly from cache
data = super().get(i)
if hasattr(data, "slide_name"):
all_slides.append(data.slide_name)
else:
raise ValueError(f"Graph at index {i} missing slide_name attribute")
logger.info(f"Extracted {len(all_slides)} slide names from cached data")
return all_slides
except Exception as e:
logger.error(f"Failed to extract slide names: {e}")
raise
[docs] def _get_labels(self) -> Dict[str, Union[int, Tuple[float, int]]]:
"""
Extract labels from the DataFrame based on current configuration.
This allows labels to be extracted fresh without being cached with features.
Returns:
Dictionary mapping slide names to labels (either int for classification or (duration, event) for survival)
"""
try:
# Process the DataFrame same way as in process()
if self.wsl:
processed_data = wsl_preprocess(self.raw_data)
else:
processed_data = self.raw_data.copy()
# Filter by split type (unless split is "all")
if self.split != "all":
processed_data = filter_split(processed_data, self.split)
# Extract labels for valid slides as a dictionary
labels: Dict[str, Union[int, Tuple[float, int]]] = {}
for _, row in processed_data.iterrows():
try:
slide_name, label_value = preprocess_row(
row,
self.label,
self.folder,
self.extractor,
self.graph_creator,
self.segmentation_model,
do_validate_features=False,
)
if slide_name in self.all_slides and label_value is not None:
labels[slide_name] = label_value
except Exception as e:
raise ValueError(f"Error extracting label: {e}")
return labels
except Exception as e:
logger.error(f"Failed to extract labels: {e}")
raise
[docs] def _apply_roi_filtering(self) -> None:
"""
Apply ROI filtering to the loaded dataset in-memory only.
This filters cells within ROI boundaries and keeps filtered graphs in memory.
Also removes slides that have no cells within ROI from self.slides.
Does NOT modify the cached files on disk.
"""
logger.info("Applying ROI filtering to dataset (in-memory only)...")
if self.roi_folder is None:
raise ValueError("roi_folder must be specified for ROI filtering")
if not self.centroid:
raise ValueError("ROI filtering requires centroid=True")
# Store filtered graphs in memory
self.roi_filtered_graphs: Dict[str, Any] = {}
slides_to_remove: list[str] = []
total_cells_before = 0
total_cells_after = 0
for slide_name in tqdm(self.slides, desc="Applying ROI filtering"):
# Get the graph data for this slide
try:
actual_idx = self.all_slides.index(slide_name)
graph_data = super().get(actual_idx)
# Clone to avoid modifying cached data
graph_data = graph_data.clone()
# Load ROI for this slide
roi_df = load_roi_for_slide(slide_name, self.roi_folder, self.raw_data)
if roi_df is None:
logger.debug(
f"No ROI found for slide {slide_name}, keeping all cells"
)
self.roi_filtered_graphs[slide_name] = graph_data
continue
# Get cell centroids (original pixel coordinates)
centroids_dict = get_centroids(
self.folder, slide_name, self.segmentation_model
)
if centroids_dict is None or len(centroids_dict) == 0:
logger.warning(
f"Could not load centroids for slide {slide_name}, keeping all cells"
)
self.roi_filtered_graphs[slide_name] = graph_data
continue
# Filter cells by ROI using actual pixel coordinates
cells_to_keep = filter_cells_by_roi(centroids_dict, roi_df)
total_cells_before += len(centroids_dict)
if len(cells_to_keep) == 0:
logger.debug(
f"No cells within ROI for slide {slide_name}, removing slide"
)
slides_to_remove.append(slide_name)
continue
total_cells_after += len(cells_to_keep)
# Check if graph has required attributes
if not hasattr(graph_data, "cell_ids") or graph_data.cell_ids is None:
logger.warning(
f"Graph for {slide_name} missing cell_ids, keeping all cells"
)
self.roi_filtered_graphs[slide_name] = graph_data
continue
# Convert cell_ids tensor to list and create mask based on cells_to_keep
cell_ids_list = graph_data.cell_ids.tolist()
# Create mask for nodes to keep (cells within ROI)
mask = torch.tensor(
[cell_id in cells_to_keep for cell_id in cell_ids_list],
dtype=torch.bool,
)
if not mask.any():
logger.debug(
f"No matching cells in graph for slide {slide_name}, removing slide"
)
slides_to_remove.append(slide_name)
continue
# Filter node features
graph_data.x = graph_data.x[mask]
# Filter position data
if hasattr(graph_data, "pos") and graph_data.pos is not None:
graph_data.pos = graph_data.pos[mask]
# Filter cell_ids
graph_data.cell_ids = graph_data.cell_ids[mask]
# Filter cell types if they exist
if (
hasattr(graph_data, "cell_types")
and graph_data.cell_types is not None
):
graph_data.cell_types = graph_data.cell_types[mask]
# Update edge indices to match filtered nodes
# Create mapping from old indices to new indices
num_nodes = mask.shape[0]
old_to_new = torch.full((num_nodes,), -1, dtype=torch.long)
old_to_new[mask] = torch.arange(mask.sum().item())
# Filter edges: keep only edges where both nodes are in the mask
if (
hasattr(graph_data, "edge_index")
and graph_data.edge_index is not None
):
edge_mask = (
mask[graph_data.edge_index[0]] & mask[graph_data.edge_index[1]]
)
graph_data.edge_index = graph_data.edge_index[:, edge_mask]
# Remap edge indices to new node indices
graph_data.edge_index = old_to_new[graph_data.edge_index]
# Filter edge attributes if they exist
if (
hasattr(graph_data, "edge_attr")
and graph_data.edge_attr is not None
):
graph_data.edge_attr = graph_data.edge_attr[edge_mask]
# Store filtered graph
self.roi_filtered_graphs[slide_name] = graph_data
logger.debug(
f"ROI filtering for {slide_name}: kept {len(cells_to_keep)}/{len(centroids_dict)} cells "
f"({len(cells_to_keep) / len(centroids_dict) * 100:.1f}%)"
)
except Exception as e:
logger.error(f"Error applying ROI filtering to slide {slide_name}: {e}")
slides_to_remove.append(slide_name)
continue
# Remove slides with no cells in ROI from slides list
for slide_name in slides_to_remove:
if slide_name in self.slides:
self.slides.remove(slide_name)
if slide_name in self.labels:
del self.labels[slide_name]
logger.info(
f"ROI filtering complete: kept {total_cells_after}/{total_cells_before} cells "
f"({total_cells_after / total_cells_before * 100:.1f}% if total_cells_before > 0 else 0) across {len(self.slides)} slides"
)
if slides_to_remove:
logger.info(f"Removed {len(slides_to_remove)} slides with no cells in ROI")
[docs] def _process_single_graph(self, slide_name: str) -> Optional[Data]:
"""
Process a single graph by merging pre-computed graph structure with features.
Labels are handled separately to enable caching across different classification tasks.
Args:
slide_name: Name of the slide
Returns:
Processed Data object without labels or None if processing fails
"""
try:
# Load features
features, cell_indices, _ = get_cell_features(
self.folder,
slide_name,
self.extractor,
self.graph_creator,
self.segmentation_model,
)
if cell_indices is None or features is None:
raise ValueError(
f"Missing features or cell_indices for slide {slide_name}"
)
# Get centroids
if len(cell_indices) == 0:
raise ValueError(
f"No cell indices for slide {slide_name}, using zero centroids"
)
centroids = torch.zeros(features.size(0), 2, dtype=torch.float32)
elif not self.centroid:
centroids = torch.zeros(features.size(0), 2, dtype=torch.float32)
else:
centroids = self._get_centroids_tensor(slide_name, cell_indices)
# Normalize centroids to [0, 1] per slide
min_vals = centroids.min(dim=0).values
max_vals = centroids.max(dim=0).values
denom = max_vals - min_vals
denom[denom == 0] = 1.0 # Prevent division by zero
centroids = (centroids - min_vals) / denom
# Store cell types as separate attribute if requested (not concatenated to features)
cell_types_tensor: Optional[torch.Tensor] = None
if self.cell_type:
if len(cell_indices) == 0:
raise ValueError(
f"Missing cell_indices for slide {slide_name}; adding zero cell types."
)
n_cell_types = len(TYPE_NUCLEI_DICT)
cell_types_tensor = torch.zeros(
features.size(0), n_cell_types, dtype=torch.float32
)
else:
cell_types_tensor = self._get_cell_types_tensor(
slide_name, cell_indices
)
# Load pre-computed graph
precomputed_graph = load_precomputed_graph(
self.folder, slide_name, self.graph_creator, self.segmentation_model
)
# Merge graph with features
graph_data = merge_graph_with_features(
precomputed_graph, features, cell_indices, centroids
)
# Store slide name for reference
graph_data.slide_name = slide_name
# Store cell types as a separate attribute if enabled and return_cell_types is True
if cell_types_tensor is not None and self.return_cell_types:
graph_data.cell_types = cell_types_tensor
return graph_data
except Exception as e:
logger.error(f"Error processing slide {slide_name}: {e}")
return None
[docs] def _get_cell_types_tensor(
self, slide_name: str, cell_indices: Dict[int, int]
) -> torch.Tensor:
"""Get cell types tensor for a slide."""
cell_types = self.cell_types.get(slide_name)
if cell_types is None or len(cell_types) == 0:
n_cells = len(cell_indices)
n_cell_types = len(TYPE_NUCLEI_DICT)
return torch.zeros(n_cells, n_cell_types, dtype=torch.float32)
return cell_types_to_tensor(cell_types, cell_indices)
[docs] def _get_centroids_tensor(
self, slide_name: str, cell_indices: Dict[int, int]
) -> torch.Tensor:
"""Get centroids tensor for a slide."""
centroids = self.centroids.get(slide_name)
if centroids is None or len(centroids) == 0:
n_cells = len(cell_indices)
return torch.zeros(n_cells, 2, dtype=torch.float32)
return centroids_to_tensor(centroids, cell_indices)
[docs] def __len__(self) -> int:
"""Return the number of slides with labels for this task."""
if hasattr(self, "slides") and hasattr(self, "labels") and self.labels:
return len(self.slides)
else:
return super().__len__()
[docs] def get(self, idx: int) -> Data:
"""
Override get method to apply feature transforms and attach labels dynamically.
Args:
idx: Index of the sample to retrieve (based on slides with labels)
Returns:
Data object with transforms and labels applied
"""
# Map from logical index (slides with labels) to actual cache index (all slides)
if hasattr(self, "slides") and idx < len(self.slides):
slide_name = self.slides[idx]
# Use ROI filtered graph if available
if (
hasattr(self, "roi_filtered_graphs")
and slide_name in self.roi_filtered_graphs
):
data = self.roi_filtered_graphs[slide_name].clone()
else:
# Find the actual index in the cached data
actual_idx = self.all_slides.index(slide_name)
# Get the original data from the parent class
data = super().get(actual_idx)
# Clone to avoid modifying the cached data
data = data.clone()
else:
actual_idx = idx
# Get the original data from the parent class
data = super().get(actual_idx)
# Clone to avoid modifying the cached data
data = data.clone()
# Filter by cell types if requested (only if not already filtered during process)
if (
self.cell_types_to_keep_indices is not None
and hasattr(data, "cell_types")
and data.cell_types is not None
):
# Get the cell type for each cell (argmax of one-hot encoding)
# Note: cell_types tensor uses 0-based indexing (TYPE_NUCLEI_DICT keys are 1-based)
cell_type_tensor_indices = torch.argmax(
data.cell_types, dim=1
) # (n_cells,) values 0-4
# Convert 1-based TYPE_NUCLEI_DICT indices to 0-based tensor indices
tensor_indices_to_keep = [
idx - 1 for idx in self.cell_types_to_keep_indices
]
# Create mask for cells to keep
mask = torch.zeros(cell_type_tensor_indices.shape[0], dtype=torch.bool)
for tensor_idx in tensor_indices_to_keep:
mask |= cell_type_tensor_indices == tensor_idx
# Apply mask to filter node features and cell_types
data.x = data.x[mask]
if self.return_cell_types:
data.cell_types = data.cell_types[mask]
# Update edge indices to match filtered nodes
# Create mapping from old indices to new indices
old_to_new = torch.full((mask.shape[0],), -1, dtype=torch.long)
old_to_new[mask] = torch.arange(mask.sum().item())
# Filter edges: keep only edges where both nodes are in the mask
if hasattr(data, "edge_index") and data.edge_index is not None:
edge_mask = mask[data.edge_index[0]] & mask[data.edge_index[1]]
data.edge_index = data.edge_index[:, edge_mask]
# Remap edge indices to new node indices
data.edge_index = old_to_new[data.edge_index]
# Filter edge attributes if they exist
if hasattr(data, "edge_attr") and data.edge_attr is not None:
data.edge_attr = data.edge_attr[edge_mask]
# Attach label dynamically (not cached with features)
if (
hasattr(self, "labels")
and hasattr(data, "slide_name")
and data.slide_name in self.labels
):
label = self.labels[data.slide_name]
# Apply label transforms if provided
if self.label_transforms is not None:
labels_dict = {data.slide_name: label}
transformed = self.label_transforms.transform_labels(labels_dict)
label = transformed[data.slide_name]
data.y = (
torch.tensor([label], dtype=torch.long)
if isinstance(label, int)
else label
)
# Note: Transform application is now handled by subsets to allow
# different transforms per fold without contamination
return data # type: ignore
[docs] def create_subset(self, indices: List[int]) -> "SubsetCellGNNMILDataset":
"""
Create a subset of the dataset using the specified indices.
This is useful for creating train/val/test splits when using split="all".
Note: This creates a lightweight wrapper that references the original data.
Args:
indices: List of indices to include in the subset
Returns:
New SubsetCellGNNMILDataset instance containing only the specified samples
Raises:
ValueError: If any index is out of range
"""
# Allow empty indices for validation-less training (e.g., final model on all data)
if not indices:
subset = SubsetCellGNNMILDataset(self, [])
logger.info("Created empty GNN subset (0 samples)")
return subset
max_idx = len(self) - 1
invalid_indices = [idx for idx in indices if idx < 0 or idx > max_idx]
if invalid_indices:
raise ValueError(
f"Invalid indices {invalid_indices}. Valid range is 0-{max_idx}"
)
# Create a simple subset wrapper
subset = SubsetCellGNNMILDataset(self, indices)
logger.info(
f"Created GNN subset with {len(indices)} samples from {len(self)} total samples"
)
return subset
[docs] def create_train_val_datasets(
self,
train_indices: List[int],
val_indices: List[int],
transforms: Optional[Union[TransformPipeline, Transform]] = None,
label_transforms: Optional[
Union[LabelTransform, LabelTransformPipeline]
] = None,
) -> Tuple["SubsetCellGNNMILDataset", "SubsetCellGNNMILDataset"]:
"""
Create train and validation datasets with transforms fitted only on training data.
This method prevents data leakage by ensuring that any fittable transforms
(like normalization or feature selection) are fitted only on the training set
and then applied to both train and validation sets.
Args:
train_indices: List of indices for training set
val_indices: List of indices for validation set
transforms: Optional transforms to apply. If provided, any FittableTransform
will be fitted on training data only
label_transforms: Optional LabelTransform or LabelTransformPipeline for labels (e.g., TimeDiscretizerTransform for survival analysis)
Returns:
Tuple of (train_dataset, val_dataset) with properly fitted transforms
Raises:
ValueError: If indices lists are empty or contain invalid indices
"""
if not train_indices:
raise ValueError("train_indices cannot be empty")
# Allow empty val_indices for training final model on all data
# Validate indices
max_idx = len(self) - 1
invalid_train = [idx for idx in train_indices if idx < 0 or idx > max_idx]
invalid_val = [idx for idx in val_indices if idx < 0 or idx > max_idx]
if invalid_train:
raise ValueError(
f"Invalid train indices {invalid_train}. Valid range is 0-{max_idx}"
)
if invalid_val:
raise ValueError(
f"Invalid val indices {invalid_val}. Valid range is 0-{max_idx}"
)
logger.info(
f"Creating train/val GNN datasets with {len(train_indices)} train and {len(val_indices)} val samples"
)
# If no transforms provided, create simple subsets
if transforms is None:
train_dataset = self.create_subset(train_indices)
val_dataset = self.create_subset(val_indices)
return train_dataset, val_dataset
# Always fit fittable transforms on training data
logger.info("Fitting transforms on training data to prevent data leakage...")
# Collect training features for fitting
all_train_features: List[torch.Tensor] = []
max_samples_per_slide = 100_000
sample_size = min(len(train_indices), 20)
indices = random.sample(train_indices, sample_size)
for idx in indices:
try:
# Get graph data and extract node features
graph_data = self.get(idx)
features = graph_data.x # Node features from graph
if features is None:
raise ValueError(f"No features found for sample {idx}")
# Subsample if needed to prevent memory issues
if features.shape[0] > max_samples_per_slide:
perm_indices = torch.randperm(features.shape[0])[
:max_samples_per_slide
]
features = features[perm_indices]
all_train_features.append(features)
except Exception as e:
raise ValueError(
f"Failed to load sample {idx} for transform fitting: {e}"
)
continue
if not all_train_features:
raise ValueError("No valid training samples found for fitting transforms")
# Concatenate all training features
combined_train_features = torch.cat(all_train_features, dim=0)
logger.info(
f"Fitting transforms on {combined_train_features.shape[0]} training instances with {combined_train_features.shape[1]} features"
)
# Fit the transforms if they are fittable
if isinstance(transforms, (TransformPipeline, FittableTransform)):
transforms.fit(combined_train_features)
else:
# If not fittable, just assign
pass
# Create datasets with fitted transforms
train_dataset = self.create_subset(train_indices)
val_dataset = self.create_subset(val_indices)
# Store transforms on individual subset datasets, not parent dataset
train_dataset.transforms = transforms
val_dataset.transforms = transforms
# Fit label transforms on training labels only
if label_transforms is not None:
from ..transforms import FittableLabelTransform
# Get training labels
train_labels = {
self.slides[train_indices[i]]: self.labels[
self.slides[train_indices[i]]
]
for i in range(len(train_indices))
}
# Fit label transforms on training data
if isinstance(label_transforms, FittableLabelTransform):
label_transforms.fit(train_labels)
logger.info(
f"Fitted label transform on {len(train_labels)} training labels"
)
elif isinstance(label_transforms, LabelTransformPipeline):
label_transforms.fit(train_labels)
logger.info(
f"Fitted label transform pipeline on {len(train_labels)} training labels"
)
# Apply to parent dataset (shared by both train and val subsets)
self.label_transforms = label_transforms
logger.info(
"Successfully created train/val GNN datasets with fitted transforms"
)
logger.info(f"Train dataset: {len(train_dataset)} samples")
logger.info(f"Val dataset: {len(val_dataset)} samples")
return train_dataset, val_dataset
[docs] def get_config(self) -> dict[str, Any]:
"""Get dataset configuration as a dictionary."""
config: dict[str, Any] = {
"dataset_type": self.__class__.__name__,
"label": str(self.label),
"extractor": str(self.extractor),
"graph_creator": str(self.graph_creator),
"segmentation_model": str(self.segmentation_model),
"split": self.split,
"cell_type": self.cell_type,
"return_cell_types": self.return_cell_types,
"centroid": self.centroid,
}
if self.cell_types_to_keep_indices is not None:
config["cell_types_to_keep_indices"] = self.cell_types_to_keep_indices
if self.roi_folder is not None:
config["roi_folder"] = str(self.roi_folder)
return config
[docs] def get_num_labels(self) -> int:
"""
Get the number of unique labels in the dataset.
Note: For survival prediction tasks, this returns 0 as there are no discrete classes.
"""
if hasattr(self, "labels") and self.labels:
# Check if we have survival data by looking at the first label
first_label = next(iter(self.labels.values()))
if isinstance(first_label, tuple):
# Survival data - no discrete classes
return 0
# Classification data
return len(set(self.labels.values()))
return self.num_classes
[docs]class SubsetCellGNNMILDataset:
"""
A lightweight wrapper for creating subsets of CellGNNMILDataset.
This avoids the complexity of properly initializing InMemoryDataset subsets.
"""
[docs] def __init__(self, parent_dataset: CellGNNMILDataset, indices: List[int]):
self.parent_dataset = parent_dataset
self.indices = indices
self.transforms: Optional[Union[TransformPipeline, Transform]] = (
None # Allow subset to have its own transforms
)
# Note: cell_types_to_keep filtering is handled by parent dataset's get() method
def __len__(self) -> int:
return len(self.indices)
def __getitem__(self, idx: int) -> Data:
if idx < 0 or idx >= len(self.indices):
raise IndexError(
f"Index {idx} out of range for subset of length {len(self.indices)}"
)
original_idx = self.indices[idx]
# Get the data from parent dataset
data = self.parent_dataset.get(original_idx)
# Apply subset-specific transforms if they exist
if self.transforms is not None and hasattr(data, "x") and data.x is not None:
# Clone to avoid modifying the cached data
data = data.clone()
data.x = self.transforms.transform(data.x) # type: ignore
return data
[docs] def get(self, idx: int) -> Data:
"""Alias for __getitem__ to match PyTorch Geometric interface."""
return self.__getitem__(idx)
@property
def num_classes(self) -> int:
"""Get number of classes from parent dataset."""
return self.parent_dataset.num_classes
[docs] def get_num_labels(self) -> int:
"""Get number of labels from parent dataset."""
return self.parent_dataset.get_num_labels()