Source code for cellmil.datamodels.model
import json
import shutil
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Any, Union, Optional
from dataclasses import dataclass, asdict
from cellmil.utils import logger
from cellmil.datamodels.transforms import TransformPipeline, LabelTransformPipeline
[docs]def convert_numpy_types(obj: Any) -> Any:
"""
Recursively convert numpy types to Python native types for JSON serialization.
Args:
obj: Object to convert
Returns:
Object with numpy types converted to Python native types
"""
if isinstance(obj, np.integer):
return int(obj)
elif isinstance(obj, np.floating):
return float(obj)
elif isinstance(obj, np.ndarray):
return obj.tolist()
elif isinstance(obj, dict):
return {key: convert_numpy_types(value) for key, value in obj.items()}
elif isinstance(obj, list):
return [convert_numpy_types(item) for item in obj]
elif isinstance(obj, tuple):
return tuple(convert_numpy_types(item) for item in obj)
return obj
[docs]@dataclass
class FoldMetadata:
"""Metadata for a single fold."""
fold_idx: int
train_size: int
val_size: int
best_epoch: int
best_metric_value: float
metric_name: str
is_survival: bool
metrics: dict[str, Any]
[docs]@dataclass
class ExperimentMetadata:
"""Metadata for the entire k-fold experiment."""
name: str
k_folds: int
random_state: int
balance_cell_counts: bool
cell_balance_bins: int
is_survival: bool
aggregated_metrics: dict[str, Any]
best_fold_idx: int
avg_best_epoch: float
dataset_config: dict[str, Any]
model_config: dict[str, Any]
[docs]class ModelStorage:
"""
Manages storage and retrieval of k-fold cross-validation results.
Directory structure:
{output_dir}/
├── experiment_metadata.json
├── fold_0/
│ ├── best_model.ckpt
│ ├── train_indices.json
│ ├── val_indices.json
│ ├── predictions.csv
│ ├── transforms/
│ │ ├── pipeline_config.json
│ │ ├── transform_0_*.json
│ │ └── ...
│ ├── label_transforms/
│ │ ├── pipeline.json
│ │ ├── transform_0.json
│ │ └── ...
│ └── metadata.json
├── fold_1/
│ └── ...
├── ...
└── final_model/
├── final_model.ckpt
└── metadata.json
"""
[docs] def __init__(self, output_dir: Union[str, Path], experiment_name: str, load_existing: bool = False):
"""
Initialize ModelStorage.
Args:
output_dir: Base directory for storing results
experiment_name: Name of the experiment
load_existing: If True, load from existing directory without versioning
"""
base_dir = Path(output_dir)
proposed_dir = base_dir / experiment_name
# If loading existing, use the directory as-is
if load_existing:
if not proposed_dir.exists():
raise FileNotFoundError(
f"Cannot load existing experiment: '{proposed_dir}' does not exist"
)
self.output_dir = proposed_dir
self.experiment_name = experiment_name
logger.info(f"Loading existing experiment: '{self.experiment_name}'")
else:
# If directory exists, create versioned name
if proposed_dir.exists():
version = 2
while (base_dir / f"{experiment_name}_v{version}").exists():
version += 1
self.output_dir = base_dir / f"{experiment_name}_v{version}"
self.experiment_name = f"{experiment_name}_v{version}"
logger.warning(
f"Experiment '{experiment_name}' already exists. "
f"Creating new version: '{self.experiment_name}'"
)
else:
self.output_dir = proposed_dir
self.experiment_name = experiment_name
self.output_dir.mkdir(parents=True, exist_ok=True)
self.fold_metadata: dict[int, FoldMetadata] = {}
self.experiment_metadata: Optional[ExperimentMetadata] = None
# If loading existing, load all metadata
if load_existing:
self._load_all_metadata()
[docs] def save_fold_results(
self,
fold_idx: int,
checkpoint_path: Union[str, Path],
train_indices: list[int],
val_indices: list[int],
predictions: dict[str, Any],
metadata: FoldMetadata,
transforms: Any = None,
label_transforms: Any = None,
) -> None:
"""
Save all results for a single fold.
Args:
fold_idx: Fold index
checkpoint_path: Path to the best checkpoint for this fold
train_indices: Training indices
val_indices: Validation indices
predictions: Dictionary with 'y_true' and 'y_pred' arrays
metadata: Fold metadata
transforms: Optional feature transforms
label_transforms: Optional label transforms
"""
fold_dir = self.output_dir / f"fold_{fold_idx}"
fold_dir.mkdir(parents=True, exist_ok=True)
# Save checkpoint
if Path(checkpoint_path).exists():
shutil.copy2(checkpoint_path, fold_dir / "best_model.ckpt")
# Save indices
with open(fold_dir / "train_indices.json", "w") as f:
json.dump(train_indices, f, indent=2)
with open(fold_dir / "val_indices.json", "w") as f:
json.dump(val_indices, f, indent=2)
# Save predictions as CSV
df = pd.DataFrame(predictions)
df.to_csv(fold_dir / "predictions.csv", index=False)
# Save transforms using their native save methods (JSON format)
if transforms is not None:
# TransformPipeline or Transform - saves to directory/file
transforms_dir = fold_dir / "transforms"
transforms.save(transforms_dir)
if label_transforms is not None:
# For single LabelTransform, save as JSON file
# For LabelTransformPipeline, save as directory structure
if isinstance(label_transforms, LabelTransformPipeline):
label_transforms_dir = fold_dir / "label_transforms"
label_transforms.save(label_transforms_dir)
else:
# Single transform - save as JSON file
label_transforms_file = fold_dir / "label_transforms.json"
label_transforms.save(label_transforms_file)
# Save metadata
with open(fold_dir / "metadata.json", "w") as f:
metadata_dict = convert_numpy_types(asdict(metadata))
json.dump(metadata_dict, f, indent=2)
self.fold_metadata[fold_idx] = metadata
[docs] def save_experiment_metadata(self, metadata: ExperimentMetadata) -> None:
"""Save overall experiment metadata."""
self.experiment_metadata = metadata
with open(self.output_dir / "experiment_metadata.json", "w") as f:
metadata_dict = convert_numpy_types(asdict(metadata))
json.dump(metadata_dict, f, indent=2)
[docs] def save_final_model(
self,
checkpoint_path: Union[str, Path],
avg_epochs: float,
final_metrics: dict[str, Any],
transforms: Any = None,
label_transforms: Any = None,
) -> None:
"""
Save the final model trained on average epochs.
Args:
checkpoint_path: Path to final model checkpoint
avg_epochs: Average number of epochs used
final_metrics: Metrics from final model
transforms: Optional feature transforms
label_transforms: Optional label transforms
"""
final_dir = self.output_dir / "final_model"
final_dir.mkdir(parents=True, exist_ok=True)
if Path(checkpoint_path).exists():
shutil.copy2(checkpoint_path, final_dir / "final_model.ckpt")
# Save transforms using their native save methods (JSON format)
if transforms is not None:
transforms_dir = final_dir / "transforms"
transforms.save(transforms_dir)
if label_transforms is not None:
# For single LabelTransform, save as JSON file
# For LabelTransformPipeline, save as directory structure
if isinstance(label_transforms, LabelTransformPipeline):
label_transforms_dir = final_dir / "label_transforms"
label_transforms.save(label_transforms_dir)
else:
# Single transform - save as JSON file
label_transforms_file = final_dir / "label_transforms.json"
label_transforms.save(label_transforms_file)
metadata: dict[str, Any] = {
"avg_epochs": avg_epochs,
"metrics": final_metrics,
}
with open(final_dir / "metadata.json", "w") as f:
metadata_converted = convert_numpy_types(metadata)
json.dump(metadata_converted, f, indent=2)
[docs] def load_fold_checkpoint(self, fold_idx: int) -> Path:
"""Load checkpoint path for a specific fold."""
checkpoint_path = self.output_dir / f"fold_{fold_idx}" / "best_model.ckpt"
if not checkpoint_path.exists():
raise FileNotFoundError(f"Checkpoint not found for fold {fold_idx}")
return checkpoint_path
[docs] def load_final_checkpoint(self) -> Path:
"""Load the final model checkpoint."""
checkpoint_path = self.output_dir / "final_model" / "final_model.ckpt"
if not checkpoint_path.exists():
raise FileNotFoundError("Final model checkpoint not found")
return checkpoint_path
[docs] def load_fold_predictions(self, fold_idx: int) -> pd.DataFrame:
"""Load predictions for a specific fold."""
pred_path = self.output_dir / f"fold_{fold_idx}" / "predictions.csv"
if not pred_path.exists():
raise FileNotFoundError(f"Predictions not found for fold {fold_idx}")
return pd.read_csv(pred_path) # type: ignore
[docs] def load_all_predictions(self) -> pd.DataFrame:
"""Load and concatenate predictions from all folds."""
all_preds: list[pd.DataFrame] = []
for fold_idx in sorted(self.fold_metadata.keys()):
df = self.load_fold_predictions(fold_idx)
df["fold"] = fold_idx
all_preds.append(df)
return pd.concat(all_preds, ignore_index=True)
[docs] def load_fold_transforms(self, fold_idx: int) -> tuple[Any, Any]:
"""Load transforms for a specific fold."""
fold_dir = self.output_dir / f"fold_{fold_idx}"
transforms = None
label_transforms = None
# Try loading from JSON (native format)
transforms_dir = fold_dir / "transforms"
if transforms_dir.exists():
try:
transforms = TransformPipeline.load(transforms_dir)
except Exception as e:
logger.warning(f"Failed to load transforms from JSON: {e}")
# Try loading label transforms - check for both pipeline (directory) and single transform (file)
label_transforms_dir = fold_dir / "label_transforms"
label_transforms_file = fold_dir / "label_transforms.json"
if label_transforms_dir.exists() and label_transforms_dir.is_dir():
# Pipeline format (directory)
try:
label_transforms = LabelTransformPipeline.load(label_transforms_dir)
except Exception as e:
logger.warning(f"Failed to load label transforms pipeline from directory: {e}")
elif label_transforms_file.exists():
# Single transform format (JSON file)
try:
from .transforms import TimeDiscretizerTransform
with open(label_transforms_file, 'r') as f:
import json
config = json.load(f)
transform_class_name = config.pop('transform_class', None)
if transform_class_name == 'TimeDiscretizerTransform':
label_transforms = TimeDiscretizerTransform.from_config(config)
else:
logger.warning(f"Unknown transform class: {transform_class_name}")
except Exception as e:
logger.warning(f"Failed to load label transforms from JSON file: {e}")
elif (fold_dir / "label_transforms").exists() and (fold_dir / "label_transforms").is_file():
# Legacy format: single transform saved as file without .json extension
try:
from .transforms import TimeDiscretizerTransform
with open(fold_dir / "label_transforms", 'r') as f:
import json
config = json.load(f)
transform_class_name = config.pop('transform_class', None)
if transform_class_name == 'TimeDiscretizerTransform':
label_transforms = TimeDiscretizerTransform.from_config(config)
else:
logger.warning(f"Unknown transform class: {transform_class_name}")
except Exception as e:
logger.warning(f"Failed to load legacy label transforms: {e}")
return transforms, label_transforms
[docs] def load_final_transforms(self) -> tuple[Any, Any]:
"""Load transforms for the final model."""
final_dir = self.output_dir / "final_model"
transforms = None
label_transforms = None
# Try loading from JSON (native format)
transforms_dir = final_dir / "transforms"
if transforms_dir.exists():
try:
transforms = TransformPipeline.load(transforms_dir)
except Exception as e:
logger.warning(f"Failed to load transforms from JSON: {e}")
# Try loading label transforms - check for both pipeline (directory) and single transform (file)
label_transforms_dir = final_dir / "label_transforms"
label_transforms_file = final_dir / "label_transforms.json"
if label_transforms_dir.exists() and label_transforms_dir.is_dir():
# Pipeline format (directory)
try:
label_transforms = LabelTransformPipeline.load(label_transforms_dir)
except Exception as e:
logger.warning(f"Failed to load label transforms pipeline from directory: {e}")
elif label_transforms_file.exists():
# Single transform format (JSON file)
try:
from .transforms import TimeDiscretizerTransform
with open(label_transforms_file, 'r') as f:
import json
config = json.load(f)
transform_class_name = config.pop('transform_class', None)
if transform_class_name == 'TimeDiscretizerTransform':
label_transforms = TimeDiscretizerTransform.from_config(config)
else:
logger.warning(f"Unknown transform class: {transform_class_name}")
except Exception as e:
logger.warning(f"Failed to load label transforms from JSON file: {e}")
elif (final_dir / "label_transforms").exists() and (final_dir / "label_transforms").is_file():
# Legacy format: single transform saved as file without .json extension
try:
from .transforms import TimeDiscretizerTransform
with open(final_dir / "label_transforms", 'r') as f:
import json
config = json.load(f)
transform_class_name = config.pop('transform_class', None)
if transform_class_name == 'TimeDiscretizerTransform':
label_transforms = TimeDiscretizerTransform.from_config(config)
else:
logger.warning(f"Unknown transform class: {transform_class_name}")
except Exception as e:
logger.warning(f"Failed to load legacy label transforms: {e}")
return transforms, label_transforms
[docs] def get_average_best_epoch(self) -> float:
"""Calculate average of best epochs across all folds."""
if not self.fold_metadata:
raise ValueError("No fold metadata available")
epochs = [meta.best_epoch for meta in self.fold_metadata.values()]
return float(np.mean(epochs))
[docs] def get_experiment_summary(self) -> dict[str, Any]:
"""Get a summary of the entire experiment."""
if self.experiment_metadata is None:
raise ValueError("Experiment metadata not saved yet")
summary = asdict(self.experiment_metadata)
summary["folds"] = {
f"fold_{idx}": asdict(meta) for idx, meta in self.fold_metadata.items()
}
return summary
[docs] def get_fold_indices(self, fold_idx: int) -> tuple[list[int], list[int]]:
"""Get train and validation indices for a specific fold."""
fold_dir = self.output_dir / f"fold_{fold_idx}"
with open(fold_dir / "train_indices.json", "r") as f:
train_indices = json.load(f)
with open(fold_dir / "val_indices.json", "r") as f:
val_indices = json.load(f)
return train_indices, val_indices
[docs] @classmethod
def from_directory(cls, experiment_dir: Union[str, Path]) -> "ModelStorage":
"""
Load an existing experiment from a directory.
Args:
experiment_dir: Path to the experiment directory
Returns:
ModelStorage instance with loaded metadata
Example:
>>> storage = ModelStorage.from_directory("/path/to/experiments/my_experiment")
>>> print(storage.experiment_metadata)
>>> predictions = storage.load_all_predictions()
"""
experiment_dir = Path(experiment_dir)
if not experiment_dir.exists():
raise FileNotFoundError(f"Experiment directory not found: {experiment_dir}")
# Extract experiment name and parent directory
experiment_name = experiment_dir.name
output_dir = experiment_dir.parent
# Create instance with load_existing=True
return cls(output_dir, experiment_name, load_existing=True)
[docs] def _load_experiment_metadata(self) -> None:
"""Load experiment metadata from disk."""
metadata_path = self.output_dir / "experiment_metadata.json"
if not metadata_path.exists():
logger.warning("Experiment metadata file not found")
return
with open(metadata_path, "r") as f:
metadata_dict = json.load(f)
self.experiment_metadata = ExperimentMetadata(**metadata_dict)
[docs] def _load_fold_metadata(self, fold_idx: int) -> None:
"""Load metadata for a specific fold."""
metadata_path = self.output_dir / f"fold_{fold_idx}" / "metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(f"Metadata not found for fold {fold_idx}")
with open(metadata_path, "r") as f:
metadata_dict = json.load(f)
self.fold_metadata[fold_idx] = FoldMetadata(**metadata_dict)
[docs] def _load_all_metadata(self) -> None:
"""Load all experiment and fold metadata from disk."""
# Load experiment metadata
self._load_experiment_metadata()
# Find all fold directories and load their metadata
fold_dirs = sorted(self.output_dir.glob("fold_*"))
for fold_dir in fold_dirs:
try:
fold_idx = int(fold_dir.name.split("_")[1])
self._load_fold_metadata(fold_idx)
except (ValueError, IndexError) as _:
logger.warning(f"Skipping invalid fold directory: {fold_dir.name}")
logger.info(
f"Loaded experiment '{self.experiment_name}' with {len(self.fold_metadata)} folds"
)
[docs] def list_folds(self) -> list[int]:
"""Get list of available fold indices."""
return sorted(self.fold_metadata.keys())
[docs] def has_final_model(self) -> bool:
"""Check if a final model exists."""
return (self.output_dir / "final_model" / "final_model.ckpt").exists()
def __repr__(self) -> str:
n_folds = len(self.fold_metadata)
return f"ModelStorage(experiment='{self.experiment_name}', folds={n_folds}, dir='{self.output_dir}')'"