Source code for cellmil.features.feature_extractor

import json
import cv2
import traceback
import torch
import radiomics  # type: ignore
import numpy as np
import yaml
from typing import Any, Optional, Generator, cast
from pathlib import Path
from tqdm import tqdm
from multiprocessing import Pool
from torch.utils.data import Dataset, DataLoader

from .extractor import Extractor
from cellmil.utils import logger
from cellmil.utils.stain_normalization import macenko_normalization
from cellmil.interfaces import FeatureExtractorConfig
from cellmil.interfaces.FeatureExtractorConfig import FeatureExtractionType, ExtractorType
from .extractor import MorphologicalExtractor, TopologicalExtractor, EmbeddingExtractor
from cellmil.utils.tools import get_cpu_count

pyradiomics_logger = radiomics.logger
pyradiomics_logger.setLevel("ERROR")

# Per-process extractor for multiprocessing workers to avoid recreating it per task
_PROCESS_EXTRACTOR: Optional[Extractor] = None

[docs]def _pool_init(extractor_type: ExtractorType) -> None: """Initializer for multiprocessing Pool to create one Extractor per process.""" global _PROCESS_EXTRACTOR _PROCESS_EXTRACTOR = Extractor.create(extractor_type) # type: ignore
[docs]class FeatureExtractor: """Feature extractor for whole slide images."""
[docs] def __init__(self, config: FeatureExtractorConfig): self.config = config self.config.patched_slide_path = Path(self.config.patched_slide_path) self.config.wsi_path = Path(self.config.wsi_path) if self.config.wsi_path else None # TODO: Make configurable self.parallel = True self.stain_normalization = True self.test = False # TODO: ----- if self.test or self.config.extractor in FeatureExtractionType.Embedding : self.n_workers = 1 elif self.config.extractor in FeatureExtractionType.Morphological: self.n_workers = max(1, min(get_cpu_count() - 1, 61)) elif self.config.extractor in FeatureExtractionType.Topological: self.n_workers = max(1, min(get_cpu_count() - 1, 8)) logger.info(f"Feature extractor initialized with config: {self.config}") logger.info(f"Number of workers: {self.n_workers}") self._load_metadata() if self.config.extractor in FeatureExtractionType.Morphological or self.config.extractor in FeatureExtractionType.Topological: self._load_cells() if self.config.extractor in FeatureExtractionType.Topological: self._load_graph() if self.config.extractor in FeatureExtractionType.Embedding: self._load_patches()
[docs] def _load_cells(self): """Load cell data from the specified path.""" if self.config.segmentation_model is None: raise ValueError("Segmentation model must be specified.") json_path = ( self.config.patched_slide_path / "cell_detection" / str(self.config.segmentation_model) / "cells.json" ) if not json_path.exists(): raise FileNotFoundError(f"Cells path {json_path} does not exist.") with open(json_path, "r") as f: self.json = json.load(f) self.cells: list[dict[str, Any]] = self.json["cells"] logger.info(f"Loaded {len(self.cells)} cells from {json_path}")
[docs] def _load_metadata(self): """Load metadata from the patched slide path to get patch size and overlap.""" metadata_path = self.config.patched_slide_path / "metadata.yaml" if not metadata_path.exists(): raise FileNotFoundError(f"Metadata file not found: {metadata_path}") with open(metadata_path, "r") as file: self.metadata = yaml.safe_load(file) # Extract key parameters self.patch_size = self.metadata.get("patch_size", 256) self.patch_overlap = self.metadata.get("patch_overlap", 0) self.target_mag = self.metadata.get("magnification", 40.0) logger.info(f"Loaded metadata: patch_size={self.patch_size}, patch_overlap={self.patch_overlap}, target_mag={self.target_mag}")
[docs] def _load_graph(self): """Load graph from the specified path""" if self.config.segmentation_model is None: raise ValueError("Segmentation model must be specified.") if self.config.graph_method is None: raise ValueError("Graph method must be specified.") graph_path = ( self.config.patched_slide_path / "graphs" / str(self.config.graph_method) / str(self.config.segmentation_model) / "graph.pt" ) if not graph_path.exists(): raise FileNotFoundError(f"Graph path {graph_path} does not exist.") _graph = torch.load(graph_path, map_location="cpu", weights_only=False) self.graph = { "node_features": _graph["node_features"], "edge_indices": _graph["edge_indices"], "edge_features": _graph["edge_features"], } logger.info(f"Loaded graph from {graph_path}")
def _load_patches(self): folder = ( self.config.patched_slide_path / "patches" ) if not folder.exists() or not folder.is_dir(): raise FileNotFoundError(f"Patches folder {folder} does not exist or is not a directory.") # Get all PNG files in the patches folder self.patches = list(folder.glob("*.png")) logger.info(f"Loaded {len(self.patches)} patch paths from {folder}")
[docs] def get_features(self): """Extract features from the whole slide image using parallel processing.""" logger.info( f"Extracting features for {self.config.patched_slide_path}..." ) output_path = None if self.config.extractor in FeatureExtractionType.Morphological: def _arg_iter() -> Generator[dict[str, Any], None, None]: for cell in self.cells: yield { "cell_data": cell, "patch_size": self.patch_size, "patch_overlap": self.patch_overlap, "extractor_name": str(self.config.extractor), "patched_slide_path": str(self.config.patched_slide_path), "stain_normalization": self.stain_normalization, } total_tasks = len(self.cells) chunksize = max(1, min(1000, total_tasks // max(1, self.n_workers * 8))) logger.info(f"Multiprocessing chunksize: {chunksize}") results: list[dict[str, Any]] = [] with Pool( self.n_workers, initializer=_pool_init, initargs=(str(self.config.extractor),), maxtasksperchild=500, ) as pool: for res in tqdm( pool.imap_unordered( FeatureExtractor._get_morphological_features, _arg_iter(), chunksize=chunksize ), total=total_tasks, desc="Extracting features", mininterval=0.5, ): if res is not None: results.append(res) logger.info( f"Successfully extracted features for {len(results)} cells out of {len(self.cells)} total cells" ) output_path = self._save_pt(results) elif self.config.extractor in FeatureExtractionType.Topological: # if self.config.extractor in [ExtractorType.connectivity, ExtractorType.structure]: # Initialize the extractor _pool_init(self.config.extractor) results: list[dict[str, Any] ] = [] for node in tqdm(self.graph["node_features"], desc="Extracting topological features"): res = self._get_topological_features({ "cell_id": node, "graph": self.graph, "cells": self.cells, "extractor_name": str(self.config.extractor) }) if res is not None: results.append(res) output_path = self._save_pt(results) # else: # def _arg_iter() -> Generator[dict[str, Any], None, None]: # for node in self.graph["node_features"]: # yield { # "cell_id": node, # "graph": self.graph.copy(), # "cells": self.cells.copy(), # "extractor_name": str(self.config.extractor), # } # total_tasks = len(self.graph["node_features"]) # chunksize = max(1, min(1000, total_tasks // max(1, self.n_workers * 8))) # logger.info(f"Multiprocessing chunksize: {chunksize}") # results: list[dict[str, Any]] = [] # with Pool( # self.n_workers, # initializer=_pool_init, # initargs=(str(self.config.extractor),), # ) as pool: # for res in tqdm( # pool.imap_unordered( # FeatureExtractor._get_topological_features, _arg_iter(), chunksize=chunksize # ), # total=total_tasks, # desc="Extracting features", # mininterval=0.5, # ): # if res is not None: # results.append(res) # logger.info( # f"Successfully extracted features for {len(results)} cells out of {len(self.cells)} total cells" # ) # output_path = self._save_pt(results) elif self.config.extractor in FeatureExtractionType.Embedding: _pool_init(self.config.extractor) results = self._get_embedding_features({ "patches": self.patches, "extractor_name": str(self.config.extractor) }) output_path = self._save_pt(results) logger.info( f"Results saved to {output_path}" if output_path else "No results to save" )
[docs] @staticmethod def _load_patch(patched_slide_path: Path, row: int, col: int) -> np.ndarray[Any, Any]: """Load a patch image from the file system. Args: patched_slide_path (Path): Path to the patched slide directory row (int): Row coordinate of the patch col (int): Column coordinate of the patch Returns: np.ndarray: The patch image as a numpy array Raises: FileNotFoundError: If the patch file doesn't exist """ if isinstance(patched_slide_path, str): patched_slide_path = Path(patched_slide_path) patch_path = FeatureExtractor._get_patch_path(patched_slide_path, row, col) if patch_path.exists(): patch = cv2.imread(str(patch_path)) # Convert BGR to RGB patch = cv2.cvtColor(patch, cv2.COLOR_BGR2RGB) return patch else: logger.error( f"Patch file {patch_path} does not exist." ) raise FileNotFoundError( f"Patch file {patch_path} does not exist. Please ensure the patch extraction step was completed." )
[docs] @staticmethod def _get_patch_path(patched_slide_path: Path, row: int, col: int) -> Path: """Get the path to a patch file based on row and column coordinates. Args: patched_slide_path (Path): Path to the patched slide directory row (int): Row coordinate of the patch col (int): Column coordinate of the patch Returns: Path: Full path to the patch file """ wsi_name = patched_slide_path.name patch_filename = f"{wsi_name}_{row}_{col}.png" patch_path = patched_slide_path / "patches" / patch_filename return patch_path
[docs] def _save_pt(self, results: list[dict[str, Any]]): """Save extracted features to a PyTorch file. (Faster than JSON)""" logger.info("Saving features to PyTorch file") # Create output directory if it doesn't exist if self.config.extractor in FeatureExtractionType.Morphological: output_dir = ( self.config.patched_slide_path / "feature_extraction" / str(self.config.extractor) / str(self.config.segmentation_model) ) elif self.config.extractor in FeatureExtractionType.Topological: output_dir = ( self.config.patched_slide_path / "feature_extraction" / str(self.config.extractor) / str(self.config.graph_method) / str(self.config.segmentation_model) ) else: output_dir = ( self.config.patched_slide_path / "feature_extraction" / str(self.config.extractor) ) output_dir.mkdir(parents=True, exist_ok=True) filename = "features.pt" output_path = output_dir / filename # Save to PyTorch file try: if not results or len(results) == 0: if self.config.extractor in FeatureExtractionType.Embedding: output: dict[str, Any] = { "features": torch.empty((0, 0), dtype=torch.float32), "patch_indices": {}, # Empty dictionary "feature_names": [], # Empty list } else: output = { "features": torch.empty((0, 0), dtype=torch.float32), "cell_indices": {}, # Empty dictionary "feature_names": [], # Empty list } torch.save(output, output_path) logger.info(f"Successfully saved empty features file to {output_path}") logger.info("Feature tensor shape: (0, 0)") else: # Create tensor from feature dictionaries features_tensor = self._results_to_tensor(results) feature_names = list(results[0]["features"].keys()) feature_names.sort() # Sort to ensure consistent order with tensor # Create identifier to index mapping dictionary if self.config.extractor in FeatureExtractionType.Embedding: # For embeddings, use patch_name as identifier patch_indices = {item["patch_name"]: i for i, item in enumerate(results)} output: dict[str, Any] = { "features": features_tensor, "patch_indices": patch_indices, # Dictionary mapping identifier to tensor index "feature_names": feature_names, # List of feature names in tensor order } else: # For morphological/topological, use cell_id as identifier cell_indices = {item["cell_id"]: i for i, item in enumerate(results)} output = { "features": features_tensor, "cell_indices": cell_indices, # Dictionary mapping identifier to tensor index "feature_names": feature_names, # List of feature names in tensor order } torch.save(output, output_path) logger.info(f"Successfully saved {len(results)} results to {output_path}") logger.info(f"Feature tensor shape: {features_tensor.shape}") return output_path except Exception as e: logger.info(f"Error saving results to {output_path}: {e}") return None
[docs] def _results_to_tensor(self, data: list[dict[str, Any]]) -> torch.Tensor: """Convert a dictionary of features to a PyTorch tensor. Args: data: List of dictionaries containing cell_id and features Returns: torch.Tensor: Tensor of shape [N, D] where N is the number of cells and D is the number of features """ if not data or len(data) == 0: logger.warning("No data to convert to tensor") return torch.empty((0, 0), dtype=torch.float32) # Get feature names from the first element feature_names = list(data[0]["features"].keys()) feature_names.sort() # Sort to ensure consistent order # Create a tensor of the appropriate size n_cells = len(data) n_features = len(feature_names) features_tensor = torch.zeros((n_cells, n_features), dtype=torch.float32) # Fill in the tensor for i, item in enumerate(tqdm(data, desc="Converting features to tensor")): # Get identifier for logging - handle both cell_id and patch_name identifier = item.get("cell_id") or item.get("patch_name", f"item_{i}") for j, feature_name in enumerate(feature_names): # Get feature value, defaulting to 0.0 if not present value = item["features"].get(feature_name, 0.0) # Convert to float if it's a scalar, otherwise use first element try: if isinstance(value, (int, float, np.ndarray)): features_tensor[i, j] = float(value) # type: ignore elif hasattr(value, "__len__") and len(value) > 0: features_tensor[i, j] = float(value[0]) else: features_tensor[i, j] = 0.0 logger.warning( f"Using default value 0.0 for feature {feature_name} (identifier {identifier}): incompatible type {type(value)}" ) except (ValueError, TypeError) as e: features_tensor[i, j] = 0.0 logger.warning( f"Error converting feature {feature_name} for identifier {identifier}: {e}" ) logger.info(f"Created features tensor with shape {features_tensor.shape}") return features_tensor
# ---- MORPHOLOGICAL EXTRACTION ----
[docs] @staticmethod def _get_morphological_features(item: dict[str, Any]) -> Optional[dict[str, Any]]: """Process a single cell and extract features.""" i = int(item["cell_data"]["cell_id"]) # type: ignore cell = dict(item["cell_data"]) # type: ignore patch_size = item["patch_size"] # type: ignore extractor_name = item["extractor_name"] # type: ignore patched_slide_path = Path(item.get("patched_slide_path", "")) # type: ignore patch_overlap = item.get("patch_overlap", 0) # type: ignore stain_normalization = item.get("stain_normalization", False) # type: ignore # Use per-process extractor if available to avoid re-initialization overhead global _PROCESS_EXTRACTOR extractor = _PROCESS_EXTRACTOR if _PROCESS_EXTRACTOR is not None else Extractor.create(extractor_name) # type: ignore try: # Check if contour is valid if not cell["contour"] or len(cell["contour"]) < 3: logger.warning( f"Skipping cell {i}: Invalid contour (need at least 3 points)" ) return None # Check if we can use pre-extracted patch if patched_slide_path and patched_slide_path.exists() and "patch_coordinates" in cell: try: row, col = cell["patch_coordinates"] except (ValueError, TypeError, KeyError): raise ValueError(f"Invalid patch_coordinates for cell {i}") else: raise ValueError(f"Missing patch_coordinates for cell {i}") # Get bounding box from cell contour to determine patch region contour_points = np.array(cell["contour"]) x_coords = contour_points[:, 0] y_coords = contour_points[:, 1] y_min, y_max = y_coords.min(), y_coords.max() x_min, x_max = x_coords.min(), x_coords.max() # Add some padding pad = patch_size // 2 patch = None mask = None if x_max - x_min <= 0 or y_max - y_min <= 0: logger.warning( f"Invalid patch size for cell {i}: x_min={x_min}, x_max={x_max}, y_min={y_min}, y_max={y_max}" ) return None # Prepare patch and mask patch_shape = (patch_size, patch_size) # Use pre-extracted patch row, col = cell["patch_coordinates"] patch = FeatureExtractor._load_patch(patched_slide_path, row, col) # Apply stain normalization if enabled if stain_normalization: normalized_patches, _, _ = macenko_normalization([patch]) if normalized_patches is not None: # type: ignore patch = normalized_patches[0] x_global = int( row * patch_size - (row + 0.5) * patch_overlap ) y_global = int( col * patch_size - (col + 0.5) * patch_overlap ) # Create mask mask, mask_area = FeatureExtractor.create_cell_mask( cell["contour"], (x_global, y_global, x_global + patch_shape[1], y_global + patch_shape[0]), patch_shape, ) if mask_area == 0: logger.warning(f"Skipping cell {i}: Empty mask created") return None # feature extraction if not isinstance(extractor, MorphologicalExtractor): raise TypeError(f"Expected MorphologicalExtractor, got {type(extractor)}") features = extractor.extract_features(patch, mask) # Time result preparation # Clean features (remove metadata) clean_features = { k: v for k, v in features.items() if not k.startswith("diagnostics_") } # type: ignore # Add cell metadata result: dict[str, Any] = { "cell_id": i, "features": clean_features, } # Clean up memory del patch, mask, features, clean_features del contour_points, x_coords, y_coords del x_min, y_min, x_max, y_max, pad return result except Exception as e: logger.warning(f"Error processing cell {i}: {e}") traceback.print_exc() return None
[docs] @staticmethod def create_cell_mask( cell_contour: list[list[int]], patch_bounds: tuple[int, int, int, int], patch_shape: tuple[int, int], ) -> tuple[np.ndarray[Any, Any], int]: """ Create a binary mask from cell contour for a specific patch region only. Args: cell_contour: List of [x, y] coordinates defining the cell boundary patch_bounds: Tuple of (x_min, y_min, x_max, y_max) defining the patch region patch_shape: Tuple of (height, width) for the output mask Returns: Binary mask array with shape patch_shape """ x_min, y_min, _, _ = patch_bounds mask = np.zeros(patch_shape, dtype=np.uint8) if not cell_contour or len(cell_contour) < 3: logger.warning( f"Invalid contour with {len(cell_contour) if cell_contour else 0} points" ) return mask, 0 # Translate contour coordinates to patch-relative coordinates (vectorized) contour_array = np.array(cell_contour) translated_contour_array = contour_array - np.array([y_min, x_min]) # Filter points that are within patch bounds (vectorized) x_valid = (translated_contour_array[:, 0] >= 0) & (translated_contour_array[:, 0] < patch_shape[1]) y_valid = (translated_contour_array[:, 1] >= 0) & (translated_contour_array[:, 1] < patch_shape[0]) valid_mask = x_valid & y_valid # Get valid translated points translated_contour_array = translated_contour_array[valid_mask] # Log warnings for invalid points if any if not np.all(valid_mask): invalid_points = contour_array[~valid_mask] invalid_translated = (contour_array - np.array([x_min, y_min]))[~valid_mask] for orig_point, trans_point in zip(invalid_points, invalid_translated): logger.warning( f"Point {orig_point} translated to ({trans_point[0]}, {trans_point[1]}) is outside patch bounds" ) # If we have enough points after translation, create the mask if len(translated_contour_array) >= 3: contour_array = translated_contour_array.astype(np.int32) cv2.fillPoly(mask, [contour_array], 1) # type: ignore area = np.sum(mask) if area == 0: logger.warning( f"Empty mask created for contour with {len(translated_contour_array)} translated points" ) return mask, int(area) else: logger.warning( f"Not enough points ({len(translated_contour_array)}) in patch region for contour" ) return mask, 0
# ---- TOPOLOGICAL EXTRACTION ----
[docs] @staticmethod def _get_topological_features(item: dict[str, Any]) -> Optional[dict[str, Any]]: """Process a single cell and extract features.""" cell_id = item["cell_id"] graph = item["graph"] cells = item["cells"] extractor_name = item["extractor_name"] global _PROCESS_EXTRACTOR extractor = _PROCESS_EXTRACTOR if _PROCESS_EXTRACTOR is not None else Extractor.create(extractor_name) try: # feature extraction if not isinstance(extractor, TopologicalExtractor): raise TypeError(f"Expected TopologicalExtractor, got {type(extractor)}") features = extractor.extract_features(cell_id, graph, cells) result: dict[str, Any] = { "cell_id": cell_id.item(), "features": features, } return result except Exception as e: logger.warning(f"Error processing node/cell {cell_id.item()}: {e}") traceback.print_exc() return None
# ---- EMBEDDING EXTRACTION ---- @staticmethod def _get_embedding_features(item: dict[str, Any]) -> list[dict[str, Any]]: patches = cast(list[str], item["patches"]) extractor_name = item["extractor_name"] global _PROCESS_EXTRACTOR extractor = _PROCESS_EXTRACTOR if _PROCESS_EXTRACTOR is not None else Extractor.create(extractor_name) if not isinstance(extractor, EmbeddingExtractor): raise TypeError(f"Expected EmbeddingExtractor, got {type(extractor)}") class PatchesDataset(Dataset[torch.Tensor]): def __init__(self, patches: list[str]): self.patches = patches def __len__(self): return len(self.patches) def __getitem__(self, idx: int) -> torch.Tensor: patch = self.patches[idx] image = cv2.imread(str(patch)) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to tensor for consistency with extractor expectations image_tensor = torch.from_numpy(image).permute(2, 0, 1).float() # type: ignore return image_tensor dataset = PatchesDataset(patches) dataloader = DataLoader( dataset, batch_size=32, shuffle=False, num_workers=4 ) features: list[torch.Tensor] = [] with torch.no_grad(): for batch in tqdm(dataloader, desc="Extracting features"): batch_features = extractor.extract_features(batch) features.append(batch_features) # Concatenate all features from batches all_features = torch.cat(features, dim=0) # Build list of dictionaries mapping patch name to features results:list[dict[str, Any]] = [] for i, patch_path in tqdm(enumerate(patches), desc="Mapping patches to features"): patch_name = Path(patch_path).stem # Extract just the row_col coordinates from the patch name # Assuming format: slidename_row_col.png -> extract "row_col" parts = patch_name.split('_') if len(parts) >= 2: # Take the last two parts as row_col patch_coordinates = f"{parts[-2]}_{parts[-1]}" else: # Fallback to full name if format is unexpected patch_coordinates = patch_name feature_dict = {str(j): all_features[i, j].item() for j in range(all_features.shape[1])} results.append({ "patch_name": patch_coordinates, "features": feature_dict }) return results