Training¶
This section describes how to train MIL models using k-fold cross-validation, the recommended approach for robust model evaluation.
K-Fold Cross-Validation¶
Overview¶
The recommended approach for training and evaluating MIL models is using k-fold cross-validation. This provides more robust performance estimates than a single train/validation split and automatically handles data splitting and transform fitting.
Model Creator Function¶
The lit_model_creator function is called for each fold to create a fresh model instance:
def lit_model_creator(input_dim: int, use_lr_scheduler: bool = True) -> Pl.LightningModule:
"""
Create a LightningModule for training.
Args:
input_dim: Input feature dimension (determined automatically)
use_lr_scheduler: Whether to use learning rate scheduler
Returns:
LightningModule instance ready for training
"""
# Create model architecture
model = AttentionDeepMIL(
embed_dim=input_dim,
size_arg=[256, 128],
n_classes=2,
)
# Define optimizer
optimizer = AdamW(model.parameters(), lr=1e-4)
# Optional learning rate scheduler
lr_scheduler = None
if use_lr_scheduler:
lr_scheduler = ReduceLROnPlateau(
optimizer, mode="min", patience=5, factor=0.8
)
# Wrap in LightningModule
lit_model = LitAttentionDeepMIL(
model=model,
optimizer=optimizer,
loss=FocalLoss(alpha=0.5, gamma=2.0),
lr_scheduler=lr_scheduler,
)
return lit_model
Key points:
Takes
input_dimas parameter (automatically determined from data)Takes
use_lr_schedulerto optionally enable learning rate schedulingReturns a
LightningModuleinstanceCreates fresh model for each fold (no weight sharing)
Defines optimizer, loss, and optional scheduler
K-Fold Parameters¶
Constructor Parameters (KFoldCrossValidation):
Parameter |
Default |
Description |
|---|---|---|
|
|
Number of cross-validation folds |
|
|
Random seed for reproducibility |
Evaluate Method Parameters:
Parameter |
Default |
Description |
|---|---|---|
|
Required |
Experiment name for logging and checkpoints |
|
Required |
Function that takes |
|
Required |
The full MIL dataset (no split) |
|
Required |
Directory to store all results and checkpoints |
|
Required |
Weights & Biases project name for logging |
|
|
Transform pipeline (fit on each fold’s training data) |
|
|
Label transforms (e.g., |
|
|
Jointly stratify by label and cell-count quantiles |
|
|
Number of quantile bins when balancing cell counts |
|
|
Epochs without improvement before early stopping |
Advanced Features¶
Cell-Count Balanced Stratification¶
When slides have varying numbers of cells, you can balance folds by both label and cell count:
k_fold = KFoldCrossValidation(k=5, random_state=42)
model_storage = k_fold.evaluate(
name="balanced_experiment",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
transforms=transforms,
balance_cell_counts=True, # Enable cell-count balancing
cell_balance_bins=5, # Number of quantile bins
)
This creates a combined stratification target from labels and cell-count quantile bins, ensuring each fold has similar distributions of both class labels and cell counts.
Class Imbalance Handling¶
For imbalanced datasets, use FocalLoss with computed class weights:
from cellmil.utils.train import FocalLoss
from cellmil.utils.train.dataset import complementary_frequencies
# Get class frequencies from your data
_, alpha = complementary_frequencies(df, 'label')
def lit_model_creator(input_dim: int) -> Pl.LightningModule:
model = AttentionDeepMIL(embed_dim=input_dim, ...)
lit_model = LitAttentionDeepMIL(
model=model,
optimizer=optimizer,
loss=FocalLoss(
alpha=alpha, # Class weight (higher for minority class)
gamma=2.0, # Focus parameter
),
)
return lit_model
complementary_frequencies returns (frequencies, alpha) where alpha is the weight for the positive class.
Learning Rate Scheduling¶
Use learning rate schedulers to adjust learning rate during training:
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
def lit_model_creator(input_dim: int, use_lr_scheduler: bool = True) -> Pl.LightningModule:
model = AttentionDeepMIL(embed_dim=input_dim, ...)
optimizer = AdamW(model.parameters(), lr=1e-4)
# Option 1: Reduce on plateau
scheduler = ReduceLROnPlateau(
optimizer,
mode="min", # Minimize validation loss
patience=5, # Wait 5 epochs before reducing
factor=0.8, # Multiply LR by 0.8
)
# Option 2: Cosine annealing
scheduler = CosineAnnealingLR(
optimizer,
T_max=50, # Maximum epochs
eta_min=1e-6, # Minimum LR
)
lit_model = LitAttentionDeepMIL(
model=model,
optimizer=optimizer,
lr_scheduler=scheduler if use_scheduler else None,
)
return lit_model
What K-Fold Does Internally¶
The KFoldCrossValidation.evaluate() method:
Splits data using
StratifiedKFoldto maintain class balanceCreates train/val subsets using
dataset.create_train_val_datasets()Fits transforms only on the training subset
Applies transforms to both train and validation
Fits label transforms (if provided) on training labels
Creates fresh model using
lit_model_creatorTrains with early stopping and checkpointing
Evaluates and logs metrics to wandb
Aggregates results across all folds
Trains final model on full dataset with average best epochs
Returns
ModelStorageobject with all results
Next Steps¶
Model Architectures - Explore different model architectures
../explainability - Interpret model predictions