cellmil.utils.train.evals.k_fold_cross_validation

Classes

KFoldCrossValidation([k, random_state])

class cellmil.utils.train.evals.k_fold_cross_validation.KFoldCrossValidation(k: int = 5, random_state: int = 42)[source]

Bases: object

__init__(k: int = 5, random_state: int = 42)[source]
evaluate(name: str, lit_model_creator: Callable[[int], LightningModule], dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], output_dir: Union[str, Path], wandb_project: str, transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None, balance_cell_counts: bool = False, cell_balance_bins: int = 5, **kwargs: Any) ModelStorage[source]

Perform k-fold cross-validation with comprehensive result storage.

Parameters:
  • name – Experiment name

  • lit_model_creator – Function that creates a Lightning module given fold index

  • dataset – Dataset for cross-validation

  • output_dir – Directory to store all results

  • transforms – Optional feature transforms

  • label_transforms – Optional label transforms

  • balance_cell_counts – Whether to balance by cell counts

  • cell_balance_bins – Number of bins for cell count balancing

  • max_epochs – Maximum training epochs per fold

  • **kwargs – Additional arguments

Returns:

ModelStorage object with all results

get_train_val_dataloaders(dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], train_indices: list[int], val_indices: list[int], transforms: Optional[Union[TransformPipeline, Transform]] = None, label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]] = None) tuple[Union[torch.utils.data.dataloader.DataLoader[Any], torch_geometric.loader.dataloader.DataLoader], Union[torch.utils.data.dataloader.DataLoader[Any], torch_geometric.loader.dataloader.DataLoader], Union[cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.CellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset, cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.SubsetCellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset], Union[cellmil.datamodels.datasets.cell_mil_dataset.CellMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.CellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.PatchGNNMILDataset, cellmil.datamodels.datasets.patch_mil_dataset.PatchMILDataset, cellmil.datamodels.datasets.cell_gnn_mil_dataset.SubsetCellGNNMILDataset, cellmil.datamodels.datasets.patch_gnn_mil_dataset.SubsetPatchGNNMILDataset]][source]

Create train and validation dataloaders with proper transforms.

Parameters:
  • dataset – Full dataset

  • train_indices – Indices for training set

  • val_indices – Indices for validation set

  • transforms – Optional feature transforms

  • label_transforms – Optional label transforms

Returns:

Tuple of (train_dataloader, val_dataloader, train_dataset, val_dataset)

_get_predictions(trainer: Trainer, model: LightningModule, dataloader: Union[DataLoader[Any], DataLoader], is_survival: bool) dict[str, Any][source]

Get predictions and true labels from the model.

Parameters:
  • trainer – Lightning trainer

  • model – Lightning model

  • dataloader – DataLoader

  • is_survival – Whether this is a survival model

Returns:

Dictionary with prediction data

_aggregate_reports(fold_reports: list[dict[str, Any]]) dict[str, Any][source]

Aggregate reports across all folds.

Parameters:

fold_reports – List of fold reports

Returns:

Aggregated report dictionary

_train_final_model(name: str, lit_model_creator: Callable[[int, bool], LightningModule], dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], transforms: Optional[Union[Transform, TransformPipeline]], label_transforms: Optional[Union[LabelTransform, LabelTransformPipeline]], target_epochs: int, project_name: str) tuple[str, Union[cellmil.datamodels.transforms.base_transform.Transform, cellmil.datamodels.transforms.pipeline.TransformPipeline, NoneType], Union[cellmil.datamodels.transforms.base_label_transform.LabelTransform, cellmil.datamodels.transforms.label_pipeline.LabelTransformPipeline, NoneType]][source]

Train final model on full dataset with target number of epochs.

Parameters:
  • name – Experiment name

  • lit_model_creator – Function to create model

  • dataset – Full dataset

  • transforms – Feature transforms

  • label_transforms – Label transforms

  • target_epochs – Number of epochs to train

  • project_name – Wandb project name

Returns:

Tuple of (checkpoint_path, fitted_transforms, fitted_label_transforms)