import lightning as Pl
import numpy as np
import pandas as pd
import time
import wandb
import torch
from pathlib import Path
from typing import Callable, Union, Any, cast
from cellmil.datamodels.datasets.cell_mil_dataset import CellMILDataset
from sklearn.model_selection import StratifiedKFold # type: ignore
from cellmil.datamodels.datasets.cell_gnn_mil_dataset import (
CellGNNMILDataset,
SubsetCellGNNMILDataset,
)
from cellmil.datamodels.datasets.patch_gnn_mil_dataset import (
PatchGNNMILDataset,
SubsetPatchGNNMILDataset,
)
from cellmil.datamodels.datasets.patch_mil_dataset import PatchMILDataset
from cellmil.datamodels.transforms import (
Transform,
TransformPipeline,
LabelTransform,
LabelTransformPipeline,
)
from cellmil.datamodels.model import ModelStorage, FoldMetadata, ExperimentMetadata
from cellmil.utils import logger
from cellmil.utils.train.evals.utils import (
compute_slide_cell_counts,
is_survival_model,
Report,
)
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import WandbLogger
from lightning.pytorch.callbacks import EarlyStopping, ModelCheckpoint
from torch.utils.data import DataLoader as DataLoaderTorch
from torch_geometric.loader import DataLoader as DataLoaderPyG # type: ignore
[docs]class KFoldCrossValidation:
[docs] def __init__(
self,
k: int = 5,
random_state: int = 42,
):
self.k = k
self.random_state = random_state
seed_everything(self.random_state)
# Initialize stratified k-fold
self.skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=random_state)
[docs] def evaluate(
self,
name: str,
lit_model_creator: Callable[[int], Pl.LightningModule],
dataset: Union[
CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset
],
output_dir: Union[str, Path],
wandb_project: str,
transforms: Union[Transform, TransformPipeline, None] = None,
label_transforms: Union[LabelTransform, LabelTransformPipeline, None] = None,
balance_cell_counts: bool = False,
cell_balance_bins: int = 5,
**kwargs: Any,
) -> ModelStorage:
"""
Perform k-fold cross-validation with comprehensive result storage.
Args:
name: Experiment name
lit_model_creator: Function that creates a Lightning module given fold index
dataset: Dataset for cross-validation
output_dir: Directory to store all results
transforms: Optional feature transforms
label_transforms: Optional label transforms
balance_cell_counts: Whether to balance by cell counts
cell_balance_bins: Number of bins for cell count balancing
max_epochs: Maximum training epochs per fold
**kwargs: Additional arguments
Returns:
ModelStorage object with all results
"""
# Initialize model storage
storage = ModelStorage(output_dir, name)
# Get stratification targets
indices = np.arange(len(dataset))
targets = self._get_targets(
dataset=dataset,
balance_cell_counts=balance_cell_counts,
cell_balance_bins=cell_balance_bins,
)
# Store fold reports
fold_reports: list[dict[str, Any]] = []
# Login to wandb
wandb.login()
logger.info(f"Starting {self.k}-fold cross-validation for '{name}'...")
for fold_idx, (train_idx, val_idx) in enumerate(
self.skf.split(indices, targets) # type: ignore
):
logger.info(f"Training fold {fold_idx + 1}/{self.k}")
logger.info(
f"Train indices: {len(train_idx)}, Test indices: {len(val_idx)}"
)
# Create dataloaders for this fold
train_loader, val_loader, train_dataset, _ = self.get_train_val_dataloaders(
dataset=dataset,
train_indices=train_idx.tolist(),
val_indices=val_idx.tolist(),
transforms=transforms,
label_transforms=label_transforms,
)
# Create fresh model instance for this fold
# Use transformed dataset to get correct feature dimension after correlation filter
if isinstance(
dataset,
(
CellGNNMILDataset,
SubsetCellGNNMILDataset,
PatchGNNMILDataset,
SubsetPatchGNNMILDataset,
),
):
sample_data = train_dataset[0]
model = lit_model_creator(sample_data.x.shape[1]) # type: ignore
else:
sample_data = train_dataset[0]
model = lit_model_creator(sample_data[0].shape[1])
# Detect if this is a survival model
is_surv = is_survival_model(model)
# Setup trainer with checkpoint callback
if is_surv:
monitor_metric = "val/c_index"
mode = "max"
else:
monitor_metric = "val/f1"
mode = "max"
checkpoint_callback = ModelCheckpoint(
monitor=monitor_metric,
mode=mode,
save_top_k=1,
dirpath=f"./temp_checkpoints/{name}/fold_{fold_idx}",
filename="best",
)
early_stopping = EarlyStopping(
monitor=monitor_metric,
patience=cast(int, kwargs.get("early_stopping_patience", 30)),
mode=mode,
verbose=True,
)
wandb_logger = WandbLogger(
project=wandb_project,
name=f"FOLD_{fold_idx + 1}_{name}_{time.strftime('%Y-%m-%d_%H-%M-%S')}",
tags=["fold", f"fold-{fold_idx + 1}"],
)
trainer = Trainer(
max_epochs=cast(int, kwargs.get("max_epochs", 100)),
accelerator="gpu",
devices=[0],
log_every_n_steps=1,
logger=wandb_logger,
callbacks=[checkpoint_callback, early_stopping],
enable_progress_bar=False,
)
# Train the model
trainer.fit(model, train_loader, val_loader)
# Get evaluation report for this fold
report_gen = Report(trainer, model)
fold_report = report_gen.generate(val_loader)
wandb.log(fold_report)
fold_reports.append(fold_report)
# Get predictions for validation set
predictions = self._get_predictions(trainer, model, val_loader, is_surv)
best_epoch = trainer.current_epoch if trainer.current_epoch == kwargs.get("max_epochs", 100) else trainer.current_epoch - early_stopping.patience
best_metric_value = checkpoint_callback.best_model_score
if best_metric_value is None:
best_metric_value = fold_report.get(monitor_metric.replace("val_", ""), 0.0)
# Create fold metadata
fold_meta = FoldMetadata(
fold_idx=fold_idx,
train_size=len(train_idx),
val_size=len(val_idx),
best_epoch=int(best_epoch)
if isinstance(best_epoch, (int, float))
else trainer.current_epoch,
best_metric_value=float(best_metric_value),
metric_name=monitor_metric,
is_survival=is_surv,
metrics=fold_report,
)
# Get the fitted transforms from the datasets
fitted_transforms = getattr(train_dataset, "transforms", None)
fitted_label_transforms = getattr(train_dataset, "label_transforms", None)
# Save fold results
storage.save_fold_results(
fold_idx=fold_idx,
checkpoint_path=checkpoint_callback.best_model_path, # type: ignore
train_indices=train_idx.tolist(),
val_indices=val_idx.tolist(),
predictions=predictions,
metadata=fold_meta,
transforms=fitted_transforms,
label_transforms=fitted_label_transforms,
)
logger.info(
f"Fold {fold_idx + 1} completed with {monitor_metric}: {best_metric_value:.4f}"
)
wandb.finish()
# Aggregate reports across folds
logger.info("Aggregating results across folds...")
aggregated_report = self._aggregate_reports(fold_reports)
# Determine best fold
if "c_index" in aggregated_report:
metric_values = [r.get("c_index", 0) for r in fold_reports]
else:
metric_values = [r.get("macro avg", {}).get("f1-score", 0) for r in fold_reports]
best_fold_idx = int(np.argmax(metric_values))
# Calculate average best epoch
avg_best_epoch = storage.get_average_best_epoch()
# Extract dataset configuration using the dataset's get_config method
dataset_config = dataset.get_config()
# Extract model name from first fold model
model = lit_model_creator(0)
model_config: dict[str, Any] = {
"model_class": model.__class__.__name__,
}
# Save experiment metadata
exp_meta = ExperimentMetadata(
name=name,
k_folds=self.k,
random_state=self.random_state,
balance_cell_counts=balance_cell_counts,
cell_balance_bins=cell_balance_bins,
is_survival="c_index" in aggregated_report,
aggregated_metrics=aggregated_report,
best_fold_idx=best_fold_idx,
avg_best_epoch=avg_best_epoch,
dataset_config=dataset_config,
model_config=model_config,
)
storage.save_experiment_metadata(exp_meta)
# Train final model with average epochs
logger.info(f"Training final model with {avg_best_epoch:.1f} average epochs...")
final_checkpoint, final_transforms, final_label_transforms = self._train_final_model(
name=name,
lit_model_creator=lit_model_creator,
dataset=dataset,
transforms=transforms,
label_transforms=label_transforms,
target_epochs=int(np.round(avg_best_epoch)),
project_name=wandb_project,
)
# Save final model
storage.save_final_model(
checkpoint_path=final_checkpoint,
avg_epochs=avg_best_epoch,
final_metrics=aggregated_report,
transforms=final_transforms,
label_transforms=final_label_transforms,
)
logger.info("K-fold cross-validation completed successfully!")
logger.info(f"Results saved to: {storage.output_dir}")
return storage
[docs] def get_train_val_dataloaders(
self,
dataset: Union[
CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset
],
train_indices: list[int],
val_indices: list[int],
transforms: Union[Transform, TransformPipeline, None] = None,
label_transforms: Union[LabelTransform, LabelTransformPipeline, None] = None,
) -> tuple[
Union[DataLoaderTorch[Any], DataLoaderPyG],
Union[DataLoaderTorch[Any], DataLoaderPyG],
Union[
CellMILDataset,
CellGNNMILDataset,
PatchGNNMILDataset,
PatchMILDataset,
SubsetCellGNNMILDataset,
SubsetPatchGNNMILDataset,
],
Union[
CellMILDataset,
CellGNNMILDataset,
PatchGNNMILDataset,
PatchMILDataset,
SubsetCellGNNMILDataset,
SubsetPatchGNNMILDataset,
],
]:
"""
Create train and validation dataloaders with proper transforms.
Args:
dataset: Full dataset
train_indices: Indices for training set
val_indices: Indices for validation set
transforms: Optional feature transforms
label_transforms: Optional label transforms
Returns:
Tuple of (train_dataloader, val_dataloader, train_dataset, val_dataset)
"""
train_dataset, val_dataset = dataset.create_train_val_datasets(
train_indices=train_indices,
val_indices=val_indices,
transforms=transforms,
label_transforms=label_transforms,
)
if isinstance(dataset, (CellGNNMILDataset, PatchGNNMILDataset)):
train_loader = DataLoaderPyG(
train_dataset, # type: ignore
batch_size=1,
shuffle=True,
num_workers=8
)
val_loader = DataLoaderPyG(
val_dataset, # type: ignore
batch_size=1,
shuffle=False,
num_workers=8
)
else:
train_loader = DataLoaderTorch(
train_dataset, # type: ignore
batch_size=1,
shuffle=True,
num_workers=8
)
val_loader = DataLoaderTorch(
val_dataset, # type: ignore
batch_size=1,
shuffle=False,
num_workers=8
)
return train_loader, val_loader, train_dataset, val_dataset
def _get_targets(
self,
dataset: Union[
CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset
],
balance_cell_counts: bool = False,
cell_balance_bins: int = 5,
) -> np.ndarray[Any, Any]:
slides = (
np.array(dataset.slides)
if hasattr(dataset, "slides")
else np.arange(len(dataset))
)
# Handle both classification labels (single value) and survival labels (tuple)
raw_labels = [dataset.labels[slide] for slide in slides]
# Check if this is survival data (labels are tuples of (duration, event))
is_survival_data = isinstance(raw_labels[0], tuple)
if is_survival_data:
# For survival analysis, stratify by event status only
y = cast(
np.ndarray[Any, Any],
np.array([label[1] for label in raw_labels]), # type: ignore
) # Extract event indicators
logger.info(
f"Detected survival data - stratifying by event status: {np.bincount(y.astype(int))}"
)
else:
# For classification, use labels directly
y = np.array(raw_labels)
logger.info(
f"Detected classification data - stratifying by class labels: {np.bincount(y.astype(int))}"
)
cell_counts = compute_slide_cell_counts(dataset, slides)
split_targets = y
if balance_cell_counts:
try:
unique_counts = np.unique(cell_counts)
if unique_counts.size <= 1:
logger.warning(
"Cell-count balancing disabled: insufficient variability in counts."
)
else:
quantile_bins = min(cell_balance_bins, unique_counts.size)
bin_array = np.asarray(
pd.qcut( # type: ignore
cell_counts,
q=quantile_bins,
labels=False,
duplicates="drop",
),
dtype=float,
)
if np.isnan(bin_array).any():
raise ValueError("Quantile binning produced NaNs")
cell_bins = bin_array.astype(int)
# For both classification and survival, combine event/label with cell bins
combined_labels = pd.Series(
[f"{label}_{bin_idx}" for label, bin_idx in zip(y, cell_bins)]
)
combined_codes, _ = pd.factorize(combined_labels, sort=True) # type: ignore
valid_mask = combined_codes >= 0
if not valid_mask.any():
raise ValueError("Factorization produced all invalid codes")
class_counts = np.bincount(combined_codes[valid_mask])
if class_counts.size == 0 or class_counts.min() < self.k:
logger.warning(
"Cell-count balancing fallback: at least one label/bin combo has fewer samples than folds."
)
else:
split_targets = combined_codes
if is_survival_data:
logger.info(
"Enabled joint stratification on event status and %d cell-count bins",
len(np.unique(cell_bins)),
)
else:
logger.info(
"Enabled joint stratification on labels and %d cell-count bins",
len(np.unique(cell_bins)),
)
except Exception as exc:
logger.warning(
f"Failed to apply cell-count-balanced stratification: {exc}. Using event/label-only stratification."
)
return split_targets
[docs] def _get_predictions(
self,
trainer: Trainer,
model: Pl.LightningModule,
dataloader: Union[DataLoaderTorch[Any], DataLoaderPyG],
is_survival: bool,
) -> dict[str, Any]:
"""
Get predictions and true labels from the model.
Args:
trainer: Lightning trainer
model: Lightning model
dataloader: DataLoader
is_survival: Whether this is a survival model
Returns:
Dictionary with prediction data
"""
y_pred = trainer.predict(model, dataloader)
# Extract true labels
if isinstance(dataloader, DataLoaderPyG):
y_true = [data.y for data in dataloader]
else:
y_true = [batch[-1] for batch in dataloader]
if is_survival:
# For survival: extract durations and events
durations: list[float] = []
events: list[float] = []
for label in y_true:
if isinstance(label, (tuple, list)) and len(label) == 2: # type: ignore
label = cast(tuple[float, float], label)
durations.append(float(label[0]))
events.append(float(label[1]))
elif torch.is_tensor(label) and label.numel() == 2:
durations.append(float(label[0]))
events.append(float(label[1]))
else:
logger.warning(f"Unexpected label format: {label}")
durations.append(0.0)
events.append(0.0)
# Get risk scores (predictions)
if y_pred is not None and isinstance(y_pred[0], torch.Tensor):
# For survival, predictions are risk from logits
risk_scores: list[float] = []
for pred in y_pred: # type: ignore
logits = pred.cpu() if pred.is_cuda else pred # type: ignore
# Apply sigmoid to get hazards
hazards = torch.sigmoid(logits) # type: ignore
# Calculate cumulative survival probability
survival = torch.cumprod(1 - hazards, dim=0)
# Risk is negative sum of survival probabilities
risk = -float(torch.sum(survival))
risk_scores.append(risk)
y_pred_flat = risk_scores
else:
y_pred_flat = [0.0] * len(durations)
return {
"y_true_duration": np.array(durations),
"y_true_event": np.array(events),
"y_pred_risk": np.array(y_pred_flat),
}
else:
# For classification: extract predicted and true classes
if y_pred is not None and isinstance(y_pred[0], torch.Tensor):
y_pred_flat = [pred.cpu().numpy().flatten()[0] for pred in y_pred] # type: ignore
else:
y_pred_flat = [ # type: ignore
pred.flatten()[0] if hasattr(pred, "flatten") else pred # type: ignore
for pred in y_pred # type: ignore
]
if y_true and isinstance(y_true[0], torch.Tensor):
y_true_flat = [true.cpu().numpy().flatten()[0] for true in y_true]
else:
y_true_flat = [
true.flatten()[0] if hasattr(true, "flatten") else true
for true in y_true
]
return {
"y_true": np.array(y_true_flat),
"y_pred": np.array(y_pred_flat), # type: ignore
}
[docs] def _aggregate_reports(
self,
fold_reports: list[dict[str, Any]],
) -> dict[str, Any]:
"""
Aggregate reports across all folds.
Args:
fold_reports: List of fold reports
Returns:
Aggregated report dictionary
"""
if not fold_reports:
return {}
# Check if this is survival or classification
if "c_index" in fold_reports[0]:
# Survival analysis aggregation
return {
"c_index": np.mean(
[report.get("c_index", 0) for report in fold_reports]
),
"n_samples": np.sum(
[report.get("n_samples", 0) for report in fold_reports]
),
"n_events": np.sum(
[report.get("n_events", 0) for report in fold_reports]
),
"c_index_std": np.std(
[report.get("c_index", 0) for report in fold_reports]
)
}
else:
# Classification aggregation
class_labels: set[str] = set()
for report in fold_reports:
for key in report.keys():
if key not in ["accuracy", "macro avg", "weighted avg"]:
class_labels.add(key)
aggregated_report: dict[str, Any] = {}
# Aggregate per-class metrics
for label in class_labels:
aggregated_report[label] = {
"precision": np.mean(
[
report.get(label, {}).get("precision", 0)
for report in fold_reports
]
),
"recall": np.mean(
[
report.get(label, {}).get("recall", 0)
for report in fold_reports
]
),
"f1-score": np.mean(
[
report.get(label, {}).get("f1-score", 0)
for report in fold_reports
]
),
"support": np.sum(
[
report.get(label, {}).get("support", 0)
for report in fold_reports
]
),
}
# Aggregate overall metrics
aggregated_report["accuracy"] = np.mean(
[report.get("accuracy", 0) for report in fold_reports]
)
aggregated_report["macro avg"] = {
"precision": np.mean(
[
report.get("macro avg", {}).get("precision", 0)
for report in fold_reports
]
),
"recall": np.mean(
[
report.get("macro avg", {}).get("recall", 0)
for report in fold_reports
]
),
"f1-score": np.mean(
[
report.get("macro avg", {}).get("f1-score", 0)
for report in fold_reports
]
),
"support": np.sum(
[
report.get("macro avg", {}).get("support", 0)
for report in fold_reports
]
),
}
aggregated_report["weighted avg"] = {
"precision": np.mean(
[
report.get("weighted avg", {}).get("precision", 0)
for report in fold_reports
]
),
"recall": np.mean(
[
report.get("weighted avg", {}).get("recall", 0)
for report in fold_reports
]
),
"f1-score": np.mean(
[
report.get("weighted avg", {}).get("f1-score", 0)
for report in fold_reports
]
),
"support": np.sum(
[
report.get("weighted avg", {}).get("support", 0)
for report in fold_reports
]
),
}
return aggregated_report
[docs] def _train_final_model(
self,
name: str,
lit_model_creator: Callable[[int, bool], Pl.LightningModule],
dataset: Union[
CellMILDataset, CellGNNMILDataset, PatchGNNMILDataset, PatchMILDataset
],
transforms: Union[Transform, TransformPipeline, None],
label_transforms: Union[LabelTransform, LabelTransformPipeline, None],
target_epochs: int,
project_name: str,
) -> tuple[str, Union[Transform, TransformPipeline, None], Union[LabelTransform, LabelTransformPipeline, None]]:
"""
Train final model on full dataset with target number of epochs.
Args:
name: Experiment name
lit_model_creator: Function to create model
dataset: Full dataset
transforms: Feature transforms
label_transforms: Label transforms
target_epochs: Number of epochs to train
project_name: Wandb project name
Returns:
Tuple of (checkpoint_path, fitted_transforms, fitted_label_transforms)
"""
logger.info("Training final model on full dataset...")
# Use all indices for training the final model
all_indices = list(range(len(dataset)))
# Use dataset's create_train_val_datasets to properly fit transforms on full data
# Pass all indices as both train and val (we'll only use train)
if transforms is not None or label_transforms is not None:
train_dataset, _ = dataset.create_train_val_datasets(
train_indices=all_indices,
val_indices=[], # Empty validation set
transforms=transforms,
label_transforms=label_transforms,
)
else:
# No transforms, use dataset directly
train_dataset = dataset
# Create dataloader
if isinstance(dataset, (CellGNNMILDataset, PatchGNNMILDataset)):
dataloader = DataLoaderPyG(
train_dataset, # type: ignore
batch_size=1,
shuffle=True,
num_workers=8,
)
sample_data = train_dataset[0]
model = lit_model_creator(sample_data.x.shape[1], use_lr_scheduler=False) # type: ignore
else:
dataloader = DataLoaderTorch(
train_dataset, # type: ignore
batch_size=1,
shuffle=True,
num_workers=8,
)
sample_data = train_dataset[0]
model = lit_model_creator(sample_data[0].shape[1], use_lr_scheduler=False)
# Setup checkpoint callback
is_surv = is_survival_model(model)
monitor_metric = "train/c_index" if is_surv else "train/f1"
mode = "max"
checkpoint_callback = ModelCheckpoint(
monitor=monitor_metric,
mode=mode,
save_top_k=1,
dirpath=f"./temp_checkpoints/{name}/final_model",
filename="final",
)
wandb_logger = WandbLogger(
project=project_name,
name=f"FINAL_{name}_{time.strftime('%Y-%m-%d_%H-%M-%S')}",
tags=["final"],
)
trainer = Trainer(
max_epochs=target_epochs,
accelerator="gpu",
devices=[0],
log_every_n_steps=1,
logger=wandb_logger,
callbacks=[checkpoint_callback],
enable_progress_bar=False,
)
# Train without validation (full dataset)
trainer.fit(model, dataloader)
wandb.finish()
# Extract fitted transforms from the training dataset
fitted_transforms = getattr(train_dataset, "transforms", None)
fitted_label_transforms = getattr(train_dataset, "label_transforms", None)
return str(checkpoint_callback.best_model_path), fitted_transforms, fitted_label_transforms # type: ignore
# def _plot_sample_features(
# train_dataset: Union[
# CellMILDataset,
# CellGNNMILDataset,
# PatchGNNMILDataset,
# PatchMILDataset,
# SubsetCellGNNMILDataset,
# SubsetPatchGNNMILDataset,
# ],
# test_dataset: Union[
# CellMILDataset,
# CellGNNMILDataset,
# PatchGNNMILDataset,
# PatchMILDataset,
# SubsetCellGNNMILDataset,
# SubsetPatchGNNMILDataset,
# ],
# name: str,
# ) -> None:
# """
# Plot random samples from train and test datasets with their labels and feature heatmaps.
# Args:
# train_dataset: Training dataset for the current fold
# test_dataset: Test dataset for the current fold
# fold_idx: Current fold index
# name: Experiment name for saving plots
# """
# try:
# # Create figure with subplots
# fig, axes = plt.subplots(2, 2, figsize=(15, 12)) # type: ignore
# fig.suptitle(f"Feature Visualization - {name}", fontsize=16) # type: ignore
# # Get random samples from train and test
# train_idx = random.randint(0, len(train_dataset) - 1)
# test_idx = random.randint(0, len(test_dataset) - 1)
# train_sample = cast(tuple[torch.Tensor, int], train_dataset[train_idx])
# test_sample = cast(tuple[torch.Tensor, int], test_dataset[test_idx])
# # Extract features and labels based on dataset type
# if isinstance(
# train_dataset,
# (
# CellGNNMILDataset,
# PatchGNNMILDataset,
# SubsetCellGNNMILDataset,
# SubsetPatchGNNMILDataset,
# ),
# ):
# # For graph datasets
# train_features = cast(np.ndarray[Any, Any], train_sample.x.cpu().numpy()) # type: ignore
# train_label = cast(np.ndarray[Any, Any], train_sample.y.cpu().numpy()) # type: ignore
# test_features = cast(np.ndarray[Any, Any], test_sample.x.cpu().numpy()) # type: ignore
# test_label = cast(np.ndarray[Any, Any], test_sample.y.cpu().numpy()) # type: ignore
# else:
# # For regular MIL datasets
# train_features, train_label = train_sample
# test_features, test_label = test_sample
# train_features = cast(np.ndarray[Any, Any], train_features.cpu().numpy()) # type: ignore
# train_label = cast(int, train_label) # type: ignore
# test_features = cast(np.ndarray[Any, Any], test_features.cpu().numpy()) # type: ignore
# test_label = cast(int, test_label) # type: ignore
# # Ensure features are 2D for heatmap
# if hasattr(train_features, "shape") and len(train_features.shape) == 1:
# raise Exception("Train features have invalid shape.")
# if hasattr(test_features, "shape") and len(test_features.shape) == 1:
# raise Exception("Test features have invalid shape.")
# # Plot train sample
# sns.heatmap( # type: ignore
# train_features[:50], # Limit to first 50 instances for readability
# ax=axes[0, 0],
# cmap="viridis",
# cbar=True,
# xticklabels=False,
# yticklabels=False,
# )
# axes[0, 0].set_title(
# f"Train Sample {train_idx}\nLabel: {train_label}\nShape: {train_features.shape}"
# )
# # Plot feature distribution for train sample
# feature_means = np.mean(train_features, axis=0)
# axes[0, 1].hist(feature_means, bins=30, alpha=0.7, color="blue")
# axes[0, 1].set_title(
# f"Train Sample - Feature Mean Distribution\nMean: {np.mean(feature_means):.3f}"
# )
# axes[0, 1].set_xlabel("Feature Value")
# axes[0, 1].set_ylabel("Frequency")
# # Plot test sample
# sns.heatmap( # type: ignore
# test_features[:50], # Limit to first 50 instances for readability
# ax=axes[1, 0],
# cmap="viridis",
# cbar=True,
# xticklabels=False,
# yticklabels=False,
# )
# axes[1, 0].set_title(
# f"Test Sample {test_idx}\nLabel: {test_label}\nShape: {test_features.shape}"
# )
# # Plot feature distribution for test sample
# feature_means = np.mean(test_features, axis=0)
# axes[1, 1].hist(feature_means, bins=30, alpha=0.7, color="red")
# axes[1, 1].set_title(
# f"Test Sample - Feature Mean Distribution\nMean: {np.mean(feature_means):.3f}"
# )
# axes[1, 1].set_xlabel("Feature Value")
# axes[1, 1].set_ylabel("Frequency")
# plt.tight_layout()
# # Save plot
# plot_dir = Path(f"./plots/{name}")
# plot_dir.mkdir(parents=True, exist_ok=True)
# plot_path = plot_dir / f"{name}_sample_visualization.png"
# plt.savefig(plot_path, dpi=300, bbox_inches="tight") # type: ignore
# plt.close(fig)
# logger.info(f"Sample visualization saved to: {plot_path}")
# except Exception as e:
# logger.error(f"Exception occurred: {e}\n{traceback.format_exc()}")
# raise RuntimeError(f"Error during sample feature plotting: {e}")