Source code for cellmil.models.mil.cellconv

from typing_extensions import Self
import torch
import torchmetrics
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import LRScheduler
from .utils import LitGeneral, AEM
from typing import IO, Any, Callable
from pathlib import Path
from cellmil.utils.train.losses import NegativeLogLikelihoodSurvLoss
from cellmil.utils.train.metrics import ConcordanceIndex, BrierScore


def _groupnorm_groups(channels: int) -> int:
    for candidate in (32, 16, 8, 4, 2):
        if channels % candidate == 0:
            return candidate
    return 1


class _ResidualConvBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        kernel_size: int,
        dilation: int,
        dropout: float,
    ) -> None:
        super().__init__()  # type: ignore[misc]
        padding = ((kernel_size - 1) // 2) * dilation
        groups = _groupnorm_groups(channels)
        self.depthwise = nn.Conv1d(
            channels,
            channels,
            kernel_size=kernel_size,
            padding=padding,
            dilation=dilation,
            groups=channels,
            bias=False,
        )
        self.pointwise = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
        self.norm1 = nn.GroupNorm(groups, channels)
        self.norm2 = nn.GroupNorm(groups, channels)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        out = self.depthwise(x)
        out = self.norm1(out)
        out = self.act(out)
        out = self.dropout(out)
        out = self.pointwise(out)
        out = self.norm2(out)
        out = self.dropout(out)
        out = out + residual
        return self.act(out)


[docs]class CellConv(nn.Module):
[docs] def __init__( self, embed_dim: int, n_classes: int = 2, convolution_depth: int = 3, size_arg: list[int] = [512, 128], attention_branches: int = 1, temperature: float = 1.0, dropout: float = 0.0, kernel_size: int = 3, ) -> None: super().__init__() # type: ignore if len(size_arg) != 2: raise ValueError("size_arg must contain [latent_dim, attention_dim]") self.convolution_depth = convolution_depth self.latent_dim = size_arg[0] self.attention_dim = size_arg[1] self.embed_dim = embed_dim self.temperature = temperature self.n_classes = n_classes self.dropout = dropout self.kernel_size = kernel_size self.ATTENTION_BRANCHES = attention_branches self.input_proj = nn.Sequential( nn.LazyLinear(self.embed_dim), nn.LayerNorm(self.embed_dim), nn.Dropout(self.dropout), ) dilations: list[int] = [] dilation = 1 for _ in range(max(1, self.convolution_depth)): dilations.append(dilation) if dilation < 16: dilation *= 2 self.conv_backbone = nn.ModuleList( [ _ResidualConvBlock( channels=self.embed_dim, kernel_size=self.kernel_size, dilation=d, dropout=self.dropout, ) for d in dilations ] ) self.post_conv_norm = nn.LayerNorm(self.embed_dim) self.feature_head = nn.Sequential( nn.Linear(self.embed_dim, self.latent_dim), nn.GELU(), nn.Dropout(self.dropout), ) self.attention = nn.Sequential( nn.Linear(self.latent_dim, self.attention_dim), nn.Tanh(), nn.Linear(self.attention_dim, self.ATTENTION_BRANCHES), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Dropout(self.dropout), nn.Linear(self.latent_dim * self.ATTENTION_BRANCHES, self.n_classes), )
[docs] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: if x.dim() != 2: raise ValueError("Input tensor must be 2D (num_cells x feature_dim)") h = self.input_proj(x) conv_input = h.transpose(0, 1).unsqueeze(0) for block in self.conv_backbone: conv_input = block(conv_input) h = conv_input.squeeze(0).transpose(0, 1).contiguous() h = self.post_conv_norm(h) features = self.feature_head(h) attention_scores = self.attention(features) attention_scores = attention_scores.transpose(0, 1) attention_weights = F.softmax(attention_scores / self.temperature, dim=1) bag_representation = torch.mm(attention_weights, features) logits = self.classifier(bag_representation.unsqueeze(0)) y_prob = F.softmax(logits, dim=1) y_hat = torch.topk(y_prob, 1, dim=1)[1] output_dict = { "y_prob": y_prob, "y_hat": y_hat, "attention": attention_weights, "bag_representation": bag_representation, } return logits, output_dict
[docs]class LitCellConv(LitGeneral):
[docs] def __init__( self, model: nn.Module, optimizer: torch.optim.Optimizer, loss: nn.Module = nn.CrossEntropyLoss(), lr_scheduler: LRScheduler | None = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, ) -> None: super().__init__(model, optimizer, loss, lr_scheduler) self.n_classes = getattr(model, "n_classes", 2) self.subsampling = subsampling self.use_aem = use_aem if self.use_aem: self.aem = AEM( weight_initial=aem_weight_initial, weight_final=aem_weight_final, annealing_epochs=aem_annealing_epochs, ) model_config: dict[str, Any] = { "model_class": model.__class__.__name__, "embed_dim": getattr(model, "embed_dim", None), "latent_dim": getattr(model, "latent_dim", None), "attention_dim": getattr(model, "attention_dim", None), "attention_branches": getattr(model, "ATTENTION_BRANCHES", 1), "convolution_depth": getattr(model, "convolution_depth", None), "kernel_size": getattr(model, "kernel_size", None), "temperature": getattr(model, "temperature", 1.0), "dropout": getattr(model, "dropout", 0.0), } self.save_hyperparameters( { **model_config, "optimizer_class": optimizer.__class__.__name__, "optimizer_lr": optimizer.param_groups[0]["lr"], "loss": loss, "lr_scheduler_class": lr_scheduler.__class__.__name__ if lr_scheduler else None, "subsampling": subsampling, "use_aem": use_aem, "aem_weight_initial": aem_weight_initial, "aem_weight_final": aem_weight_final, "aem_annealing_epochs": aem_annealing_epochs, } )
[docs] @classmethod def load_from_checkpoint( cls, checkpoint_path: str | Path | IO[bytes], map_location: torch.device | str | int | Callable[[torch.UntypedStorage, str], torch.UntypedStorage | None] | dict[torch.device | str | int, torch.device | str | int] | None = None, hparams_file: str | Path | None = None, strict: bool | None = None, **kwargs: Any, ) -> Self: checkpoint = torch.load( checkpoint_path, map_location=map_location, # type: ignore[arg-type] weights_only=False, ) hparams = checkpoint.get("hyper_parameters", {}) model = CellConv( embed_dim=hparams.get("embed_dim", 256), n_classes=hparams.get("n_classes", 2), convolution_depth=hparams.get("convolution_depth", 3), size_arg=[ hparams.get("latent_dim", 512), hparams.get("attention_dim", 128), ], attention_branches=hparams.get("attention_branches", 1), temperature=hparams.get("temperature", 1.0), dropout=hparams.get("dropout", 0.1), kernel_size=hparams.get("kernel_size", 3), ) optimizer_cls = getattr(torch.optim, hparams.get("optimizer_class", "Adam")) optimizer = optimizer_cls( model.parameters(), lr=hparams.get("optimizer_lr", 1e-3) ) loss_param = hparams.get("loss", "CrossEntropyLoss") if isinstance(loss_param, str): loss_fn = getattr(nn, loss_param)() else: loss_fn = loss_param lit_model = cls( model=model, optimizer=optimizer, loss=loss_fn, lr_scheduler=None, # type: ignore[arg-type] subsampling=hparams.get("subsampling", 1.0), use_aem=hparams.get("use_aem", False), aem_weight_initial=hparams.get("aem_weight_initial", 0.0001), aem_weight_final=hparams.get("aem_weight_final", 0.0), aem_annealing_epochs=hparams.get("aem_annealing_epochs", 25), ) lit_model.load_state_dict( checkpoint["state_dict"], strict=strict if strict is not None else True ) return lit_model
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: logits, _ = self.model(x) return logits
def _shared_step( self, batch: tuple[torch.Tensor, torch.Tensor], stage: str, log: bool = True ): x, y = batch assert x.size(0) == 1, "Batch size must be 1 for MIL" x = x.squeeze(0) logits, output_dict = self.model(x) loss = self.loss(logits, y) current_epoch = self.current_epoch if hasattr(self, "current_epoch") else 0 aem_term: torch.Tensor | None = None if self.use_aem and stage == "train": attention_weights = output_dict["attention"] aem_term = self.aem.get_aem(current_epoch, attention_weights) loss = loss + aem_term if log: self.log( f"{stage}/total_loss", loss, prog_bar=(stage != "train"), on_step=(stage == "train"), on_epoch=True, ) if current_epoch == 0 and stage in ["train", "val"]: self.log( f"{stage}/num_instances", batch[0].squeeze(0).shape[0], prog_bar=False, on_step=True, on_epoch=False, ) if self.use_aem and stage == "train" and aem_term is not None: self.log( f"{stage}/aem", aem_term, prog_bar=True, on_step=False, on_epoch=True, ) return loss, logits, y
[docs] def get_attention_weights(self, x: torch.Tensor) -> torch.Tensor: self.model.eval() if x.dim() != 2: raise ValueError("Input tensor must be of shape [n_instances, feat_dim]") _, output_dict = self.model(x) return output_dict["attention"]
[docs] def transfer_batch_to_device( self, batch: tuple[torch.Tensor, torch.Tensor], device: torch.device, dataloader_idx: int, ) -> tuple[torch.Tensor, torch.Tensor]: x, y = batch if self.training and self.subsampling != 1.0: if x.size(0) != 1: raise ValueError("Batch size must be 1 for MIL") bag = x.squeeze(0) if 0 < self.subsampling < 1.0: num_samples = max(int(self.subsampling * bag.shape[0]), 1) elif self.subsampling >= 1.0: num_samples = min(int(self.subsampling), bag.shape[0]) else: raise ValueError(f"Invalid subsampling value: {self.subsampling}") if num_samples < bag.shape[0]: sampled_indices = torch.randperm(bag.shape[0])[:num_samples] bag = bag.index_select(0, sampled_indices) x = bag.unsqueeze(0) return super().transfer_batch_to_device((x, y), device, dataloader_idx)
[docs]class LitSurvCellConv(LitCellConv):
[docs] def __init__( self, model: CellConv, optimizer: torch.optim.Optimizer, loss: nn.Module = NegativeLogLikelihoodSurvLoss(), lr_scheduler: LRScheduler | None = None, subsampling: float = 1.0, use_aem: bool = False, aem_weight_initial: float = 0.0001, aem_weight_final: float = 0.0, aem_annealing_epochs: int = 25, ): super().__init__( model, optimizer, loss, lr_scheduler, subsampling, use_aem, aem_weight_initial, aem_weight_final, aem_annealing_epochs, ) # For logistic hazard, n_classes should equal num_bins # Store this for converting back to continuous risk scores self.num_bins = model.n_classes # Setup survival-specific metrics self._setup_metrics()
[docs] def _setup_metrics(self): """Setup C-index and Brier score metrics for survival analysis.""" metrics = torchmetrics.MetricCollection( { "c_index": ConcordanceIndex(), "brier_score": BrierScore(), } ) self.train_metrics = metrics.clone(prefix="train/") self.val_metrics = metrics.clone(prefix="val/") self.test_metrics = metrics.clone(prefix="test/")
[docs] def predict_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int): """Prediction step returns logits for discrete-time hazard intervals.""" x, _ = batch # Ensure MIL batch size is 1 assert x.size(0) == 1, "Batch size must be 1 for MIL" x = x.squeeze(0) # [n_instances, feat_dim] logits, _ = self.model(x) return logits # Return logits, not hazards