Model Architectures

This section describes available MIL models and when to use them.

Set-Based Models

Attention-Based MIL (ABMIL)

Ilse, M., Tomczak, J. M., & Welling, M. (2018). Attention-based Deep Multiple Instance Learning.

Classification

from cellmil.models.mil.attentiondeepmil import AttentionDeepMIL, LitAttentionDeepMIL
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
    model = AttentionDeepMIL(
        embed_dim=input_dim,
        size_arg=[256, 128],
        n_classes=2,
        attention_branches=8,
        temperature=1.5,
    )

    lit_model = LitAttentionDeepMIL(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Survival (SurvAttentionDeepMIL)

from cellmil.models.mil.attentiondeepmil import AttentionDeepMIL, LitSurvAttentionDeepMIL
from cellmil.datamodels.transforms import TimeDiscretizerTransform
from torch.optim import AdamW

label = ("OS_MONTHS", "OS_EVENT")
label_transforms = TimeDiscretizerTransform(n_bins=4)

def lit_model_creator(input_dim: int):
    model = AttentionDeepMIL(
        embed_dim=input_dim,
        size_arg=[256, 128],
        n_classes=4,  # n_bins
        attention_branches=8,
    )

    lit_model = LitSurvAttentionDeepMIL(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

CLAM

Data-efficient and weakly supervised computational pathology on whole-slide images. Lu, Ming Y et al., Nature Biomedical Engineering, 2021. DOI: https://doi.org/10.1038/s41551-021-00707-9

CLAM Model Architecture

Classification

from cellmil.models.mil.clam import CLAM_SB, LitCLAM
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
     model = CLAM_SB(
         embed_dim=input_dim,
         size_arg="small",
         n_classes=2,
         k_sample=8
     )

     lit_model = LitCLAM(
         model=model,
         optimizer=AdamW(model.parameters(), lr=1e-4),
     )
     return lit_model

Survival (SurvCLAM)

from cellmil.models.mil.clam import CLAM_SB, LitSurvCLAM
from cellmil.datamodels.transforms import TimeDiscretizerTransform
from torch.optim import AdamW

label = ("OS_MONTHS", "OS_EVENT")
label_transforms = TimeDiscretizerTransform(n_bins=4)

def lit_model_creator(input_dim: int):
    model = CLAM_SB(
        embed_dim=input_dim,
        size_arg="small",
        n_classes=4,  # n_bins
        k_sample=8,
    )

    lit_model = LitSurvCLAM(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Head4Type

Uses separate attention heads for each cell type (neoplastic, inflammatory, connective, etc.). Each cell type gets its own attention mechanism.

Warning

Requirements: This model requires cell type information from the segmentation model. Use CellViT or HoVerNet for segmentation. Cellpose+SAM does not provide cell types and cannot be used with this model.

Note

Don’t use with deep learning features (ResNet50, GigaPath, UNI) as they don’t preserve cell-level type information.

Classification

from cellmil.models.mil.head4type import Head4Type, LitHead4Type
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
    model = Head4Type(
        embed_dim=input_dim,
        size_arg=[256, 128],
        n_classes=2,
        cell_types=5,
        temperature=1.5,
    )

    lit_model = LitHead4Type(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Survival (SurvHead4Type)

from cellmil.models.mil.head4type import Head4Type, LitSurvHead4Type
from cellmil.datamodels.transforms import TimeDiscretizerTransform
from torch.optim import AdamW

label = ("OS_MONTHS", "OS_EVENT")
label_transforms = TimeDiscretizerTransform(n_bins=4)

def lit_model_creator(input_dim: int):
    model = Head4Type(
        embed_dim=input_dim,
        size_arg=[256, 128],
        n_classes=4,  # n_bins
        cell_types=5,
        temperature=1.5,
    )

    lit_model = LitSurvHead4Type(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

CellConv

Applies 1D convolutions over the cell sequence before attention pooling. Tries to capture local patterns in spatially-ordered cell features.

It is experimental scenarios only. This model was only used for technical validation while developing CellMIL.

Classification

from cellmil.models.mil.cellconv import CellConv, LitCellConv
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
    model = CellConv(
         embed_dim=input_dim,
         n_classes=2,
         convolution_depth=3,
         size_arg=[512, 128],
         attention_branches=1,
         temperature=1.0,
         dropout=0.0,
         kernel_size=3,
    )

    lit_model = LitCellConv(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Survival (SurvCellConv)

from cellmil.models.mil.cellconv import CellConv, LitSurvCellConv
from cellmil.datamodels.transforms import TimeDiscretizerTransform
from torch.optim import AdamW

label = ("OS_MONTHS", "OS_EVENT")
label_transforms = TimeDiscretizerTransform(n_bins=4)

def lit_model_creator(input_dim: int):
    model = CellConv(
        embed_dim=input_dim,
        n_classes=4,  # n_bins
        convolution_depth=3,
        size_arg=[512, 128],
        attention_branches=1,
    )

    lit_model = LitSurvCellConv(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Multifocus

Uses multi-head attention where each head focuses on different feature dimensions. Aggregates using diagonal elements of the attention-feature product matrix.

It is experimental scenarios only. This model uses a unique diagonal aggregation approach that hasn’t been extensively validated.

from cellmil.models.mil.multifocus import MultiFocus, LitMultiFocus
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
    model = MultiFocus(
        embed_dim=input_dim,
        n_classes=2,
        size_arg=[32],
        temperature=1.0,
        dropout=0.0,
    )

    lit_model = LitMultiFocus(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

Standard MIL

Data-efficient and weakly supervised computational pathology on whole-slide images. Lu, Ming Y et al., Nature Biomedical Engineering, 2021. DOI: https://doi.org/10.1038/s41551-021-00707-9

Classification

from cellmil.models.mil.standard import MIL_fc, MIL_fc_mc, LitStandard
from torch.optim import AdamW

def lit_model_creator(input_dim: int):
    # Binary classification
    model = MIL_fc(
        embed_dim=input_dim,
        size_arg="small",  # or list like [256]
        n_classes=2,
        top_k=1,  # Number of top instances to select
        dropout=0.25,
    )

    # For multi-class (n_classes > 2), use MIL_fc_mc instead:
    # model = MIL_fc_mc(embed_dim=input_dim, n_classes=3, top_k=1)

    lit_model = LitStandard(
        model=model,
        optimizer=AdamW(model.parameters(), lr=1e-4),
    )
    return lit_model

HistoBistro

Transformer-based biomarker prediction from colorectal cancer histology: A large-scale multicentric study. Wagner, Sophia J et al., Cancer Cell, Elsevier. DOI: https://doi.org/10.1016/j.ccell.2023.02.002

TransMIL

Transmil: Transformer based correlated multiple instance learning for whole slide image classification. Shao, Zhuchen et al., Advances in Neural Information Processing Systems, 2021. DOI: https://proceedings.neurips.cc/paper/2021/hash/10c272d06794d3e5785d5e7c5356e9ff-Abstract.html

Graph-Based Models

These models build a graph connecting nearby cells and use graph neural networks to let cells exchange information with their neighbors.

General setup:

from cellmil.datamodels.datasets import GNNMILDataset

# Use GNNMILDataset instead of MILDataset
dataset = GNNMILDataset(
    root=config.root,
    label=config.label,
    folder=config.folder,
    data=df,
    extractor=config.extractor,
    segmentation_model=config.segmentation_model,
    graph_creator="delaunay_radius",  # Required!
)

Graph Attention Networks (GAT)

Veličković, P., Cucurull, G., Casanova, A., Romero, A., Lio, P., & Bengio, Y. (2017). Graph attention networks. arXiv preprint arXiv:1710.10903.

Classification

from cellmil.models.mil.graphmil import GAT, Attention, LitGraphMIL
from cellmil.datamodels.datasets import GNNMILDataset
from cellmil.utils.train.evals import KFoldCrossValidation
from torch.optim import AdamW

# Use GNNMILDataset for graph-based models
dataset = GNNMILDataset(
    root=config.root,
    label=config.label,
    folder=config.folder,
    data=df,
    extractor=config.extractor,
    segmentation_model=config.segmentation_model,
    graph_creator="delaunay_radius",
)

def lit_model_creator(input_dim: int):
    # GNN backbone
    gnn = GAT(
        input_dim=input_dim,
        hidden_dim=256,
        n_layers=4,
        dropout=0.25,
        heads=3,
    )

    # Attention pooling head
    pooling = Attention(
        input_dim=gnn.hidden_dim,
        size_arg=[256, 128],
        n_classes=2,
        attention_branches=8,
        temperature=1.5,
    )

    lit_model = LitGraphMIL(
        gnn=gnn,
        pooling_classifier=pooling,
        optimizer_cls=AdamW,
        optimizer_kwargs={"lr": 1e-4},
    )
    return lit_model

# Run k-fold cross-validation
k_fold = KFoldCrossValidation(k=5)
model_storage = k_fold.evaluate(
    name="gat_experiment",
    lit_model_creator=lit_model_creator,
    dataset=dataset,
    output_dir=Path("./results"),
    wandb_project="my_project",
)

Survival (SurvGraphMIL)

from cellmil.models.mil.graphmil import GAT, Attention, LitSurvGraphMIL
from cellmil.datamodels.datasets import GNNMILDataset
from cellmil.datamodels.transforms import TimeDiscretizerTransform
from torch.optim import AdamW

label = ("OS_MONTHS", "OS_EVENT")
label_transforms = TimeDiscretizerTransform(n_bins=4)

dataset = GNNMILDataset(
    label=label,
    graph_creator="delaunay_radius",
    # ... other parameters
)

def lit_model_creator(input_dim: int):
    gnn = GAT(
        input_dim=input_dim,
        hidden_dim=256,
        n_layers=4,
        dropout=0.25,
        heads=3,
    )

    pooling = Attention(
        input_dim=gnn.hidden_dim,
        size_arg=[256, 128],
        n_classes=4,  # n_bins
        attention_branches=8,
    )

    lit_model = LitSurvGraphMIL(
        gnn=gnn,
        pooling_classifier=pooling,
        optimizer_cls=AdamW,
        optimizer_kwargs={"lr": 1e-4},
    )
    return lit_model

Other GNN Backbones

You can swap GAT with other GNN architectures. All work with both classification (LitGraphMIL) and survival (LitSurvGraphMIL).

GraphSAGE: Hamilton, W., Ying, Z., & Leskovec, J. (2017). Inductive representation learning on large graphs. Advances in neural information processing systems, 30.

from cellmil.models.mil.graphmil import SAGE

gnn = SAGE(
    input_dim=input_dim,
    hidden_dim=256,
    n_layers=3,
    dropout=0.25,
)

EGNN (Equivariant GNN): Satorras, V. G., Hoogeboom, E., & Welling, M. (2021, July). E (n) equivariant graph neural networks. In International conference on machine learning (pp. 9323-9332). PMLR.

from cellmil.models.mil.graphmil import EGNN

gnn = EGNN(
    input_dim=input_dim,
    hidden_dim=256,
    n_layers=3,
)

SGFormer: Wu, Q., Zhao, W., Yang, C., Zhang, H., Nie, F., Jiang, H., … & Yan, J. (2023). Sgformer: Simplifying and empowering transformers for large-graph representations. Advances in Neural Information Processing Systems, 36, 64753-64773.

from cellmil.models.mil.graphmil import SGFormer

gnn = SGFormer(
    input_dim=input_dim,
    hidden_dim=256,
    n_layers=3,
)

SmallWorld:

Dynamically creates shortcuts between important cells. Experimental.

from cellmil.models.mil.graphmil import SmallWorld

gnn = SmallWorld(
    input_dim=input_dim,
    hidden_dim=256,
    gamma=0.5,
)

Dynamically creates shortcuts between important cells. Experimental.

SmallWorld Architecture

Pooling Methods

After the GNN processes the graph, you need to pool node embeddings into a slide-level representation:

Attention Pooling (Recommended):

from cellmil.models.mil.graphmil import Attention

pooling = Attention(
    input_dim=gnn.hidden_dim,
    size_arg=[256, 128],
    n_classes=2,
    attention_branches=8,
)

Mean Pooling:

Simple averaging. Use as baseline.

CLAM Pooling:

from cellmil.models.mil.graphmil import CLAM

pooling = CLAM(
    input_dim=gnn.hidden_dim,
    size_arg=[256, 128],
    n_classes=2,
)

Next Steps

  • Training - Train your chosen model

  • configuration - Configure feature extractors