cellmil.utils.train.evals.utils

Functions

compute_slide_cell_counts(dataset, slides)

Estimate per-slide cell counts for balancing folds.

is_survival_model(lit_model)

Check if the model is a survival analysis model.

Classes

Report(trainer, lit_model)

Base class for evaluation reports.

cellmil.utils.train.evals.utils.is_survival_model(lit_model: LightningModule) bool[source]

Check if the model is a survival analysis model.

cellmil.utils.train.evals.utils.compute_slide_cell_counts(dataset: Union[CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset], slides: ndarray[Any, Any])[source]

Estimate per-slide cell counts for balancing folds.

class cellmil.utils.train.evals.utils.Report(trainer: Trainer, lit_model: LightningModule)[source]

Bases: object

Base class for evaluation reports.

__init__(trainer: Trainer, lit_model: LightningModule) None[source]
_load_best_model() None[source]

Load the best model checkpoint if available.

generate(dataloader: Union[DataLoader[Any], DataLoader]) dict[str, Any][source]

Generate the report.

_get_classification_report(dataloader: Union[DataLoader[Any], DataLoader]) dict[str, Any][source]

Get classification report with precision, recall, f1-score.

_get_survival_report(dataloader: Union[DataLoader[Any], DataLoader]) dict[str, Any][source]

Get survival analysis report with C-index