Source code for cellmil.models.segmentation.hovernet

# -*- coding: utf-8 -*-
# HoVerNet Model Implementation
#
# References:
# Building tools for machine learning and artificial intelligence in cancer research: best practices and a case study with the PathML toolkit for computational pathology
# Rosenthal, J. et al., Molecular Cancer Research, 2022

"""
Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine
License: GNU GPL 2.0
"""

import cv2
import numpy as np
import torch
from scipy.ndimage import binary_fill_holes  # type: ignore
from skimage.segmentation import watershed  # type: ignore
from torch import nn
from torch.nn import functional as F
from typing import List, Tuple, Any

from cellmil.utils import logger
from .utils.pathml import center_crop_im_batch, dice_loss, get_sobel_kernels
from .utils.post_proc_hv import DetectionCellPostProcessorHV


[docs]class _BatchNormRelu(nn.Module): """BatchNorm + Relu layer"""
[docs] def __init__(self, n_channels: int): super(_BatchNormRelu, self).__init__() # type: ignore self.batch_norm = nn.BatchNorm2d(n_channels) self.relu = nn.ReLU()
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: return self.relu(self.batch_norm(inputs))
[docs]class _HoVerNetResidualUnit(nn.Module): """ Residual unit. See: Fig. 2(a) from Graham et al. 2019 HoVer-Net paper. This unit is not preactivated! That's handled when assembling units into blocks. output_channels corresponds to m in the figure """
[docs] def __init__(self, input_channels: int, output_channels: int, stride: int): super(_HoVerNetResidualUnit, self).__init__() # type: ignore internal_channels = output_channels // 4 if stride != 1 or input_channels != output_channels: self.convshortcut = nn.Conv2d( input_channels, output_channels, kernel_size=1, stride=stride, padding=0, dilation=1, bias=False, ) else: self.convshortcut = None self.conv1 = nn.Conv2d( input_channels, internal_channels, kernel_size=1, bias=False ) self.bnrelu1 = _BatchNormRelu(internal_channels) self.conv2 = nn.Conv2d( internal_channels, internal_channels, kernel_size=3, stride=stride, padding=1, bias=False, ) self.bnrelu2 = _BatchNormRelu(internal_channels) self.conv3 = nn.Conv2d( internal_channels, output_channels, kernel_size=1, bias=False )
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: skip = self.convshortcut(inputs) if self.convshortcut else inputs out = self.conv1(inputs) out = self.bnrelu1(out) out = self.conv2(out) out = self.bnrelu2(out) out = self.conv3(out) out = out + skip return out
[docs]def _make_HoVerNet_residual_block( input_channels: int, output_channels: int, stride: int, n_units: int ): """ Stack multiple residual units into a block. output_channels is given as m in Fig. 2 from Graham et al. 2019 paper """ units: list[_HoVerNetResidualUnit | _BatchNormRelu] = [] # first unit in block is different units.append(_HoVerNetResidualUnit(input_channels, output_channels, stride)) for _ in range(n_units - 1): units.append(_HoVerNetResidualUnit(output_channels, output_channels, stride=1)) # add a final activation ('preact' for the next unit) # This is different from how authors implemented - they added BNRelu before all units except the first, plus # a final one at the end. # I think this is equivalent to just adding a BNRelu after each unit units.append(_BatchNormRelu(output_channels)) return nn.Sequential(*units)
[docs]class _HoVerNetEncoder(nn.Module): """ Encoder for HoVer-Net. 7x7 conv, then four residual blocks, then 1x1 conv. BatchNormRelu after first convolution, based on code from authors, see: (https://github.com/vqdang/hover_net/blob/5d1560315a3de8e7d4c8122b97b1fe9b9513910b/src/model/graph.py#L67) Return a list of the outputs from each residual block, for later skip connections """
[docs] def __init__(self): super(_HoVerNetEncoder, self).__init__() # type: ignore self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, padding=3) self.bnrelu1 = _BatchNormRelu(64) self.block1 = _make_HoVerNet_residual_block( input_channels=64, output_channels=256, stride=1, n_units=3 ) self.block2 = _make_HoVerNet_residual_block( input_channels=256, output_channels=512, stride=2, n_units=4 ) self.block3 = _make_HoVerNet_residual_block( input_channels=512, output_channels=1024, stride=2, n_units=6 ) self.block4 = _make_HoVerNet_residual_block( input_channels=1024, output_channels=2048, stride=2, n_units=3 ) self.conv2 = nn.Conv2d( in_channels=2048, out_channels=1024, kernel_size=1, padding=0 )
[docs] def forward(self, inputs: torch.Tensor) -> List[torch.Tensor]: out1 = self.conv1(inputs) out1 = self.bnrelu1(out1) out1 = self.block1(out1) out2 = self.block2(out1) out3 = self.block3(out2) out4 = self.block4(out3) out4 = self.conv2(out4) return [out1, out2, out3, out4]
[docs]class _HoVerNetDenseUnit(nn.Module): """ Dense unit. See: Fig. 2(b) from Graham et al. 2019 HoVer-Net paper. """
[docs] def __init__(self, input_channels: int): super(_HoVerNetDenseUnit, self).__init__() # type: ignore self.bnrelu1 = _BatchNormRelu(input_channels) self.conv1 = nn.Conv2d( in_channels=input_channels, out_channels=128, kernel_size=1 ) self.bnrelu2 = _BatchNormRelu(128) self.conv2 = nn.Conv2d( in_channels=128, out_channels=32, kernel_size=5, padding=2 )
[docs] def forward(self, inputs: torch.Tensor) -> torch.Tensor: out = self.bnrelu1(inputs) out = self.conv1(out) out = self.bnrelu2(out) out = self.conv2(out) # need to make sure that inputs have same shape as out, so that we can concat cropdims = (inputs.size(2) - out.size(2), inputs.size(3) - out.size(3)) inputs_cropped = center_crop_im_batch(inputs, dims=cropdims) out = torch.cat((inputs_cropped, out), dim=1) return out
[docs]def _make_HoVerNet_dense_block(input_channels: int, n_units: int): """ Stack multiple dense units into a block. """ units: list[_HoVerNetDenseUnit | _BatchNormRelu] = [] in_dim = input_channels for _ in range(n_units): units.append(_HoVerNetDenseUnit(in_dim)) in_dim += 32 units.append(_BatchNormRelu(in_dim)) return nn.Sequential(*units)
[docs]class _HoverNetDecoder(nn.Module): """ One of the three identical decoder branches. """
[docs] def __init__(self): super(_HoverNetDecoder, self).__init__() # type: ignore self.upsample1 = nn.Upsample(scale_factor=2) self.conv1 = nn.Conv2d( in_channels=1024, out_channels=256, kernel_size=5, padding=2, stride=1, bias=False, ) self.dense1 = _make_HoVerNet_dense_block(input_channels=256, n_units=8) self.conv2 = nn.Conv2d( in_channels=512, out_channels=512, kernel_size=1, stride=1, bias=False ) self.upsample2 = nn.Upsample(scale_factor=2) self.conv3 = nn.Conv2d( in_channels=512, out_channels=128, kernel_size=5, padding=2, stride=1, bias=False, ) self.dense2 = _make_HoVerNet_dense_block(input_channels=128, n_units=4) self.conv4 = nn.Conv2d( in_channels=256, out_channels=256, kernel_size=1, stride=1, bias=False ) self.upsample3 = nn.Upsample(scale_factor=2) self.conv5 = nn.Conv2d( in_channels=256, out_channels=64, kernel_size=5, stride=1, bias=False, padding=2, )
[docs] def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Inputs should be a list of the outputs from each residual block, so that we can use skip connections """ block1_out, block2_out, block3_out, block4_out = inputs out = self.upsample1(block4_out) # skip connection addition out = out + block3_out out = self.conv1(out) out = self.dense1(out) out = self.conv2(out) out = self.upsample2(out) # skip connection out = out + block2_out out = self.conv3(out) out = self.dense2(out) out = self.conv4(out) out = self.upsample3(out) # last skip connection out = out + block1_out out = self.conv5(out) return out
[docs]class HoVerNet(nn.Module): """ Model for simultaneous segmentation and classification based on HoVer-Net. Can also be used for segmentation only, if class labels are not supplied. Each branch returns logits. Args: n_classes (int): Number of classes for classification task. If ``None`` then the classification branch is not used. References: Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019. Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. Medical Image Analysis, 58, p.101563. """
[docs] def __init__(self, n_classes: int | None = 6): super().__init__() # type: ignore self.n_classes = n_classes self.encoder = _HoVerNetEncoder() # NP branch (nuclear pixel) self.np_branch = _HoverNetDecoder() # classification head self.np_head = nn.Sequential( # two channels in output - background prob and pixel prob nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1) ) # HV branch (horizontal vertical) self.hv_branch = _HoverNetDecoder() # hv = horizontal vertical # classification head self.hv_head = nn.Sequential( # two channels in output - horizontal and vertical nn.Conv2d(in_channels=64, out_channels=2, kernel_size=1) ) # NC branch (nuclear classification) # If n_classes is none, then we are in nucleus detection, not classification, so we don't use this branch if self.n_classes is not None: self.nc_branch = _HoverNetDecoder() # classification head self.nc_head = nn.Sequential( # one channel in output for each class nn.Conv2d(in_channels=64, out_channels=self.n_classes, kernel_size=1) )
[docs] def forward(self, inputs: torch.Tensor) -> dict[str, torch.Tensor]: encoded = self.encoder(inputs) """for i, block_output in enumerate(encoded): print(f"block {i} output shape: {block_output.shape}")""" out_np = self.np_branch(encoded) out_np = self.np_head(out_np) out_hv = self.hv_branch(encoded) out_hv = self.hv_head(out_hv) outputs = [out_np, out_hv] if self.n_classes is not None: out_nc = self.nc_branch(encoded) out_nc = self.nc_head(out_nc) outputs.append(out_nc) out_dict: dict[str, torch.Tensor] = {} num_tissue_classes = 19 out_dict["tissue_types"] = torch.zeros( inputs.shape[0], num_tissue_classes, device=inputs.device, dtype=inputs.dtype, ) out_dict["nuclei_binary_map"] = outputs[0] out_dict["hv_map"] = outputs[1] out_dict["nuclei_type_map"] = outputs[2] return out_dict
# TODO: Here is the magic I have to put on cellpose also
[docs] def calculate_instance_map( self, predictions: dict[str, torch.Tensor], magnification: int | float = 40 ) -> Tuple[torch.Tensor, list[dict[np.int32, dict[str, Any]]]]: """Calculate Instance Map from network predictions (after Softmax output) Args: predictions (dict): Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (batch_size, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (batch_size, H, W, 6) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (batch_size, H, W, 2) magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40. Returns: Tuple[torch.Tensor, List[dict]]: * torch.Tensor: Instance map. Each Instance has own integer. Shape: (batch_size, H, W) * List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus. For each nucleus, the following information are returned: "bbox", "centroid", "contour", "type_prob", "type" """ # reshape to B, H, W, C predictions_ = predictions.copy() predictions_["nuclei_type_map"] = predictions_["nuclei_type_map"].permute( 0, 2, 3, 1 ) predictions_["nuclei_binary_map"] = predictions_["nuclei_binary_map"].permute( 0, 2, 3, 1 ) predictions_["hv_map"] = predictions_["hv_map"].permute(0, 2, 3, 1) predictions = predictions_ cell_post_processor = DetectionCellPostProcessorHV( nr_types=self.n_classes, magnification=magnification, gt=False ) instance_preds: list[np.ndarray[Any, Any]] = [] type_preds: list[dict[np.int32, dict[str, Any]]] = [] for i in range(predictions["nuclei_binary_map"].shape[0]): pred_map: np.ndarray[Any, Any] = np.concatenate( # type: ignore [ torch.argmax(predictions["nuclei_type_map"], dim=-1)[i] .detach() .cpu()[..., None], torch.argmax(predictions["nuclei_binary_map"], dim=-1)[i] .detach() .cpu()[..., None], predictions["hv_map"][i].detach().cpu(), ], axis=-1, ) instance_pred = cell_post_processor.post_process_cell_segmentation(pred_map) instance_preds.append(instance_pred[0]) type_preds.append(instance_pred[1]) return torch.Tensor(np.stack(instance_preds)), type_preds
# loss functions and associated utils
[docs]def _convert_multiclass_mask_to_binary(mask: torch.Tensor) -> torch.Tensor: """ Input mask of shape (B, n_classes, H, W) is converted to a mask of shape (B, 1, H, W). The last channel is assumed to be background, so the binary mask is computed by taking its inverse. """ m = torch.tensor(1) - mask[:, -1, :, :] m = m.unsqueeze(dim=1) return m
[docs]def _dice_loss_np_head( np_out: torch.Tensor, true_mask: torch.Tensor, epsilon: float = 1e-3 ): """ Dice loss term for nuclear pixel branch. This will compute dice loss for the entire batch (not the same as computing dice loss for each image and then averaging!) Args: np_out: logit outputs of np branch. Tensor of shape (B, 2, H, W) true_mask: True mask. Tensor of shape (B, n_classes, H, W) epsilon (float): Epsilon passed to ``dice_loss()`` """ # get logits for only the channel corresponding to prediction of 1 # unsqueeze to keep the dimensions the same preds = np_out[:, 1, :, :].unsqueeze(dim=1) true_mask = _convert_multiclass_mask_to_binary(true_mask) true_mask = true_mask.type(torch.long) loss = dice_loss(logits=preds, true=true_mask, eps=epsilon) return loss
[docs]def _dice_loss_nc_head( nc_out: torch.Tensor, true_mask: torch.Tensor, epsilon: float = 1e-3 ): """ Dice loss term for nuclear classification branch. Computes dice loss for each channel, and sums up. This will compute dice loss for the entire batch (not the same as computing dice loss for each image and then averaging!) Args: nc_out: logit outputs of nc branch. Tensor of shape (B, n_classes, H, W) true_mask: True mask. Tensor of shape (B, n_classes, H, W) epsilon (float): Epsilon passed to ``dice_loss()`` """ truth = torch.argmax(true_mask, dim=1, keepdim=True).type(torch.long) loss = dice_loss(logits=nc_out, true=truth, eps=epsilon) return loss
[docs]def _ce_loss_nc_head(nc_out: torch.Tensor, true_mask: torch.Tensor): """ Cross-entropy loss term for nc branch. Args: nc_out: logit outputs of nc branch. Tensor of shape (B, n_classes, H, W) true_mask: True mask. Tensor of shape (B, n_classes, H, W) """ truth = torch.argmax(true_mask, dim=1).type(torch.long) ce = nn.CrossEntropyLoss() loss = ce(nc_out, truth) return loss
[docs]def _ce_loss_np_head(np_out: torch.Tensor, true_mask: torch.Tensor): """ Cross-entropy loss term for np branch. Args: np_out: logit outputs of np branch. Tensor of shape (B, 2, H, W) true_mask: True mask. Tensor of shape (B, n_classes, H, W) """ truth = ( _convert_multiclass_mask_to_binary(true_mask).type(torch.long).squeeze(dim=1) ) ce = nn.CrossEntropyLoss() loss = ce(np_out, truth) return loss
[docs]def compute_hv_map(mask: np.ndarray[Any, Any]) -> np.ndarray[Any, Any]: """ Preprocessing step for HoVer-Net architecture. Compute center of mass for each nucleus, then compute distance of each nuclear pixel to its corresponding center of mass. Nuclear pixel distances are normalized to (-1, 1). Background pixels are left as 0. Operates on a single mask. Can be used in Dataset object to make Dataloader compatible with HoVer-Net. Based on https://github.com/vqdang/hover_net/blob/195ed9b6cc67b12f908285492796fb5c6c15a000/src/loader/augs.py#L192 Args: mask (np.ndarray): Mask indicating individual nuclei. Array of shape (H, W), where each pixel is in {0, ..., n} with 0 indicating background pixels and {1, ..., n} indicating n unique nuclei. Returns: np.ndarray: array of hv maps of shape (2, H, W). First channel corresponds to horizontal and second vertical. """ assert mask.ndim == 2, ( f"Input mask has shape {mask.shape}. Expecting a mask with 2 dimensions (H, W)" ) out = np.zeros((2, mask.shape[0], mask.shape[1])) # each individual nucleus is indexed with a different number inst_list: list[int] = list(np.unique(mask)) # type: ignore try: inst_list.remove(0) # 0 is background # TODO: change to specific exception except Exception: logger.warning( "No pixels with 0 label. This means that there are no background pixels. This may indicate a problem. Ignore this warning if this is expected/intended." ) for inst_id in inst_list: # get the mask for the nucleus inst_map = mask == inst_id inst_map = inst_map.astype(np.uint8) contours, _ = cv2.findContours( # type: ignore inst_map, mode=cv2.RETR_LIST, method=cv2.CHAIN_APPROX_NONE ) # get center of mass coords mom = cv2.moments(contours[0]) # type: ignore com_x = mom["m10"] / (mom["m00"] + 1e-6) com_y = mom["m01"] / (mom["m00"] + 1e-6) inst_com = (int(com_y), int(com_x)) inst_x_range = np.arange(1, inst_map.shape[1] + 1) # type: ignore inst_y_range = np.arange(1, inst_map.shape[0] + 1) # type: ignore # shifting center of pixels grid to instance center of mass inst_x_range -= inst_com[1] # type: ignore inst_y_range -= inst_com[0] # type: ignore inst_x, inst_y = np.meshgrid(inst_x_range, inst_y_range) # type: ignore # remove coord outside of instance inst_x[inst_map == 0] = 0 inst_y[inst_map == 0] = 0 inst_x = inst_x.astype("float32") inst_y = inst_y.astype("float32") # normalize min into -1 scale if np.min(inst_x) < 0: inst_x[inst_x < 0] /= -np.amin(inst_x[inst_x < 0]) if np.min(inst_y) < 0: inst_y[inst_y < 0] /= -np.amin(inst_y[inst_y < 0]) # normalize max into +1 scale if np.max(inst_x) > 0: inst_x[inst_x > 0] /= np.amax(inst_x[inst_x > 0]) if np.max(inst_y) > 0: inst_y[inst_y > 0] /= np.amax(inst_y[inst_y > 0]) # add to output mask # this works assuming background is 0, and each pixel is assigned to only one nucleus. out[0, :, :] += inst_x out[1, :, :] += inst_y return out
[docs]def _get_gradient_hv(hv_batch: torch.Tensor, kernel_size: int = 5): """ Calculate the horizontal partial differentiation for horizontal channel and the vertical partial differentiation for vertical channel. The partial differentiation is approximated by calculating the central differnce which is obtained by using Sobel kernel of size 5x5. The boundary is zero-padded when channel is convolved with the Sobel kernel. Args: hv_batch: tensor of shape (B, 2, H, W). Channel index 0 for horizonal maps and 1 for vertical maps. These maps are distance from each nuclear pixel to center of mass of corresponding nucleus. kernel_size (int): width of kernel to use for gradient approximation. Returns: Tuple of (h_grad, v_grad) where each is a Tensor giving horizontal and vertical gradients respectively """ assert hv_batch.shape[1] == 2, ( f"inputs have shape {hv_batch.shape}. Expecting tensor of shape (B, 2, H, W)" ) h_kernel, v_kernel = get_sobel_kernels(kernel_size, dt=hv_batch.dtype) # move kernels to same device as batch h_kernel = h_kernel.to(hv_batch.device) v_kernel = v_kernel.to(hv_batch.device) # add extra dims so we can convolve with a batch h_kernel = h_kernel.unsqueeze(0).unsqueeze(0) v_kernel = v_kernel.unsqueeze(0).unsqueeze(0) # get the inputs for the h and v channels h_inputs = hv_batch[:, 0, :, :].unsqueeze(dim=1) v_inputs = hv_batch[:, 1, :, :].unsqueeze(dim=1) h_grad = F.conv2d(h_inputs, h_kernel, stride=1, padding=2) v_grad = F.conv2d(v_inputs, v_kernel, stride=1, padding=2) del h_kernel del v_kernel return h_grad, v_grad
[docs]def _loss_hv_grad( hv_out: torch.Tensor, true_hv: torch.Tensor, nucleus_pixel_mask: torch.Tensor ): """ Equation 3 from HoVer-Net paper for calculating loss for HV predictions. Mask is used to compute the hv loss ONLY for nuclear pixels Args: hv_out: Ouput of hv branch. Tensor of shape (B, 2, H, W) true_hv: Ground truth hv maps. Tensor of shape (B, 2, H, W) nucleus_pixel_mask: Boolean mask indicating nuclear pixels. Tensor of shape (B, H, W) """ pred_grad_h, pred_grad_v = _get_gradient_hv(hv_out) true_grad_h, true_grad_v = _get_gradient_hv(true_hv) # pull out only the values from nuclear pixels pred_h = torch.masked_select(pred_grad_h, mask=nucleus_pixel_mask) true_h = torch.masked_select(true_grad_h, mask=nucleus_pixel_mask) pred_v = torch.masked_select(pred_grad_v, mask=nucleus_pixel_mask) true_v = torch.masked_select(true_grad_v, mask=nucleus_pixel_mask) loss_h = F.mse_loss(pred_h, true_h) loss_v = F.mse_loss(pred_v, true_v) loss = loss_h + loss_v return loss
[docs]def _loss_hv_mse(hv_out: torch.Tensor, true_hv: torch.Tensor): """ Equation 2 from HoVer-Net paper for calculating loss for HV predictions. Args: hv_out: Ouput of hv branch. Tensor of shape (B, 2, H, W) true_hv: Ground truth hv maps. Tensor of shape (B, 2, H, W) """ loss = F.mse_loss(hv_out, true_hv) return loss
[docs]def loss_hovernet( outputs: list[torch.Tensor], ground_truth: list[torch.Tensor], n_classes: int | None = None, ): """ Compute loss for HoVer-Net. Equation (1) in Graham et al. Args: outputs: Output of HoVer-Net. Should be a list of [np, hv] if n_classes is None, or a list of [np, hv, nc] if n_classes is not None. Shapes of each should be: - np: (B, 2, H, W) - hv: (B, 2, H, W) - nc: (B, n_classes, H, W) ground_truth: True labels. Should be a list of [mask, hv], where mask is a Tensor of shape (B, 1, H, W) if n_classes is ``None`` or (B, n_classes, H, W) if n_classes is not ``None``. hv is a tensor of precomputed horizontal and vertical distances of nuclear pixels to their corresponding centers of mass, and is of shape (B, 2, H, W). n_classes (int): Number of classes for classification task. If ``None`` then the classification branch is not used. References: Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019. Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. Medical Image Analysis, 58, p.101563. """ true_mask, true_hv = ground_truth # unpack outputs, and also calculate nucleus masks if n_classes is None: np_out, hv = outputs nc = None nucleus_mask = true_mask[:, 0, :, :] == 1 else: np_out, hv, nc = outputs # in multiclass setting, last channel of masks indicates background, so # invert that to get a nucleus mask (Based on convention from PanNuke dataset) nucleus_mask = true_mask[:, -1, :, :] == 0 # from Eq. 1 in HoVer-Net paper, loss function is composed of two terms for each branch. np_loss_dice = _dice_loss_np_head(np_out, true_mask) np_loss_ce = _ce_loss_np_head(np_out, true_mask) hv_loss_grad = _loss_hv_grad(hv, true_hv, nucleus_mask) hv_loss_mse = _loss_hv_mse(hv, true_hv) # authors suggest using coefficient of 2 for hv gradient loss term hv_loss_grad = 2 * hv_loss_grad if n_classes is not None and nc is not None: nc_loss_dice = _dice_loss_nc_head(nc, true_mask) nc_loss_ce = _ce_loss_nc_head(nc, true_mask) else: nc_loss_dice = 0 nc_loss_ce = 0 loss = ( np_loss_dice + np_loss_ce + hv_loss_mse + hv_loss_grad + nc_loss_dice + nc_loss_ce ) return loss
# Post-processing of HoVer-Net outputs
[docs]def remove_small_objs(array_in: np.ndarray[Any, Any], min_size: int): """ Removes small foreground regions from binary array, leaving only the contiguous regions which are above the size threshold. Pixels in regions below the size threshold are zeroed out. Args: array_in (np.ndarray): Input array. Must be binary array with dtype=np.uint8. min_size (int): Minimum size of each region. Returns: np.ndarray: Array of labels for regions above the threshold. Each separate contiguous region is labelled with a different integer from 1 to n, where n is the number of total distinct contiguous regions """ assert array_in.dtype == np.uint8, ( f"Input dtype is {array_in.dtype}. Must be np.uint8" ) # remove elements below size threshold # each contiguous nucleus region gets a unique id n_labels, labels = cv2.connectedComponents(array_in) # each integer is a different nucleus, so bincount gives nucleus sizes sizes = np.bincount(labels.flatten()) for nucleus_ix, size_ix in zip(range(n_labels), sizes): if size_ix < min_size: # below size threshold - set all to zero labels[labels == nucleus_ix] = 0 return labels
[docs]def _post_process_single_hovernet( np_out: torch.Tensor, hv_out: torch.Tensor, small_obj_size_thresh: int = 10, kernel_size: int = 21, h: float = 0.5, k: float = 0.5, ) -> np.ndarray[Any, Any]: """ Combine predictions of np channel and hv channel to create final predictions. Works by creating energy landscape from gradients, and the applying watershed segmentation. This function works on a single image and is wrapped in ``post_process_batch_hovernet()`` to apply across a batch. See: Section B of HoVer-Net article and https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L27 Args: np_out (torch.Tensor): Output of NP branch. Tensor of shape (2, H, W) of logit predictions for binary classification hv_out (torch.Tensor): Output of HV branch. Tensor of shape (2, H, W) of predictions for horizontal/vertical maps small_obj_size_thresh (int): Minimum number of pixels in regions. Defaults to 10. kernel_size (int): Width of Sobel kernel used to compute horizontal and vertical gradients. h (float): hyperparameter for thresholding nucleus probabilities. Defaults to 0.5. k (float): hyperparameter for thresholding energy landscape to create markers for watershed segmentation. Defaults to 0.5. """ # compute pixel probabilities from logits, apply threshold, and get into np array np_preds = F.softmax(np_out, dim=0)[1, :, :] np_preds: Any = np_preds.numpy() # type: ignore np_preds[np_preds >= h] = 1 np_preds[np_preds < h] = 0 np_preds = np_preds.astype(np.uint8) np_preds = remove_small_objs(np_preds, min_size=small_obj_size_thresh) # Back to binary. now np_preds corresponds to tau(q, h) from HoVer-Net paper np_preds[np_preds > 0] = 1 tau_q_h = np_preds # normalize hv predictions, and compute horizontal and vertical gradients, and normalize again hv_out = hv_out.numpy().astype(np.float32) # type: ignore h_out = hv_out[0, ...] v_out = hv_out[1, ...] # https://stackoverflow.com/a/39037135 h_normed = cv2.normalize( # type: ignore h_out, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, # type: ignore ) v_normed = cv2.normalize( # type: ignore v_out, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, # type: ignore ) h_grad = cv2.Sobel(h_normed, cv2.CV_64F, dx=1, dy=0, ksize=kernel_size) # type: ignore v_grad = cv2.Sobel(v_normed, cv2.CV_64F, dx=0, dy=1, ksize=kernel_size) # type: ignore h_grad = cv2.normalize( # type: ignore h_grad, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, # type: ignore ) v_grad = cv2.normalize( # type: ignore v_grad, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F, # type: ignore ) # flip the gradient direction so that highest values are steepest gradient h_grad = 1 - h_grad # type: ignore v_grad = 1 - v_grad # type: ignore S_m = np.maximum(h_grad, v_grad) # type: ignore S_m[tau_q_h == 0] = 0 # energy landscape # note that the paper says that they use E = (1 - tau(S_m, k)) * tau(q, h) # but in the authors' code the actually use: E = (1 - S_m) * tau(q, h) # this actually makes more sense because no need to threshold the energy surface energy = (1.0 - S_m) * tau_q_h # get markers # In the paper it says they use M = sigma(tau(q, h) - tau(S_m, k)) # But it makes more sense to threshold the energy landscape to get the peaks of hills. # Also, the fact they used sigma in the paper makes me think that this is what they intended, m = np.array(energy >= k, dtype=np.uint8) m = binary_fill_holes(m).astype(np.uint8) # type: ignore m = remove_small_objs(m, min_size=small_obj_size_thresh) # type: ignore # nuclei values form mountains so inverse to get basins for watershed energy = -cv2.GaussianBlur(energy, (3, 3), 0) # type: ignore out = watershed(image=energy, markers=m, mask=tau_q_h) # type: ignore return out # type: ignore
[docs]def post_process_batch_hovernet( outputs: list[torch.Tensor], n_classes: int | None, small_obj_size_thresh: int = 10, kernel_size: int = 21, h: float = 0.5, k: float = 0.5, return_nc_out_preds: bool = False, ) -> ( np.ndarray[Any, Any] | Tuple[np.ndarray[Any, Any], np.ndarray[Any, Any]] | Tuple[np.ndarray[Any, Any], np.ndarray[Any, Any], np.ndarray[Any, Any]] ): """ Post-process HoVer-Net outputs to get a final predicted mask. See: Section B of HoVer-Net article and https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L27 Args: outputs (list): Outputs of HoVer-Net model. List of [np_out, hv_out], or [np_out, hv_out, nc_out] depending on whether model is predicting classification or not. - np_out is a Tensor of shape (B, 2, H, W) of logit predictions for binary classification - hv_out is a Tensor of shape (B, 2, H, W) of predictions for horizontal/vertical maps - nc_out is a Tensor of shape (B, n_classes, H, W) of logits for classification n_classes (int): Number of classes for classification task. If ``None`` then only segmentation is performed. small_obj_size_thresh (int): Minimum number of pixels in regions. Defaults to 10. kernel_size (int): Width of Sobel kernel used to compute horizontal and vertical gradients. h (float): hyperparameter for thresholding nucleus probabilities. Defaults to 0.5. k (float): hyperparameter for thresholding energy landscape to create markers for watershed segmentation. Defaults to 0.5. Returns: np.ndarray: If n_classes is None, returns det_out. In classification setting, returns (det_out, class_out). - det_out is np.ndarray of shape (B, H, W) - class_out is np.ndarray of shape (B, n_classes, H, W) Each pixel is labelled from 0 to n, where n is the number of individual nuclei detected. 0 pixels indicate background. Pixel values i indicate that the pixel belongs to the ith nucleus. Modified previous method to output nc_out_preds. """ # Check if outputs are tensors and convert to NumPy if so outputs = [ output.detach().cpu() if output.requires_grad else output.cpu() for output in outputs ] assert len(outputs) in { 2, 3, }, "outputs must have size 2 (for segmentation) or 3 (for classification)" np_out, hv_out = outputs[:2] # Check if classification is needed classification = n_classes is not None and len(outputs) == 3 batchsize = hv_out.shape[0] # first get the nucleus detection preds out_detection_list: list[np.ndarray[Any, Any]] = [] for i in range(batchsize): preds = _post_process_single_hovernet( np_out[i, ...], hv_out[i, ...], small_obj_size_thresh, kernel_size, h, k ) out_detection_list.append(preds) out_detection = np.stack(out_detection_list) if classification: nc_out = outputs[2] # need to do last step of majority vote # get the pixel-level class predictions from the logits nc_out_preds: np.ndarray[Any, Any] = ( F.softmax(nc_out, dim=1).argmax(dim=1).numpy() ) # type: ignore out_classification = np.zeros_like(nc_out, dtype=np.uint8) for batch_ix, nuc_preds in enumerate(out_detection_list): # get labels of nuclei from nucleus detection nucleus_labels = list(np.unique(nuc_preds)) # type: ignore if 0 in nucleus_labels: nucleus_labels.remove(0) nucleus_class_preds = nc_out_preds[batch_ix, ...] out_class_preds_single = out_classification[batch_ix, ...] # for each nucleus, get the class predictions for the pixels and take a vote for nucleus_ix in nucleus_labels: # get mask for the specific nucleus ix_mask = nuc_preds == nucleus_ix votes = nucleus_class_preds[ix_mask] majority_class = np.argmax(np.bincount(votes)) out_class_preds_single[majority_class][ix_mask] = nucleus_ix out_classification[batch_ix, ...] = out_class_preds_single if return_nc_out_preds: return out_detection, out_classification, nc_out_preds else: return out_detection, out_classification else: return out_detection
[docs]def extract_nuclei_info( pred_inst: np.ndarray[Any, Any], pred_type: np.ndarray[Any, Any] | None = None ) -> dict[int, dict[str, Any]]: """ Extract centroids, type, and probability of the type for each cell in the nuclei mask. This function is adapted from the Hover-Net post-processing implementation. Original source: https://github.com/vqdang/hover_net/blob/14c5996fa61ede4691e87905775e8f4243da6a62/models/hovernet/post_proc.py#L94 Args: pred_inst (np.ndarray): Instance mask of nuclei. pred_type (np.ndarray): Optional type mask for nuclei. Returns: dict: Information about each cell, including centroid, type, and probability. """ inst_id_list = np.unique(pred_inst)[1:] # exclude background inst_info_dict: dict[int, dict[str, Any]] = {} for inst_id in inst_id_list: inst_map = pred_inst == inst_id inst_moment = cv2.moments(inst_map.astype(np.uint8)) if inst_moment["m00"] == 0: continue inst_centroid = [ inst_moment["m10"] / inst_moment["m00"], inst_moment["m01"] / inst_moment["m00"], ] inst_info: dict[str, Any] = { "centroid": inst_centroid, "type": None, "prob": None, } inst_info_dict[inst_id] = inst_info if pred_type is not None: inst_type = pred_type[inst_map] type_list, type_pixels = np.unique(inst_type, return_counts=True) type_list = sorted( zip(type_list, type_pixels), key=lambda x: x[1], reverse=True ) inst_type = type_list[0][0] if inst_type == 0 and len(type_list) > 1: inst_type = type_list[1][0] type_dict = {v[0]: v[1] for v in type_list} type_prob = type_dict[inst_type] / (np.sum(inst_map) + 1.0e-6) inst_info_dict[inst_id]["type"] = int(inst_type) inst_info_dict[inst_id]["prob"] = float(type_prob) return inst_info_dict
[docs]def group_centroids_by_type(cell_dict: dict[str, Any], prob_threshold: float): """ Group centroids by cell type for cells above a certain probability threshold. Args: cell_dict (dict): Dictionary containing cell information. prob_threshold (float): Minimum probability threshold for a cell to be considered. Returns: dict: A dictionary with cell types as keys and lists of centroids as values. """ grouped_centroids: dict[Any, Any] = {} for cell_info in cell_dict.values(): if cell_info["prob"] >= prob_threshold: cell_type = cell_info["type"] if cell_type not in grouped_centroids: grouped_centroids[cell_type] = [] grouped_centroids[cell_type].append(cell_info["centroid"]) return grouped_centroids