Spaces:
Configuration error
Configuration error
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import numpy as np | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.append(str(ROOT)) | |
| from scripts.run_random_baseline import run_random_baseline | |
| from scripts.run_surrogate_baseline import run_surrogate_baseline | |
| def _average_metric_dict(records: List[Dict[str, float]]) -> Dict[str, float]: | |
| if not records: | |
| return {} | |
| keys = sorted({key for record in records for key in record.keys()}, key=lambda value: int(value)) | |
| return { | |
| key: float(np.mean(np.asarray([record[key] for record in records if key in record], dtype=np.float32))) | |
| for key in keys | |
| } | |
| def _summarize_runs(runs: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| mean_regret_records = [run["aggregate_metrics"].get("mean_regret_at", {}) for run in runs] | |
| median_regret_records = [run["aggregate_metrics"].get("median_regret_at", {}) for run in runs] | |
| auc_values = [run["aggregate_metrics"].get("mean_auc_regret") for run in runs] | |
| oracle_hit_values = [run["aggregate_metrics"].get("oracle_hit_rate_final") for run in runs] | |
| return { | |
| "mean_regret_at": _average_metric_dict(mean_regret_records), | |
| "median_regret_at": _average_metric_dict(median_regret_records), | |
| "mean_best_so_far_auc": float(np.mean(np.asarray(auc_values, dtype=np.float32))) if auc_values else None, | |
| "mean_oracle_hit_rate_final": float(np.mean(np.asarray(oracle_hit_values, dtype=np.float32))) if oracle_hit_values else None, | |
| } | |
| def _evaluate_section( | |
| section_name: str, | |
| split: Dict[str, Any], | |
| measurement_path: str, | |
| episodes: int, | |
| budget: int, | |
| seed: int, | |
| acquisition: str, | |
| beta: float, | |
| xi: float, | |
| ) -> Dict[str, Any]: | |
| train_tasks = split["train_tasks"] | |
| test_tasks = split["test_tasks"] | |
| random_runs: List[Dict[str, Any]] = [] | |
| surrogate_runs: List[Dict[str, Any]] = [] | |
| for idx, task in enumerate(test_tasks): | |
| task_seed = seed + idx * 1000 | |
| random_runs.append( | |
| run_random_baseline( | |
| task=task, | |
| episodes=episodes, | |
| budget=budget, | |
| seed=task_seed, | |
| measurement_path=measurement_path, | |
| ) | |
| ) | |
| surrogate_runs.append( | |
| run_surrogate_baseline( | |
| task=task, | |
| episodes=episodes, | |
| budget=budget, | |
| seed=task_seed, | |
| measurement_path=measurement_path, | |
| train_task_ids=train_tasks, | |
| acquisition=acquisition, | |
| beta=beta, | |
| xi=xi, | |
| ) | |
| ) | |
| return { | |
| "section": section_name, | |
| "train_tasks": train_tasks, | |
| "test_tasks": test_tasks, | |
| "random_summary": _summarize_runs(random_runs), | |
| "surrogate_summary": _summarize_runs(surrogate_runs), | |
| "task_runs": { | |
| "random": random_runs, | |
| "surrogate": surrogate_runs, | |
| }, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Evaluate random vs surrogate on shape and family holdout splits.") | |
| parser.add_argument("--measurement-path", type=str, default="data/autotune_measurements.csv") | |
| parser.add_argument("--splits", type=Path, default=Path("data/benchmark_splits.json")) | |
| parser.add_argument("--episodes", type=int, default=20) | |
| parser.add_argument("--budget", type=int, default=6) | |
| parser.add_argument("--seed", type=int, default=2) | |
| parser.add_argument("--acquisition", choices=("mean", "ucb", "ei"), default="ucb") | |
| parser.add_argument("--beta", type=float, default=2.0) | |
| parser.add_argument("--xi", type=float, default=0.0) | |
| parser.add_argument("--output", type=Path, default=Path("outputs/generalization_eval.json")) | |
| args = parser.parse_args() | |
| splits = json.loads(args.splits.read_text(encoding="utf-8")) | |
| sections = { | |
| "shape_generalization": splits["shape_generalization"], | |
| "family_holdout": splits["family_holdout"], | |
| } | |
| results = { | |
| name: _evaluate_section( | |
| section_name=name, | |
| split=section, | |
| measurement_path=args.measurement_path, | |
| episodes=args.episodes, | |
| budget=args.budget, | |
| seed=args.seed, | |
| acquisition=args.acquisition, | |
| beta=args.beta, | |
| xi=args.xi, | |
| ) | |
| for name, section in sections.items() | |
| } | |
| summary = { | |
| "measurement_path": args.measurement_path, | |
| "splits_path": str(args.splits), | |
| "episodes": args.episodes, | |
| "budget": args.budget, | |
| "acquisition": args.acquisition, | |
| "beta": args.beta, | |
| "xi": args.xi, | |
| "results": results, | |
| } | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| with args.output.open("w", encoding="utf-8") as handle: | |
| json.dump(summary, handle, indent=2) | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |