Source code for cellmil.graph.graph_creator

import torch
import json
import random
import numpy as np
import plotly.graph_objs as go  # type: ignore
from typing import Any, cast
from pathlib import Path
from cellmil.utils import logger
from cellmil.interfaces import GraphCreatorConfig
from cellmil.interfaces.GraphCreatorConfig import GraphCreatorCategory
from cellmil.interfaces.FeatureExtractorConfig import ExtractorType
from cellmil.datamodels.datasets.utils import get_cell_features, validate_features, get_cell_types, cell_types_to_tensor
from cellmil.datamodels.transforms.correlation_filter import CorrelationFilterTransform
from cellmil.datamodels.transforms.normalization import RobustScalerTransform
from .debug_visualizer import GraphDebugVisualizer
from .creator import Creator

[docs]class GraphCreator: """Main class for graph creation from segmented cells."""
[docs] def __init__(self, config: GraphCreatorConfig) -> None: self.config = config self.config.patched_slide_path = Path(self.config.patched_slide_path) self.device = f"cuda:{self.config.gpu}" if torch.cuda.is_available() else "cpu" # self.device = "cpu" # Create the creator with appropriate parameters self.creator = Creator( self.config.method, self.device, ) # TODO: Make configurable self.correlation_threshold = 0.9 self.feature_extractors = [ExtractorType.morphometrics, ExtractorType.pyradiomics_hed] self.cell_type = True self.debug_sample_size = 2000 # Number of cells to sample in debug mode # TODO: ----- logger.info(f"Graph creator initialized with config: {self.config}, GPU: {self.device}") self._load_cells() if self.config.method in GraphCreatorCategory.FeatureDependent: logger.info(f"Using feature-dependent graph creation method: {self.config.method}") self._load_morphological_features(self.feature_extractors)
[docs] def _load_morphological_features(self, feature_extractors: list[ExtractorType]) -> None: """Load morphological features for similarity-based graph creation.""" valid_features = validate_features( folder=self.config.patched_slide_path.parent, slide_name=self.config.patched_slide_path.name, extractor=feature_extractors, segmentation_model=self.config.segmentation_model ) if valid_features is False: raise ValueError("Morphological features are not valid for graph creation.") features, cell_indices, _ = get_cell_features( folder=self.config.patched_slide_path.parent, slide_name=self.config.patched_slide_path.name, extractor=feature_extractors, segmentation_model=self.config.segmentation_model ) if features is None or cell_indices is None: raise ValueError("Failed to load morphological features for graph creation.") # Filter cells to only those with available features original_count = len(self.cells) self.cells = [cell for cell in self.cells if cell["cell_id"] in cell_indices] filtered_count = len(self.cells) if filtered_count < original_count: logger.info( f"Filtered cells from {original_count} to {filtered_count} based on available features. " f"Excluded {original_count - filtered_count} cells without features." ) # Apply correlation filtering to remove redundant features logger.info(f"Applying correlation filter to {features.shape[1]} features...") corr_filter = CorrelationFilterTransform( correlation_threshold=self.correlation_threshold, plot_correlation_matrix=False, ) corr_filter.fit(features) filtered_features = corr_filter.transform(features) logger.info(f"Features reduced to {filtered_features.shape[1]} after correlation filtering") # Apply robust scaling for normalization logger.info("Applying robust scaling to features...") robust_scaler = RobustScalerTransform( apply_log_transform=True, quantile_range=(0.25, 0.75), clip_quantiles=(0.005, 0.995), constant_threshold=1e-8, ) robust_scaler.fit(filtered_features) normalized_features = robust_scaler.transform(filtered_features) logger.info("Features normalized using robust scaling") # Load and concatenate cell types if requested if self.cell_type: logger.info("Loading cell types for feature concatenation...") cell_types_dict = get_cell_types( self.config.patched_slide_path.parent, self.config.patched_slide_path.name, self.config.segmentation_model ) if cell_types_dict is None or len(cell_types_dict) == 0: logger.warning(f"No cell types found for slide. Skipping cell type concatenation.") cell_types_tensor = None else: # Convert cell types to one-hot tensor cell_types_tensor = cell_types_to_tensor(cell_types_dict, cell_indices) logger.info(f"Loaded cell types: {cell_types_tensor.shape}") # Concatenate cell types to features normalized_features = torch.cat([normalized_features, cell_types_tensor], dim=1) # normalized_features = cell_types_tensor.float() logger.info( f"Concatenated {cell_types_tensor.shape[1]} cell type features. " f"New feature dimension: {normalized_features.shape[1]}" ) # Add processed features to each cell using cell_indices # cell_indices is a dict: {cell_id: row_index_in_features} for cell in self.cells: cell_id = cell["cell_id"] if cell_id in cell_indices: feature_idx = cell_indices[cell_id] cell["features"] = normalized_features[feature_idx, :] logger.info( f"Added {normalized_features.shape[1]} processed features to {len(self.cells)} cells" )
[docs] def _load_cells(self): """Load cell data from the specified path.""" json_path = ( self.config.patched_slide_path / "cell_detection" / self.config.segmentation_model / "cells.json" ) if not json_path.exists(): raise FileNotFoundError(f"Cells path {json_path} does not exist.") with open(json_path, "r") as f: self.json = json.load(f) all_cells: list[dict[str, Any]] = self.json["cells"] # Sample cells if in debug mode if self.config.debug and len(all_cells) > self.debug_sample_size: random.seed(42) # For reproducibility self.cells = random.sample(all_cells, self.debug_sample_size) logger.info(f"Debug mode: Sampled {len(self.cells)} cells from {len(all_cells)} total cells") else: self.cells = all_cells logger.info(f"Loaded {len(self.cells)} cells from {json_path}")
[docs] def create_graph(self) -> None: """Create a graph from the segmented cells.""" # Launch debug visualizer if in debug mode if self.config.debug: self._launch_debug_visualizer() return logger.info(f"Creating graph using {self.config.method} method...") # Create graph using the Creator class node_features, edge_indices, edge_features = self.creator.create( self.cells ) # Save graph self._save_graph(node_features, edge_indices, edge_features) # Plot graph if self.config.plot: self._plot_graph(node_features, edge_indices) logger.info("Graph creation completed successfully.")
[docs] def _launch_debug_visualizer(self) -> None: """Launch the interactive debug visualizer.""" logger.info("Launching debug visualizer...") logger.info(f"Using {len(self.cells)} sampled cells for visualization") visualizer = GraphDebugVisualizer( cells=self.cells, method=self.config.method, device=self.device, ) visualizer.run(port=8050, debug=True)
[docs] def _save_graph( self, node_features: torch.Tensor, edge_indices: torch.Tensor, edge_features: torch.Tensor ): """Save the graph to a PyTorch file.""" # Create output directory output_dir = ( self.config.patched_slide_path / "graphs" / self.config.method / self.config.segmentation_model ) output_dir.mkdir(parents=True, exist_ok=True) # Prepare graph data graph_data: dict[str, Any] = { "node_features": node_features, "edge_indices": edge_indices, "edge_features": edge_features, "metadata": { "n_nodes": node_features.shape[0], "n_edges": edge_indices.shape[1], "node_feature_dim": node_features.shape[1], "edge_feature_dim": edge_features.shape[1], "graph_method": self.config.method, "segmentation_model": self.config.segmentation_model, "name_node_features": "cell_id", "name_edge_features": ["distance", "direction_x", "direction_y"] } } # Save graph output_path = output_dir / "graph.pt" torch.save(graph_data, output_path) logger.info(f"Graph saved to {output_path}") logger.info(f"Graph statistics: {graph_data['metadata']['n_nodes']} nodes, {graph_data['metadata']['n_edges']} edges")
[docs] def _plot_graph(self, node_features: torch.Tensor, edge_indices: torch.Tensor) -> None: """Create and save a Plotly visualization of the graph.""" try: # Create output directory for images output_dir = ( self.config.patched_slide_path / "graphs" / self.config.method / self.config.segmentation_model ) output_dir.mkdir(parents=True, exist_ok=True) title = f'Cell Graph Visualization ({self.config.method})' positions = self._get_cells_position(node_features) # Convert positions to numpy for plotting pos_np = cast( np.ndarray[Any, Any], positions.cpu().numpy() if positions.is_cuda else positions.numpy(), # type: ignore ) edges_np = cast( np.ndarray[Any, Any], ( edge_indices.cpu().numpy() # type: ignore if edge_indices.is_cuda else edge_indices.numpy() # type: ignore ), ) max_nodes_to_plot = 100000 if len(pos_np) > max_nodes_to_plot: # Sample nodes first node_sample_indices = np.random.choice( len(pos_np), max_nodes_to_plot, replace=False ) node_set = set(node_sample_indices) pos_sample = pos_np[node_sample_indices] # Create mapping from old indices to new indices old_to_new = { old_idx: new_idx for new_idx, old_idx in enumerate(node_sample_indices) } logger.info( f"Sampling {max_nodes_to_plot} nodes for visualization from {len(pos_np)} total nodes" ) # Filter edges to only include those between sampled nodes valid_edges: list[list[int]] = [] if edges_np.shape[1] > 0: for i in range(edges_np.shape[1]): src_idx, dst_idx = edges_np[0, i], edges_np[1, i] if src_idx in node_set and dst_idx in node_set: # Remap indices to the sampled node indices new_src = old_to_new[src_idx] new_dst = old_to_new[dst_idx] valid_edges.append([new_src, new_dst]) if valid_edges: edges_to_plot = np.array(valid_edges).T logger.info( f"Plotting {len(valid_edges)} edges between sampled nodes from {edges_np.shape[1]} total edges" ) else: edges_to_plot = np.empty((2, 0)) else: pos_sample = pos_np edges_to_plot = edges_np # Create edge traces for Plotly edge_x: list[float | None] = [] edge_y: list[float | None] = [] if edges_to_plot.shape[1] > 0: for i in range(edges_to_plot.shape[1]): start_idx, end_idx = int(edges_to_plot[0, i]), int(edges_to_plot[1, i]) start_pos = pos_sample[start_idx] end_pos = pos_sample[end_idx] edge_x.extend([start_pos[0], end_pos[0], None]) edge_y.extend([start_pos[1], end_pos[1], None]) edge_trace = go.Scatter( x=edge_x, y=edge_y, line=dict(width=1.5, color='rgba(70, 130, 180, 0.6)'), hoverinfo='none', mode='lines', showlegend=False ) # Create node trace node_x = pos_sample[:, 0].tolist() node_y = pos_sample[:, 1].tolist() node_trace = go.Scatter( x=node_x, y=node_y, mode='markers', hoverinfo='text', text=[f'Cell {i}' for i in range(len(pos_sample))], marker=dict( size=4, color='rgba(255, 100, 100, 0.7)', line=dict(width=0.5, color='rgba(255, 255, 255, 0.5)'), opacity=0.8 ), showlegend=False ) # Create figure with elegant styling fig = go.Figure( data=[edge_trace, node_trace], layout=go.Layout( title=dict( text=title, font=dict(size=20, color='black', family='Arial Black'), x=0.5 ), showlegend=False, hovermode='closest', margin=dict(b=30, l=20, r=20, t=60), annotations=[ dict( text=f"Nodes: {len(pos_sample):,} | Edges: {edges_to_plot.shape[1]:,}", showarrow=False, xref="paper", yref="paper", x=0.02, y=0.98, xanchor='left', yanchor='top', font=dict( color='rgba(0, 0, 0, 0.6)', size=12, family='Arial' ) ) ], xaxis=dict( showgrid=False, zeroline=False, showticklabels=True, title='X coordinate', showline=True, linecolor='rgba(0, 0, 0, 0.3)' ), yaxis=dict( showgrid=False, zeroline=False, showticklabels=True, title='Y coordinate', autorange='reversed', showline=True, linecolor='rgba(0, 0, 0, 0.3)', scaleanchor='x', scaleratio=1 ), plot_bgcolor='white', paper_bgcolor='white', font=dict(color='black') ) ) # Save interactive HTML plot html_path = output_dir / "graph_interactive.html" fig.write_html(str(html_path)) # type: ignore # Save static PNG image png_path = output_dir / "graph_visualization.png" fig.write_image(str(png_path), width=1920, height=1080, scale=2) # type: ignore logger.info(f"Static graph visualization saved to {png_path}") logger.info(f"Interactive graph visualization saved to {html_path}") except Exception as e: logger.warning(f"Failed to create graph visualization: {e}") import traceback traceback.print_exc()
[docs] def _get_cells_position(self, node_features: torch.Tensor) -> torch.Tensor: """Get the positions of the cell centroid for visualization.""" # Extract centroid positions from cells centroids = np.array([cell["centroid"] for cell in self.cells]) # Convert to tensor positions = torch.from_numpy(centroids).float() # type: ignore return positions