Data Preparation¶
This section describes how to prepare data for MIL training, including feature extractors, transforms, filtering, and dataset creation.
Key Parameters¶
Dataset Parameters¶
Parameter |
Description |
|---|---|
|
Cache directory where processed datasets are stored |
|
Path to the dataset directory created by the dataset creation tool |
|
Column name in metadata containing target labels (str for classification, tuple for survival) |
|
pandas DataFrame with slide metadata and labels |
|
Feature extractor(s) |
|
Segmentation model used: |
|
Graph construction method (required for graph models/features) |
|
Boolean - include cell type information (required for Head4Type model) |
Training Parameters¶
Parameter |
Description |
|---|---|
|
Experiment name for tracking and logging |
|
Directory where results and checkpoints will be saved |
|
Weights & Biases project name (required for experiment tracking) |
|
TransformPipeline for feature preprocessing |
|
Label transforms (e.g., TimeDiscretizerTransform for survival) |
|
Boolean - apply cell stratification to balance instance counts |
Normalization and Filtering¶
Transforms are applied using the TransformPipeline class.
Robust Scaling¶
Apply robust scaling normalization to features:
from cellmil.datamodels.transforms import (
TransformPipeline,
RobustScalerTransform,
)
transforms = TransformPipeline([
RobustScalerTransform(apply_log_transform=True),
])
# Pass to k-fold training
model_storage = k_fold.evaluate(
# ... other parameters
transforms=transforms,
)
This applies robust scaling which:
Removes median and scales by IQR
Robust to outliers
Optional log transformation
Correlation Filtering¶
Remove highly correlated features to reduce redundancy:
from cellmil.datamodels.transforms import (
TransformPipeline,
CorrelationFilterTransform,
)
transforms = TransformPipeline([
CorrelationFilterTransform(
correlation_threshold=0.95,
plot_correlation_matrix=False,
),
])
Recommended when:
Combining multiple feature extractors
Using large feature sets (e.g., “ALL”)
Reducing dimensionality
Combined Transforms¶
Chain multiple transforms together:
transforms = TransformPipeline([
CorrelationFilterTransform(correlation_threshold=0.95),
RobustScalerTransform(apply_log_transform=True),
])
Note: Order matters! Correlation filtering should typically come before normalization.
Configuration Examples¶
Basic Classification¶
from pathlib import Path
import pandas as pd
from cellmil.interfaces.CellSegmenterConfig import ModelType
from cellmil.datamodels.datasets import MILDataset
from cellmil.utils.train import get_extractors_from_name
from cellmil.utils.train.evals import KFoldCrossValidation
# Constants
ROOT = Path("./MIL_dataset")
DATASET_FOLDER = Path("./dataset")
# Load data
df = pd.read_excel("metadata.xlsx")
# Create dataset
dataset = MILDataset(
root=ROOT,
label="subtype",
folder=DATASET_FOLDER,
data=df,
extractor=ExtractorType.morphometrics,
segmentation_model=ModelType.cellvit,
)
# Train
k_fold = KFoldCrossValidation(k=5)
model_storage = k_fold.evaluate(
name="morpho_subtype",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
)
Graph-Based Model¶
from cellmil.datamodels.datasets import GNNMILDataset
from cellmil.datamodels.transforms import TransformPipeline, RobustScalerTransform
# Use GNNMILDataset for graph models
dataset = GNNMILDataset(
root=ROOT,
label="RESPONSE",
folder=DATASET_FOLDER,
data=df,
extractor=ExtractorType.morphometrics,
segmentation_model=ModelType.cellvit,
graph_creator=GraphCreatorType.delaunay_radius, # Required!
)
transforms = TransformPipeline([
RobustScalerTransform(apply_log_transform=True),
])
model_storage = k_fold.evaluate(
name="gat_morpho_response",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
transforms=transforms,
)
Survival Analysis¶
from cellmil.datamodels.transforms import TimeDiscretizerTransform
# Label is tuple of (duration, event)
dataset = MILDataset(
root=ROOT,
label=("duration", "event"), # Tuple for survival
folder=DATASET_FOLDER,
data=df,
extractor=ExtractorType.pyradiomics_hed,
segmentation_model=ModelType.cellvit,
)
# Create transforms
transforms = TransformPipeline([
RobustScalerTransform(apply_log_transform=True),
])
# Label transform for time discretization
label_transforms = TimeDiscretizerTransform(n_bins=4)
# Train with survival model
model_storage = k_fold.evaluate(
name="survival_pyrad",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
transforms=transforms,
label_transforms=label_transforms, # Required for survival
)
Foundation Model Embeddings¶
dataset = MILDataset(
root=ROOT,
label="subtype",
folder=DATASET_FOLDER,
data=df,
extractor=ExtractorType.gigapath,
segmentation_model=ModelType.cellvit,
)
model_storage = k_fold.evaluate(
name="gigapath_subtype",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
transforms=transforms,
)
Head4Type¶
# Enable cell_type parameter
dataset = MILDataset(
root=ROOT,
label="subtype",
folder=DATASET_FOLDER,
data=df,
extractor=ExtractorType.morphometrics,
segmentation_model=ModelType.cellvit,
cell_type=True, # Required for Head4Type
)
transforms = TransformPipeline([
RobustScalerTransform(apply_log_transform=True),
])
# Use cell stratification for balanced sampling
model_storage = k_fold.evaluate(
name="head4type_morpho",
lit_model_creator=lit_model_creator,
dataset=dataset,
output_dir=Path("./results"),
wandb_project="my_project",
transforms=transforms,
balance_cell_counts=True, # Cell stratification
)
Data Filtering¶
Cell Type Filtering¶
Include only specific cell types in your analysis:
dataset = MILDataset(
root=ROOT,
label="DCR",
folder=DATASET_FOLDER,
data=df,
extractor=extractors,
segmentation_model=ModelType.cellvit,
cell_types_to_keep=["Neoplastic", "Inflammatory"], # Only these types
)
Available cell types (CellViT/HoVerNet):
NeoplasticInflammatoryConnectiveDeadEpithelial
ROI Filtering¶
Filter cells to only those within specific regions of interest:
dataset = MILDataset(
root=ROOT,
label="DCR",
folder=DATASET_FOLDER,
data=df, # Must contain 'ID', 'I3LUNG_ID', 'CENTER' columns
extractor=extractors,
segmentation_model=ModelType.cellvit,
roi_folder=Path("./data/rois"), # Folder with ROI CSV files
)
ROI file format:
Each ROI file should be a CSV with cell coordinates:
x |
y |
|---|---|
1024 |
2048 |
1030 |
2055 |
… |
Next Steps¶
Training - Train models with k-fold cross-validation
Model Architectures - Explore available model architectures