Source code for cellmil.datamodels.datasets.utils

import torch
import ujson
import pandas as pd
import numpy as np
from typing import cast, Tuple, List, Optional, Dict, Union, Any
from pathlib import Path
from cellmil.utils import logger
import matplotlib.pyplot as plt
from cellmil.interfaces.FeatureExtractorConfig import (
    ExtractorType,
    FeatureExtractionType,
)
from cellmil.interfaces.CellSegmenterConfig import ModelType, TYPE_NUCLEI_DICT
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorType
from torch_geometric.data import Data  # type: ignore
from shapely.geometry import Point, Polygon
from shapely.ops import unary_union


[docs]def column_sanity_check(data: pd.DataFrame | None, label: Union[str, Tuple[str, str]]) -> None: """Perform sanity checks on the input data.""" if data is None or data.empty: raise ValueError("Input data is empty or None.") # Handle both single label and survival data (duration, event) tuple if isinstance(label, tuple): required_columns = ["FULL_PATH", label[0], label[1]] else: required_columns = ["FULL_PATH", label] for col in required_columns: if col not in data.columns: raise ValueError(f"Missing required column: {col}")
[docs]def preprocess_row( row: pd.Series, label: Union[str, Tuple[str, str]], folder: Path, extractor: ExtractorType | List[ExtractorType], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, do_validate_features: bool = True ) -> Tuple[str, Union[int, Tuple[float, int]]] | Tuple[None, ...]: """ Process a single slide row to extract slide name and validate features. Args: row: A pandas Series representing a row from the Excel file label: Either a single string (classification) or tuple of (duration, event) strings (survival) Returns: For classification: Tuple of (slide_name, label) For survival: Tuple of (slide_name, (duration, event)) On error: Tuple of (None, None) """ try: file_path = Path(cast(str, row["FULL_PATH"])) slide_name = extract_slide_name(file_path) if do_validate_features: valid_features = validate_features( folder, slide_name, extractor, graph_creator, segmentation_model ) else: valid_features = True if slide_name and valid_features: # Handle survival data (tuple of duration and event columns) if isinstance(label, tuple): duration_col, event_col = label duration = float(row[duration_col]) event = int(row[event_col]) return slide_name, (duration, event) else: # Regular classification label return slide_name, row[label] else: logger.warning( f"Skipping slide {slide_name}: slide has no valid features or slide name is invalid" ) return None, None except Exception as e: logger.warning(f"Error processing slide row: {e}") return None, None
[docs]def validate_features( folder: Path, slide_name: str, extractor: ExtractorType | List[ExtractorType], graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, ): """ Check if the feature file(s) exist and contain valid (non-empty) features. """ def _check_single( extractor: ExtractorType, ) -> Tuple[bool, Optional[torch.Tensor], Dict[int, int]]: p = get_feature_path( folder, slide_name, extractor, graph_creator, segmentation_model ) if not p.exists(): logger.warning(f"Feature file does not exist for slide {slide_name}: {p}") return False, None, {} try: data = torch.load(p, map_location="cpu", weights_only=False) if "features" not in data: logger.warning( f"No 'features' key in data for slide {slide_name} ({extractor})" ) return False, None, {} ft = data["features"] if ft.numel() == 0 or ft.shape[0] == 0: logger.warning( f"Empty features tensor for slide {slide_name} ({extractor}): shape {ft.shape}" ) return False, None, {} if ft.shape[0] < 200: logger.warning( f"Insufficient features for slide {slide_name} ({extractor}): {ft.shape[0]} < 200" ) return False, None, {} ci = cast(Dict[int, int], data.get("cell_indices", {})) return True, ft, ci except Exception as e: logger.warning( f"Error loading features for slide {slide_name} ({extractor}): {e}" ) return False, None, {} if isinstance(extractor, list): results = [_check_single(ext) for ext in extractor] if not all(ok for ok, _, _ in results): logger.warning(f"Some extractors failed validation for slide {slide_name}") return False fts = [cast(torch.Tensor, ft) for _, ft, _ in results] cis = [ci for _, _, ci in results] have_maps = all(len(m) > 0 for m in cis) if have_maps: common = set(cis[0].keys()) original_counts = [len(m) for m in cis] # Track which cells are missing from which extractors all_cells = cast(set[int], set().union(*[set(m.keys()) for m in cis])) # type: ignore for _, m in enumerate(cis[1:], 1): common &= set(m.keys()) if len(common) == 0: logger.warning( f"No overlapping cell ids across extractors for slide {slide_name}. " f"Extractor cell counts: {original_counts}" ) return False # Log detailed information about missing cells per extractor total_unique_cells = len(all_cells) excluded_cells = total_unique_cells - len(common) if excluded_cells > 0: logger.warning( f"Validation: {excluded_cells} cells will be excluded from slide {slide_name} " f"due to missing features in some extractors. " f"Will use {len(common)} common cells out of {total_unique_cells} total unique cells." ) # Log which extractors are missing which cells for _, (ext, cell_map) in enumerate(zip(extractor, cis)): missing_cells = all_cells - set(cell_map.keys()) if missing_cells: logger.warning( f" Extractor {ext} is missing {len(missing_cells)} cells: " f"cell IDs {sorted(list(missing_cells))[:10]}{'...' if len(missing_cells) > 10 else ''}" ) else: logger.info(f" Extractor {ext} has all {len(cell_map)} cells") return True # Without mappings, require same number of instances to allow naive concat n0 = fts[0].size(0) if any(ft.size(0) != n0 for ft in fts): logger.warning( f"Mismatched instance counts without cell_indices for slide {slide_name}. " f"Counts: {[ft.size(0) for ft in fts]}" ) return False return True else: ok, _, _ = _check_single(extractor) return ok
[docs]def get_feature_path( folder: Path, slide_name: str, extractor: ExtractorType, graph_creator: Optional[GraphCreatorType] = None, segmentation_model: Optional[ModelType] = None, ) -> Path: """ Get the path to the feature file for the given slide. """ if extractor in FeatureExtractionType.Embedding: return ( folder / slide_name / "feature_extraction" / str(extractor) / "features.pt" ) if segmentation_model is None: raise ValueError("Segmentation model is not set") if extractor in FeatureExtractionType.Morphological: return ( folder / slide_name / "feature_extraction" / str(extractor) / str(segmentation_model) / "features.pt" ) if graph_creator is None: raise ValueError("Graph creator is not set") if extractor in FeatureExtractionType.Topological: return ( folder / slide_name / "feature_extraction" / str(extractor) / str(graph_creator) / str(segmentation_model) / "features.pt" ) raise ValueError(f"Unknown extractor type: {extractor}")
[docs]def filter_split(data: pd.DataFrame, split: str) -> pd.DataFrame: """Filter the DataFrame by the specified split.""" data = data[data["SPLIT"] == split] logger.info(f"Using {split} split: {len(data)} slides") return data
[docs]def apply_permutation(features: torch.Tensor) -> torch.Tensor: """ Randomly permute the order of instances (rows). Args: features: Input feature tensor of shape (n_instances, n_features) Returns: Feature tensor with rows permuted """ if features.size(0) > 1: # Only shuffle if there are multiple instances # Generate random permutation indices perm_indices = torch.randperm(features.size(0)) # Shuffle the rows using the permutation indices features = features[perm_indices] return features
[docs]def subsample_and_pad( features: torch.Tensor, target_size: int, ) -> torch.Tensor: """ Randomly subsample or pad the bag to a fixed target size by replicating rows. Args: features: Input feature tensor of shape (n_instances, n_features) target_size: Desired number of instances per bag Returns: Tensor of shape (target_size, n_features) """ n = features.size(0) if n == 0: raise ValueError("Features tensor is empty; cannot subsample/pad an empty bag.") if n == target_size: return features if n > target_size: idx = torch.randperm(n)[:target_size] return features[idx] # n < target_size: pad by replicating rows pad_count = target_size - n pad_idx = torch.randint(low=0, high=n, size=(pad_count,), dtype=torch.long) return torch.cat([features, features[pad_idx]], dim=0)
[docs]def wsl_preprocess(data: pd.DataFrame) -> pd.DataFrame: """Preprocess the data to ensure paths are correctly formatted.""" data_copy = data.copy() data_copy["FULL_PATH"] = data_copy["FULL_PATH"].apply( # type: ignore lambda path: path.replace("\\", "/").replace("D:", "/mnt/d") # type: ignore ) return data_copy
[docs]def extract_slide_name(file_path: Path) -> str: """Extract slide name from full path.""" # Get the last part of the path (filename) and remove extension if not file_path: return "" return file_path.stem
[docs]def get_cell_detection_path( folder: Path, slide_name: str, segmentation_model: ModelType ) -> Path: """ Get the path to the cell detection file for the given slide. """ return ( folder / slide_name / "cell_detection" / str(segmentation_model) / "cell_detection.json" )
[docs]def get_cell_types( folder: Path, slide_name: str, segmentation_model: ModelType ) -> Dict[int, int] | None: # Load cell detection data once cell_detection_path = get_cell_detection_path( folder, slide_name, segmentation_model ) if not cell_detection_path.exists(): return None with open(cell_detection_path, "r") as f: cell_data = ujson.load(f) cells = cell_data.get("cells", []) # Use dictionary comprehension for faster processing cell_type_dict = { cell["cell_id"]: cell.get("type", 0) for cell in cells if cell.get("cell_id") is not None } return cell_type_dict
[docs]def get_centroids( folder: Path, slide_name: str, segmentation_model: ModelType ) -> Dict[int, Tuple[float, float]] | None: """ Get centroids for cells from the segmentation data. Args: folder: Path to the dataset folder slide_name: Name of the slide segmentation_model: Segmentation model used Returns: Dictionary mapping cell_id to (x, y) centroid coordinates, or None if data not found """ # Load cell detection data once cell_detection_path = get_cell_detection_path( folder, slide_name, segmentation_model ) if not cell_detection_path.exists(): return None with open(cell_detection_path, "r") as f: cell_data = ujson.load(f) cells = cell_data.get("cells", []) # Use dictionary comprehension for faster processing # This is significantly faster than a for loop with dict.append centroid_dict = { cell["cell_id"]: (float(cell["centroid"][0]), float(cell["centroid"][1])) for cell in cells if cell.get("cell_id") is not None and cell.get("centroid") is not None and len(cell["centroid"]) >= 2 } return centroid_dict
[docs]def centroids_to_tensor( centroids: Dict[int, Tuple[float, float]], cell_indices: Dict[int, int] ) -> torch.Tensor: """ Convert centroids dictionary to a tensor. Args: centroids: Dictionary mapping cell_id to (x, y) centroid coordinates cell_indices: Dictionary mapping cell_id to its index in the tensor Returns: A tensor containing the centroid coordinates [num_cells, 2] """ centroid_tensor = torch.zeros(len(cell_indices), 2, dtype=torch.float32) for cell_id, (x, y) in centroids.items(): if cell_id in cell_indices: centroid_tensor[cell_indices[cell_id], 0] = x centroid_tensor[cell_indices[cell_id], 1] = y return centroid_tensor
[docs]def cell_types_to_tensor( cell_types: Dict[int, int], cell_indices: Dict[int, int] ) -> torch.Tensor: """ Convert cell types dictionary to a tensor. Args: cell_types: Dictionary mapping cell_id to cell_type cell_indices: Dictionary mapping cell_id to its index in the tensor Returns: A tensor containing the cell types """ cell_type_tensor = torch.zeros( len(cell_indices), len(TYPE_NUCLEI_DICT), dtype=torch.long ) for cell_id, cell_type in cell_types.items(): if cell_id in cell_indices: cell_type_tensor[cell_indices[cell_id], cell_type - 1] = 1.0 return cell_type_tensor
[docs]def get_cell_features( folder: Path, slide_name: str, extractor: Union[ExtractorType, List[ExtractorType]], graph_creator: GraphCreatorType | None = None, segmentation_model: ModelType | None = None, ) -> Tuple[torch.Tensor | None, Dict[int, int] | None, List[str] | None]: """ Get the features for a specific slide using the specified extractor. Args: folder: Path to the dataset folder slide_name: Name of the slide extractor: Feature extractor type or list of types to use for feature extraction. graph_creator: Optional graph creator type, needed for some extractors segmentation_model: Optional Segmentation model type, needed for some extractors Returns: A tensor containing the extracted features, or None if extraction failed. """ if isinstance(extractor, list): tensors: List[torch.Tensor] = [] maps: List[Dict[int, int]] = [] feature_names: List[str] = [] for ext in extractor: data = torch.load( get_feature_path( folder, slide_name, ext, graph_creator, segmentation_model ), map_location="cpu", weights_only=False, ) tensors.append(data["features"]) # (N_i, F_i) maps.append(cast(Dict[int, int], data.get("cell_indices", {}))) ext_feature_names = cast(list[str] | None, data.get("feature_names", None)) if ext_feature_names: feature_names.extend(ext_feature_names) have_maps = all(len(m) > 0 for m in maps) if have_maps: # Align by common cell ids common = set(maps[0].keys()) original_counts = [len(m) for m in maps] # Track which cells are missing from which extractors all_cells = cast(set[int], set().union(*maps)) # type: ignore for m in maps[1:]: common &= set(m.keys()) ordered = sorted(common) if len(ordered) == 0: logger.warning( f"No overlapping cells across extractors for slide {slide_name}. " f"Extractor cell counts: {original_counts}" ) raise ValueError( f"No overlapping cells across extractors for slide {slide_name}." ) # Log cells that will be excluded with detailed per-extractor information total_unique_cells = len(all_cells) excluded_cells = total_unique_cells - len(ordered) if excluded_cells > 0: logger.warning( f"Excluding {excluded_cells} cells from slide {slide_name} due to missing features in some extractors. " f"Using {len(ordered)} common cells out of {total_unique_cells} total unique cells." ) # Log which extractors are missing which cells for _, (ext, cell_map) in enumerate(zip(extractor, maps)): missing_cells = all_cells - set(cell_map.keys()) if missing_cells: logger.warning( f" Extractor {ext} is missing {len(missing_cells)} cells: " f"cell IDs {sorted(list(missing_cells))[:10]}{'...' if len(missing_cells) > 10 else ''}" ) else: logger.info(f" Extractor {ext} has all {len(cell_map)} cells") aligned: List[torch.Tensor] = [] for t, m in zip(tensors, maps): idxs = torch.tensor([m[cid] for cid in ordered], dtype=torch.long) aligned.append(t[idxs]) features = torch.cat(aligned, dim=1) cell_indices: Dict[int, int] = {cid: i for i, cid in enumerate(ordered)} else: # Fallback: naive concat if instance counts match n0 = tensors[0].size(0) if any(t.size(0) != n0 for t in tensors): raise ValueError( f"Cannot align features without cell_indices for slide {slide_name}: mismatched instance counts." ) features = torch.cat(tensors, dim=1) cell_indices = {} else: features_data = torch.load( get_feature_path( folder, slide_name, extractor, graph_creator, segmentation_model ), map_location="cpu", weights_only=False, ) features = features_data["features"] cell_indices = cast(Dict[int, int], features_data.get("cell_indices", {})) feature_names = cast(List[str], features_data.get("feature_names", [])) return features, cell_indices, feature_names
[docs]def compute_normalization(features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: logger.info( f"Computing robust scaling parameters for {features.shape[1]} features using {features.shape[0]} instances..." ) # Compute median and IQR (75th percentile - 25th percentile) for each feature epsilon = 1e-8 features = torch.sign(features) * torch.log1p(torch.abs(features) + epsilon) median_values = torch.median(features, dim=0)[0] # Shape: (n_features,) q1 = torch.quantile(features, 0.25, dim=0) q3 = torch.quantile(features, 0.75, dim=0) iqr_values = q3 - q1 # Handle features with zero/small IQR (near-constant features) constant_mask = iqr_values <= 1e-8 if constant_mask.sum() > 0: logger.info( f"Found {constant_mask.sum()} near-constant features with very small IQR" ) # For constant features, set IQR to 1 to avoid division by zero iqr_values[constant_mask] = 1.0 return median_values, iqr_values
[docs]def correlation_filter( features: torch.Tensor, correlation_threshold: float, plot: bool = True ): # Store shape info before clearing memory total_features = features.shape[1] total_instances = features.shape[0] logger.info( f"Computing correlation matrix for {total_features} features using {total_instances} instances..." ) feature_std = features.std(dim=0) non_constant_mask = feature_std > 1e-8 # Clear feature_std as it's no longer needed del feature_std if non_constant_mask.sum() == 0: raise ValueError("All features are constant. Skipping correlation filter.") else: logger.info( f"Found {non_constant_mask.sum()} non-constant features out of {total_features} total features" ) # Only compute correlation for non-constant features - extract and delete original valid_features = features[:, non_constant_mask] del features # Free the large original tensor immediately # Compute correlation matrix with minimal intermediate tensors # Center the features in-place feature_means = valid_features.mean(dim=0) valid_features -= feature_means # In-place subtraction to save memory del feature_means # Clean up means # Compute covariance matrix directly n_samples = valid_features.shape[0] cov_matrix = torch.mm(valid_features.T, valid_features) / (n_samples - 1) del valid_features # Free the centered features tensor # Compute correlation matrix in-place # Extract diagonal for standard deviations std_devs = torch.sqrt(torch.diag(cov_matrix)) # Compute correlation matrix by modifying cov_matrix in-place corr_matrix = cov_matrix / torch.outer(std_devs, std_devs) del cov_matrix, std_devs # Clean up intermediate tensors # Plot correlation matrix if requested if plot: try: # Convert to numpy for plotting (create a copy to avoid modifying original) corr_np = corr_matrix.detach().cpu().numpy() # type: ignore # Create the plot _, ax = plt.subplots(figsize=(12, 10)) # type: ignore # Create heatmap im = ax.imshow(corr_np, cmap="coolwarm", vmin=-1, vmax=1, aspect="auto") # type: ignore # Add colorbar cbar = plt.colorbar(im, ax=ax, shrink=0.8) # type: ignore cbar.set_label("Correlation Coefficient", rotation=270, labelpad=20) # type: ignore # Set title and labels ax.set_title( # type: ignore f"Feature Correlation Matrix\n({corr_np.shape[0]} non-constant features)", fontsize=14, pad=20, ) ax.set_xlabel("Feature Index", fontsize=12) # type: ignore ax.set_ylabel("Feature Index", fontsize=12) # type: ignore # Add grid for better readability ax.grid(True, alpha=0.3) # type: ignore # Adjust layout and save plt.tight_layout() plt.show() # type: ignore logger.info("Correlation matrix plot saved as 'correlation_matrix.png'") except Exception as e: logger.warning(f"Failed to create correlation matrix plot: {e}") # Find highly correlated pairs upper_triangle = torch.triu(torch.abs(corr_matrix), diagonal=1) high_corr_pairs = torch.where(upper_triangle > correlation_threshold) # Store the count before cleaning up num_high_corr_pairs = len(high_corr_pairs[0]) # Clean up correlation matrix to save memory del corr_matrix, upper_triangle # Create mask for features to keep features_to_remove: set[int] = set() # Convert to int lists for iteration row_indices = [int(x) for x in high_corr_pairs[0]] col_indices = [int(x) for x in high_corr_pairs[1]] # Clean up high_corr_pairs tensors del high_corr_pairs for i, j in zip(row_indices, col_indices): if i not in features_to_remove and j not in features_to_remove: # Remove the second feature (j) of the pair features_to_remove.add(j) # Create final mask mapping back to original feature space keep_mask = torch.ones(total_features, dtype=torch.bool) # Map back to original indices valid_indices = torch.where(non_constant_mask)[0] for idx_to_remove in features_to_remove: original_idx = valid_indices[idx_to_remove] keep_mask[original_idx] = False # Also remove constant features keep_mask = keep_mask & non_constant_mask # Clean up valid_indices tensor del valid_indices features_removed = (~keep_mask).sum().item() features_kept = keep_mask.sum().item() logger.info( f"Correlation filter: removed {features_removed} features, kept {features_kept} features" ) logger.info( f"Found {num_high_corr_pairs} highly correlated pairs (threshold: {correlation_threshold})" ) return keep_mask, non_constant_mask
[docs]def weights_for_sampler(labels: list[int]) -> torch.Tensor: """ Compute weights for WeightedRandomSampler to handle class imbalance. The weight for each sample is computed as 1 / (class_frequency * num_samples_in_class). This gives higher weights to samples from underrepresented classes. Returns: torch.Tensor: Weights for each sample in the dataset, with shape (len(dataset),). These weights can be used directly with torch.utils.data.WeightedRandomSampler. """ if len(labels) == 0: logger.warning("No labels found in dataset. Returning uniform weights.") return torch.ones(len(labels), dtype=torch.float32) # Convert labels to tensor for easier computation labels_tensor = torch.tensor(labels, dtype=torch.long) # Count frequency of each class unique_labels, counts = cast( tuple[torch.Tensor, torch.Tensor], torch.unique(labels_tensor, return_counts=True), # type: ignore ) # Create a mapping from label to its frequency label_to_count = { label.item(): count.item() for label, count in zip(unique_labels, counts) } # Compute weight for each sample: 1 / count_of_its_class weights = torch.zeros(len(labels), dtype=torch.float32) for i, label in enumerate(labels): weights[i] = 1.0 / label_to_count[label] # Normalize weights so they sum to the number of samples weights = weights * len(weights) / weights.sum() logger.info(f"Computed sampling weights for {len(unique_labels)} classes:") for label, count in label_to_count.items(): weight_per_sample = 1.0 / count logger.info( f" Class {label}: {count} samples, weight per sample: {weight_per_sample:.4f}" ) return weights
[docs]def load_precomputed_graph( folder: Path, slide_name: str, graph_creator: GraphCreatorType, segmentation_model: ModelType, ) -> Data: """ Load pre-computed graph from disk. Args: folder: Base folder containing slide data slide_name: Name of the slide graph_creator: Graph creator type used (string or enum) segmentation_model: Segmentation model used (string or enum) Returns: Data object containing the loaded graph Raises: ValueError: If graph file doesn't exist or has invalid format """ graph_path = ( folder / slide_name / "graphs" / graph_creator / segmentation_model / "graph.pt" ) if not graph_path.exists(): raise ValueError(f"Pre-computed graph not found at {graph_path}") try: graph_dict = torch.load(graph_path, map_location="cpu", weights_only=False) if not isinstance(graph_dict, dict): raise ValueError(f"Expected graph dictionary, got {type(graph_dict)}") required_keys = ["node_features", "edge_indices", "edge_features"] missing_keys = [key for key in required_keys if key not in graph_dict] if missing_keys: raise ValueError(f"Graph missing required keys {missing_keys}") node_features = cast(torch.Tensor, graph_dict["node_features"]) edge_indices = cast(torch.Tensor, graph_dict["edge_indices"]) edge_features = cast(torch.Tensor, graph_dict["edge_features"]) if edge_indices.dim() != 2 or edge_indices.shape[0] != 2: raise ValueError( f"edge_indices should have shape [2, num_edges], got {edge_indices.shape}" ) if node_features.shape[1] < 1: raise ValueError("Graph node features seem empty") graph_data = Data( x=node_features, edge_index=edge_indices, edge_attr=edge_features, num_nodes=node_features.shape[0], ) if "metadata" in graph_dict: graph_data.metadata = graph_dict["metadata"] return graph_data except Exception as e: raise ValueError( f"Failed to load pre-computed graph for slide {slide_name}: {e}" )
[docs]def merge_graph_with_features( graph_data: Data, features: torch.Tensor, cell_indices: Dict[int, int], cell_coordinates: torch.Tensor, ) -> Data: """ Merge pre-computed graph structure with extracted features, ensuring proper alignment. This function ensures features are correctly assigned to their corresponding graph nodes based on cell IDs, creating a proper subgraph with aligned features. Args: graph_data: Data object containing graph structure features: Feature tensor cell_indices: Mapping from cell_id to feature tensor index cell_coordinates: Optional cell coordinates tensor [num_cells, 2] Returns: Data object with properly aligned features and graph structure """ # Extract graph components from Data object node_features = graph_data.x edge_indices = graph_data.edge_index edge_features = getattr(graph_data, "edge_attr", None) if node_features is None: raise ValueError("Graph node features are None") if edge_indices is None: raise ValueError("Graph edge indices are None") # Extract cell_ids from graph node features (assuming first column contains cell IDs) graph_cell_ids = node_features[:, 0].long() # Create mapping from cell_id to graph node index cell_id_to_graph_idx = { cell_id.item(): idx for idx, cell_id in enumerate(graph_cell_ids) } # Find intersection of cell_ids between graph and features common_cell_ids: List[int] = [] graph_indices: List[int] = [] feature_indices: List[int] = [] for cell_id, feature_idx in cell_indices.items(): if cell_id in cell_id_to_graph_idx: common_cell_ids.append(cell_id) graph_indices.append(cell_id_to_graph_idx[cell_id]) feature_indices.append(feature_idx) if not common_cell_ids: raise ValueError("No common cell_ids found between graph and features") logger.info(f"Found {len(common_cell_ids)} common cells between graph and features") graph_indices_tensor = torch.tensor(graph_indices, dtype=torch.long) feature_indices_tensor = torch.tensor(feature_indices, dtype=torch.long) # Create subgraph by filtering edges subgraph_edge_mask = torch.isin(edge_indices[0], graph_indices_tensor) & torch.isin( edge_indices[1], graph_indices_tensor ) if not subgraph_edge_mask.any(): logger.warning("No edges found in subgraph - creating isolated nodes") # Create empty edge_index for isolated nodes remapped_edges = torch.empty((2, 0), dtype=torch.long) subgraph_edge_attr = None else: # Remap edge indices to new node ordering old_to_new_idx = { old_idx.item(): new_idx for new_idx, old_idx in enumerate(graph_indices_tensor) } subgraph_edges = edge_indices[:, subgraph_edge_mask] remapped_edges = torch.zeros_like(subgraph_edges) for i in range(subgraph_edges.shape[1]): src_old = subgraph_edges[0, i].item() dst_old = subgraph_edges[1, i].item() remapped_edges[0, i] = old_to_new_idx[src_old] remapped_edges[1, i] = old_to_new_idx[dst_old] # Get edge attributes for subgraph subgraph_edge_attr = None if edge_features is not None: subgraph_edge_attr = edge_features[subgraph_edge_mask] # Select features for the common cells in the correct order selected_features = features[feature_indices_tensor] # Prepare centroids if available pos = cell_coordinates[feature_indices_tensor] # Create final merged Data object merged_data = Data( x=selected_features, edge_index=remapped_edges, edge_attr=subgraph_edge_attr, pos=pos, num_nodes=len(common_cell_ids), ) # Store cell_ids for reference merged_data.cell_ids = torch.tensor(common_cell_ids, dtype=torch.long) return merged_data
[docs]def cell_type_name_to_index(cell_type_names: List[str]) -> List[int]: """ Convert cell type names to their corresponding indices. Args: cell_type_names: List of cell type names (case-insensitive) Returns: List of cell type indices (1-based, as used in TYPE_NUCLEI_DICT) Raises: ValueError: If any cell type name is invalid """ # Create a case-insensitive lookup dictionary name_to_index = {name.lower(): idx for idx, name in TYPE_NUCLEI_DICT.items()} indices: list[int] = [] for name in cell_type_names: name_lower = name.lower() if name_lower not in name_to_index: valid_names = list(TYPE_NUCLEI_DICT.values()) raise ValueError( f"Invalid cell type name: '{name}'. " f"Valid names are: {valid_names} (case-insensitive)" ) indices.append(name_to_index[name_lower]) return indices
[docs]def load_roi_for_slide( slide_name: str, roi_folder: Path, metadata: pd.DataFrame ) -> Optional[pd.DataFrame]: """ Load ROI data for a specific slide. Args: slide_name: Name of the slide (DIG_PAT_XXXXXXXX format) roi_folder: Path to directory containing ROI CSV files metadata: DataFrame containing 'ID', 'I3LUNG_ID', and 'CENTER' columns Returns: DataFrame with ROI coordinates or None if not found """ try: # Find the slide in metadata slide_row = metadata[metadata["ID"] == slide_name] if slide_row.empty: raise ValueError(f"Slide {slide_name} not found in metadata") i3lung_id = cast(str, slide_row["I3LUNG_ID"].values[0]) # type: ignore center = cast(str, slide_row["CENTER"].values[0]) # type: ignore # Map center to folder name center_to_folder: Dict[str, Union[List[str], str]] = { "GHD": "GHD_RoI_auto", "INT": "INT_RoI_auto", "MH": "MH_RoI_auto", "SZMC": ["SZMC_RoI_auto", "SZMC-unzipped_RoI_auto"], "UOC": "UOC_RoI_auto", "VHIO": "VHIO_RoI_auto", } folder = center_to_folder.get(center) if folder is None: raise ValueError(f"Unknown CENTER value '{center}' for slide {slide_name}") # Try to find the ROI file roi_path = None if isinstance(folder, list): for f in folder: potential_path = roi_folder / f / f"{i3lung_id}.csv" if potential_path.exists(): roi_path = potential_path break else: potential_path = roi_folder / folder / f"{i3lung_id}.csv" if potential_path.exists(): roi_path = potential_path if roi_path is None: raise FileNotFoundError(f"ROI file not found for slide {slide_name}") # Load ROI data roi_df = pd.read_csv(roi_path) # type: ignore logger.debug(f"Loaded ROI for slide {slide_name}: {len(roi_df)} points") return roi_df except Exception as e: logger.error(f"Error loading ROI for slide {slide_name}: {e}") return None
[docs]def filter_cells_by_roi( centroids: Dict[int, Tuple[float, float]], roi_df: pd.DataFrame ) -> set[int]: """ Filter cells to keep only those within ROI boundaries. Args: centroids: Dictionary mapping cell_id to (x, y) centroid coordinates roi_df: DataFrame with ROI coordinates (columns: roi_name, label, x_base, y_base) Returns: Set of cell IDs that are within the ROI boundaries """ # Create polygons for each ROI roi_polygons: list[Polygon] = [] for roi_name in roi_df["roi_name"].unique(): # type: ignore roi_points = cast( np.ndarray[Any, Any], roi_df[roi_df["roi_name"] == roi_name][["x_base", "y_base"]].values, # type: ignore ) if len(roi_points) >= 3: # Need at least 3 points to form a polygon roi_polygons.append(Polygon(roi_points)) if len(roi_polygons) == 0: logger.warning("No valid ROI polygons found") return set() # Merge all ROI polygons into a single geometry roi_union = unary_union(roi_polygons) # Filter cells by checking if centroid is within ROI cells_to_keep: set[int] = set() for cell_id, (x, y) in centroids.items(): point = Point(x, y) if point.within(roi_union): cells_to_keep.add(cell_id) logger.debug(f"ROI filtering: kept {len(cells_to_keep)}/{len(centroids)} cells") return cells_to_keep