cellmil.models.segmentation.cellvit¶
Classes
|
CellViT Modell for cell segmentation. |
|
CellViT with SAM backbone settings |
- class cellmil.models.segmentation.cellvit.CellViT(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: list[int], mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)[source]¶
Bases:
ModuleCellViT Modell for cell segmentation. U-Net like network with vision transformer as backbone encoder
Skip connections are shared between branches, but each network has a distinct encoder
- The modell is having multiple branches:
tissue_types: Tissue prediction based on global class token
nuclei_binary_map: Binary nuclei prediction
hv_map: HV-prediction to separate isolated instances
nuclei_type_map: Nuclei instance-prediction
[Optional, if regression loss]:
regression_map: Regression map for binary prediction
- Parameters:
num_nuclei_classes (int) – Number of nuclei classes (including background)
num_tissue_classes (int) – Number of tissue classes
embed_dim (int) – Embedding dimension of backbone ViT
input_channels (int) – Number of input channels
depth (int) – Depth of the backbone ViT
num_heads (int) – Number of heads of the backbone ViT
extract_layers – (List[int]): List of Transformer Blocks whose outputs should be returned in addition to the tokens. First blocks starts with 1, and maximum is N=depth. Is used for skip connections. At least 4 skip connections needs to be returned.
mlp_ratio (float, optional) – MLP ratio for hidden MLP dimension of backbone ViT. Defaults to 4.
qkv_bias (bool, optional) – If bias should be used for query (q), key (k), and value (v) in backbone ViT. Defaults to True.
drop_rate (float, optional) – Dropout in MLP. Defaults to 0.
attn_drop_rate (float, optional) – Dropout for attention layer in backbone ViT. Defaults to 0.
drop_path_rate (float, optional) – Dropout for skip connection . Defaults to 0.
regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
- __init__(num_nuclei_classes: int, num_tissue_classes: int, embed_dim: int, input_channels: int, depth: int, num_heads: int, extract_layers: list[int], mlp_ratio: float = 4, qkv_bias: bool = True, drop_rate: float = 0, attn_drop_rate: float = 0, drop_path_rate: float = 0, regression_loss: bool = False)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(x: Tensor, retrieve_tokens: bool = False) dict[str, Any][source]¶
Forward pass
- Parameters:
x (torch.Tensor) – Images in BCHW style
retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.
- Returns:
- Output for all branches:
tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
[Optional, if retrieve tokens]: tokens
[Optional, if regression loss]:
regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
- Return type:
- _forward_upsample(z0: Tensor, z1: Tensor, z2: Tensor, z3: Tensor, z4: Tensor, branch_decoder: Sequential) Tensor[source]¶
Forward upsample branch
- Parameters:
z0 (torch.Tensor) – Highest skip
z1 (torch.Tensor) –
Skip
z2 (torch.Tensor) –
Skip
z3 (torch.Tensor) –
Skip
z4 (torch.Tensor) – Bottleneck
branch_decoder (nn.Sequential) – Branch decoder network
- Returns:
Branch Output
- Return type:
- create_upsampling_branch(num_classes: int) Sequential[source]¶
Create Upsampling branch
- Parameters:
num_classes (int) – Number of output classes
- Returns:
Upsampling path
- Return type:
nn.Sequential
- calculate_instance_map(predictions: dict[str, torch.Tensor], magnification: int | float = 40) Tuple[Tensor, list[dict[numpy.int32, dict[str, Any]]]][source]¶
Calculate Instance Map from network predictions (after Softmax output)
- Parameters:
predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (B, 2, H, W) * nuclei_type_map: Type prediction of nuclei. Shape: (B, self.num_nuclei_classes, H, W) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (B, 2, H, W)
magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.
- Returns:
torch.Tensor: Instance map. Each Instance has own integer. Shape: (B, H, W)
- List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.
For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”
- Return type:
Tuple[torch.Tensor, List[dict]]
- generate_instance_nuclei_map(instance_maps: Tensor, type_preds: list[dict[numpy.int32, dict[str, Any]]]) Tensor[source]¶
Convert instance map (binary) to nuclei type instance map
- Parameters:
instance_maps (torch.Tensor) – Binary instance map, each instance has own integer. Shape: (B, H, W)
type_preds (List[dict]) – List (len=B) of dictionary with instance type information (compare post_process_hovernet function for more details)
- Returns:
Nuclei type instance map. Shape: (B, self.num_nuclei_classes, H, W)
- Return type:
- class cellmil.models.segmentation.cellvit.CellViTSAM(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]¶
Bases:
CellViTCellViT with SAM backbone settings
Skip connections are shared between branches, but each network has a distinct encoder
- Parameters:
model_path (Union[Path, str]) – Path to pretrained SAM model
num_nuclei_classes (int) – Number of nuclei classes (including background)
num_tissue_classes (int) – Number of tissue classes
vit_structure (Literal["SAM-B", "SAM-L", "SAM-H"]) – SAM model type
drop_rate (float, optional) – Dropout in MLP. Defaults to 0.
regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
- Raises:
NotImplementedError – Unknown SAM configuration
- __init__(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- load_pretrained_encoder(model_path: pathlib.Path | str)[source]¶
Load pretrained SAM encoder from provided path
- Parameters:
model_path (str) – Path to SAM model
- forward(x: Tensor, retrieve_tokens: bool = False)[source]¶
Forward pass
- Parameters:
x (torch.Tensor) – Images in BCHW style
retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.
- Returns:
- Output for all branches:
tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
[Optional, if retrieve tokens]: tokens
[Optional, if regression loss]:
regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
- Return type: