Source code for cellmil.data.patch_extractor
# -*- coding: utf-8 -*-
# Patch Extractor Module
#
# References:
# CellViT: Vision Transformers for precise cell segmentation and classification
# Fabian Hörst et al., Medical Image Analysis, 2024
# DOI: https://doi.org/10.1016/j.media.2024.103143
import numpy as np
import random
import os
import json
import multiprocessing
from multiprocessing import Process
from tqdm import tqdm
from shutil import rmtree
from typing import Union, Any, List, Type
from pathlib import Path
from openslide import OpenSlide
from shapely.geometry import Polygon
from natsort import natsorted
from PIL.Image import Image as ImageType
from PIL import Image
from cellmil.interfaces import PatchExtractorConfig
from cellmil.utils import logger
from cellmil.utils.tools import module_exists, start_timer, end_timer
from .storage import Storage
from .utils.patch_util import (
get_files_from_dir,
is_power_of_two,
patch_to_tile_size,
target_mag_to_downsample,
compute_interesting_patches,
generate_thumbnails,
pad_tile,
calculate_background_ratio,
get_intersected_labels,
DeepZoomGeneratorOS
)
from .utils.exceptions import WrongParameterException
[docs]def queue_worker(q: Any, store: Storage, processed_count: Any) -> None:
"""Queue Worker to save patches with metadata
Args:
q (Any): Queue for input
store (Storage): Storage object
processed_count (multiprocessing.Value): Processed element count for tqdm, shared between processes
"""
while True:
item: Any = q.get()
if item is None:
break
# check if size matches, otherwise rescale in multiprocessing
# TODO: check for context patches and masks!
item = list(item)
tile = item[0]
tile_size = tile.shape[0]
target_tile_size = item[-1]
if tile_size != target_tile_size:
tile = Image.fromarray(tile)
if tile_size > target_tile_size:
tile.thumbnail(
(target_tile_size, target_tile_size),
getattr(Image, "Resampling", Image).LANCZOS, # type: ignore[call-arg]
)
else:
tile = tile.resize(
(target_tile_size, target_tile_size),
getattr(Image, "Resampling", Image).LANCZOS, # type: ignore[call-arg]
)
tile = np.array(tile, dtype=np.uint8)
item[0] = tile
item.pop()
item = tuple(item)
store.save_elem_to_disk(item)
processed_count.value += 1
[docs]class PatchExtractor:
"""Class for preparing data from WSI"""
[docs] def __init__(self, config: PatchExtractorConfig) -> None:
self.config = config
self.files: list[Path] = []
self.annotation_files: list[Path] = []
self.num_files: int = 0
self.rescaling_factor: float = 1.0
# TODO: Make this configurable from config
self.downsample: int = 1
self.min_intersection_ratio: float = 0.01
self.store_masks: bool = False
self.context_scales: list[int] | None = None
self.masked_otsu: bool = False
self.label_map: dict[str | int, int] = {"background": 0}
self.otsu_annotation = None
self.tissue_annotation_intersection_ratio = 0.01
self.apply_prefilter = False
self.normalize_stains: bool = False
self.processes = 24
self.save_only_annotated_patches: bool = False
self.adjust_brightness: bool = False
self.save_context: bool = True if self.context_scales is not None else False
self.overlapping_labels: bool = False
# TODO: ---
# paths
self.setup_output_path(self.config.output_path)
self._set_wsi_path(self.config.wsi_path)
# hardware
self._set_hardware()
# convert overlap from percentage to pixels
self.config.patch_overlap = int(
np.floor(self.config.patch_size / 2 * self.config.patch_overlap / 100)
)
# set seed
random.seed(42)
logger.info(f"Data store directory: {str(self.config.output_path)}")
logger.info(f"Images found: {self.num_files}")
logger.info(f"Annotations found: {len(self.annotation_files)}")
[docs] def _set_hardware(self, hardware_selection: str = "cucim") -> None:
"""Either load CuCIM (GPU-accelerated) or OpenSlide
Args:
hardware_selection (str, optional): Specify hardware. Just for experiments. Must be either "openslide", or "cucim".
Defaults to cucim.
"""
if (
module_exists("cucim", error="ignore")
and hardware_selection.lower() == "cucim"
):
logger.info("Using CuCIM")
from cucim import CuImage # type: ignore[import]
from cellmil.data.cucim_deepzoom import DeepZoomGeneratorCucim
self.deepzoomgenerator: (
Type[DeepZoomGeneratorCucim] | Type[DeepZoomGeneratorOS]
) = DeepZoomGeneratorCucim
self.image_loader: Any = CuImage
else:
logger.info("Using OpenSlide")
self.deepzoomgenerator = DeepZoomGeneratorOS
self.image_loader = OpenSlide
[docs] def _set_wsi_path(self, wsi_path: str | Path) -> None:
"""Set the path to the WSI file.
Args:
wsi_paths (Union[str, Path, List]): Path to the folder where all WSI are stored or path to a single WSI-file.
"""
if isinstance(wsi_path, str):
wsi_path = Path(wsi_path)
wsi_extension = wsi_path.suffix.lower()[1:]
self.files = get_files_from_dir(wsi_path, wsi_extension)
self.num_files = len(self.files)
def key(x: Path) -> str:
return x.name
self.files = natsorted(self.files, key=key)
[docs] def get_patches(self) -> None:
"""Main functiuon to create a dataset. Sample the complete dataset.
This function acts as an entrypoint to the patch-processing.
When this function is called, all WSI that have been detected are processed.
Depending on the selected configuration, either already processed WSI will be removed or newly processed.
The processed WSI are stored in the file `processed.json` in the output-folder.
"""
# perform logical check
self._check_patch_params(
patch_size=self.config.patch_size,
patch_overlap=int(self.config.patch_overlap),
downsample=self.downsample,
min_background_ratio=self.min_intersection_ratio,
)
# remove database or check to continue from checkpoint
self._check_overwrite()
total_count = 0
start_time = start_timer()
for i, wsi_file in enumerate(self.files):
try:
logger.info(f"{(os.get_terminal_size()[0] - 33) * '*'}")
except Exception:
pass
logger.info(f"{i + 1}/{len(self.files)}: {wsi_file.name}")
# prepare wsi, espeically find patches
(
_,
(wsi_metadata, mask_images, mask_images_annotations, thumbnails),
(
interesting_coords_wsi,
level_wsi,
polygons_downsampled_wsi,
region_labels_wsi,
),
) = self._prepare_wsi(wsi_file)
# setup storage
store = Storage(
wsi_name=wsi_file.stem,
output_path=self.config.output_path,
metadata=wsi_metadata,
mask_images=mask_images,
mask_images_annotations=mask_images_annotations,
thumbnails=thumbnails,
store_masks=self.store_masks,
save_context=self.context_scales is not None,
context_scales=self.context_scales,
)
logger.info("Start extracting patches...")
patch_count, patch_distribution, patch_result_metadata = self.process_queue(
batch=interesting_coords_wsi,
wsi_file=wsi_file,
wsi_metadata=wsi_metadata,
level=level_wsi,
polygons=polygons_downsampled_wsi,
region_labels=region_labels_wsi,
store=store,
)
if patch_count == 0:
logger.warning(f"No patches sampled from {wsi_file.name}")
logger.info(f"Total patches sampled: {patch_count}")
store.clean_up(patch_distribution, patch_result_metadata)
total_count += patch_count
logger.info(f"Patches saved to: {self.config.output_path.resolve()}")
logger.info(f"Total patches sampled for all WSI: {total_count}")
end_timer(start_time)
[docs] def _prepare_wsi(
self, wsi_file: Path
) -> tuple[
tuple[int, int],
tuple[
dict[str, Any],
dict[str, ImageType],
dict[str, ImageType],
dict[str, ImageType],
],
tuple[list[Any], Any, list[Polygon], list[str]],
]:
"""Prepare a WSI for preprocessing
First, some sanity checks are performed and the target level for DeepZoomGenerator is calculated.
We are not using OpenSlides default DeepZoomGenerator, but rather one based on the cupy library which is much faster
(cf https://github.com/rapidsai/cucim). One core element is to find all patches that are non-background patches.
For this, a tissue mask is generated. At this stage, no patches are extracted!
For further documentation (i.e., configuration settings), see the class documentation [link].
Args:
wsi_file (str): Name of the wsi file
Raises:
WrongParameterException: The level resulting from target magnification or downsampling factor must exists to extract patches.
Returns:
Tuple[Tuple[int, int], Tuple[dict, dict, dict, dict], Callable, List[List[Tuple]]]:
- Tuple[int, int]: Number of rows, cols of the WSI at the given level
- dict: Dictionary with Metadata of the WSI
- dict[str, Image]: Masks generated during tissue detection stored in dict with keys equals the mask name and values equals the PIL image
- dict[str, Image]: Annotation masks for provided annotations for the complete WSI. Masks are equal to the tissue masks sizes.
Keys are the mask names and values are the PIL images.
- dict[str, Image]: Thumbnail images with different downsampling and resolutions.
Keys are the thumbnail names and values are the PIL images.
- callable: Batch-Processing function performing the actual patch-extraction task
- List[List[Tuple]]: Divided List with batches of batch-size. Each batch-element contains the row, col position of a patch together with bg-ratio.
"""
logger.info(f"Computing patches for {wsi_file.name}")
# load slide (OS and CuImage/OS)
slide = OpenSlide(str(wsi_file))
slide_cu = self.image_loader(str(wsi_file))
mpp_value = slide.properties.get("openslide.mpp-x")
if mpp_value is not None:
slide_mpp = float(mpp_value)
else:
raise NotImplementedError("MPP must be in metadata of the WSI file!")
mag_value = slide.properties.get("openslide.objective-power")
if mag_value is not None:
slide_mag = float(mag_value)
else:
raise NotImplementedError("MPP must be in metadata of the WSI file!")
slide_properties = {"mpp": slide_mpp, "magnification": slide_mag}
# Generate thumbnails
logger.info("Generate thumbnails")
thumbnails = generate_thumbnails(
slide, slide_properties["mpp"], sample_factors=[128]
)
# target mag has precedence before downsample!
self.downsample = target_mag_to_downsample(
slide_properties["magnification"],
self.config.target_mag,
)
# Zoom Recap:
# - Row and column of the tile within the Deep Zoom level (t_)
# - Pixel coordinates within the Deep Zoom level (z_)
# - Pixel coordinates within the slide level (l_)
# - Pixel coordinates within slide level 0 (l0_)
# Tile size is the amount of pixels that are taken from the image (without overlaps)
tile_size, overlap = patch_to_tile_size(
self.config.patch_size,
int(self.config.patch_overlap),
self.rescaling_factor,
)
tiles = self.deepzoomgenerator(
osr=slide,
cucim_slide=slide_cu,
tile_size=tile_size,
overlap=overlap,
limit_bounds=True,
)
# Each level is downsampled by a factor of 2
# downsample expresses the desired downsampling, we need to count how many times the
# downsampling is performed to find the level
# e.g. downsampling of 8 means 2 * 2 * 2 = 3 times
# we always need to remove 1 level more than necessary, so 4
# so we can just use the bit length of the numbers, since 8 = 1000 and len(1000) = 4
level = tiles.level_count - self.downsample.bit_length()
if level >= tiles.level_count:
raise WrongParameterException(
"Requested level does not exist. Number of slide levels:",
tiles.level_count,
)
# store level!
self.curr_wsi_level = level
# initialize annotation objects
region_labels: List[str] = []
polygons: List[Polygon] = []
polygons_downsampled: List[Polygon] = []
tissue_region: List[Polygon] = []
# get the interesting coordinates: no background, filtered by annotation etc.
# original number of tiles of total wsi
n_cols, n_rows = tiles.level_tiles[level]
(
interesting_coords,
mask_images,
mask_images_annotations,
) = compute_interesting_patches(
polygons=polygons,
slide=slide,
tiles=tiles,
target_level=level,
target_patch_size=tile_size, # self.config.patch_size,
target_overlap=overlap, # self.config.patch_overlap,
rescaling_factor=self.rescaling_factor,
mask_otsu=self.masked_otsu,
label_map=self.label_map,
region_labels=region_labels,
tissue_annotation=tissue_region,
otsu_annotation=self.otsu_annotation,
tissue_annotation_intersection_ratio=self.tissue_annotation_intersection_ratio,
apply_prefilter=self.apply_prefilter,
)
if len(interesting_coords) == 0:
logger.warning(f"No patches sampled from {wsi_file.name}")
wsi_metadata: dict[str, Any] = {
"orig_n_tiles_cols": n_cols,
"orig_n_tiles_rows": n_rows,
"base_magnification": slide_mag,
"downsampling": self.downsample,
"label_map": self.label_map,
"patch_overlap": self.config.patch_overlap * 2,
"patch_size": self.config.patch_size,
"base_mpp": slide_mpp,
"target_patch_mpp": slide_mpp * self.rescaling_factor,
"stain_normalization": self.normalize_stains,
"magnification": slide_mag / (self.downsample * self.rescaling_factor),
"level": level,
}
logger.info(f"{wsi_file.name}: Processing {len(interesting_coords)} patches.")
return (
(n_cols, n_rows),
(wsi_metadata, mask_images, mask_images_annotations, thumbnails),
(list(interesting_coords), level, polygons_downsampled, region_labels),
)
[docs] def process_queue(
self,
batch: List[tuple[int, int, float]],
wsi_file: Path,
wsi_metadata: dict[str, Any],
level: int,
polygons: List[Polygon],
region_labels: List[str],
store: Storage,
) -> tuple[int, dict[int, int], list[dict[str, dict[str, Any]]]]:
"""Extract patches for a list of coordinates by using multiprocessing queues
Patches are extracted according to their coordinate with given patch-settings (size, overlap).
Patch annotation masks can be stored, as well as context patches with the same shape retrieved.
Optionally, stains can be nornalized according to macenko normalization.
Args:
batch (List[Tuple[int, int, float]]): A batch of patch coordinates (row, col, backgropund ratio)
wsi_file (Union[Path, str]): Path to the WSI file from which the patches should be extracted from
wsi_metadata (dict): Dictionary with important WSI metadata
level (int): The tile level for sampling.
polygons (List[Polygon]): Annotations of this WSI as a list of polygons (referenced to highest level of WSI).
If no annotations, pass an empty list [].
region_labels (List[str]): List of labels for the annotations provided as polygons parameter.
If no annotations, pass an empty list [].
store (Storage): Storage object passed to each worker to store the files
Returns:
int: Number of processed patches
"""
logger.debug(f"Started process {multiprocessing.current_process().name}")
# reload image
slide = OpenSlide(str(wsi_file))
slide_cu = self.image_loader(str(wsi_file))
tile_size, overlap = patch_to_tile_size(
self.config.patch_size,
int(self.config.patch_overlap),
self.rescaling_factor,
)
tiles = self.deepzoomgenerator(
osr=slide,
cucim_slide=slide_cu,
tile_size=tile_size,
overlap=overlap,
limit_bounds=True,
)
# queue setup
queue: Any = multiprocessing.Queue() # type: ignore[assignment]
processes: list[Process] = []
processed_count = multiprocessing.Value("i", 0)
pbar = tqdm(total=len(batch), desc="Retrieving patches")
for _ in range(self.processes):
p = multiprocessing.Process(
target=queue_worker, args=(queue, store, processed_count)
)
p.start()
processes.append(p)
patches_count = 0
patch_result_list: list[dict[str, dict[str, Any]]] = []
patch_distribution = self.label_map
patch_distribution = {v: 0 for _, v in patch_distribution.items()}
start_time = start_timer()
for row, col, _ in batch:
pbar.update()
# set name
patch_fname = f"{wsi_file.stem}_{row}_{col}.png"
patch_yaml_name = f"{wsi_file.stem}_{row}_{col}.yaml"
# OpenSlide: Address of the tile within the level as a (column, row) tuple
new_tile = np.array(tiles.get_tile(level, (col, row)), dtype=np.uint8)
patch = pad_tile(new_tile, tile_size + 2 * overlap, col, row)
# calculate background ratio for every patch
background_ratio = calculate_background_ratio(
new_tile, self.config.patch_size
)
# patch_label
if background_ratio > 1 - self.min_intersection_ratio:
logger.debug(
f"Removing file {patch_fname} because of intersection ratio with background is too big"
)
intersected_labels: List[int] = [] # Zero means background
ratio: dict[int, float] = {}
else:
intersected_labels, _ratio, _ = get_intersected_labels(
tile_size=tile_size,
patch_overlap=int(self.config.patch_overlap),
col=col,
row=row,
polygons=polygons,
label_map=self.label_map,
min_intersection_ratio=self.min_intersection_ratio,
region_labels=region_labels,
overlapping_labels=self.overlapping_labels,
store_masks=self.store_masks,
)
ratio = {k: v for k, v in zip(intersected_labels, _ratio)}
if len(intersected_labels) == 0 and self.save_only_annotated_patches:
continue
patch_metadata: dict[str, Any] = {
"row": row,
"col": col,
"background_ratio": float(background_ratio),
"intersected_labels": intersected_labels,
"label_ratio": ratio,
"wsi_metadata": wsi_metadata,
}
# increase patch_distribution count
for patch_label in patch_metadata["intersected_labels"]:
patch_distribution[patch_label] += 1
patches_count = patches_count + 1
queue_elem: tuple[np.ndarray[Any, Any], dict[str, Any], None, dict[str, Any], int] = (
patch,
patch_metadata,
None,
{},
self.config.patch_size,
)
queue.put(queue_elem)
# store metadata for all patches
patch_metadata.pop("wsi_metadata")
patch_metadata["metadata_path"] = f"./metadata/{patch_yaml_name}"
patch_result_list.append({patch_fname: patch_metadata})
# Add termination markers to the queue
for _ in range(self.processes):
queue.put(None)
pbar.close()
# wait for the queue to end
while not queue.empty():
print(f"Progress: {processed_count.value}/{len(batch)}", end="\r")
print("", end="", flush=True)
# Wait for all workers to finish
for p in processes:
p.join()
p.close()
pbar.close()
logger.info("Finished Processing and Storing. Took:")
end_timer(start_time)
return patches_count, patch_distribution, patch_result_list
[docs] def _drop_processed_files(self, processed_files: list[str]) -> None:
"""Drop processed file from `processed.json` file from dataset.
Args:
processed_files (list[str]): List with processed filenames
"""
self.files = [file for file in self.files if file.stem not in processed_files]
[docs] def _check_overwrite(self, overwrite: bool = False) -> None:
"""Performs data cleanage, depending on overwrite.
If true, overwrites the patches that have already been created in
case they already exist. If false, skips already processed files from `processed.json`
in the provided output path (created during class initialization)
Args:
overwrite (bool, optional): Overwrite flag. Defaults to False.
"""
if overwrite:
logger.info("Removing complete dataset! This may take a while.")
subdirs = [f for f in Path(self.config.output_path).iterdir() if f.is_dir()]
for subdir in subdirs:
rmtree(subdir.resolve(), ignore_errors=True)
if (Path(self.config.output_path) / "processed.json").exists():
os.remove(Path(self.config.output_path) / "processed.json")
self.setup_output_path(self.config.output_path)
else:
try:
with open(
str(Path(self.config.output_path) / "processed.json"), "r"
) as processed_list:
processed_files = json.load(processed_list)[
"processed_files"
] # TODO: check
logger.info(
f"Found {len(processed_files)} files. Continue to process {len(self.files) - len(processed_files)}/{len(self.files)} files."
)
self._drop_processed_files(processed_files)
except FileNotFoundError:
logger.info("Empty output folder. Processing all files")
[docs] @staticmethod
def _check_patch_params(
patch_size: int,
patch_overlap: int,
downsample: int | None = None,
target_mag: float | None = None,
level: int | None = None,
min_background_ratio: float = 0.01,
) -> None:
"""Sanity Check for parameters
See `Raises`section for further comments about the sanity check.
Args:
patch_size (int): The size of the patches in pixel that will be retrieved from the WSI, e.g. 256 for 256px
patch_overlap (int): The amount pixels that should overlap between two different patches.
downsample (int, optional): Downsampling factor from the highest level (largest resolution). Defaults to None.
target_mag (float, optional): If this parameter is provided, the output level of the wsi
corresponds to the level that is at the target magnification of the wsi.
Alternative to downsaple and level. Defaults to None.
level (int, optional): The tile level for sampling, alternative to downsample. Defaults to None.
min_background_ratio (float, optional): Minimum background selection ratio. Defaults to 1.0.
Raises:
WrongParameterException: Either downsample, level, or target_magnification must have been selected.
WrongParameterException: Downsampling must be a power of two.
WrongParameterException: Negative overlap is not allowed.
WrongParameterException: Overlap should not be larger than half of the patch size.
WrongParameterException: Background Percentage must be between 0 and 1.
"""
if downsample is None and level is None and target_mag is None:
raise WrongParameterException(
"Both downsample and level are none, "
"please fill one of the two parameters."
)
if downsample is not None and not is_power_of_two(downsample):
raise WrongParameterException("Downsample can only be a power of two.")
if downsample is not None and level is not None:
logger.warning(
"Both downsample and level are set, "
"we will use downsample and ignore level."
)
if patch_overlap < 0:
raise WrongParameterException("Negative overlap not allowed.")
if patch_overlap > patch_size // 2:
raise WrongParameterException(
"An overlap greater than half the patch size yields a tile size of zero."
)
if min_background_ratio < 0.0 or min_background_ratio > 1.0:
raise WrongParameterException(
"The parameter min_background_ratio should be a "
"float value between 0 and 1 representing the "
"maximum percentage of background allowed."
)
[docs] @staticmethod
def setup_output_path(output_path: Union[str, Path]) -> None:
"""Create output path
Args:
output_path (Union[str, Path]): Output path for WSI
"""
output_path = Path(output_path)
output_path.mkdir(exist_ok=True, parents=True)