cellmil.utils.train.evals¶
- class cellmil.utils.train.evals.KFoldCrossValidation(k: int = 5, random_state: int = 42)[source]¶
Bases:
object- evaluate(name: str, lit_model_creator: Callable[[int], LightningModule], dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], output_dir: Union[str, Path], wandb_project: str, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, balance_cell_counts: bool = False, cell_balance_bins: int = 5, **kwargs: Any) ModelStorage[source]¶
Perform k-fold cross-validation with comprehensive result storage.
- Parameters:
name – Experiment name
lit_model_creator – Function that creates a Lightning module given fold index
dataset – Dataset for cross-validation
output_dir – Directory to store all results
transforms – Optional feature transforms
label_transforms – Optional label transforms
balance_cell_counts – Whether to balance by cell counts
cell_balance_bins – Number of bins for cell count balancing
max_epochs – Maximum training epochs per fold
**kwargs – Additional arguments
- Returns:
ModelStorage object with all results
- get_train_val_dataloaders(dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], train_indices: list[int], val_indices: list[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) tuple[Union[torch.utils.data.dataloader.DataLoader[Any], torch_geometric.loader.dataloader.DataLoader], Union[torch.utils.data.dataloader.DataLoader[Any], torch_geometric.loader.dataloader.DataLoader], Union[cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.CellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset, cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.SubsetCellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset], Union[cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.CellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset, cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.SubsetCellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset]][source]¶
Create train and validation dataloaders with proper transforms.
- Parameters:
dataset – Full dataset
train_indices – Indices for training set
val_indices – Indices for validation set
transforms – Optional feature transforms
label_transforms – Optional label transforms
- Returns:
Tuple of (train_dataloader, val_dataloader, train_dataset, val_dataset)
- _get_predictions(trainer: Trainer, model: LightningModule, dataloader: Union[DataLoader[Any], DataLoader], is_survival: bool) dict[str, Any][source]¶
Get predictions and true labels from the model.
- Parameters:
trainer – Lightning trainer
model – Lightning model
dataloader – DataLoader
is_survival – Whether this is a survival model
- Returns:
Dictionary with prediction data
- _aggregate_reports(fold_reports: list[dict[str, Any]]) dict[str, Any][source]¶
Aggregate reports across all folds.
- Parameters:
fold_reports – List of fold reports
- Returns:
Aggregated report dictionary
- _train_final_model(name: str, lit_model_creator: Callable[[int, bool], LightningModule], dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], transforms: Optional[Union[Transform, TransformPipeline]], label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]], target_epochs: int, project_name: str) tuple[str, Union[cellmil.datamodels.transforms.base_transform.Transform, cellmil.datamodels.transforms.pipeline.TransformPipeline, NoneType], Union[cellmil.datamodels.transforms.base_label_transform.LabelTransform, cellmil.datamodels.transforms.label_pipeline.LabelTransformPipeline, NoneType]][source]¶
Train final model on full dataset with target number of epochs.
- Parameters:
name – Experiment name
lit_model_creator – Function to create model
dataset – Full dataset
transforms – Feature transforms
label_transforms – Label transforms
target_epochs – Number of epochs to train
project_name – Wandb project name
- Returns:
Tuple of (checkpoint_path, fitted_transforms, fitted_label_transforms)
Modules