# -*- coding: utf-8 -*-
# Cell Segmenter
# This module was inspired by the CellViT project and adapted for the CellMIL framework.
#
# References:
# CellViT: Vision Transformers for precise cell segmentation and classification
# Fabian Hörst et al., Medical Image Analysis, 2024
# DOI: https://doi.org/10.1016/j.media.2024.103143
import yaml
import os
import torch
import ujson
import numpy as np
import logging
import pandas as pd
from tqdm import tqdm
from collections import deque
from typing import Any, Literal
from pathlib import Path
from torch.utils.data import DataLoader
import urllib.request
import shutil
from shapely import strtree
from shapely.geometry import Polygon, MultiPolygon
from torchvision import transforms as T # type: ignore
from torch.nn import functional as F
from cellmil.interfaces import CellSegmenterConfig, PatchExtractorConfig
from cellmil.utils import logger
from cellmil.utils.tools import unflatten_dict, to_float, to_float_normalized
from cellmil.datamodels import WSI
from cellmil.models.segmentation import HoVerNet, CellViTSAM, CellposeSAM
from .datasets.patched_wsi_inference import PatchedWSIInference
from .utils.geojson import convert_geojson
[docs]class CellSegmenter:
"""Class for cell instance segmentation in whole slide images.
Features reliable multiprocessing with Pool.imap and automatic fallback to
sequential processing if multiprocessing fails.
"""
[docs] def __init__(self, config: CellSegmenterConfig) -> None:
self.config = config
# TODO: Make this configurable from config
self.geojson = True
self.batch_size = 64
self.mixed_precision = False
# TODO: ---
self.device = f"cuda:{self.config.gpu}" if torch.cuda.is_available() else "cpu"
self._setup_wsi()
self._setup_patch_config()
self._download_models()
self._load_model()
self._load_inference_transforms()
[docs] def set_wsi(self, wsi_path: str | Path, patched_slide_path: str | Path) -> None:
"""Set the WSI path and patched slide path for segmentation."""
self.config.wsi_path = Path(wsi_path)
self.config.patched_slide_path = Path(patched_slide_path)
logger.info("WSI and patched slide paths set")
self._setup_wsi()
self._setup_patch_config()
[docs] def _setup_wsi(self) -> None:
"""Setup the whole slide image for segmentation."""
logger.info("Processing single WSI file")
wsi_path = Path(self.config.wsi_path)
wsi_name = wsi_path.stem
wsi_file = WSI(
name=wsi_name,
patient=wsi_name,
slide_path=wsi_path,
patched_slide_path=self.config.patched_slide_path,
)
self.wsi = wsi_file
def _setup_patch_config(self) -> None:
metadata_path = Path(self.config.patched_slide_path) / "metadata.yaml"
if not metadata_path.exists():
raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
# Load and parse the YAML file
with open(metadata_path, "r") as file:
metadata = yaml.safe_load(file)
# Extract the patch configuration values
# Use defaults if the values are not found in the metadata
patch_size = metadata.get("patch_size", 1024)
patch_overlap = metadata.get("patch_overlap", 64)
target_mag = metadata.get("magnification", 40.0)
# Create a PatchExtractorConfig with the extracted values
self.patch_config = PatchExtractorConfig(
wsi_path=self.config.wsi_path,
patch_size=patch_size,
patch_overlap=patch_overlap,
target_mag=target_mag,
output_path=self.config.patched_slide_path,
)
# Validate patch size for CellposeSAM model
if self.config.model == "cellpose_sam" and patch_size != 256:
logger.error(
f"CellposeSAM model requires patches of size 256, but found patch size {patch_size}. "
f"Please re-extract patches with patch_size=256."
)
raise ValueError(
f"CellposeSAM model requires patches of size 256, but found patch size {patch_size}"
)
logger.info(f"Patch configuration: size={patch_size}, overlap={patch_overlap}")
[docs] def _download_models(self) -> None:
"""Download model checkpoints if they don't exist."""
# Skip downloading for CellposeSAM as it handles its own model downloads
if self.config.model == "cellpose_sam":
# Create the checkpoints directory but don't download anything
script_dir = Path(__file__).resolve()
checkpoints_dir = script_dir.parent.parent / "checkpoints"
checkpoints_dir.mkdir(exist_ok=True, parents=True)
self.checkpoints_dir = checkpoints_dir
logger.info("CellposeSAM handles its own model downloads")
return
# Get the parent of parent directory of this script
script_dir = Path(__file__).resolve()
checkpoints_dir = script_dir.parent.parent / "checkpoints"
# Create the checkpoints directory if it doesn't exist
checkpoints_dir.mkdir(exist_ok=True, parents=True)
# Define model URLs and their local paths
model_urls = {
"cellvit_20": "https://drive.usercontent.google.com/download?id=1wP4WhHLNwyJv97AK42pWK8kPoWlrqi30&export=download&authuser=0&confirm=t&uuid=f3ea9433-f877-47d3-8a8c-b50234c4b085&at=AN8xHopWwiqwCFH-QzM8xrPx4JKY:1750773930681",
"cellvit_40": "https://drive.usercontent.google.com/download?id=1MvRKNzDW2eHbQb5rAgTEp6s2zAXHixRV&export=download&authuser=0&confirm=t&uuid=549a273b-ba4b-43d4-b113-bdcb311b8f5f&at=AN8xHoqzbI-y5SjVdwrtaHnFpnEd:1750773931839",
"hovernet": "https://drive.usercontent.google.com/download?id=1vHOsIASmAmOmmmFXeE4C2CA6-6LOMiCh&export=download&confirm=t&uuid=a9470991-3391-4c1d-bfb3-17c95fb3fbc6",
}
# Check and download each model if not present
for model_name, url in model_urls.items():
if self.config.model not in model_name:
continue
_name = f"{model_name}.pth"
model_path = checkpoints_dir / _name
if not model_path.exists():
logger.info(f"Downloading {model_name} checkpoint...")
try:
# Create a request with a User-Agent to avoid some download restrictions
req = urllib.request.Request(
url, headers={"User-Agent": "Mozilla/5.0"}
)
with (
urllib.request.urlopen(req) as response,
open(model_path, "wb") as out_file,
):
shutil.copyfileobj(response, out_file)
logger.info(f"Downloaded {model_name} checkpoint to {model_path}")
except Exception as e:
logger.error(f"Error downloading {model_name} checkpoint: {e}")
else:
logger.info(f"{model_name} checkpoint already exists at {model_path}")
# Store the checkpoints directory
self.checkpoints_dir = checkpoints_dir
[docs] def _load_model(self) -> None:
"""Load model and checkpoint"""
# Use the downloaded model based on config.model
if not hasattr(self, "checkpoints_dir"):
self._download_models()
if self.config.model == "cellvit":
_path = (
"cellvit_20.pth"
if self.wsi.metadata["magnification"] == 20
else "cellvit_40.pth"
)
model_path = self.checkpoints_dir / _path
elif self.config.model == "cellpose_sam":
# CellposeSAM doesn't use a checkpoint file, it downloads models automatically
model_path = None
else:
model_path = self.checkpoints_dir / f"{self.config.model}.pth"
# Load checkpoint if model_path exists
model_checkpoint = None
if model_path is not None:
if not model_path.exists():
raise FileNotFoundError(f"Model checkpoint not found: {model_path}")
logger.info(f"Loading model: {model_path}")
model_checkpoint = torch.load(model_path, map_location="cpu")
else:
logger.info("Initializing CellposeSAM model (no checkpoint needed)")
if self.config.model in ("cellvit", "hovernet"):
self.nuclei_types = {
"Background": 0,
"Neoplastic": 1,
"Inflammatory": 2,
"Connective": 3,
"Dead": 4,
"Epithelial": 5,
}
else:
self.nuclei_types = {"Background": 0, "Cell": 1}
if self.config.model == "cellvit":
assert model_checkpoint is not None, (
"Model checkpoint is required for cellvit"
)
self.run_conf = unflatten_dict(model_checkpoint["config"], ".")
self.model = CellViTSAM(
model_path=None,
num_nuclei_classes=self.run_conf["data"]["num_nuclei_classes"],
num_tissue_classes=self.run_conf["data"]["num_tissue_classes"],
vit_structure=self.run_conf["model"]["backbone"],
regression_loss=self.run_conf["model"].get("regression_loss", False),
)
logger.info(
self.model.load_state_dict(model_checkpoint["model_state_dict"])
)
elif self.config.model == "hovernet":
assert model_checkpoint is not None, (
"Model checkpoint is required for hovernet"
)
self.model = HoVerNet()
logger.info(self.model.load_state_dict(model_checkpoint))
elif self.config.model == "cellpose_sam":
# Initialize CellposeSAM with device parameter
device = torch.device(self.device)
self.model = CellposeSAM(pretrained_model="cpsam", device=device)
logger.info("CellposeSAM model initialized successfully")
else:
raise NotImplementedError(
f"Model {self.config.model} is not implemented or not supported."
)
self.model.eval()
self.model.to(self.device)
[docs] def _get_wsi_dimensions(self) -> tuple[int, int]:
"""Get the dimensions (width, height) of the WSI at level 0.
Returns:
tuple[int, int]: WSI dimensions as (width, height)
"""
try:
# Import openslide here to avoid dependency issues if not installed
import openslide
# Open the WSI to get its dimensions
slide = openslide.OpenSlide(str(self.config.wsi_path))
# dimensions returns (width, height) at level 0
wsi_width, wsi_height = slide.dimensions
slide.close()
return wsi_width, wsi_height
except ImportError:
logger.error(
"OpenSlide is required to validate contour coordinates. Install python-openslide."
)
raise
except Exception as e:
logger.error(f"Failed to get WSI dimensions: {e}")
raise
[docs] def _are_contour_coordinates_valid(
self, contour_global: np.ndarray[Any, Any]
) -> bool:
"""Check if contour coordinates are within the WSI bounds.
Args:
contour_global (np.ndarray): Global contour coordinates, shape (N, 2) where
each point is [x, y]
Returns:
bool: True if all contour points are within WSI bounds, False otherwise
"""
if not hasattr(self, "_wsi_width") or not hasattr(self, "_wsi_height"):
# Cache WSI dimensions for performance
self._wsi_width, self._wsi_height = self._get_wsi_dimensions()
# Check all contour points
# contour_global is in (x, y) format where x is horizontal and y is vertical
x_coords = contour_global[:, 0]
y_coords = contour_global[:, 1]
# Check if any coordinates are outside bounds
x_valid = np.all((x_coords >= 0) & (x_coords < self._wsi_width))
y_valid = np.all((y_coords >= 0) & (y_coords < self._wsi_height))
return bool(x_valid and y_valid)
[docs] def process(self) -> None:
"""Process WSI file"""
logger.info(f"Processing WSI: {self.wsi.name}")
wsi_inference_dataset = PatchedWSIInference(
self.wsi, transform=self.inference_transforms
)
cpu_count = os.cpu_count()
if cpu_count is None:
num_workers = 16
else:
num_workers = int(3 / 4 * cpu_count)
num_workers = int(np.clip(num_workers, 1, 2 * self.batch_size))
wsi_inference_dataloader = DataLoader(
dataset=wsi_inference_dataset,
batch_size=self.batch_size,
# num_workers=num_workers,
shuffle=False,
collate_fn=wsi_inference_dataset.collate_batch,
pin_memory=False,
)
if self.wsi.patched_slide_path is None:
raise ValueError(
"Patched slide path is not set. Cannot process WSI without patched slide path."
)
outdir = (
Path(self.wsi.patched_slide_path) / "cell_detection" / self.config.model
)
outdir.mkdir(exist_ok=True, parents=True)
# Log dataset info for debugging
logger.info(f"Total patches in dataset: {len(wsi_inference_dataset)}")
logger.info(f"WSI dimensions: {self._get_wsi_dimensions()}")
logger.info(f"Batch size: {self.batch_size}")
cell_dict_wsi: list[dict[str, Any]] = [] # for storing all cell information
cell_dict_detection: list[dict[str, Any]] = [] # for storing only the centroids
processed_patches: list[str] = []
cell_count: int = 0
with torch.no_grad():
pbar = tqdm(wsi_inference_dataloader, total=len(wsi_inference_dataset))
for batch in wsi_inference_dataloader:
patches = batch[0].to(self.device)
metadata = batch[1]
if self.mixed_precision:
with torch.autocast(device_type="cuda", dtype=torch.float16):
predictions = self.model.forward(patches)
else:
predictions = self.model.forward(patches)
# print("DEBUGG")
# print(predictions["hv_map"].shape)
# reshape, apply softmax to segmentation maps
# predictions = self.model.reshape_model_output(predictions_, self.device)
# TODO: This might become custom for every model
instance_types = self.get_cell_predictions(
predictions, magnification=self.wsi.metadata["magnification"]
)
# unpack each patch from batch
for _, (patch_instance_types, patch_metadata) in enumerate(
zip(instance_types, metadata)
):
pbar.update(1)
# Log patch being processed
logger.debug(
f"Processing patch [{patch_metadata['row']}, {patch_metadata['col']}]"
)
# add global patch metadata
patch_cell_detection: dict[str, Any] = {}
patch_cell_detection["patch_metadata"] = patch_metadata
patch_cell_detection["type_map"] = self.nuclei_types
processed_patches.append(
f"{patch_metadata['row']}_{patch_metadata['col']}"
)
# calculate coordinate on highest magnifications
# wsi_scaling_factor = patch_metadata["wsi_metadata"]["downsampling"]
# patch_size = patch_metadata["wsi_metadata"]["patch_size"]
wsi_scaling_factor = self.wsi.metadata["downsampling"]
patch_size = self.wsi.metadata["patch_size"]
x_global = int(
patch_metadata["row"] * patch_size * wsi_scaling_factor
- (patch_metadata["row"] + 0.5)
* self.patch_config.patch_overlap
)
y_global = int(
patch_metadata["col"] * patch_size * wsi_scaling_factor
- (patch_metadata["col"] + 0.5)
* self.patch_config.patch_overlap
)
# extract cell information
for cell in patch_instance_types.values():
if cell["type"] == self.nuclei_types["Background"]:
continue
offset_global = np.array([x_global, y_global])
centroid_global = cell["centroid"] + np.flip(offset_global)
contour_global = cell["contour"] + np.flip(offset_global)
bbox_global = cell["bbox"] + offset_global
# Debug coordinate calculation for problematic cases
if not self._are_contour_coordinates_valid(contour_global):
logger.info(
f"COORDINATE DEBUG - Patch [{patch_metadata['row']}, {patch_metadata['col']}]: "
f"x_global={x_global}, y_global={y_global}, "
f"offset_global={offset_global}, "
f"flipped_offset={np.flip(offset_global)}, "
f"cell_contour_range=x[{cell['contour'][:, 0].min():.1f}, {cell['contour'][:, 0].max():.1f}], "
f"y[{cell['contour'][:, 1].min():.1f}, {cell['contour'][:, 1].max():.1f}], "
f"contour_global_range=x[{contour_global[:, 0].min():.1f}, {contour_global[:, 0].max():.1f}], "
f"y[{contour_global[:, 1].min():.1f}, {contour_global[:, 1].max():.1f}]"
)
# Check if contour coordinates are within WSI limits
if not self._are_contour_coordinates_valid(contour_global):
logger.warning(
f"Cell contour coordinates exceed WSI bounds for patch "
f"[{patch_metadata['row']}, {patch_metadata['col']}]. Skipping cell. "
f"WSI dims: [{self._wsi_width}, {self._wsi_height}], "
f"Contour range: x=[{contour_global[:, 0].min():.1f}, {contour_global[:, 0].max():.1f}], "
f"y=[{contour_global[:, 1].min():.1f}, {contour_global[:, 1].max():.1f}], "
f"Offset: [{x_global}, {y_global}], Patch size: {patch_size}, Scaling: {wsi_scaling_factor}"
)
continue
cell_dict: dict[str, Any] = {
"bbox": bbox_global.tolist(),
"centroid": centroid_global.tolist(),
"contour": contour_global.tolist(),
"type_prob": cell["type_prob"],
"type": cell["type"],
"patch_coordinates": [
patch_metadata["row"],
patch_metadata["col"],
],
"cell_status": get_cell_position_marging(
cell["bbox"],
self.patch_config.patch_size,
int(self.patch_config.patch_overlap),
),
"offset_global": offset_global.tolist(),
}
cell_detection: dict[str, Any] = {
"bbox": bbox_global.tolist(),
"centroid": centroid_global.tolist(),
"type": cell["type"],
}
if (
np.max(cell["bbox"]) == self.patch_config.patch_size
or np.min(cell["bbox"]) == 0
):
position = get_cell_position(
cell["bbox"], self.patch_config.patch_size
)
cell_dict["edge_position"] = True
cell_dict["edge_information"] = {}
cell_dict["edge_information"]["position"] = position
cell_dict["edge_information"]["edge_patches"] = (
get_edge_patch(
position,
patch_metadata["row"],
patch_metadata["col"],
)
)
else:
cell_dict["edge_position"] = False
cell_dict_wsi.append(cell_dict)
cell_dict_detection.append(cell_detection)
# get the cell token
bb_index = cell["bbox"] / self.patch_config.patch_size
bb_index[0, :] = np.floor(bb_index[0, :])
bb_index[1, :] = np.ceil(bb_index[1, :])
bb_index = bb_index.astype(np.uint8)
cell_count = cell_count + 1
pbar.set_postfix(Cells=cell_count) # type: ignore
# post processing
logger.info(f"Detected cells before cleaning: {len(cell_dict_wsi)}")
if len(cell_dict_wsi) > 0:
keep_idx = self.post_process_edge_cells(cell_list=cell_dict_wsi)
cell_dict_wsi = [
{"cell_id": i, **cell_dict_wsi[idx_c]} for i, idx_c in enumerate(keep_idx)
]
cell_dict_detection = [
{"cell_id": i, **cell_dict_detection[idx_c]}
for i, idx_c in enumerate(keep_idx)
]
logger.info(f"Detected cells after cleaning: {len(keep_idx)}")
else:
logger.warning("No cells detected.")
cell_dict_wsi = []
cell_dict_detection = []
logger.info(
f"Processed all patches. Storing final results: {str(outdir / 'cells.json')} and cell_detection.json"
)
_cell_dict_wsi: dict[str, Any] = {
"wsi_metadata": self.wsi.metadata,
"processed_patches": processed_patches,
"type_map": self.nuclei_types,
"cells": cell_dict_wsi,
}
with open(str(outdir / "cells.json"), "w") as outfile:
ujson.dump(_cell_dict_wsi, outfile, indent=2)
logger.info("Converting segmentation to geojson")
geojson_collection = convert_geojson(
_cell_dict_wsi["cells"], True, self.config.model
)
with open(str(str(outdir / "cells.geojson")), "w") as outfile:
ujson.dump(geojson_collection, outfile, indent=2)
_cell_dict_detection: dict[str, Any] = {
"wsi_metadata": self.wsi.metadata,
"processed_patches": processed_patches,
"type_map": self.nuclei_types,
"cells": cell_dict_detection,
}
with open(str(outdir / "cell_detection.json"), "w") as outfile:
ujson.dump(_cell_dict_detection, outfile, indent=2)
logger.info("Converting detection to geojson")
geojson_collection = convert_geojson(
_cell_dict_wsi["cells"], False, self.config.model
)
with open(str(str(outdir / "cell_detection.geojson")), "w") as outfile:
ujson.dump(geojson_collection, outfile, indent=2)
# Generate statistics only if cells were detected
if len(_cell_dict_wsi["cells"]) > 0:
cell_stats_df = pd.DataFrame(_cell_dict_wsi["cells"])
cell_stats = dict(cell_stats_df.value_counts("type")) # type: ignore
nuclei_types_inverse = {v: k for k, v in self.nuclei_types.items()}
verbose_stats = {nuclei_types_inverse[k]: v for k, v in cell_stats.items()} # type: ignore
logger.info(f"Finished with cell detection for WSI {self.wsi.name}")
logger.info("Stats:")
logger.info(f"{verbose_stats}")
else:
logger.info(f"Finished with cell detection for WSI {self.wsi.name}")
logger.info("Stats: No cells detected")
[docs] def get_cell_predictions(
self, predictions: dict[str, torch.Tensor], magnification: float | int = 40
) -> list[dict[np.int32, dict[str, Any]]] | list[dict[int, dict[str, Any]]]:
"""Get cell predictions from model output
Args:
predictions (torch.Tensor): Model output
magnification (float, optional): Magnification of the WSI. Defaults to 40.0.
Returns:
dict[str, Any]: Dictionary with cell predictions
"""
if self.config.model in ("hovernet", "cellvit"):
predictions["nuclei_binary_map"] = F.softmax(
predictions["nuclei_binary_map"], dim=1
) # shape: (batch_size, 2, H, W)
predictions["nuclei_type_map"] = F.softmax(
predictions["nuclei_type_map"], dim=1
) # shape: (batch_size, num_nuclei_classes, H, W)
# get the instance types
_, instance_types = self.model.calculate_instance_map(
predictions, magnification=magnification
)
return instance_types
elif self.config.model == "cellpose_sam":
# For Cellpose-SAM, the predictions already contain the final masks
# We just need to extract the instance information
_, instance_types = self.model.calculate_instance_map(
predictions, magnification=magnification
)
return instance_types
else:
raise NotImplementedError(f"Model {self.config.model} not implemented")
[docs] def post_process_edge_cells(self, cell_list: list[dict[str, Any]]) -> list[Any]:
"""Use the CellPostProcessor to remove multiple cells and merge due to overlap
Args:
cell_list (List[dict]): List with cell-dictionaries. Required keys:
* bbox
* centroid
* contour
* type_prob
* type
* patch_coordinates
* cell_status
* offset_global
Returns:
List[int]: List with integers of cells that should be kept
"""
cell_processor = CellPostProcessor(cell_list, logger)
cleaned_cells = cell_processor.post_process_cells()
return list(cleaned_cells.index.values) # type: ignore
[docs]def get_cell_position(bbox: np.ndarray[Any, Any], patch_size: int = 1024) -> list[int]:
"""Get cell position as a list
Entry is 1, if cell touches the border: [top, right, down, left]
Args:
bbox (np.ndarray): Bounding-Box of cell
patch_size (int, optional): Patch-size. Defaults to 1024.
Returns:
List[int]: List with 4 integers for each position
"""
# bbox = 2x2 array in h, w style
# bbox[0,0] = upper position (height)
# bbox[1,0] = lower dimension (height)
# boox[0,1] = left position (width)
# bbox[1,1] = right position (width)
# bbox[:,0] -> x dimensions
top, left, down, right = False, False, False, False
if bbox[0, 0] == 0:
top = True
if bbox[0, 1] == 0:
left = True
if bbox[1, 0] == patch_size:
down = True
if bbox[1, 1] == patch_size:
right = True
position = [top, right, down, left]
position = [int(pos) for pos in position]
return position
[docs]def get_cell_position_marging(
bbox: np.ndarray[Any, Any], patch_size: int = 1024, margin: int = 64
) -> Literal[0, 1, 2, 3, 4, 5, 6, 7, 8]:
"""Get the status of the cell, describing the cell position
A cell is either in the mid (0) or at one of the borders (1-8)
# Numbers are assigned clockwise, starting from top left
# i.e., top left = 1, top = 2, top right = 3, right = 4, bottom right = 5 bottom = 6, bottom left = 7, left = 8
# Mid status is denoted by 0
Args:
bbox (np.ndarray): Bounding Box of cell
patch_size (int, optional): Patch-Size. Defaults to 1024.
margin (int, optional): Margin-Size. Defaults to 64.
Returns:
int: Cell Status
"""
cell_status = None
if np.max(bbox) > patch_size - margin or np.min(bbox) < margin:
if bbox[0, 0] < margin:
# top left, top or top right
if bbox[0, 1] < margin:
# top left
cell_status = 1
elif bbox[1, 1] > patch_size - margin:
# top right
cell_status = 3
else:
# top
cell_status = 2
elif bbox[1, 1] > patch_size - margin:
# top right, right or bottom right
if bbox[1, 0] > patch_size - margin:
# bottom right
cell_status = 5
else:
# right
cell_status = 4
elif bbox[1, 0] > patch_size - margin:
# bottom right, bottom, bottom left
if bbox[0, 1] < margin:
# bottom left
cell_status = 7
else:
# bottom
cell_status = 6
elif bbox[0, 1] < margin:
# bottom left, left, top left, but only left is left
cell_status = 8
else:
cell_status = 0
else:
cell_status = 0
return cell_status
[docs]def get_edge_patch(position: list[int], row: int, col: int):
# row starting on bottom or on top?
if position == [1, 0, 0, 0]:
# top
return [[row - 1, col]]
if position == [1, 1, 0, 0]:
# top and right
return [[row - 1, col], [row - 1, col + 1], [row, col + 1]]
if position == [0, 1, 0, 0]:
# right
return [[row, col + 1]]
if position == [0, 1, 1, 0]:
# right and down
return [[row, col + 1], [row + 1, col + 1], [row + 1, col]]
if position == [0, 0, 1, 0]:
# down
return [[row + 1, col]]
if position == [0, 0, 1, 1]:
# down and left
return [[row + 1, col], [row + 1, col - 1], [row, col - 1]]
if position == [0, 0, 0, 1]:
# left
return [[row, col - 1]]
if position == [1, 0, 0, 1]:
# left and top
return [[row, col - 1], [row - 1, col - 1], [row - 1, col]]
[docs]class CellPostProcessor:
[docs] def __init__(
self,
cell_list: list[dict[str, Any]],
logger: logging.Logger,
) -> None:
"""Post-Processing a list of cells from one WSI
Args:
cell_list (List[dict]): List with cell-dictionaries. Required keys:
* bbox
* centroid
* contour
* type_prob
* type
* patch_coordinates
* cell_status
* offset_global
logger (logging.Logger): Logger
"""
self.logger = logger
self.logger.info("Initializing Cell-Postprocessor")
self.cell_df: pd.DataFrame = pd.DataFrame(cell_list)
# Fast vectorized coordinate conversion
self.logger.info(f"DataFrame has {len(self.cell_df)} rows, using vectorized processing")
# Extract x, y coordinates from patch_coordinates column
coords = pd.DataFrame(self.cell_df['patch_coordinates'].tolist(), columns=['x', 'y'], index=self.cell_df.index)
# Convert to string format using vectorized operations
self.cell_df['patch_coordinates'] = coords['x'].astype(str) + '_' + coords['y'].astype(str)
self.mid_cells = self.cell_df[ # type: ignore
self.cell_df["cell_status"] == 0 # type: ignore
] # cells in the mid
self.cell_df_margin = self.cell_df[ # type: ignore
self.cell_df["cell_status"] != 0 # type: ignore
] # cells either torching the border or margin
[docs] def post_process_cells(self) -> pd.DataFrame:
"""Main Post-Processing coordinator, entry point
Returns:
pd.DataFrame: DataFrame with post-processed and cleaned cells
"""
self.logger.info("Finding edge-cells for merging")
cleaned_edge_cells = self._clean_edge_cells()
self.logger.info("Removal of cells detected multiple times")
cleaned_edge_cells = self._remove_overlap(cleaned_edge_cells)
# merge with mid cells
postprocessed_cells = pd.concat(
[self.mid_cells, cleaned_edge_cells] # type: ignore
).sort_index() # type: ignore
return postprocessed_cells
[docs] def _clean_edge_cells(self) -> pd.DataFrame:
"""Create a DataFrame that just contains all margin cells (cells inside the margin, not touching the border)
and border/edge cells (touching border) with no overlapping equivalent (e.g, if patch has no neighbour)
Returns:
pd.DataFrame: Cleaned DataFrame
"""
margin_cells = self.cell_df_margin[ # type: ignore
self.cell_df_margin["edge_position"] == 0 # type: ignore
] # cells at the margin, but not touching the border
edge_cells = self.cell_df_margin[ # type: ignore
self.cell_df_margin["edge_position"] == 1 # type: ignore
] # cells touching the border
existing_patches = list(set(self.cell_df_margin["patch_coordinates"].to_list())) # type: ignore
edge_cells_unique = pd.DataFrame(
columns=self.cell_df_margin.columns # type: ignore
) # cells torching the border without having an overlap from other patches
# Add progress bar for processing edge cells
edge_cells_pbar = tqdm( # type: ignore
edge_cells.iterrows(), total=len(edge_cells), desc="Processing edge cells" # type: ignore
)
for idx, cell_info in edge_cells_pbar: # type: ignore
edge_information = dict(cell_info["edge_information"]) # type: ignore
try:
edge_patch = edge_information["edge_patches"][0] # type: ignore
edge_patch = f"{edge_patch[0]}_{edge_patch[1]}"
if edge_patch not in existing_patches:
edge_cells_unique.loc[idx, :] = cell_info # type: ignore
except (TypeError, IndexError, KeyError) as e:
self.logger.warning(
f"Skipping edge cell {idx} due to invalid edge_patches data: {e}. "
f"Edge information: {edge_information}"
)
continue
cleaned_edge_cells = pd.concat([margin_cells, edge_cells_unique]) # type: ignore
return cleaned_edge_cells.sort_index() # type: ignore
[docs] def _remove_overlap(self, cleaned_edge_cells: pd.DataFrame) -> pd.DataFrame:
"""Remove overlapping cells from provided DataFrame
Args:
cleaned_edge_cells (pd.DataFrame): DataFrame that should be cleaned
Returns:
pd.DataFrame: Cleaned DataFrame
"""
merged_cells = cleaned_edge_cells
for iteration in range(20):
# Create list of polygons and corresponding indices
poly_list: list[Polygon] = []
poly_indices: list[Any] = []
for idx, cell_info in merged_cells.iterrows(): # type: ignore
poly: Polygon = Polygon(cell_info["contour"]) # type: ignore
if not poly.is_valid:
self.logger.debug("Found invalid polygon - Fixing with buffer 0")
multi = poly.buffer(0)
if isinstance(multi, MultiPolygon):
if len(multi.geoms) > 1:
poly_idx = np.argmax([p.area for p in multi.geoms])
poly = multi.geoms[poly_idx]
else:
poly = multi.geoms[0]
else:
poly = multi
poly_list.append(poly)
poly_indices.append(idx)
# Use STRtree with explicit indices rather than storing uid on the geometry
tree = strtree.STRtree(poly_list)
merged_idx: deque[Any] = deque()
iterated_cells: set[Any] = set()
overlaps = 0
for i, query_poly in enumerate(poly_list):
query_idx = poly_indices[i]
if query_idx not in iterated_cells:
# Get indices of intersecting polygons
intersected_indices = tree.query(query_poly, predicate="intersects")
intersected_polygons: list[Polygon] = [
poly_list[j] for j in intersected_indices
]
intersected_ids = [poly_indices[j] for j in intersected_indices]
if (
len(intersected_polygons) > 1
): # we have at least one intersection with another cell
submergers: list[
Polygon
] = [] # all cells that overlap with query
submerger_indices: list[Any] = [] # corresponding indices
for j, inter_poly in enumerate(intersected_polygons):
inter_idx = intersected_ids[j]
if (
inter_idx != query_idx
and inter_idx not in iterated_cells
):
intersection = query_poly.intersection(inter_poly)
query_area = query_poly.area
inter_area = inter_poly.area
if (
intersection.area / query_area > 0.01
or intersection.area / inter_area > 0.01
):
overlaps = overlaps + 1
submergers.append(inter_poly)
submerger_indices.append(inter_idx)
iterated_cells.add(inter_idx)
# catch block: empty list -> some cells are touching, but not overlapping strongly enough
if len(submergers) == 0:
merged_idx.append(query_idx)
else: # merging strategy: take the biggest cell, other merging strategies needs to get implemented
areas = [poly.area for poly in submergers]
selected_poly_index = np.argmax(areas)
selected_poly_uid = submerger_indices[selected_poly_index]
merged_idx.append(selected_poly_uid)
else:
# no intersection, just add
merged_idx.append(query_idx)
iterated_cells.add(query_idx)
self.logger.info(
f"Iteration {iteration}: Found overlap of # cells: {overlaps}"
)
if overlaps == 0:
self.logger.info("Found all overlapping cells")
break
elif iteration == 20:
self.logger.info(
f"Not all doubled cells removed, still {overlaps} to remove. For perfomance issues, we stop iterations now. Please raise an issue in git or increase number of iterations."
)
merged_cells = cleaned_edge_cells.loc[
cleaned_edge_cells.index.isin(merged_idx) # type: ignore
].sort_index()
return merged_cells.sort_index() # type: ignore