Source code for cellmil.features.extractor.topological

import torch
import networkx as nx
import numpy as np
import cv2
from scipy.spatial import cKDTree, ConvexHull  # type: ignore
from typing import Any, cast, Mapping
from cellmil.interfaces.FeatureExtractorConfig import ExtractorType
from cellmil.utils import logger


[docs]class TopologicalExtractor:
[docs] def __init__(self, extractor_name: ExtractorType): self.extractor_name = extractor_name if self.extractor_name == ExtractorType.connectivity: self.extractor = ConnectivityExtractor() elif self.extractor_name == ExtractorType.structure: self.extractor = StructureExtractor() elif self.extractor_name == ExtractorType.geometric: self.extractor = GeometricExtractor() else: raise ValueError(f"Unknown extractor type: {self.extractor_name}")
[docs] def extract_features( self, cell_id: torch.Tensor, graph: dict[str, torch.Tensor], cells: list[dict[str, Any]], ) -> dict[str, Any]: try: features = self.extractor.extract_features(cell_id, graph, cells) return features except Exception as e: raise RuntimeError(f"Error extracting features: {e}")
[docs]class ConnectivityExtractor: """Extract connectivity features from the graph. * Degree * Weighted degree (by distance) * K-core number * PageRank * Eigenvector centrality (Approx.) """
[docs] def __init__(self): """Initialize the connectivity extractor.""" # Cache for computed global metrics self._cached_k_core: dict[int, float] | None = None self._cached_pagerank: dict[int, float] | None = None self._cached_eigenvector: dict[int, float] | None = None # Cache for degree metrics self._cached_degree: dict[int, float] | None = None self._cached_weighted_degree: dict[int, float] | None = None
[docs] def _ensure_global_metrics_computed( self, edge_indices: torch.Tensor, edge_features: torch.Tensor, num_nodes: int ) -> tuple[ dict[int, float], dict[int, float], dict[int, float], dict[int, float], dict[int, float], ]: """Compute global metrics (k-core, pagerank, eigenvector, degree, weighted_degree) if not cached.""" if ( self._cached_k_core is None or self._cached_pagerank is None or self._cached_eigenvector is None or self._cached_degree is None or self._cached_weighted_degree is None ): # Build NetworkX graph once G = nx.Graph() G.add_nodes_from(range(num_nodes)) # type: ignore edges = cast(np.ndarray[Any, Any], edge_indices.t().cpu().numpy()) # type: ignore if edge_features.shape[0] > 0: # Weighted graph distances = cast( np.ndarray[Any, Any], edge_features[:, 0].cpu().numpy(), # type: ignore ) weights = 1.0 / (distances + 1e-8) weighted_edges = [ (int(edges[i, 0]), int(edges[i, 1]), float(weights[i])) for i in range(len(edges)) ] G.add_weighted_edges_from(weighted_edges) # type: ignore else: # Unweighted graph G.add_edges_from(edges) # type: ignore # Compute all global metrics at once try: self._cached_k_core = dict(cast(Mapping[int, float], nx.core_number(G))) # type: ignore except Exception: self._cached_k_core = {i: 0.0 for i in range(num_nodes)} try: if edge_features.shape[0] > 0: self._cached_pagerank = dict( cast( Mapping[int, float], nx.pagerank( # type: ignore G, alpha=0.85, max_iter=100, tol=1e-6, weight="weight" ), ) ) # type: ignore else: self._cached_pagerank = dict( cast( Mapping[int, float], nx.pagerank(G, alpha=0.85, max_iter=100, tol=1e-6), # type: ignore ) ) except Exception: default_pr = 1.0 / num_nodes if num_nodes > 0 else 0.0 self._cached_pagerank = {i: default_pr for i in range(num_nodes)} try: if edge_features.shape[0] > 0: self._cached_eigenvector = dict( cast( Mapping[int, float], nx.eigenvector_centrality( # type: ignore G, max_iter=100, tol=1e-6, weight="weight" ), ) ) else: self._cached_eigenvector = dict( cast( Mapping[int, float], nx.eigenvector_centrality(G, max_iter=100, tol=1e-6), # type: ignore ) ) except Exception: # Fallback to degree centrality if eigenvector fails try: self._cached_eigenvector = dict( cast(Mapping[int, float], nx.degree_centrality(G)) # type: ignore ) except Exception: self._cached_eigenvector = {i: 0.0 for i in range(num_nodes)} # Compute degree metrics efficiently self._cached_degree = {} self._cached_weighted_degree = {} # Use NetworkX degree method for all nodes at once (much faster) degree_dict = cast(dict[int, int], dict(G.degree())) # type: ignore for node in range(num_nodes): self._cached_degree[node] = float(degree_dict.get(node, 0)) # Compute weighted degree for all nodes at once if edge_features.shape[0] > 0: weighted_degree_dict = cast( dict[int, float], dict(G.degree(weight="weight")) # type: ignore ) # type: ignore for node in range(num_nodes): self._cached_weighted_degree[node] = float( weighted_degree_dict.get(node, 0.0) ) else: # If no edge features, weighted degree equals regular degree self._cached_weighted_degree = self._cached_degree.copy() return ( self._cached_k_core, self._cached_pagerank, self._cached_eigenvector, self._cached_degree, self._cached_weighted_degree, )
[docs] def clear_cache(self) -> None: """Clear the cached global metrics. Useful for testing or memory management.""" self._cached_k_core = None self._cached_pagerank = None self._cached_eigenvector = None self._cached_degree = None self._cached_weighted_degree = None
[docs] def extract_features( self, cell_id: torch.Tensor, graph: dict[str, torch.Tensor], cells: list[dict[str, Any]], ) -> dict[str, Any]: """Extract connectivity features for a specific cell from the graph. Args: cell_id: Tensor containing the target cell ID graph: Dictionary containing 'edge_indices', 'edge_features', 'node_features' cells: List of cell dictionaries (not used in connectivity extraction) Returns: Dictionary containing connectivity features """ try: edge_indices = graph["edge_indices"] # [2, num_edges] edge_features = graph[ "edge_features" ] # [num_edges, 3] - [distance, direction_x, direction_y] node_features = graph["node_features"] # [num_nodes, 1] - contains cell_ids # Extract the cell_id from the tensor # cell_id is the node feature tensor which contains the cell_id value in its first element target_cell_id = int( cell_id.item() if cell_id.numel() == 1 else cell_id[0].item() ) # Find the node index that corresponds to this cell_id # node_features[:, 0] contains the actual cell_ids node_index = self._find_index(target_cell_id, node_features) if node_index is None: # Cell ID not found in the graph return { "degree": 0.0, "weighted_degree": 0.0, "k_core": 0.0, "pagerank": 0.0, "eigenvector_centrality": 0.0, } # Get total number of nodes num_nodes = node_features.shape[0] # Extract features using the node index features: dict[str, Any] = {} # Get cached global metrics (computed once for all nodes) ( k_core_scores, pagerank_scores, eigenvector_scores, degree_scores, weighted_degree_scores, ) = self._ensure_global_metrics_computed( edge_indices, edge_features, num_nodes ) # 1. Degree (number of connections) - from cache features["degree"] = degree_scores.get(node_index, 0.0) # 2. Weighted degree (sum of inverse distances) - from cache features["weighted_degree"] = weighted_degree_scores.get(node_index, 0.0) features["k_core"] = k_core_scores.get(node_index, 0.0) features["pagerank"] = pagerank_scores.get( node_index, 1.0 / num_nodes if num_nodes > 0 else 0.0 ) features["eigenvector_centrality"] = eigenvector_scores.get(node_index, 0.0) return features except Exception: # Return zero features if computation fails return { "degree": 0.0, "weighted_degree": 0.0, "k_core": 0.0, "pagerank": 0.0, "eigenvector_centrality": 0.0, }
[docs] def _find_index(self, cell_id: int, node_features: torch.Tensor) -> int | None: """Find the node index that corresponds to the given cell_id. Args: cell_id: The cell ID to find node_features: Tensor with shape [num_nodes, 1] containing cell_ids in first column Returns: Node index if found, None otherwise """ # node_features[:, 0] contains the cell_ids cell_ids = node_features[:, 0] matches = (cell_ids == cell_id).nonzero(as_tuple=True)[0] if len(matches) > 0: return int(matches[0].item()) return None
[docs]class StructureExtractor: """Extract structural features from the graph. * Weighted clustering coefficient (by distance) * Local efficiency * Ego-network density """
[docs] def __init__(self): """Initialize the structure extractor.""" # Cache for computed global metrics self._cached_weighted_clustering: dict[int, float] | None = None
[docs] def _ensure_global_metrics_computed( self, edge_indices: torch.Tensor, edge_features: torch.Tensor, num_nodes: int ) -> dict[int, float]: """Compute global structural metrics if not cached.""" if self._cached_weighted_clustering is None: # Build NetworkX graph G = nx.Graph() G.add_nodes_from(range(num_nodes)) # type: ignore edges = cast(np.ndarray[Any, Any], edge_indices.t().cpu().numpy()) # type: ignore if edge_features.shape[0] > 0: # Weighted graph distances = cast( np.ndarray[Any, Any], edge_features[:, 0].cpu().numpy(), # type: ignore ) weights = 1.0 / (distances + 1e-8) weighted_edges = [ (int(edges[i, 0]), int(edges[i, 1]), float(weights[i])) for i in range(len(edges)) ] G.add_weighted_edges_from(weighted_edges) # type: ignore else: G.add_edges_from(edges) # type: ignore try: if edge_features.shape[0] > 0: self._cached_weighted_clustering = dict( cast(Mapping[int, float], nx.clustering(G, weight="weight")) # type: ignore ) else: self._cached_weighted_clustering = dict( cast(Mapping[int, float], nx.clustering(G)) # type: ignore ) except Exception: self._cached_weighted_clustering = {i: 0.0 for i in range(num_nodes)} return self._cached_weighted_clustering
[docs] def extract_features( self, cell_id: torch.Tensor, graph: dict[str, torch.Tensor], cells: list[dict[str, Any]], ) -> dict[str, Any]: """Extract structural features for a specific cell from the graph.""" try: edge_indices = graph["edge_indices"] edge_features = graph["edge_features"] node_features = graph["node_features"] # Extract the cell_id from the tensor target_cell_id = int( cell_id.item() if cell_id.numel() == 1 else cell_id[0].item() ) # Find the node index node_index = self._find_index(target_cell_id, node_features) if node_index is None: return { "weighted_clustering_coefficient": 0.0, "local_efficiency": 0.0, "ego_network_density": 0.0, } num_nodes = node_features.shape[0] features: dict[str, Any] = {} # Vectorized neighbor finding (compute once, use multiple times) neighbors = self._get_neighbors_vectorized(node_index, edge_indices) # Get cached global metrics weighted_clustering_scores = self._ensure_global_metrics_computed( edge_indices, edge_features, num_nodes ) # 1. Weighted clustering coefficient features["weighted_clustering_coefficient"] = ( weighted_clustering_scores.get(node_index, 0.0) ) # 2. Local efficiency (using pre-computed neighbors) features["local_efficiency"] = self._calculate_local_efficiency( node_index, edge_indices, edge_features, neighbors ) # 3. Ego-network density (using pre-computed neighbors) features["ego_network_density"] = self._calculate_ego_network_density( node_index, edge_indices, neighbors ) return features except Exception: return { "weighted_clustering_coefficient": 0.0, "local_efficiency": 0.0, "ego_network_density": 0.0, }
[docs] def _find_index(self, cell_id: int, node_features: torch.Tensor) -> int | None: """Find the node index that corresponds to the given cell_id.""" cell_ids = node_features[:, 0] matches = (cell_ids == cell_id).nonzero(as_tuple=True)[0] if len(matches) > 0: return int(matches[0].item()) return None
[docs] def _get_neighbors_vectorized( self, node_idx: int, edge_indices: torch.Tensor ) -> torch.Tensor: """Vectorized neighbor finding using torch operations.""" # Find edges where node_idx appears as source or target mask_src = edge_indices[0] == node_idx mask_tgt = edge_indices[1] == node_idx # Get neighbors from both directions neighbors_from_src = edge_indices[1][mask_src] # When node_idx is source neighbors_from_tgt = edge_indices[0][mask_tgt] # When node_idx is target # Combine and get unique neighbors all_neighbors = torch.cat([neighbors_from_src, neighbors_from_tgt]) # Get unique values manually to avoid type issues if len(all_neighbors) == 0: return torch.tensor([], dtype=torch.long) neighbors_array = cast(np.ndarray[Any, Any], all_neighbors.cpu().numpy()) # type: ignore unique_neighbors_list = list(set(neighbors_array.astype(int))) unique_neighbors = torch.tensor( unique_neighbors_list, dtype=torch.long, device=all_neighbors.device ) return unique_neighbors
[docs] def _calculate_local_efficiency( self, node_idx: int, edge_indices: torch.Tensor, edge_features: torch.Tensor, neighbors: torch.Tensor, ) -> float: """Calculate local efficiency using pre-computed neighbors.""" if len(neighbors) < 2: return 0.0 # Convert to set for fast lookup neighbors_array = cast(np.ndarray[Any, Any], neighbors.cpu().numpy()) # type: ignore neighbors_set = set(neighbors_array.astype(int)) # Build subgraph of neighbors G = nx.Graph() G.add_nodes_from(neighbors_set) # type: ignore # Vectorized edge filtering between neighbors edges = edge_indices.t() # [num_edges, 2] # Create masks for edges between neighbors src_in_neighbors = torch.isin(edges[:, 0], neighbors) tgt_in_neighbors = torch.isin(edges[:, 1], neighbors) neighbor_edge_mask = src_in_neighbors & tgt_in_neighbors neighbor_edges_indices = edges[neighbor_edge_mask] if edge_features.shape[0] > 0 and neighbor_edge_mask.sum() > 0: # Weighted graph - get corresponding distances neighbor_distances = edge_features[neighbor_edge_mask, 0] weights = 1.0 / (neighbor_distances + 1e-8) weighted_edges = [ (int(edge[0].item()), int(edge[1].item()), float(weight.item())) for edge, weight in zip(neighbor_edges_indices, weights) ] G.add_weighted_edges_from(weighted_edges) # type: ignore else: # Unweighted graph unweighted_edges = [ (int(edge[0].item()), int(edge[1].item())) for edge in neighbor_edges_indices ] G.add_edges_from(unweighted_edges) # type: ignore # Calculate efficiency try: return float(nx.local_efficiency(G)) # type: ignore except Exception as e: logger.error(f"Error in local_efficiency calculation: {e}", exc_info=True) return 0.0
[docs] def _calculate_ego_network_density( self, node_idx: int, edge_indices: torch.Tensor, neighbors: torch.Tensor ) -> float: """Calculate ego-network density using pre-computed neighbors.""" # Ego network = node + its neighbors ego_nodes = torch.cat( [torch.tensor([node_idx], device=neighbors.device), neighbors] ) ego_size = len(ego_nodes) if ego_size < 2: return 0.0 # Vectorized edge counting within ego network edges = edge_indices.t() # [num_edges, 2] # Create masks for edges within ego network src_in_ego = torch.isin(edges[:, 0], ego_nodes) tgt_in_ego = torch.isin(edges[:, 1], ego_nodes) ego_edge_mask = src_in_ego & tgt_in_ego # Count unique edges (undirected graph has each edge twice) edges_in_ego = ego_edge_mask.sum().item() // 2 # Maximum possible edges in ego network max_edges = ego_size * (ego_size - 1) // 2 return float(edges_in_ego / max_edges) if max_edges > 0 else 0.0
[docs] def clear_cache(self) -> None: """Clear the cached global metrics.""" self._cached_weighted_clustering = None
[docs]class GeometricExtractor: """Extract geometric features from the graph. * Distance to nearest neighbor * Distance to nearest neighbor of each type * Mean distance to neighbors * Edge length variance * Anisotropy → Dominant direction of nearest neighbors * Local density (number of nodes in a radius) * Spatial entropy of neighbors * Shape of local convex hull * Area/perimeter ratio of local neighborhood * Nucleus size relative to local density * Anisotropy of neighborhood * Relative orientation of neighbors """
[docs] def __init__(self): """Initialize the geometric extractor.""" self.radius_for_density = 200.0 # Default radius for local density calculation # Cache mapping from cell_id to cell dict to avoid rebuilding per call self._cell_id_to_cell_cache: dict[int, dict[str, Any]] | None = None self._cells_obj_id: int | None = None # track identity of current cells list for caches # Spatial index caches for local density self._positions_array: np.ndarray[Any, Any] | None = None self._cell_ids_array: np.ndarray[Any, Any] | None = None self._cell_id_to_pos_index: dict[int, int] | None = None self._kdtree: cKDTree[np.ndarray[Any, Any]] | None = None # Cache for local density queries: key = (cells_obj_id, int(radius*1000), cell_id) self._local_density_cache: dict[tuple[int, int, int], float] = {} # Graph-scoped caches (reset when tensors change identity) self._neighbor_cache_key: tuple[int, int, int] | None = None self._neighbor_cache: dict[int, dict[str, Any]] = {} # Node index cache keyed by node_features identity self._node_index_cache_key: int | None = None self._node_index_cache: dict[int, int] = {}
[docs] def _get_cell_mapping( self, cells: list[dict[str, Any]] ) -> dict[int, dict[str, Any]]: """Return cached mapping cell_id -> cell; rebuild only when cells list identity changes.""" cells_id = id(cells) if self._cell_id_to_cell_cache is None or self._cells_obj_id != cells_id: self._cell_id_to_cell_cache = { int(cell["cell_id"]): cell for cell in cells if "cell_id" in cell } self._cells_obj_id = cells_id return self._cell_id_to_cell_cache
[docs] def _ensure_spatial_index(self, cells: list[dict[str, Any]]) -> None: """Build and cache KDTree and arrays from cells when the list identity changes.""" if ( self._kdtree is not None and self._positions_array is not None and self._cells_obj_id == id(cells) ): return positions: list[list[float]] = [] cell_ids: list[int] = [] for cell in cells: if "centroid" in cell and cell["centroid"] is not None: cx, cy = cell["centroid"][0], cell["centroid"][1] positions.append([float(cx), float(cy)]) cell_ids.append(int(cell.get("cell_id", -1))) if not positions: # Clear caches if no positions self._positions_array = None self._cell_ids_array = None self._cell_id_to_pos_index = None self._kdtree = None self._cells_obj_id = id(cells) self._local_density_cache.clear() return self._positions_array = np.asarray(positions, dtype=np.float64) self._cell_ids_array = np.asarray(cell_ids, dtype=np.int64) # Build mapping from cell_id to the last seen index (stable enough for our use) self._cell_id_to_pos_index = { int(cid): idx for idx, cid in enumerate(self._cell_ids_array.tolist()) } # Build KDTree once self._kdtree = cKDTree(self._positions_array) # type: ignore self._cells_obj_id = id(cells) self._local_density_cache.clear()
[docs] def _graph_key( self, edge_indices: torch.Tensor, edge_features: torch.Tensor, node_features: torch.Tensor, ) -> tuple[int, int, int]: """Create a lightweight identity key for current graph tensors (no hashing).""" return (id(edge_indices), id(edge_features), id(node_features))
[docs] def extract_features( self, cell_id: torch.Tensor, graph: dict[str, torch.Tensor], cells: list[dict[str, Any]], ) -> dict[str, Any]: """Extract geometric features for a specific cell from the graph.""" try: edge_indices = graph["edge_indices"] edge_features = graph[ "edge_features" ] # [distance, direction_x, direction_y] node_features = graph["node_features"] # Extract the cell_id from the tensor target_cell_id = int( cell_id.item() if cell_id.numel() == 1 else cell_id[0].item() ) # Find the node index node_index = self._find_index(target_cell_id, node_features) if node_index is None: raise ValueError( f"Cell ID {target_cell_id} not found in node features." ) # Get cell information for this node using cached mapping cell_id_to_cell = self._get_cell_mapping(cells) target_cell = cell_id_to_cell.get(target_cell_id) if target_cell is None: raise ValueError(f"Cell ID {target_cell_id} not found in cells list.") features: dict[str, Any] = {} # Get neighbors and their properties neighbor_data = self._get_neighbour_data( node_index, edge_indices, edge_features, cells, node_features ) if not neighbor_data["neighbors"]: raise ValueError(f"Cell ID {target_cell_id} has no neighbors.") # 1. Distance to nearest neighbor (using centroid distances) features["distance_to_nearest_neighbor"] = ( min(neighbor_data["distances"]) if neighbor_data["distances"] else 0.0 ) # 2. Mean distance to neighbors (using centroid distances) features["mean_distance_to_neighbors"] = ( float(np.mean(neighbor_data["distances"])) # type: ignore if neighbor_data["distances"] else 0.0 ) # 3. Edge length variance (using centroid distances) features["edge_length_variance"] = ( float(np.var(neighbor_data["distances"])) if neighbor_data["distances"] else 0.0 ) # 4. Anisotropy - dominant direction of nearest neighbors features.update(self._calculate_anisotropy(neighbor_data)) # 5. Local density (number of nodes in a radius) features["local_density"] = self._calculate_local_density( target_cell, cells, self.radius_for_density ) # 6. Spatial entropy of neighbors features["spatial_entropy"] = self._calculate_spatial_entropy(neighbor_data) # 7. Convex hull features features.update( self._calculate_convex_hull_features(neighbor_data, target_cell) ) # 8. Nucleus size relative to local density features["nucleus_size_relative_to_density"] = ( self._calculate_relative_nucleus_size( target_cell, features["local_density"] ) ) # 9. Relative orientation of neighbors features["mean_neighbor_orientation"] = self._calculate_mean_orientation( neighbor_data ) return features except Exception: logger.error("Error extracting features", exc_info=True) return { "distance_to_nearest_neighbor": 0.0, "mean_distance_to_neighbors": 0.0, "edge_length_variance": 0.0, "anisotropy_ratio": 0.0, "dominant_direction_x": 0.0, "dominant_direction_y": 0.0, "local_density": 0.0, "spatial_entropy": 0.0, "convex_hull_area": 0.0, "convex_hull_perimeter": 0.0, "area_perimeter_ratio": 0.0, "nucleus_size_relative_to_density": 0.0, "mean_neighbor_orientation": 0.0, }
[docs] def _find_index(self, cell_id: int, node_features: torch.Tensor) -> int | None: """Find the node index that corresponds to the given cell_id.""" # Rebuild cache if the node_features tensor identity changed nf_id = id(node_features) if self._node_index_cache_key != nf_id: try: ids_tensor = node_features[:, 0] ids_list = cast(list[int], ids_tensor.detach().cpu().tolist()) # type: ignore[arg-type] self._node_index_cache = {} for idx, cid in enumerate(ids_list): if cid not in self._node_index_cache: self._node_index_cache[int(cid)] = int(idx) self._node_index_cache_key = nf_id except Exception: cell_ids = node_features[:, 0] matches = (cell_ids == cell_id).nonzero(as_tuple=True)[0] if len(matches) > 0: return int(matches[0].item()) return None return self._node_index_cache.get(int(cell_id))
[docs] def _get_neighbour_data( self, node_idx: int, edge_indices: torch.Tensor, edge_features: torch.Tensor, cells: list[dict[str, Any]], node_features: torch.Tensor, ) -> dict[str, Any]: """Get comprehensive neighbor data using vectorized ops (undirected graphs) with simple caching.""" # Ensure cache corresponds to current graph tensors gkey = self._graph_key(edge_indices, edge_features, node_features) if self._neighbor_cache_key != gkey: self._neighbor_cache_key = gkey self._neighbor_cache.clear() # Fast path: return cached result if node_idx in self._neighbor_cache: return self._neighbor_cache[node_idx] # Connections mask (node as source or target) mask_connected = (edge_indices[0] == node_idx) | (edge_indices[1] == node_idx) # Early return if no connections if not bool(mask_connected.any().item()): result: dict[str, Any] = { "neighbors": [], "distances": [], "directions_x": [], "directions_y": [], "neighbor_cells": [], } self._neighbor_cache[node_idx] = result return result # Indices of all connected edges (preserve order) connected_edges = mask_connected.nonzero(as_tuple=True)[0] # Source/target tensors for connected edges src_nodes = edge_indices[0, connected_edges] tgt_nodes = edge_indices[1, connected_edges] # Neighbor index for each connected edge (preserves original ordering) neighbor_idx_tensor = torch.where(src_nodes == node_idx, tgt_nodes, src_nodes) # Convert neighbors to Python list[int] neighbors: list[int] = [ int(x) for x in cast(list[int], neighbor_idx_tensor.tolist()) # type: ignore ] # Distances and directions (if features exist) distances: list[float] = [] directions_x: list[float] = [] directions_y: list[float] = [] if edge_features.shape[0] > 0 and len(connected_edges) > 0: # Distances for connected edges dist_tensor = edge_features[connected_edges, 0] distances = [float(x) for x in cast(list[float], dist_tensor.tolist())] # type: ignore # Direction sign: +1 if node is source, -1 if node is target # cast bool to numeric on same device/dtype without creating new tensors sign = (src_nodes == node_idx).to(dtype=edge_features.dtype) sign = sign.mul_(2).sub_( 1 ) # 1 for True -> (2*1-1)=1, 0 for False -> (0*2-1)=-1 # Apply sign to (dx, dy) for each connected edge dir_tensor = edge_features[connected_edges, 1:3] * sign.view(-1, 1) directions_x = [ float(x) for x in cast(list[float], dir_tensor[:, 0].tolist()) # type: ignore ] directions_y = [ float(y) for y in cast(list[float], dir_tensor[:, 1].tolist()) # type: ignore ] # Neighbor cell lookup (keep order; include only found cells) neighbor_cells: list[dict[str, Any]] = [] cell_id_to_cell = self._get_cell_mapping(cells) if len(neighbors) > 0: neighbor_idx_cpu = neighbor_idx_tensor.cpu() valid_mask = (neighbor_idx_cpu >= 0) & ( neighbor_idx_cpu < node_features.shape[0] ) if bool(valid_mask.any().item()): valid_indices = neighbor_idx_cpu[valid_mask] neighbor_cell_ids = cast( list[int], node_features[valid_indices, 0].cpu().tolist() # type: ignore ) # Preserve order; skip missing ids without branching per element neighbor_cells = [ cell_id_to_cell[cid] for cid in neighbor_cell_ids if cid in cell_id_to_cell ] result: dict[str, Any] = { "neighbors": neighbors, "distances": distances, "directions_x": directions_x, "directions_y": directions_y, "neighbor_cells": neighbor_cells, } self._neighbor_cache[node_idx] = result return result
[docs] def clear_cache(self) -> None: """Clear all caches maintained by GeometricExtractor.""" self._cell_id_to_cell_cache = None self._cells_obj_id = None self._positions_array = None self._cell_ids_array = None self._cell_id_to_pos_index = None self._kdtree = None self._neighbor_cache_key = None self._neighbor_cache.clear() self._node_index_cache_key = None self._node_index_cache.clear()
[docs] def _calculate_anisotropy(self, neighbor_data: dict[str, Any]) -> dict[str, float]: """Calculate anisotropy and dominant direction.""" directions_x = np.array(neighbor_data["directions_x"]) directions_y = np.array(neighbor_data["directions_y"]) if len(directions_x) == 0: return { "anisotropy_ratio": 0.0, "dominant_direction_x": 0.0, "dominant_direction_y": 0.0, } # Calculate covariance matrix of directions directions = np.column_stack([directions_x, directions_y]) if directions.shape[0] < 2: return { "anisotropy_ratio": 0.0, "dominant_direction_x": 0.0, "dominant_direction_y": 0.0, } try: cov_matrix = np.cov(directions.T) eigenvalues, eigenvectors = np.linalg.eigh(cov_matrix) # Sort eigenvalues and eigenvectors idx = np.argsort(eigenvalues)[::-1] eigenvalues = eigenvalues[idx] eigenvectors = eigenvectors[:, idx] # Anisotropy ratio (ratio of major to minor axis) anisotropy_ratio = eigenvalues[0] / (eigenvalues[1] + 1e-8) # Dominant direction (first eigenvector) dominant_direction = eigenvectors[:, 0] return { "anisotropy_ratio": float(anisotropy_ratio), "dominant_direction_x": float(dominant_direction[0]), "dominant_direction_y": float(dominant_direction[1]), } except Exception: return { "anisotropy_ratio": 0.0, "dominant_direction_x": 0.0, "dominant_direction_y": 0.0, }
[docs] def _calculate_local_density( self, target_cell: dict[str, Any], cells: list[dict[str, Any]], radius: float ) -> float: """Calculate number of cells within radius using cached KDTree (no Python loops).""" # Early return if no centroid if "centroid" not in target_cell or target_cell["centroid"] is None: return 0.0 # Ensure spatial index is ready (built once per cells list identity) self._ensure_spatial_index(cells) if self._kdtree is None or self._positions_array is None: return 0.0 # Query count of neighbors within radius (prefer return_length for speed) # Check cache first cells_id = self._cells_obj_id if self._cells_obj_id is not None else id(cells) target_id = int(target_cell.get("cell_id", -1)) radius_key = int(round(float(radius) * 1000)) cache_key = (cells_id, radius_key, target_id) cached = self._local_density_cache.get(cache_key) if cached is not None: return cached target_pos = np.asarray( [target_cell["centroid"][0], target_cell["centroid"][1]], dtype=np.float64 ) try: cnt = self._kdtree.query_ball_point( target_pos, radius, p=2.0, return_length=True ) count_val = ( int(cnt) if not isinstance(cnt, (list, tuple, np.ndarray)) else (int(cnt[0]) if len(cnt) > 0 else 0) # type: ignore ) except TypeError: # Older SciPy: no return_length, fall back to indices length indices = self._kdtree.query_ball_point(target_pos, radius, p=2.0) count_val = ( len(indices) # type: ignore if isinstance(indices, (list, tuple, np.ndarray)) # type: ignore else int(indices) ) # Exclude the target cell itself only if we know it's indexed in the tree if ( self._cell_id_to_pos_index is not None and target_id in self._cell_id_to_pos_index ): count_val = max(0, count_val - 1) result = float(count_val) self._local_density_cache[cache_key] = result return result
[docs] def _calculate_spatial_entropy(self, neighbor_data: dict[str, Any]) -> float: """Calculate spatial entropy of neighbor distribution.""" if not neighbor_data["directions_x"] or not neighbor_data["directions_y"]: return 0.0 # Divide space into angular bins angles = np.arctan2( neighbor_data["directions_y"], neighbor_data["directions_x"] ) # Normalize to [0, 2π] angles = (angles + 2 * np.pi) % (2 * np.pi) # Create bins (8 directions) n_bins = 8 bin_edges = np.linspace(0, 2 * np.pi, n_bins + 1) hist, _ = np.histogram(angles, bins=bin_edges) # Normalize to probabilities prob = hist / np.sum(hist) prob = prob[prob > 0] # Remove zero probabilities # Calculate entropy if len(prob) > 0: entropy = -np.sum(prob * np.log2(prob)) return float(entropy) return 0.0
[docs] def _calculate_convex_hull_features( self, neighbor_data: dict[str, Any], target_cell: dict[str, Any] ) -> dict[str, float]: """Calculate convex hull area and perimeter of neighborhood.""" try: # Get positions of neighbors + target cell points: list[list[float]] = [] # Add target cell position if "centroid" in target_cell: points.append([target_cell["centroid"][0], target_cell["centroid"][1]]) # Add neighbor positions for cell in neighbor_data["neighbor_cells"]: if "centroid" in cell: points.append([cell["centroid"][0], cell["centroid"][1]]) if len(points) < 3: # Need at least 3 points for convex hull return { "convex_hull_area": 0.0, "convex_hull_perimeter": 0.0, "area_perimeter_ratio": 0.0, } points_array = np.array(points) hull = ConvexHull(points_array) area = float(hull.volume) # In 2D, volume is actually area perimeter = 0.0 # Calculate perimeter for simplex in hull.simplices: p1 = points_array[simplex[0]] p2 = points_array[simplex[1]] perimeter += np.linalg.norm(p2 - p1) area_perimeter_ratio = ( (4 * np.pi * area) / (perimeter**2) if perimeter > 0 else 0.0 ) return { "convex_hull_area": area, "convex_hull_perimeter": float(perimeter), "area_perimeter_ratio": float(area_perimeter_ratio), } except Exception: return { "convex_hull_area": 0.0, "convex_hull_perimeter": 0.0, "area_perimeter_ratio": 0.0, }
[docs] def _calculate_relative_nucleus_size( self, target_cell: dict[str, Any], local_density: float ) -> float: """Calculate nucleus size relative to local density.""" # Get area from contour if available, otherwise fall back to area or size fields nucleus_size = target_cell.get("area", 0.0) # Use contour to calculate area if available and area not already set if nucleus_size == 0.0 and "contour" in target_cell: contour = target_cell["contour"] if contour is not None and len(contour) > 0: try: # Convert contour points to the format expected by cv2.contourArea # It should be an array of points with shape (n, 1, 2) where n is number of points if isinstance(contour, list): contour_array = np.array(contour, dtype=np.int32) # Reshape if needed if ( len(contour_array.shape) == 2 and contour_array.shape[1] == 2 ): # Already in format [[x1, y1], [x2, y2], ...] - convert to required format contour_array = contour_array.reshape((-1, 1, 2)) nucleus_size = float(cv2.contourArea(contour_array)) else: # Assuming contour is already a numpy array nucleus_size = float(cv2.contourArea(contour)) except Exception: # Fall back to area or size if contour processing fails nucleus_size = 0.0 # Fall back to size if contour and area are not available if nucleus_size == 0.0: nucleus_size = target_cell.get("size", 1.0) if local_density > 0: return float(nucleus_size / local_density) return float(nucleus_size)
[docs] def _calculate_mean_orientation(self, neighbor_data: dict[str, Any]) -> float: """Calculate mean orientation of neighbors.""" if not neighbor_data["directions_x"] or not neighbor_data["directions_y"]: return 0.0 directions_x_array = np.array(neighbor_data["directions_x"]) directions_y_array = np.array(neighbor_data["directions_y"]) # Directly average the vectors mean_x = np.mean(directions_x_array) mean_y = np.mean(directions_y_array) # Convert average vector back to an angle mean_angle = np.arctan2(mean_y, mean_x) return float(mean_angle)