Source code for glasscut.stain_normalizers.base

# encoding: utf-8

"""Base classes for stain normalization.

This module provides abstract base classes and mixins for implementing
stain normalization strategies in histopathology image processing.
"""

from abc import ABC, abstractmethod
from typing import Any
from warnings import warn

import numpy as np
from PIL import Image

from glasscut.utils import np_to_pil


[docs] class StainNormalizer(ABC): """Abstract base class for stain normalization strategies. Stain normalization is a crucial preprocessing step in digital pathology to reduce color variations caused by staining protocols and imaging conditions. Subclasses must implement the `fit` and `transform` methods. """
[docs] @abstractmethod def fit(self, target_image: Image.Image, **kwargs: Any) -> None: """Fit the normalizer to a target image. Parameters ---------- target_image : Image.Image Target reference image for normalization. **kwargs Additional arguments specific to the normalization method. """ pass
[docs] @abstractmethod def transform(self, image: Image.Image, **kwargs: Any) -> Image.Image: """Apply stain normalization to an image. Parameters ---------- image : Image.Image Image to normalize. **kwargs Additional arguments specific to the normalization method. Returns ------- Image.Image Normalized image. """ pass
[docs] class TransformerStainMatrixMixin: """Mixin implementing fit/transform for matrix-based stain normalizers. This mixin assumes the subclass implements a `stain_matrix` method that returns a 3×3 stain matrix. """
[docs] def fit(self, target_image: Image.Image, **kwargs: Any) -> None: """Fit stain normalizer using target image. Parameters ---------- target_image : Image.Image Target image for stain normalization. Can be RGB or RGBA. **kwargs Additional arguments passed to stain_matrix method. Commonly includes background_intensity (int, default 240). """ background_intensity = kwargs.get("background_intensity", 240) self.stain_matrix_target = self.stain_matrix( target_image, background_intensity=background_intensity, **kwargs ) target_concentrations = self._find_concentrations( target_image, self.stain_matrix_target, background_intensity ) self.max_concentrations_target = np.percentile( target_concentrations, 99, axis=1 )
[docs] def transform(self, image: Image.Image, **kwargs: Any) -> Image.Image: """Normalize staining of image. Parameters ---------- image : Image.Image Image to normalize. Can be RGB or RGBA. **kwargs Additional arguments passed to stain_matrix method. Commonly includes background_intensity (int, default 240). Returns ------- Image.Image Image with normalized stain. """ background_intensity = kwargs.get("background_intensity", 240) stain_matrix_source = self.stain_matrix( image, background_intensity=background_intensity, **kwargs ) source_concentrations = self._find_concentrations( image, stain_matrix_source, background_intensity ) max_concentrations_source = np.percentile(source_concentrations, 99, axis=1) max_concentrations_source = np.divide( max_concentrations_source, self.max_concentrations_target ) conc_tmp = np.divide( source_concentrations, max_concentrations_source[:, np.newaxis] ) img_norm = np.multiply( background_intensity, np.exp(-self.stain_matrix_target.dot(conc_tmp)) ) img_norm = np.clip(img_norm, a_min=None, a_max=255) img_norm = np.reshape(img_norm.T, (*image.size[::-1], 3)) return np_to_pil(img_norm)
@staticmethod def _find_concentrations( img_rgb: Image.Image, stain_matrix: np.ndarray, background_intensity: int = 240, ) -> np.ndarray: """Return concentrations of individual stains in image. Parameters ---------- img_rgb : Image.Image Input image. stain_matrix : np.ndarray Stain matrix of image. Shape (3, 3). background_intensity : int, optional Background transmitted light intensity. Default is 240. Returns ------- np.ndarray Concentrations of individual stains. Shape (3, n_pixels). Notes ----- Uses least squares to solve the underdetermined system: stain_matrix @ concentrations = optical_density """ if img_rgb.mode == "RGBA": red, green, blue, _ = img_rgb.split() img_rgb = Image.merge("RGB", (red, green, blue)) warn( "Input image is RGBA. Converting to RGB before OD conversion. " "Alpha channel will be discarded." ) img_arr = np.array(img_rgb) od = -np.log((img_arr.astype(np.float64) + 1) / background_intensity) # rows correspond to channels (RGB), columns to OD values od = np.reshape(od, (-1, 3)).T # determine concentrations of the individual stains return np.linalg.lstsq(stain_matrix, od, rcond=None)[0]
[docs] @abstractmethod def stain_matrix( self, img_rgb: Image.Image, background_intensity: int = 240, **kwargs: Any, ) -> np.ndarray: """Calculate stain matrix for image. Parameters ---------- img_rgb : Image.Image Input image. background_intensity : int, optional Background transmitted light intensity. Default is 240. **kwargs Additional arguments specific to the normalization method. Returns ------- np.ndarray Stain matrix of image. Shape (3, 3). """ pass