# -*- coding: utf-8 -*-
# Utils for CLAM Model Training
#
# References:
# Data-efficient and weakly supervised computational pathology on whole-slide images
# Lu, Ming Y et al., Nature Biomedical Engineering, 2021
# DOI: https://doi.org/10.1038/s41551-021-00707-9
import math
import numpy as np
import torch
from torch import nn
import lightning as Pl
import torchmetrics
from typing import cast, Any
from torch.optim.lr_scheduler import LRScheduler
[docs]class Accuracy_Logger(object):
"""Accuracy logger"""
[docs] def __init__(self, n_classes: int):
super().__init__()
self.n_classes = n_classes
self.initialize()
[docs] def initialize(self):
self.data = [{"count": 0, "correct": 0} for _ in range(self.n_classes)]
[docs] def log(self, Y_hat: torch.Tensor, Y: torch.Tensor):
_Y_hat = int(Y_hat)
_Y = int(Y)
self.data[_Y]["count"] += 1
self.data[_Y]["correct"] += (_Y_hat == _Y)
[docs] def log_batch(
self,
Y_hat: torch.Tensor,
Y: torch.Tensor
):
_Y_hat = np.array(Y_hat).astype(int)
_Y = np.array(Y).astype(int)
for label_class in np.unique(_Y):
cls_mask = _Y == label_class
self.data[label_class]["count"] += cls_mask.sum()
self.data[label_class]["correct"] += (_Y_hat[cls_mask] == _Y[cls_mask]).sum()
[docs] def get_summary(
self,
c: int
):
count = self.data[c]["count"]
correct = self.data[c]["correct"]
if count == 0:
acc = None
else:
acc = float(correct) / count
return acc, correct, count
[docs]class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
[docs] def __init__(
self,
patience: int = 20,
stop_epoch: int = 50,
verbose: bool = False
):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 20
stop_epoch (int): Earliest epoch possible for stopping
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
"""
self.patience = patience
self.stop_epoch = stop_epoch
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
def __call__(
self,
epoch: int,
val_loss: float,
model: nn.Module,
ckpt_name: str = 'checkpoint.pt'
):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, ckpt_name)
elif score < self.best_score:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience and epoch > self.stop_epoch:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, ckpt_name)
self.counter = 0
[docs] def save_checkpoint(
self,
val_loss: float,
model: nn.Module,
ckpt_name: str
):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), ckpt_name)
self.val_loss_min = val_loss
[docs]class LitGeneral(Pl.LightningModule):
[docs] def __init__(
self,
model: nn.Module,
optimizer: torch.optim.Optimizer,
loss: nn.Module = nn.CrossEntropyLoss(),
lr_scheduler: LRScheduler | None = None,
n_classes: int = 2
):
super().__init__()
self.model = model
self.optimizer = optimizer
self.loss = loss
self.n_classes = n_classes
self.lr_scheduler = lr_scheduler
self.save_hyperparameters({
"model": model.__class__.__name__,
"optimizer": optimizer,
"loss": loss
})
self._setup_metrics()
def _setup_metrics(self):
metrics = torchmetrics.MetricCollection({
"accuracy": torchmetrics.Accuracy(
task="multiclass", num_classes=self.n_classes, average="none"
),
"f1": torchmetrics.F1Score(
task="multiclass", num_classes=self.n_classes, average="macro"
),
"precision": torchmetrics.Precision(
task="multiclass", num_classes=self.n_classes, average="macro"
),
"recall": torchmetrics.Recall(
task="multiclass", num_classes=self.n_classes, average="macro"
),
"auroc": torchmetrics.AUROC(
task="multiclass", num_classes=self.n_classes, average="macro"
),
})
self.train_metrics = metrics.clone(prefix="train/")
self.val_metrics = metrics.clone(prefix="val/")
self.test_metrics = metrics.clone(prefix="test/")
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.model(x)
def _shared_step(
self,
batch: tuple[torch.Tensor, torch.Tensor],
stage: str,
log: bool = True
):
x, y = batch
logits = self(x)
loss = self.loss(logits, y)
# Convert soft labels to hard labels for metrics if needed
if y.dim() == 2:
y = y.argmax(dim=1).long()
else:
y = y.long()
if log:
self.log(f"{stage}/total_loss", loss, prog_bar=(stage != "train"), on_step=(stage=="train"), on_epoch=True)
return loss, logits, y
[docs] def training_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ):
loss, logits, y = self._shared_step(batch, stage="train", log=True)
self.train_metrics(logits, y)
return loss
[docs] def on_train_epoch_end(self) -> None:
self._flatten_and_log_metrics(
self.train_metrics.compute(), prefix="train"
)
self.train_metrics.reset()
[docs] def validation_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
loss, logits, y = self._shared_step(batch, stage="val", log=True)
self.val_metrics(logits, y)
return loss
[docs] def on_validation_epoch_end(self) -> None:
self._flatten_and_log_metrics(
self.val_metrics.compute(), prefix="val"
)
self.val_metrics.reset()
[docs] def _flatten_and_log_metrics(self, computed: dict[str, torch.Tensor], prefix: str) -> None:
"""
Convert metric dictionary produced by torchmetrics into a flat dict of
scalar values and log it with `self.log_dict`.
- Vector/tensor metrics (e.g. per-class accuracy) are expanded into
keys like `{prefix}/class_{i}_acc`.
- Scalar tensors are converted to floats.
- None values are converted to NaN to satisfy loggers that expect
numeric scalars.
"""
flat: dict[str, float] = {}
for key, val in computed.items():
# Normalize key: some metrics come as 'train/accuracy' etc.; keep full key
try:
if val.dim() == 0:
flat[key] = float(val.item())
else:
vals = cast(list[torch.Tensor], val.cpu().tolist()) # type: ignore
for i, v in enumerate(vals):
# Special-case accuracy to use *_acc suffix
if key.endswith("/accuracy"):
base = key.rsplit("/", 1)[0]
flat[f"{base}/class_{i}_acc"] = float(v)
else:
flat[f"{key}_class_{i}"] = float(v)
except Exception:
# Fallback: set NaN so logging doesn't fail
flat[key] = float("nan")
# Finally log flattened scalars
self.log_dict(flat, prog_bar=True, batch_size=1)
[docs] def predict_step(self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int):
_, logits, _ = self._shared_step(batch, stage="test", log=False)
preds = logits.argmax(dim=1)
return preds
[docs]class AEM():
"""
Attention Entropy Maximization (AEM) to encourage diversity in attention weights.
This class provides methods to compute the AEM weight based on cosine annealing
and to calculate the attention entropy loss.
"""
[docs] def __init__(
self,
weight_initial: float = 0.1,
weight_final: float = 0.01,
annealing_epochs: int = 25
):
"""
Args:
use_aem (bool): Whether to use AEM loss
aem_weight_initial (float): Initial weight for AEM loss
aem_weight_final (float): Final weight for AEM loss after annealing
aem_annealing_epochs (int): Number of epochs over which to anneal the AEM weight
"""
self.weight_initial = weight_initial
self.weight_final = weight_final
self.annealing_epochs = annealing_epochs
[docs] def get_negative_entropy(self, attention_weights: torch.Tensor) -> torch.Tensor:
"""
Calculate the negative entropy of attention weights to encourage diversity.
Args:
attention_weights (torch.Tensor): Attention weights tensor of shape (K, N) or (1, N)
where K is number of classes/branches and N is number of instances
Returns:
torch.Tensor: Negative entropy (to maximize entropy, we minimize negative entropy)
"""
# Add small epsilon to avoid log(0), choosing appropriate precision based on dtype
if attention_weights.dtype == torch.float16:
eps = 1e-7 # Safe for half precision
else:
eps = 1e-12
attention_weights = attention_weights + eps
if torch.isnan(attention_weights).any():
print(f"Nan in attention weights: {torch.isnan(attention_weights).any().item()}")
input("Press Enter to continue...")
# Calculate entropy: -sum(p * log(p))
entropy = -torch.sum(attention_weights * torch.log(attention_weights), dim=-1)
if torch.isnan(entropy).any():
print(f"Nan in entropy: {torch.isnan(entropy).any().item()}")
input("Press Enter to continue...")
return -entropy.mean()
[docs] def get_weight(self, current_epoch: int) -> float:
"""
Calculate the current AEM weight using cosine annealing.
Args:
current_epoch (int): Current training epoch
Returns:
float: Current AEM weight
"""
if current_epoch >= self.annealing_epochs:
return self.weight_final
# Cosine annealing: weight decreases from initial to final over annealing_epochs
progress = current_epoch / self.annealing_epochs
if math.isnan(progress):
print(f"Progress: {progress}")
input("Press Enter to continue...")
cosine_factor = 0.5 * (1 + math.cos(math.pi * progress))
if math.isnan(cosine_factor):
print(f"Cosine factor: {cosine_factor}")
input("Press Enter to continue...")
return self.weight_final + (self.weight_initial - self.weight_final) * cosine_factor
[docs] def get_aem(self, current_epoch: int, attention_weights: torch.Tensor) -> torch.Tensor:
"""
Get the term to sum to the loss function for AEM.
Args:
current_epoch (int): Current training epoch
attention_weights (torch.Tensor): Attention weights tensor
Returns:
float: Term to add to the loss function for AEM
"""
weight = self.get_weight(current_epoch)
neg_entropy = self.get_negative_entropy(attention_weights)
return weight * neg_entropy