cellmil.models.segmentation

class cellmil.models.segmentation.CellViTSAM(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]

Bases: CellViT

CellViT with SAM backbone settings

Skip connections are shared between branches, but each network has a distinct encoder

Parameters:
  • model_path (Union[Path, str]) – Path to pretrained SAM model

  • num_nuclei_classes (int) – Number of nuclei classes (including background)

  • num_tissue_classes (int) – Number of tissue classes

  • vit_structure (Literal["SAM-B", "SAM-L", "SAM-H"]) – SAM model type

  • drop_rate (float, optional) – Dropout in MLP. Defaults to 0.

  • regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.

Raises:

NotImplementedError – Unknown SAM configuration

__init__(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

load_pretrained_encoder(model_path: pathlib.Path | str)[source]

Load pretrained SAM encoder from provided path

Parameters:

model_path (str) – Path to SAM model

forward(x: Tensor, retrieve_tokens: bool = False)[source]

Forward pass

Parameters:
  • x (torch.Tensor) – Images in BCHW style

  • retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.

Returns:

Output for all branches:
  • tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)

  • nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)

  • hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)

  • nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)

  • [Optional, if retrieve tokens]: tokens

  • [Optional, if regression loss]:

  • regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)

Return type:

dict

init_vit_b()[source]
init_vit_l()[source]
init_vit_h()[source]
class cellmil.models.segmentation.HoVerNet(n_classes: int | None = 6)[source]

Bases: Module

Model for simultaneous segmentation and classification based on HoVer-Net. Can also be used for segmentation only, if class labels are not supplied. Each branch returns logits.

Parameters:

n_classes (int) – Number of classes for classification task. If None then the classification branch is not used.

References

Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019. Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. Medical Image Analysis, 58, p.101563.

__init__(n_classes: int | None = 6)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(inputs: Tensor) dict[str, torch.Tensor][source]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

calculate_instance_map(predictions: dict[str, torch.Tensor], magnification: int | float = 40) Tuple[Tensor, list[dict[numpy.int32, dict[str, Any]]]][source]

Calculate Instance Map from network predictions (after Softmax output)

Parameters:
  • predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (batch_size, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (batch_size, H, W, 6) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (batch_size, H, W, 2)

  • magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.

Returns:

  • torch.Tensor: Instance map. Each Instance has own integer. Shape: (batch_size, H, W)

  • List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.

    For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”

Return type:

Tuple[torch.Tensor, List[dict]]

class cellmil.models.segmentation.CellposeSAM(pretrained_model: str = 'cpsam', device: Optional[device] = None, use_bfloat16: bool = True)[source]

Bases: Module

Simplified PyTorch wrapper around Cellpose that operates directly on torch tensors with true batching for improved performance. Designed for patch-based processing.

__init__(pretrained_model: str = 'cpsam', device: Optional[device] = None, use_bfloat16: bool = True)[source]

Initialize CellposeSAMV2 wrapper.

Parameters:
  • pretrained_model – Path to pretrained cellpose model or model name

  • device – Device to run the model on

  • use_bfloat16 – Use bfloat16 precision for model weights

_init_cellpose_network()[source]

Initialize the Cellpose network directly without the wrapper.

forward(x: Tensor, normalize: bool = True, resample: bool = True, niter: int = 200, flow_threshold: float = 0.4, cellprob_threshold: float = 0.0, min_size: int = 15, max_size_fraction: float = 0.4) Dict[str, Tensor][source]

Forward pass through Cellpose model with true batched processing.

Parameters:
  • x – Input tensor of shape (B, C, H, W) where C=3 (RGB channels)

  • normalize – Whether to normalize input

  • resample – Whether to resize flows and cellprob back to original image size

  • niter – Number of iterations for mask refinement

  • flow_threshold – Threshold for flow field

  • cellprob_threshold – Threshold for cell probability map

  • min_size – Minimum size of masks to keep

  • max_size_fraction – Maximum size fraction of masks to keep

Returns:

  • masks: Instance segmentation masks (B, H, W)

  • flows: Flow fields (B, H, W, 2)

  • cellprob: Cell probability maps (B, H, W)

  • styles: Style vectors (B, style_dim)

Return type:

Dictionary containing

_compute_masks(flows: ndarray[Any, Any], cellprob: ndarray[Any, Any], flow_threshold: float = 0.4, cellprob_threshold: float = 0.0, min_size: int = 15, max_size_fraction: float = 0.4, niter: int = 200) ndarray[Any, Any][source]

Compute masks from flows and cell probabilities using cellpose dynamics.

_normalize_batch(x: Tensor) Tensor[source]

Normalize batch of images using torch operations.

_resize_flows_batch(flows: Tensor, Ly: int, Lx: int) Tensor[source]

Resize flow fields to target dimensions.

_resize_cellprob_batch(cellprob: Tensor, Ly: int, Lx: int) Tensor[source]

Resize cell probability maps to target dimensions.

eval()[source]

Set model to evaluation mode.

train(mode: bool = True)[source]

Set model to training mode (cellpose doesn’t support training).

to(device: Union[device, str])[source]

Move model to specified device.

cuda(device: Optional[Union[int, device]] = None)[source]

Move model to CUDA device.

cpu()[source]

Move model to CPU.

calculate_instance_map(predictions: Dict[str, Tensor], magnification: float = 40.0) Tuple[Tensor, list[Dict[int, Dict[str, Any]]]][source]

Calculate instance map and extract cell information from predictions.

Parameters:
  • predictions – Dictionary containing model outputs

  • magnification – Magnification level

Returns:

Tuple containing instance map and cell information

_extract_contour(mask: ndarray[Any, Any]) ndarray[Any, Any][source]

Extract contour from binary mask.

Modules

cellmil.models.segmentation.cellpose

cellmil.models.segmentation.cellposeV2

cellmil.models.segmentation.cellvit

cellmil.models.segmentation.hovernet

Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine License: GNU GPL 2.0