cellmil.models.segmentation¶
- class cellmil.models.segmentation.CellViTSAM(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]¶
Bases:
CellViTCellViT with SAM backbone settings
Skip connections are shared between branches, but each network has a distinct encoder
- Parameters:
model_path (Union[Path, str]) – Path to pretrained SAM model
num_nuclei_classes (int) – Number of nuclei classes (including background)
num_tissue_classes (int) – Number of tissue classes
vit_structure (Literal["SAM-B", "SAM-L", "SAM-H"]) – SAM model type
drop_rate (float, optional) – Dropout in MLP. Defaults to 0.
regression_loss (bool, optional) – Use regressive loss for predicting vector components. Adds two additional channels to the binary decoder, but returns it as own entry in dict. Defaults to False.
- Raises:
NotImplementedError – Unknown SAM configuration
- __init__(model_path: Optional[Union[Path, str]], num_nuclei_classes: int, num_tissue_classes: int, vit_structure: Literal['SAM-B', 'SAM-L', 'SAM-H'], drop_rate: float = 0, regression_loss: bool = False)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- load_pretrained_encoder(model_path: pathlib.Path | str)[source]¶
Load pretrained SAM encoder from provided path
- Parameters:
model_path (str) – Path to SAM model
- forward(x: Tensor, retrieve_tokens: bool = False)[source]¶
Forward pass
- Parameters:
x (torch.Tensor) – Images in BCHW style
retrieve_tokens (bool, optional) – If tokens of ViT should be returned as well. Defaults to False.
- Returns:
- Output for all branches:
tissue_types: Raw tissue type prediction. Shape: (B, num_tissue_classes)
nuclei_binary_map: Raw binary cell segmentation predictions. Shape: (B, 2, H, W)
hv_map: Binary HV Map predictions. Shape: (B, 2, H, W)
nuclei_type_map: Raw binary nuclei type preditcions. Shape: (B, num_nuclei_classes, H, W)
[Optional, if retrieve tokens]: tokens
[Optional, if regression loss]:
regression_map: Regression map for binary prediction. Shape: (B, 2, H, W)
- Return type:
- class cellmil.models.segmentation.HoVerNet(n_classes: int | None = 6)[source]¶
Bases:
ModuleModel for simultaneous segmentation and classification based on HoVer-Net. Can also be used for segmentation only, if class labels are not supplied. Each branch returns logits.
- Parameters:
n_classes (int) – Number of classes for classification task. If
Nonethen the classification branch is not used.
References
Graham, S., Vu, Q.D., Raza, S.E.A., Azam, A., Tsang, Y.W., Kwak, J.T. and Rajpoot, N., 2019. Hover-Net: Simultaneous segmentation and classification of nuclei in multi-tissue histology images. Medical Image Analysis, 58, p.101563.
- __init__(n_classes: int | None = 6)[source]¶
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(inputs: Tensor) dict[str, torch.Tensor][source]¶
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- calculate_instance_map(predictions: dict[str, torch.Tensor], magnification: int | float = 40) Tuple[Tensor, list[dict[numpy.int32, dict[str, Any]]]][source]¶
Calculate Instance Map from network predictions (after Softmax output)
- Parameters:
predictions (dict) – Dictionary with the following required keys: * nuclei_binary_map: Binary Nucleus Predictions. Shape: (batch_size, H, W, 2) * nuclei_type_map: Type prediction of nuclei. Shape: (batch_size, H, W, 6) * hv_map: Horizontal-Vertical nuclei mapping. Shape: (batch_size, H, W, 2)
magnification (Literal[20, 40], optional) – Which magnification the data has. Defaults to 40.
- Returns:
torch.Tensor: Instance map. Each Instance has own integer. Shape: (batch_size, H, W)
- List of dictionaries. Each List entry is one image. Each dict contains another dict for each detected nucleus.
For each nucleus, the following information are returned: “bbox”, “centroid”, “contour”, “type_prob”, “type”
- Return type:
Tuple[torch.Tensor, List[dict]]
- class cellmil.models.segmentation.CellposeSAM(pretrained_model: str = 'cpsam', device: Optional[device] = None, use_bfloat16: bool = True)[source]¶
Bases:
ModuleSimplified PyTorch wrapper around Cellpose that operates directly on torch tensors with true batching for improved performance. Designed for patch-based processing.
- __init__(pretrained_model: str = 'cpsam', device: Optional[device] = None, use_bfloat16: bool = True)[source]¶
Initialize CellposeSAMV2 wrapper.
- Parameters:
pretrained_model – Path to pretrained cellpose model or model name
device – Device to run the model on
use_bfloat16 – Use bfloat16 precision for model weights
- forward(x: Tensor, normalize: bool = True, resample: bool = True, niter: int = 200, flow_threshold: float = 0.4, cellprob_threshold: float = 0.0, min_size: int = 15, max_size_fraction: float = 0.4) Dict[str, Tensor][source]¶
Forward pass through Cellpose model with true batched processing.
- Parameters:
x – Input tensor of shape (B, C, H, W) where C=3 (RGB channels)
normalize – Whether to normalize input
resample – Whether to resize flows and cellprob back to original image size
niter – Number of iterations for mask refinement
flow_threshold – Threshold for flow field
cellprob_threshold – Threshold for cell probability map
min_size – Minimum size of masks to keep
max_size_fraction – Maximum size fraction of masks to keep
- Returns:
masks: Instance segmentation masks (B, H, W)
flows: Flow fields (B, H, W, 2)
cellprob: Cell probability maps (B, H, W)
styles: Style vectors (B, style_dim)
- Return type:
Dictionary containing
- _compute_masks(flows: ndarray[Any, Any], cellprob: ndarray[Any, Any], flow_threshold: float = 0.4, cellprob_threshold: float = 0.0, min_size: int = 15, max_size_fraction: float = 0.4, niter: int = 200) ndarray[Any, Any][source]¶
Compute masks from flows and cell probabilities using cellpose dynamics.
- _resize_flows_batch(flows: Tensor, Ly: int, Lx: int) Tensor[source]¶
Resize flow fields to target dimensions.
- _resize_cellprob_batch(cellprob: Tensor, Ly: int, Lx: int) Tensor[source]¶
Resize cell probability maps to target dimensions.
- calculate_instance_map(predictions: Dict[str, Tensor], magnification: float = 40.0) Tuple[Tensor, list[Dict[int, Dict[str, Any]]]][source]¶
Calculate instance map and extract cell information from predictions.
- Parameters:
predictions – Dictionary containing model outputs
magnification – Magnification level
- Returns:
Tuple containing instance map and cell information
Modules
Copyright 2021, Dana-Farber Cancer Institute and Weill Cornell Medicine License: GNU GPL 2.0 |