MIL Training¶
This section guides you through building custom Multiple Instance Learning (MIL) training scripts using the CellMIL API. You’ll learn how to configure datasets, select models, apply preprocessing transforms, and train deep learning models on preprocessed whole slide image data.
MIL Training Guide
Overview¶
Multiple Instance Learning (MIL) enables training models on whole slide images (WSI) using only slide-level labels. Each slide is treated as a “bag” of cell instances, and the model learns to aggregate cell-level features to predict slide-level outcomes such as:
Binary classification (e.g., responder vs non-responder, histological subtype)
Survival prediction (time-to-event outcomes)
Quick Start¶
Here’s a minimal example to get started:
from pathlib import Path
import pandas as pd
from cellmil.datamodels.datasets import MILDataset
from cellmil.utils.train.evals import KFoldCrossValidation
# Load metadata
df = pd.read_excel("./data/metadata.xlsx")
df = df[df["DCR"].isin([0, 1])]
# Create dataset
dataset = MILDataset(
root=Path("./MIL_dataset"),
label="DCR",
folder=Path("./dataset"),
data=df,
extractor="morphometrics",
segmentation_model="cellvit",
)
# Define model creator
def lit_model_creator(input_dim: int):
from cellmil.models.mil.attentiondeepmil import AttentionDeepMIL, LitAttentionDeepMIL
from torch.optim import AdamW
model = AttentionDeepMIL(embed_dim=input_dim, n_classes=2)
lit_model = LitAttentionDeepMIL(
model=model,
optimizer=AdamW(model.parameters(), lr=1e-4),
)
return lit_model
# Initialize k-fold cross-validation
k_fold = KFoldCrossValidation(k=5, random_state=42)
# Run evaluation
model_storage = k_fold.evaluate(
name="my_experiment",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
)
Next Steps¶
Overview - Understand the MIL framework and aggregation strategies
Model Architectures - Explore available model architectures
configuration - Learn about configuration parameters
Data Preparation - Prepare datasets and apply transforms
Training - Train models with k-fold cross-validation
See Also¶
Dataset Creation - Creating the dataset structure
Feature Extraction - Extracting features from cells