# -*- coding: utf-8 -*-
# CellViT Model Implementation
#
# References:
# CellViT: Vision Transformers for precise cell segmentation and classification
# Fabian Hörst et al., Medical Image Analysis, 2024
# DOI: https://doi.org/10.1016/j.media.2024.103143
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
from functools import partial
from typing import Any, Tuple, Union, Literal
from pathlib import Path
from .utils.post_proc_hv import DetectionCellPostProcessorHV
from .utils.cellvit_blocks import Conv2DBlock, Deconv2DBlock, ViTCellViT, ViTCellViTDeit
[docs]class CellViT(nn.Module):
"""CellViT Modell for cell segmentation. U-Net like network with vision transformer as backbone encoder
Skip connections are shared between branches, but each network has a distinct encoder
The modell is having multiple branches:
* tissue_types: Tissue prediction based on global class token
* nuclei_binary_map: Binary nuclei prediction
* hv_map: HV-prediction to separate isolated instances
* nuclei_type_map: Nuclei instance-prediction
* [Optional, if regression loss]:
* regression_map: Regression map for binary prediction
Args:
num_nuclei_classes (int): Number of nuclei classes (including background)
num_tissue_classes (int): Number of tissue classes
embed_dim (int): Embedding dimension of backbone ViT
input_channels (int): Number of input channels
depth (int): Depth of the backbone ViT
num_heads (int): Number of heads of the backbone ViT
extract_layers: (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth.
Is used for skip connections. At least 4 skip connections needs to be returned.
mlp_ratio (float, optional): MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4.
qkv_bias (bool, optional): If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True.
drop_rate (float, optional): Dropout in MLP. Defaults to 0.
attn_drop_rate (float, optional): Dropout for attention layer in backbone ViT. Defaults to 0.
drop_path_rate (float, optional): Dropout for skip connection . Defaults to 0.
regression_loss (bool, optional): Use regressive loss for predicting vector components.
Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
"""
[docs] def __init__(
self,
num_nuclei_classes: int,
num_tissue_classes: int,
embed_dim: int,
input_channels: int,
depth: int,
num_heads: int,
extract_layers: list[int],
mlp_ratio: float = 4,
qkv_bias: bool = True,
drop_rate: float = 0,
attn_drop_rate: float = 0,
drop_path_rate: float = 0,
regression_loss: bool = False,
):
# For simplicity, we will assume that extract layers must have a length of 4
super().__init__() # type: ignore[no-redef]
assert len(extract_layers) == 4, "Please provide 4 layers for skip connections"
self.patch_size = 16
self.num_tissue_classes = num_tissue_classes
self.num_nuclei_classes = num_nuclei_classes
self.embed_dim = embed_dim
self.input_channels = input_channels
self.depth = depth
self.num_heads = num_heads
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.extract_layers = extract_layers
self.drop_rate = drop_rate
self.attn_drop_rate = attn_drop_rate
self.drop_path_rate = drop_path_rate
self.encoder = ViTCellViT(
patch_size=self.patch_size,
num_classes=self.num_tissue_classes,
embed_dim=self.embed_dim,
depth=self.depth,
num_heads=self.num_heads,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
norm_layer=partial(nn.LayerNorm, eps=1e-6), # type: ignore[no-redef]
extract_layers=self.extract_layers,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
)
if self.embed_dim < 512:
self.skip_dim_11 = 256
self.skip_dim_12 = 128
self.bottleneck_dim = 312
else:
self.skip_dim_11 = 512
self.skip_dim_12 = 256
self.bottleneck_dim = 512
# version with shared skip_connections
self.decoder0 = nn.Sequential(
Conv2DBlock(3, 32, 3, dropout=self.drop_rate),
Conv2DBlock(32, 64, 3, dropout=self.drop_rate),
) # skip connection after positional encoding, shape should be H, W, 64
self.decoder1 = nn.Sequential(
Deconv2DBlock(self.embed_dim, self.skip_dim_11, dropout=self.drop_rate),
Deconv2DBlock(self.skip_dim_11, self.skip_dim_12, dropout=self.drop_rate),
Deconv2DBlock(self.skip_dim_12, 128, dropout=self.drop_rate),
) # skip connection 1
self.decoder2 = nn.Sequential(
Deconv2DBlock(self.embed_dim, self.skip_dim_11, dropout=self.drop_rate),
Deconv2DBlock(self.skip_dim_11, 256, dropout=self.drop_rate),
) # skip connection 2
self.decoder3 = nn.Sequential(
Deconv2DBlock(self.embed_dim, self.bottleneck_dim, dropout=self.drop_rate)
) # skip connection 3
self.regression_loss = regression_loss
offset_branches = 0
if self.regression_loss:
offset_branches = 2
self.branches_output = {
"nuclei_binary_map": 2 + offset_branches,
"hv_map": 2,
"nuclei_type_maps": self.num_nuclei_classes,
}
self.nuclei_binary_map_decoder = self.create_upsampling_branch(
2 + offset_branches
) # todo: adapt for helper loss
self.hv_map_decoder = self.create_upsampling_branch(
2
) # todo: adapt for helper loss
self.nuclei_type_maps_decoder = self.create_upsampling_branch(
self.num_nuclei_classes
)
[docs] def forward(self, x: torch.Tensor, retrieve_tokens: bool = False) -> dict[str, Any]:
"""Forward pass
Args:
x (torch.Tensor): Images in BCHW style
retrieve_tokens (bool, optional): If tokens of ViT should be returned as well. Defaults to False.
Returns:
dict: Output for all branches:
* tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
* nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
* hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
* nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
* [Optional, if retrieve tokens]: tokens
* [Optional, if regression loss]:
* regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
"""
assert (
x.shape[-2] % self.patch_size == 0
), "Img must have a shape of that is divisible by patch_size (token_size)"
assert (
x.shape[-1] % self.patch_size == 0
), "Img must have a shape of that is divisible by patch_size (token_size)"
out_dict: dict[str, Any] = {}
classifier_logits, _, z = self.encoder(x)
out_dict["tissue_types"] = classifier_logits
z0, z1, z2, z3, z4 = x, *z
# performing reshape for the convolutional layers and upsampling (restore spatial dimension)
patch_dim = [int(d / self.patch_size) for d in [x.shape[-2], x.shape[-1]]]
z4 = z4[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
z3 = z3[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
z2 = z2[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
z1 = z1[:, 1:, :].transpose(-1, -2).view(-1, self.embed_dim, *patch_dim)
if self.regression_loss:
nb_map = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_binary_map_decoder
)
out_dict["nuclei_binary_map"] = nb_map[:, :2, :, :]
out_dict["regression_map"] = nb_map[:, 2:, :, :]
else:
out_dict["nuclei_binary_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_binary_map_decoder
)
out_dict["hv_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.hv_map_decoder
)
out_dict["nuclei_type_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_type_maps_decoder
)
if retrieve_tokens:
out_dict["tokens"] = z4
return out_dict
[docs] def _forward_upsample(
self,
z0: torch.Tensor,
z1: torch.Tensor,
z2: torch.Tensor,
z3: torch.Tensor,
z4: torch.Tensor,
branch_decoder: nn.Sequential,
) -> torch.Tensor:
"""Forward upsample branch
Args:
z0 (torch.Tensor): Highest skip
z1 (torch.Tensor): 1. Skip
z2 (torch.Tensor): 2. Skip
z3 (torch.Tensor): 3. Skip
z4 (torch.Tensor): Bottleneck
branch_decoder (nn.Sequential): Branch decoder network
Returns:
torch.Tensor: Branch Output
"""
# Access modules by index in the sequential container
b4 = branch_decoder[0](z4) # bottleneck_upsampler
b3 = self.decoder3(z3)
b3 = branch_decoder[1](torch.cat([b3, b4], dim=1)) # decoder3_upsampler
b2 = self.decoder2(z2)
b2 = branch_decoder[2](torch.cat([b2, b3], dim=1)) # decoder2_upsampler
b1 = self.decoder1(z1)
b1 = branch_decoder[3](torch.cat([b1, b2], dim=1)) # decoder1_upsampler
b0 = self.decoder0(z0)
branch_output = branch_decoder[4](torch.cat([b0, b1], dim=1)) # decoder0_header
return branch_output
[docs] def create_upsampling_branch(self, num_classes: int) -> nn.Sequential:
"""Create Upsampling branch
Args:
num_classes (int): Number of output classes
Returns:
nn.Sequential: Upsampling path
"""
bottleneck_upsampler = nn.ConvTranspose2d(
in_channels=self.embed_dim,
out_channels=self.bottleneck_dim,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
)
decoder3_upsampler = nn.Sequential(
Conv2DBlock(
self.bottleneck_dim * 2, self.bottleneck_dim, dropout=self.drop_rate
),
Conv2DBlock(
self.bottleneck_dim, self.bottleneck_dim, dropout=self.drop_rate
),
Conv2DBlock(
self.bottleneck_dim, self.bottleneck_dim, dropout=self.drop_rate
),
nn.ConvTranspose2d(
in_channels=self.bottleneck_dim,
out_channels=256,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
),
)
decoder2_upsampler = nn.Sequential(
Conv2DBlock(256 * 2, 256, dropout=self.drop_rate),
Conv2DBlock(256, 256, dropout=self.drop_rate),
nn.ConvTranspose2d(
in_channels=256,
out_channels=128,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
),
)
decoder1_upsampler = nn.Sequential(
Conv2DBlock(128 * 2, 128, dropout=self.drop_rate),
Conv2DBlock(128, 128, dropout=self.drop_rate),
nn.ConvTranspose2d(
in_channels=128,
out_channels=64,
kernel_size=2,
stride=2,
padding=0,
output_padding=0,
),
)
decoder0_header = nn.Sequential(
Conv2DBlock(64 * 2, 64, dropout=self.drop_rate),
Conv2DBlock(64, 64, dropout=self.drop_rate),
nn.Conv2d(
in_channels=64,
out_channels=num_classes,
kernel_size=1,
stride=1,
padding=0,
),
)
decoder = nn.Sequential(
OrderedDict(
[
("bottleneck_upsampler", bottleneck_upsampler),
("decoder3_upsampler", decoder3_upsampler),
("decoder2_upsampler", decoder2_upsampler),
("decoder1_upsampler", decoder1_upsampler),
("decoder0_header", decoder0_header),
]
)
)
return decoder
[docs] def calculate_instance_map(
self, predictions: dict[str, torch.Tensor], magnification: int | float = 40
) -> Tuple[torch.Tensor, list[dict[np.int32, dict[str, Any]]]]:
"""Calculate Instance Map from network predictions (after Softmax output)
Args:
predictions (dict): Dictionary with the following required keys:
* nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W)
* nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W)
* hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W)
magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40.
Returns:
Tuple[torch.Tensor, List[dict]]:
* torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)
* List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.
For each nucleus, the following information are returned: "bbox", "centroid", "contour", "type_prob", "type"
"""
# reshape to B, H, W, C
predictions_ = predictions.copy()
predictions_["nuclei_type_map"] = predictions_["nuclei_type_map"].permute(
0, 2, 3, 1
)
predictions_["nuclei_binary_map"] = predictions_["nuclei_binary_map"].permute(
0, 2, 3, 1
)
predictions_["hv_map"] = predictions_["hv_map"].permute(0, 2, 3, 1)
cell_post_processor = DetectionCellPostProcessorHV(
nr_types=self.num_nuclei_classes, magnification=magnification, gt=False
)
instance_preds: list[np.ndarray[Any, Any]] = []
type_preds: list[dict[np.int32, dict[str, Any]]] = []
for i in range(predictions_["nuclei_binary_map"].shape[0]):
pred_map: np.ndarray[Any, Any] = np.concatenate( # type: ignore[no-redef]
[
torch.argmax(predictions_["nuclei_type_map"], dim=-1)[i]
.detach()
.cpu()[..., None],
torch.argmax(predictions_["nuclei_binary_map"], dim=-1)[i]
.detach()
.cpu()[..., None],
predictions_["hv_map"][i].detach().cpu(),
],
axis=-1,
)
instance_pred = cell_post_processor.post_process_cell_segmentation(pred_map)
instance_preds.append(instance_pred[0])
type_preds.append(instance_pred[1])
return torch.Tensor(np.stack(instance_preds)), type_preds
[docs] def generate_instance_nuclei_map(
self, instance_maps: torch.Tensor, type_preds: list[dict[np.int32, dict[str, Any]]]
) -> torch.Tensor:
"""Convert instance map (binary) to nuclei type instance map
Args:
instance_maps (torch.Tensor): Binary instance map, each instance has own integer. Shape: (B, H, W)
type_preds (List[dict]): List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details)
Returns:
torch.Tensor: Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W)
"""
batch_size, h, w = instance_maps.shape
instance_type_nuclei_maps = torch.zeros(
(batch_size, h, w, self.num_nuclei_classes)
)
for i in range(batch_size):
instance_type_nuclei_map = torch.zeros((h, w, self.num_nuclei_classes))
instance_map = instance_maps[i]
type_pred = type_preds[i]
for nuclei, spec in type_pred.items():
nuclei_type = spec["type"]
instance_type_nuclei_map[:, :, nuclei_type][
instance_map == nuclei
] = int(nuclei)
instance_type_nuclei_maps[i, :, :, :] = instance_type_nuclei_map
instance_type_nuclei_maps = instance_type_nuclei_maps.permute(0, 3, 1, 2)
return torch.Tensor(instance_type_nuclei_maps)
[docs] def freeze_encoder(self):
"""Freeze encoder to not train it"""
for layer_name, p in self.encoder.named_parameters():
if layer_name.split(".")[0] != "head": # do not freeze head
p.requires_grad = False
[docs] def unfreeze_encoder(self):
"""Unfreeze encoder to train the whole model"""
for p in self.encoder.parameters():
p.requires_grad = True
[docs]class CellViTSAM(CellViT):
"""CellViT with SAM backbone settings
Skip connections are shared between branches, but each network has a distinct encoder
Args:
model_path (Union[Path, str]): Path to pretrained SAM model
num_nuclei_classes (int): Number of nuclei classes (including background)
num_tissue_classes (int): Number of tissue classes
vit_structure (Literal["SAM-B", "SAM-L", "SAM-H"]): SAM model type
drop_rate (float, optional): Dropout in MLP. Defaults to 0.
regression_loss (bool, optional): Use regressive loss for predicting vector components.
Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
Raises:
NotImplementedError: Unknown SAM configuration
"""
[docs] def __init__(
self,
model_path: Union[Path, str] | None,
num_nuclei_classes: int,
num_tissue_classes: int,
vit_structure: Literal["SAM-B", "SAM-L", "SAM-H"],
drop_rate: float = 0,
regression_loss: bool = False,
):
if vit_structure.upper() == "SAM-B":
self.init_vit_b()
elif vit_structure.upper() == "SAM-L":
self.init_vit_l()
elif vit_structure.upper() == "SAM-H":
self.init_vit_h()
else:
raise NotImplementedError("Unknown ViT-SAM backbone structure")
self.input_channels = 3 # RGB
self.mlp_ratio = 4
self.qkv_bias = True
self.num_nuclei_classes = num_nuclei_classes
self.model_path = model_path
super().__init__(
num_nuclei_classes=num_nuclei_classes,
num_tissue_classes=num_tissue_classes,
embed_dim=self.embed_dim,
input_channels=self.input_channels,
depth=self.depth,
num_heads=self.num_heads,
extract_layers=self.extract_layers,
mlp_ratio=self.mlp_ratio,
qkv_bias=self.qkv_bias,
drop_rate=drop_rate,
regression_loss=regression_loss,
)
self.prompt_embed_dim = 256
self.encoder = ViTCellViTDeit(
extract_layers=self.extract_layers,
depth=self.depth,
embed_dim=self.embed_dim,
mlp_ratio=4,
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), # type: ignore[no-redef]
num_heads=self.num_heads,
qkv_bias=True,
use_rel_pos=True,
global_attn_indexes=tuple(self.encoder_global_attn_indexes),
window_size=14,
out_chans=self.prompt_embed_dim,
)
self.classifier_head = (
nn.Linear(self.prompt_embed_dim, num_tissue_classes)
if num_tissue_classes > 0
else nn.Identity()
)
[docs] def load_pretrained_encoder(self, model_path: Path | str):
"""Load pretrained SAM encoder from provided path
Args:
model_path (str): Path to SAM model
"""
state_dict = torch.load(str(model_path), map_location="cpu")
image_encoder = self.encoder
msg = image_encoder.load_state_dict(state_dict, strict=False)
print(f"Loading checkpoint: {msg}")
self.encoder = image_encoder
[docs] def forward(self, x: torch.Tensor, retrieve_tokens: bool = False):
"""Forward pass
Args:
x (torch.Tensor): Images in BCHW style
retrieve_tokens (bool, optional): If tokens of ViT should be returned as well. Defaults to False.
Returns:
dict: Output for all branches:
* tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
* nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
* hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
* nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
* [Optional, if retrieve tokens]: tokens
* [Optional, if regression loss]:
* regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
"""
assert (
x.shape[-2] % self.patch_size == 0
), "Img must have a shape of that is divisble by patch_soze (token_size)"
assert (
x.shape[-1] % self.patch_size == 0
), "Img must have a shape of that is divisble by patch_soze (token_size)"
out_dict: dict[str, torch.Tensor] = {}
classifier_logits, _, z = self.encoder(x)
out_dict["tissue_types"] = self.classifier_head(classifier_logits)
z0, z1, z2, z3, z4 = x, *z
# performing reshape for the convolutional layers and upsampling (restore spatial dimension)
z4 = z4.permute(0, 3, 1, 2)
z3 = z3.permute(0, 3, 1, 2)
z2 = z2.permute(0, 3, 1, 2)
z1 = z1.permute(0, 3, 1, 2)
if self.regression_loss:
nb_map = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_binary_map_decoder
)
out_dict["nuclei_binary_map"] = nb_map[:, :2, :, :]
out_dict["regression_map"] = nb_map[:, 2:, :, :]
else:
out_dict["nuclei_binary_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_binary_map_decoder
)
out_dict["hv_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.hv_map_decoder
)
out_dict["nuclei_type_map"] = self._forward_upsample(
z0, z1, z2, z3, z4, self.nuclei_type_maps_decoder
)
if retrieve_tokens:
out_dict["tokens"] = z4
return out_dict
[docs] def init_vit_b(self):
self.embed_dim = 768
self.depth = 12
self.num_heads = 12
self.encoder_global_attn_indexes = [2, 5, 8, 11]
self.extract_layers = [3, 6, 9, 12]
[docs] def init_vit_l(self):
self.embed_dim = 1024
self.depth = 24
self.num_heads = 16
self.encoder_global_attn_indexes = [5, 11, 17, 23]
self.extract_layers = [6, 12, 18, 24]
[docs] def init_vit_h(self):
self.embed_dim = 1280
self.depth = 32
self.num_heads = 16
self.encoder_global_attn_indexes = [7, 15, 23, 31]
self.extract_layers = [8, 16, 24, 32]