Source code for cellmil.features.extractor.embedding
import torch
import timm
import torch.nn as nn
import torchvision.transforms as transforms # type: ignore
from timm.layers.mlp import SwiGLUPacked
from typing import cast
from torchvision.models import resnet50 # type: ignore
from cellmil.interfaces.FeatureExtractorConfig import ExtractorType
from cellmil.utils import logger
from typing import Any
[docs]class EmbeddingExtractor:
[docs] def __init__(self, extractor_name: ExtractorType):
self.extractor_name = extractor_name
if self.extractor_name == ExtractorType.resnet50:
self.extractor = ResNet50Extractor()
elif self.extractor_name == ExtractorType.gigapath:
self.extractor = GigapathExtractor()
elif self.extractor_name == ExtractorType.uni:
self.extractor = UNIExtractor()
else:
raise ValueError(f"Unknown extractor type: {self.extractor_name}")
[docs] def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
try:
features = self.extractor.extract_features(batch)
return features
except Exception as e:
raise RuntimeError(f"Failed to extract features from {batch}: {e}")
[docs]class ResNet50Extractor:
"""ResNet50 feature extractor with adaptive mean pooling after 3rd residual block."""
[docs] def __init__(self):
# Check if GPU is available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load pretrained ResNet50
full_model = resnet50(weights="DEFAULT")
# Extract layers up to the 3rd residual block
self.features = nn.Sequential(
full_model.conv1, # 7x7 conv, 64 channels
full_model.bn1,
full_model.relu,
full_model.maxpool, # 3x3 max pool
full_model.layer1, # 1st residual block (64 -> 256 channels)
full_model.layer2, # 2nd residual block (256 -> 512 channels)
full_model.layer3, # 3rd residual block (512 -> 1024 channels)
)
# Add adaptive mean pooling to get 1024-dimensional features
# After layer3, we have 1024 channels, so adaptive pooling to 1x1 gives us 1024 features
self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))
# Move model to GPU and set to eval mode
self.features = self.features.to(self.device)
self.adaptive_pool = self.adaptive_pool.to(self.device)
self.features.eval()
# Transform for 256x256 patches
self.transform = transforms.Compose(
[
transforms.Resize((256, 256)),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
[docs] def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
try:
# Normalize input to [0, 1] range if needed
if batch.max() > 1.0:
batch = batch.float() / 255.0
# Apply transforms (resize and normalize)
_batch = cast(torch.Tensor, self.transform(batch))
with torch.no_grad():
# Move input tensor to GPU
_batch = _batch.to(self.device)
# Extract features up to 3rd residual block
features = self.features(_batch)
# Apply adaptive mean pooling to get 1024-dimensional features
pooled_features = self.adaptive_pool(features)
# Flatten to get 1024-dimensional vector
feature_vector = pooled_features.view(
pooled_features.size(0), -1
) # Ensure batch dimension is preserved
# Move back to CPU for further processing
feature_vector = feature_vector.cpu()
return feature_vector
except Exception:
# Return zero vector with 1024 dimensions in case of error
logger.error("Error in ResNet50 feature extraction, returning zero vector.")
batch_size = batch.shape[0] if hasattr(batch, "shape") else 1
return torch.zeros(batch_size, 1024)
[docs]class GigapathExtractor:
"""Gigapath feature extractor using a pretrained EfficientNet model from timm."""
[docs] def __init__(self):
# Check if GPU is available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load pretrained EfficientNet model from timm
self.model = timm.create_model(
"hf_hub:prov-gigapath/prov-gigapath", pretrained=True
)
# Move model to GPU and set to eval mode
self.model = self.model.to(self.device)
self.model.eval()
# Transform for 224x224 patches
self.transform = transforms.Compose(
[
transforms.Resize(
256, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.CenterCrop(224),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
[docs] def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
try:
# Normalize input to [0, 1] range if needed
if batch.max() > 1.0:
batch = batch.float() / 255.0
# Apply transforms (resize and normalize)
_batch = cast(torch.Tensor, self.transform(batch))
with torch.no_grad():
# Move input tensor to GPU
_batch = _batch.to(self.device)
# Extract features using EfficientNet
features = self.model(_batch)
# Move back to CPU for further processing
features = features.cpu()
return features
except Exception as e:
logger.error(f"Error in Gigapath feature extraction: {e}")
raise RuntimeError("Error in Gigapath feature extraction.")
[docs]class UNIExtractor:
"""UNI feature extractor using a pretrained model from timm.
Note: Requires huggingface_hub login with access token before first use.
Run: from huggingface_hub import login; login()
"""
[docs] def __init__(self):
# Check if GPU is available
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Model configuration for UNI2-h
timm_kwargs: dict[str, Any] = {
"img_size": 224,
"patch_size": 14,
"depth": 24,
"num_heads": 24,
"init_values": 1e-5,
"embed_dim": 1536,
"mlp_ratio": 2.66667 * 2,
"num_classes": 0,
"no_embed_class": True,
"mlp_layer": SwiGLUPacked,
"act_layer": torch.nn.SiLU,
"reg_tokens": 8,
"dynamic_img_size": True,
}
# Load pretrained UNI2-h model from timm
self.model = timm.create_model(
"hf-hub:MahmoodLab/UNI2-h", pretrained=True, **timm_kwargs
)
# Move model to GPU and set to eval mode
self.model = self.model.to(self.device)
self.model.eval()
# Transform for 224x224 patches with ImageNet normalization
self.transform = transforms.Compose(
[
transforms.Resize(224),
transforms.Normalize(
mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
),
]
)
[docs] def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
try:
# Normalize input to [0, 1] range if needed
if batch.max() > 1.0:
batch = batch.float() / 255.0
# Apply transforms (resize and normalize)
_batch = cast(torch.Tensor, self.transform(batch))
with torch.no_grad():
# Move input tensor to GPU
_batch = _batch.to(self.device)
# Extract features using UNI model
features = self.model(_batch)
# Move back to CPU for further processing
features = features.cpu()
return features
except Exception as e:
logger.error(f"Error in UNI feature extraction: {e}")
raise RuntimeError("Error in UNI feature extraction.")
# class TITANExtractor:
# """TITAN feature extractor using CONCH v1.5 for patch-level embeddings.
# Note: Requires huggingface_hub login with access token before first use.
# TITAN uses CONCH v1.5 for patch-level feature extraction at 512x512 pixels.
# """
# def __init__(self):
# # Check if GPU is available
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # Load TITAN model to get CONCH v1.5
# logger.info("Loading TITAN model and CONCH v1.5 encoder...")
# try:
# titan = AutoModel.from_pretrained(
# 'MahmoodLab/TITAN',
# trust_remote_code=True
# )
# # Get CONCH v1.5 model and preprocessing transform
# self.model, self.transform = titan.return_conch()
# except Exception as e:
# raise RuntimeError(
# f"Failed to load TITAN/CONCH v1.5. Make sure you have access and are logged in to HuggingFace: {e}"
# )
# # Move model to GPU and set to eval mode
# self.model = self.model.to(self.device)
# self.model.eval()
# logger.info("TITAN/CONCH v1.5 model loaded successfully")
# def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
# try:
# # Normalize input to [0, 1] range if needed
# if batch.max() > 1.0:
# batch = batch.float() / 255.0
# # Apply CONCH v1.5 preprocessing transform
# _batch = cast(torch.Tensor, self.transform(batch))
# with torch.no_grad():
# # Move input tensor to GPU
# _batch = _batch.to(self.device)
# # Extract patch-level features using CONCH v1.5
# # Use encode_image without projection for MIL tasks
# features = self.model.encode_image(_batch, proj_contrast=False, normalize=False)
# # Move back to CPU for further processing
# features = features.cpu()
# return features
# except Exception as e:
# logger.error(f"Error in TITAN/CONCH v1.5 feature extraction: {e}")
# raise RuntimeError("Error in TITAN/CONCH v1.5 feature extraction.")
# class VirchowExtractor:
# """Virchow feature extractor using a custom model from timm."""
# def __init__(self):
# # Check if GPU is available
# self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# # Load custom model from timm
# self.model = timm.create_model(
# "hf-hub:paige-ai/Virchow2",
# pretrained=True,
# mlp_layer=SwiGLUPacked,
# act_layer=torch.nn.SiLU,
# )
# # Move model to GPU and set to eval mode
# self.model = self.model.to(self.device)
# self.model.eval()
# self.transform = cast(
# type[transforms.Compose],
# create_transform(
# **resolve_data_config( # type: ignore
# self.model.pretrained_cfg, model=self.model
# )
# ),
# )
# def extract_features(self, batch: torch.Tensor) -> torch.Tensor:
# try:
# # Normalize input to [0, 1] range if needed
# if batch.max() > 1.0:
# batch = batch.float() / 255.0
# # Apply transforms (resize and normalize)
# _batch = cast(torch.Tensor, self.transform(batch))
# with torch.no_grad():
# # Move input tensor to GPU
# _batch = _batch.to(self.device)
# # Extract features using the custom model
# output = self.model(_batch)
# class_token = output[:, 0] # size: 1 x 1280
# patch_tokens = output[
# :, 5:
# ] # size: 1 x 256 x 1280, tokens 1-4 are register tokens so we ignore those
# # concatenate class token and average pool of patch tokens
# embedding = torch.cat(
# [class_token, patch_tokens.mean(1)], dim=-1
# ) # size: 1 x 2560
# features = embedding.cpu()
# return features
# except Exception as e:
# logger.error(f"Error in Virchow feature extraction: {e}")
# raise RuntimeError("Error in Virchow feature extraction.")