import warnings
from typing import Any, cast
import numpy as np
from PIL import Image
from skimage import color as sk_color
from glasscut.tile import Tile
from glasscut.tissue_detectors import OtsuTissueDetector
from glasscut.utils import np_to_pil
from .base import StainNormalizer
[docs]
class ReinhardtStainNormalizer(StainNormalizer):
"""Stain normalizer using E. Reinhardt et al.'s color transfer method.
This method normalizes stain appearance by matching the mean and standard
deviation of each channel in LAB color space between source and target images.
The normalization is performed only on tissue regions.
The algorithm is:
1. Identify tissue using tissue masking
2. Convert to LAB color space
3. Compute per-channel mean and std on tissue
4. Normalize source statistics to match target statistics
5. Convert back to RGB
Attributes
----------
target_means : np.ndarray or None
Target mean values per LAB channel.
target_stds : np.ndarray or None
Target standard deviation values per LAB channel.
Notes
-----
This method is computationally fast and suitable for real-time preview
during stain normalization parameter tuning. However, it may not preserve
color relationships as well as matrix-based methods for complex stains.
Examples
--------
>>> from PIL import Image
>>> from glasscut.stain_normalizers import ReinhardtStainNormalizer
>>> normalizer = ReinhardtStainNormalizer()
>>> ref_image = Image.open("reference.png")
>>> normalizer.fit(ref_image)
>>> test_image = Image.open("test.png")
>>> normalized_image = normalizer.transform(test_image)
"""
[docs]
def __init__(self):
"""Initialize ReinhardtStainNormalizer."""
warnings.warn(
"ReinhardtStainNormalizer is experimental and may produce errors or "
"unexpected results on certain images. Use with caution.",
UserWarning,
stacklevel=2,
)
self.target_means = None
self.target_stds = None
[docs]
def fit(self, target_image: Image.Image, **kwargs: Any) -> None:
"""Fit stain normalizer using target image.
Parameters
----------
target_image : Image.Image
Target image for stain normalization. Can be RGB or RGBA.
**kwargs
Additional arguments (unused for Reinhardt method).
"""
means, stds = self._summary_statistics(target_image)
self.target_means = means
self.target_stds = stds
def _summary_statistics(
self,
img_rgb: Image.Image,
*,
mask_2d: np.ndarray | None = None,
img_lab: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Compute mean and std of each LAB channel on tissue region.
Parameters
----------
img_rgb : Image.Image
Input image.
mask_2d : np.ndarray, optional
Pre-computed 2-D boolean tissue mask (H×W). When provided, avoids
a redundant tissue-detection pass.
img_lab : np.ndarray, optional
Pre-computed LAB array (H×W×3). When provided, avoids a redundant
colour-space conversion.
Returns
-------
np.ndarray
Mean of each channel in LAB space. Shape (3,).
np.ndarray
Standard deviation of each channel in LAB space. Shape (3,).
Notes
-----
Statistics are only computed on pixels identified as tissue
to avoid background artifacts.
"""
if mask_2d is None:
mask_2d = self._tissue_mask(img_rgb).astype(bool)
if img_lab is None:
img_lab = self.rgb_to_lab(img_rgb)
mask_3d = np.dstack((mask_2d, mask_2d, mask_2d))
if np.any(mask_3d):
mean_per_channel = img_lab.mean(axis=(0, 1), where=mask_3d)
std_per_channel = img_lab.std(axis=(0, 1), where=mask_3d)
else:
# Fallback avoids NaNs when no tissue is detected in a tile.
mean_per_channel = img_lab.mean(axis=(0, 1))
std_per_channel = img_lab.std(axis=(0, 1))
std_per_channel = np.maximum(std_per_channel, 1e-8)
return mean_per_channel, std_per_channel
@staticmethod
def _tissue_mask(img_rgb: Image.Image) -> np.ndarray:
"""Compute binary tissue mask for image.
Parameters
----------
img_rgb : Image.Image
Input image in RGB or RGBA format.
Returns
-------
np.ndarray
Binary tissue mask with same spatial dimensions as image.
1 = tissue, 0 = background.
"""
tile = Tile(
img_rgb,
coords=None,
magnification=None,
tissue_detector=OtsuTissueDetector(),
)
return tile.tissue_mask
# ==== helper implementations ====
[docs]
@staticmethod
def rgb_to_lab(img_rgb: Image.Image) -> np.ndarray:
"""Convert RGB image to LAB color space.
Parameters
----------
img_rgb : Image.Image
Input image in RGB or RGBA format.
Returns
-------
np.ndarray
Array representation of the image in LAB space.
Raises
------
ValueError
If the input image is grayscale.
"""
if img_rgb.mode == "L":
raise ValueError("Input image must be RGB or RGBA, not grayscale (L mode)")
if img_rgb.mode == "RGBA":
red, green, blue, _ = img_rgb.split()
img_rgb = Image.merge("RGB", (red, green, blue))
warnings.warn(
"Input image is RGBA. Converting to RGB before LAB conversion. "
"Alpha channel will be discarded.",
stacklevel=2,
)
img_arr = np.array(img_rgb)
lab_arr = cast(np.ndarray, sk_color.rgb2lab(img_arr)) # type: ignore
return lab_arr
[docs]
@staticmethod
def lab_to_rgb(img_lab: np.ndarray) -> Image.Image:
"""Convert LAB image to RGB color space.
Parameters
----------
img_lab : np.ndarray
Input image in LAB color space.
Returns
-------
Image.Image
Image in RGB color space.
"""
rgb_arr = cast(np.ndarray, sk_color.lab2rgb(img_lab)) # type: ignore
return np_to_pil(rgb_arr)