Source code for cellmil.utils.wandb

import wandb
import re
from typing import Any, cast
from collections import defaultdict
from dataclasses import dataclass
from cellmil.utils import logger

COLUMN_EXPERIMENT_ID = "EXPERIMENT_ID"
COLUMN_TASK = "TASK"
COLUMN_FEATURES = "FEATURES"
COLUMN_MODEL = "MODEL"
COLUMN_REG = "REGULARIZATION"
COLUMN_STRA = "STRATIFICATION"


[docs]@dataclass class ExperimentComponents: """Parsed components of an experiment ID.""" task: str features: str model: str regularization: str # "REG" or "*" stratification: str # "STRA" or "*" @property def has_regularization(self) -> bool: return self.regularization != "*" @property def has_stratification(self) -> bool: return self.stratification != "*"
[docs]class WandbClient: """Client for retrieving and processing wandb runs."""
[docs] def __init__(self, team: str, projects: list[str], tasks: list[str] | None = None): """Initialize the WandB client. Args: team: Team name where the projects belong in wandb. projects: List of project names to retrieve runs from. tasks: Optional list of valid tasks for filtering runs. """ self.team = team self.projects = projects self.tasks = tasks try: self.api = wandb.Api(timeout=10000) except Exception as e: raise RuntimeError(f"Failed to initialize W&B API: {e}") from e
[docs] def get_runs(self, preprocess: bool = True) -> list[Any]: """Retrieve wandb runs for configured projects and team. Args: preprocess: Whether to preprocess the runs (default: True). Returns: List of wandb runs (preprocessed if requested). """ runs: list[Any] = [] inaccessible: list[str] = [] for project in self.projects: project_path = f"{self.team}/{project}" try: project_runs = cast(Any, self.api.runs(project_path)) runs.extend(project_runs) except Exception as e: inaccessible.append(project_path) print(f"Warning: cannot access project '{project_path}': {e}") if inaccessible and not runs: raise RuntimeError( f"No accessible projects. Inaccessible: {', '.join(inaccessible)}" ) elif inaccessible: print(f"Partial access. Inaccessible projects: {', '.join(inaccessible)}") return self._preprocess_runs(runs) if preprocess else runs
def _preprocess_runs(self, runs: list[Any]) -> list[Any]: # Start with all runs new_runs = runs logger.info(f"Starting preprocessing with {len(new_runs)} total runs") # Filter out crashed or failed runs logger.info("Filtering runs with crashed or failed state") initial_count = len(new_runs) new_runs = [run for run in new_runs if run.state not in ["crashed", "failed"]] filtered_count = initial_count - len(new_runs) logger.info( f"Preprocessing runs: filtered out {filtered_count} crashed/failed runs, remaining runs {len(new_runs)}" ) logger.info("Filtering runs, which task is ADENO or any other invalid task") filtered_runs: list[Any] = [] for run in new_runs: try: exp_id = self.get_experiment_id(run) if self.get_task(exp_id): filtered_runs.append(run) except ValueError: # Skip runs that don't have a valid task continue new_runs = filtered_runs logger.info( f"Preprocessing runs: filtered out ADENO runs, remaining runs {len(new_runs)}" ) # Check that each experiment ID has exactly 6 runs (6-fold cross-validation including FINAL) logger.info( "Checking that each experiment ID has exactly 6 runs (including FINAL)" ) experiment_counts: dict[str, list[Any]] = defaultdict(list) for run in new_runs: exp_id = self.get_experiment_id(run) experiment_counts[exp_id].append(run) expected_count = 6 total_experiments = len(experiment_counts) incorrect_experiments: dict[str, int] = {} for exp_id, exp_runs in experiment_counts.items(): if len(exp_runs) != expected_count: incorrect_experiments[exp_id] = len(exp_runs) if not incorrect_experiments: logger.info( f"All {total_experiments} experiment IDs have exactly {expected_count} runs" ) else: logger.warning( f"{len(incorrect_experiments)} experiment ID(s) do not have exactly {expected_count} runs:" ) for exp_id, count in incorrect_experiments.items(): logger.warning( f" - {exp_id}: {count} runs (expected {expected_count})" ) # Remove runs with incorrect experiment IDs incorrect_experiment_ids = set(incorrect_experiments.keys()) new_runs = [ run for run in new_runs if self.get_experiment_id(run) not in incorrect_experiment_ids ] logger.info( f"Removed runs with incorrect experiment IDs. New total runs: {len(new_runs)}" ) # Check FINAL runs for DataLoader worker errors # logger.info("Checking FINAL runs for DataLoader worker errors") # experiments_with_errors: set[str] = set() # for exp_id, exp_runs in experiment_counts.items(): # # Skip if already marked as incorrect # if exp_id in incorrect_experiments: # continue # # Find the FINAL run # final_runs = [run for run in exp_runs if run.name.startswith("FINAL_")] # if not final_runs: # continue # final_run = final_runs[0] # if self._has_dataloader_error(final_run, exp_id): # experiments_with_errors.add(exp_id) # logger.warning( # f" - {exp_id}: DataLoader worker error detected in FINAL run" # ) # if experiments_with_errors: # logger.warning( # f"Found {len(experiments_with_errors)} experiment(s) with DataLoader worker errors in FINAL runs" # ) # # Remove all runs from experiments with errors # new_runs = [ # run # for run in new_runs # if self.get_experiment_id(run) not in experiments_with_errors # ] # logger.info( # f"Removed runs with DataLoader errors. New total runs: {len(new_runs)}" # ) # else: # logger.info("No DataLoader worker errors found in FINAL runs") # Now filter out FINAL_ runs after verifying the count logger.info(f"Filtering out FINAL_ runs with total runs {len(new_runs)}") initial_count = len(new_runs) new_runs = [run for run in new_runs if not run.name.startswith("FINAL_")] filtered_final = initial_count - len(new_runs) logger.info( f"Preprocessing runs: filtered out {filtered_final} FINAL_ runs, remaining runs {len(new_runs)}" ) # Summary statistics run_counts = [ len(exp_runs) for exp_id, exp_runs in experiment_counts.items() if exp_id not in incorrect_experiments ] if run_counts: logger.info( f"Total unique experiment IDs: {len(experiment_counts) - len(incorrect_experiments)}" ) logger.info(f"Total runs after filtering: {len(new_runs)}") logger.info( f"Run counts - Min: {min(run_counts)}, Max: {max(run_counts)}, Mean: {sum(run_counts) / len(run_counts):.2f}" ) return new_runs
[docs] def get_experiment_id(self, run: Any) -> str: """Get the experiment ID from a run. Handles both formats: - FOLD_N_EXPERIMENTID_YYYY-MM-DD_HH-MM-SS - FINAL_EXPERIMENTID_YYYY-MM-DD_HH-MM-SS Args: run: The wandb run object Returns: The extracted experiment ID """ name = run.name # Remove FOLD_N_ prefix if present if name.startswith("FOLD_"): # Skip "FOLD_N_" name = name.split("_", 2)[2] if len(name.split("_", 2)) >= 3 else name # Remove FINAL_ prefix if present elif name.startswith("FINAL_"): # Skip "FINAL_" name = name[6:] # Now extract experiment ID (everything before the timestamp _YYYY-MM-DD) # Match everything until we find _YYYY (4 digits starting with 20) match = re.match(r"(.+?)_\d{4}-\d{2}-\d{2}", name) if match: experiment_id = match.group(1) else: # Fallback: use everything before the last timestamp-like pattern experiment_id = name return experiment_id
[docs] @staticmethod def parse_experiment_components(experiment_id: str) -> ExperimentComponents: """Parse an experiment ID into its components. Format: TASK+FEATURES+MODEL+REG+STRA Example: DCR+ALL+ABMIL+REG+* or OS+PYRAD+HEAD4TYPE+*+STRA Args: experiment_id: The experiment ID string Returns: ExperimentComponents with parsed values Raises: ValueError: If experiment ID doesn't match expected format """ parts = experiment_id.split("+") if len(parts) != 5: raise ValueError( f"Experiment ID '{experiment_id}' does not match expected format TASK+FEATURES+MODEL+REG+STRA" ) return ExperimentComponents( task=parts[0], features=parts[1], model=parts[2], regularization=parts[3], stratification=parts[4], )
[docs] def get_task(self, experiment_id: str) -> str | None: """Get the task associated with a given experiment ID. Args: experiment_id: The ID of the experiment Returns: The name of the task, or None if not valid Raises: ValueError: If tasks are specified and the experiment ID doesn't correspond to a known task """ task = experiment_id.split("+")[0] # If no tasks are specified, infer the task from the experiment ID if self.tasks is None: return task # If tasks are specified, validate against the list if task in self.tasks: return task else: raise ValueError( f"Experiment ID '{experiment_id}' does not correspond to a known task." )
[docs] def _has_dataloader_error(self, run: Any, exp_id: str) -> bool: """Check if a run has a DataLoader worker error in its logs. Args: run: A wandb run object exp_id: The experiment ID to look for in the error message Returns: True if the error is found, False otherwise """ try: # Get the run's logs logs = run.history(keys=["_log"]) if logs.empty or "_log" not in logs.columns: return False # Check each log entry for the error pattern for log_entry in logs["_log"].dropna(): if ( "ERROR - Error in experiment" in str(log_entry) and exp_id in str(log_entry) and "DataLoader worker" in str(log_entry) and "exited unexpectedly" in str(log_entry) ): return True return False except Exception as e: logger.warning(f"Could not check logs for run {run.name}: {e}") return False
[docs] @staticmethod def get_metric(run: Any, metric: str) -> float: """Get the highest validation metric across all epochs for a given run. Args: run: A wandb run object metric: The metric name (e.g., "f1", "c_index", "balacc") Returns: The highest validation metric score Raises: ValueError: If the metric is not found or has no valid values """ history = run.history(keys=[f"val/{metric}"]) if history.empty or f"val/{metric}" not in history.columns: raise ValueError( f"Run {run.name} has no 'val/{metric}' metric in its history" ) # Drop NaN values and get the maximum metric_values = history[f"val/{metric}"].dropna() if metric_values.empty: raise ValueError( f"Run {run.name} has no valid 'val/{metric}' values in its history" ) max_metric = float(metric_values.max()) return max_metric