cellmil.models.segmentation.cellvit

Classes

CellViT(num_nuclei_classes, ...[, ...])

CellViT Modell for cell segmentation.

CellViTSAM(model_path, num_nuclei_classes, ...)

CellViT with SAM backbone settings

class cellmil.models.segmentation.cellvit.CellViT(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: list[int], mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)[source]

Bases: Module

CellViT Modell for cell segmentation. U-Net like network with vision transformer as backbone encoder

Skip connections are shared between branches, but each network has a distinct encoder

The modell is having multiple branches:
  • tissue_types: Tissue prediction based on global class token

  • nuclei_binary_map: Binary nuclei prediction

  • hv_map: HV-prediction to separate isolated instances

  • nuclei_type_map: Nuclei instance-prediction

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction

Parameters:
  • num_nuclei_classes (int) – Number of nuclei classes (including background)

  • num_tissue_classes (int) – Number of tissue classes

  • embed_dim (int) – Embedding dimension of backbone ViT

  • input_channels (int) – Number of input channels

  • depth (int) – Depth of the backbone ViT

  • num_heads (int) – Number of heads of the backbone ViT

  • extract_layers – (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth. Is used for skip connections. At least 4 skip connections needs to be returned.

  • mlp_ratio (float, optional) – MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4.

  • qkv_bias (bool, optional) – If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True.

  • drop_rate (float, optional) – Dropout in MLP. Defaults to 0.

  • attn_drop_rate (float, optional) – Dropout for attention layer in backbone ViT. Defaults to 0.

  • drop_path_rate (float, optional) – Dropout for skip connection . Defaults to 0.

  • regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.

__init__(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: list[int], mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x: Tensor, retrieve_tokens: bool = False) dict[str, Any][source]

Forward pass

Parameters:
  • x (torch.Tensor) – Images in BCHW style

  • retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.

Returns:

Output for all branches:
  • tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)

  • nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)

  • hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)

  • nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)

  • [Optional, if retrieve tokens]: tokens

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)

Return type:

dict

_forward_upsample(z0: Tensor, z1: Tensor, z2: Tensor, z3: Tensor, z4: Tensor, branch_decoder: Sequential) Tensor[source]

Forward upsample branch

Parameters:
Returns:

Branch Output

Return type:

torch.Tensor

create_upsampling_branch(num_classes: int) Sequential[source]

Create Upsampling branch

Parameters:

num_classes (int) – Number of output classes

Returns:

Upsampling path

Return type:

nn.Sequential

calculate_instance_map(predictions: dict[str, torch.Tensor], magnification: int | float = 40) Tuple[Tensor, list[dict[numpy.int32, dict[str, Any]]]][source]

Calculate Instance Map from network predictions (after Softmax output)

Parameters:
  • predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W) * nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W)

  • magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.

Returns:

  • torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)

  • List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.

    For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”

Return type:

Tuple[torch.Tensor, List[dict]]

generate_instance_nuclei_map(instance_maps: Tensor, type_preds: list[dict[numpy.int32, dict[str, Any]]]) Tensor[source]

Convert instance map (binary) to nuclei type instance map

Parameters:
  • instance_maps (torch.Tensor) – Binary instance map, each instance has own integer. Shape: (B, H, W)

  • type_preds (List[dict]) – List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details)

Returns:

Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W)

Return type:

torch.Tensor

freeze_encoder()[source]

Freeze encoder to not train it

unfreeze_encoder()[source]

Unfreeze encoder to train the whole model

class cellmil.models.segmentation.cellvit.CellViTSAM(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]

Bases: CellViT

CellViT with SAM backbone settings

Skip connections are shared between branches, but each network has a distinct encoder

Parameters:
  • model_path (Union[Path, str]) – Path to pretrained SAM model

  • num_nuclei_classes (int) – Number of nuclei classes (including background)

  • num_tissue_classes (int) – Number of tissue classes

  • vit_structure (Literal["SAM-B", "SAM-L", "SAM-H"]) – SAM model type

  • drop_rate (float, optional) – Dropout in MLP. Defaults to 0.

  • regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.

Raises:

NotImplementedError – Unknown SAM configuration

__init__(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

load_pretrained_encoder(model_path: pathlib.Path | str)[source]

Load pretrained SAM encoder from provided path

Parameters:

model_path (str) – Path to SAM model

forward(x: Tensor, retrieve_tokens: bool = False)[source]

Forward pass

Parameters:
  • x (torch.Tensor) – Images in BCHW style

  • retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.

Returns:

Output for all branches:
  • tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)

  • nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)

  • hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)

  • nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)

  • [Optional, if retrieve tokens]: tokens

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)

Return type:

dict

init_vit_b()[source]
init_vit_l()[source]
init_vit_h()[source]