Source code for glasscut.stain_normalizers.reinhardt

import warnings
from typing import Any, cast

import numpy as np
from PIL import Image
from skimage import color as sk_color

from glasscut.tile import Tile
from glasscut.tissue_detectors import OtsuTissueDetector
from glasscut.utils import np_to_pil

from .base import StainNormalizer


[docs] class ReinhardtStainNormalizer(StainNormalizer): """Stain normalizer using E. Reinhardt et al.'s color transfer method. This method normalizes stain appearance by matching the mean and standard deviation of each channel in LAB color space between source and target images. The normalization is performed only on tissue regions. The algorithm is: 1. Identify tissue using tissue masking 2. Convert to LAB color space 3. Compute per-channel mean and std on tissue 4. Normalize source statistics to match target statistics 5. Convert back to RGB Attributes ---------- target_means : np.ndarray or None Target mean values per LAB channel. target_stds : np.ndarray or None Target standard deviation values per LAB channel. Notes ----- This method is computationally fast and suitable for real-time preview during stain normalization parameter tuning. However, it may not preserve color relationships as well as matrix-based methods for complex stains. Examples -------- >>> from PIL import Image >>> from glasscut.stain_normalizers import ReinhardtStainNormalizer >>> normalizer = ReinhardtStainNormalizer() >>> ref_image = Image.open("reference.png") >>> normalizer.fit(ref_image) >>> test_image = Image.open("test.png") >>> normalized_image = normalizer.transform(test_image) """
[docs] def __init__(self): """Initialize ReinhardtStainNormalizer.""" warnings.warn( "ReinhardtStainNormalizer is experimental and may produce errors or " "unexpected results on certain images. Use with caution.", UserWarning, stacklevel=2, ) self.target_means = None self.target_stds = None
[docs] def fit(self, target_image: Image.Image, **kwargs: Any) -> None: """Fit stain normalizer using target image. Parameters ---------- target_image : Image.Image Target image for stain normalization. Can be RGB or RGBA. **kwargs Additional arguments (unused for Reinhardt method). """ means, stds = self._summary_statistics(target_image) self.target_means = means self.target_stds = stds
[docs] def transform(self, image: Image.Image, **kwargs: Any) -> Image.Image: """Normalize staining of image. Parameters ---------- image : Image.Image Image to normalize. Can be RGB or RGBA. **kwargs Additional arguments (unused for Reinhardt method). Returns ------- Image.Image Image with normalized stain. """ if self.target_means is None or self.target_stds is None: raise ValueError( "Normalizer must be fitted with a target image before transformation." ) eps = float(kwargs.get("eps", 1e-8)) # Compute tissue mask once and reuse for both statistics and normalisation. mask_2d = self._tissue_mask(image).astype(bool) mask_3d = np.dstack((mask_2d, mask_2d, mask_2d)) img_lab = self.rgb_to_lab(image) means, stds = self._summary_statistics(image, mask_2d=mask_2d, img_lab=img_lab) stds = np.maximum(stds, eps) target_stds = np.maximum(self.target_stds, eps) masked_img_lab = np.ma.masked_array(img_lab, ~mask_3d) # Normalize each channel: (x - source_mean) * (target_std / source_std) + target_mean norm_lab = ( ((masked_img_lab - means) * (target_stds / stds)) + self.target_means ).data # Restore non-tissue regions with original values for i in range(3): original = img_lab[:, :, i].copy() new = norm_lab[:, :, i].copy() original[np.not_equal(~mask_3d[:, :, 0], True)] = 0 new[~mask_2d] = 0 norm_lab[:, :, i] = new + original norm_rgb = self.lab_to_rgb(norm_lab) return norm_rgb
def _summary_statistics( self, img_rgb: Image.Image, *, mask_2d: np.ndarray | None = None, img_lab: np.ndarray | None = None, ) -> tuple[np.ndarray, np.ndarray]: """Compute mean and std of each LAB channel on tissue region. Parameters ---------- img_rgb : Image.Image Input image. mask_2d : np.ndarray, optional Pre-computed 2-D boolean tissue mask (H×W). When provided, avoids a redundant tissue-detection pass. img_lab : np.ndarray, optional Pre-computed LAB array (H×W×3). When provided, avoids a redundant colour-space conversion. Returns ------- np.ndarray Mean of each channel in LAB space. Shape (3,). np.ndarray Standard deviation of each channel in LAB space. Shape (3,). Notes ----- Statistics are only computed on pixels identified as tissue to avoid background artifacts. """ if mask_2d is None: mask_2d = self._tissue_mask(img_rgb).astype(bool) if img_lab is None: img_lab = self.rgb_to_lab(img_rgb) mask_3d = np.dstack((mask_2d, mask_2d, mask_2d)) if np.any(mask_3d): mean_per_channel = img_lab.mean(axis=(0, 1), where=mask_3d) std_per_channel = img_lab.std(axis=(0, 1), where=mask_3d) else: # Fallback avoids NaNs when no tissue is detected in a tile. mean_per_channel = img_lab.mean(axis=(0, 1)) std_per_channel = img_lab.std(axis=(0, 1)) std_per_channel = np.maximum(std_per_channel, 1e-8) return mean_per_channel, std_per_channel @staticmethod def _tissue_mask(img_rgb: Image.Image) -> np.ndarray: """Compute binary tissue mask for image. Parameters ---------- img_rgb : Image.Image Input image in RGB or RGBA format. Returns ------- np.ndarray Binary tissue mask with same spatial dimensions as image. 1 = tissue, 0 = background. """ tile = Tile( img_rgb, coords=None, magnification=None, tissue_detector=OtsuTissueDetector(), ) return tile.tissue_mask # ==== helper implementations ====
[docs] @staticmethod def rgb_to_lab(img_rgb: Image.Image) -> np.ndarray: """Convert RGB image to LAB color space. Parameters ---------- img_rgb : Image.Image Input image in RGB or RGBA format. Returns ------- np.ndarray Array representation of the image in LAB space. Raises ------ ValueError If the input image is grayscale. """ if img_rgb.mode == "L": raise ValueError("Input image must be RGB or RGBA, not grayscale (L mode)") if img_rgb.mode == "RGBA": red, green, blue, _ = img_rgb.split() img_rgb = Image.merge("RGB", (red, green, blue)) warnings.warn( "Input image is RGBA. Converting to RGB before LAB conversion. " "Alpha channel will be discarded.", stacklevel=2, ) img_arr = np.array(img_rgb) lab_arr = cast(np.ndarray, sk_color.rgb2lab(img_arr)) # type: ignore return lab_arr
[docs] @staticmethod def lab_to_rgb(img_lab: np.ndarray) -> Image.Image: """Convert LAB image to RGB color space. Parameters ---------- img_lab : np.ndarray Input image in LAB color space. Returns ------- Image.Image Image in RGB color space. """ rgb_arr = cast(np.ndarray, sk_color.lab2rgb(img_lab)) # type: ignore return np_to_pil(rgb_arr)