|
|
| """
|
| Benchmark Circuit transformer family against standard LM tasks.
|
|
|
| Usage:
|
| # Single model
|
| python -m circuits.bench --checkpoint circuits/checkpoints/slot_local_mirrored/best.pt --gpu 0
|
|
|
| # Compare all architectures
|
| python -m circuits.bench --compare --gpu 0
|
|
|
| # Quick sanity check (100 samples per task)
|
| python -m circuits.bench --compare --gpu 0 --limit 100
|
|
|
| # Specific tasks
|
| python -m circuits.bench --checkpoint path/to/best.pt --tasks hellaswag,lambada_openai
|
| """
|
|
|
| import argparse
|
| import json
|
| import time
|
| import torch
|
| from pathlib import Path
|
|
|
| import lm_eval
|
| from lm_eval.api.registry import register_model
|
|
|
| from .lm_eval_wrapper import CircuitLM
|
|
|
|
|
| register_model("circuit")(CircuitLM)
|
|
|
| DEFAULT_TASKS = "arc_challenge,arc_easy,boolq,hellaswag,lambada_openai,piqa,wikitext,winogrande"
|
|
|
|
|
| CHECKPOINTS = {
|
| "standard_12L": "circuits/checkpoints/flat/best.pt",
|
| "mirrored_9L_wide": "circuits/checkpoints/hier_wide_2/best.pt",
|
| "mirrored_15L_deep": "circuits/checkpoints/hier_resized/best.pt",
|
| "slot_local_mirrored": "circuits/checkpoints/slot_local_mirrored/best.pt",
|
| }
|
|
|
|
|
| def run_benchmark(checkpoint: str, tasks: str, device: str, limit: int = None, batch_size: int = 1, compile: bool = False):
|
| """Run lm-eval on a single checkpoint."""
|
| model_args = f"checkpoint={checkpoint},device={device},batch_size={batch_size},compile={'true' if compile else 'false'}"
|
|
|
| task_list = tasks.split(",")
|
|
|
| results = lm_eval.simple_evaluate(
|
| model="circuit",
|
| model_args=model_args,
|
| tasks=task_list,
|
| limit=limit,
|
| )
|
|
|
| return results
|
|
|
|
|
| def extract_scores(results: dict) -> dict:
|
| """Pull headline metrics from lm-eval results."""
|
| scores = {}
|
| if "results" not in results:
|
| return scores
|
| for task_name, task_results in results["results"].items():
|
|
|
| if "acc_norm,none" in task_results:
|
| scores[task_name] = task_results["acc_norm,none"]
|
| elif "acc,none" in task_results:
|
| scores[task_name] = task_results["acc,none"]
|
| elif "perplexity,none" in task_results:
|
| scores[task_name] = task_results["perplexity,none"]
|
| elif "word_perplexity,none" in task_results:
|
| scores[task_name] = task_results["word_perplexity,none"]
|
| return scores
|
|
|
|
|
| def print_comparison(all_results: dict, tasks: list):
|
| """Pretty-print comparison table."""
|
|
|
| col_width = max(len(t) for t in tasks) + 2
|
| name_width = max(len(n) for n in all_results) + 2
|
|
|
| header = f"{'Model':<{name_width}}"
|
| for task in tasks:
|
| header += f"{task:>{col_width}}"
|
| header += f"{' avg':>8}"
|
| print("\n" + "=" * len(header))
|
| print(header)
|
| print("-" * len(header))
|
|
|
| for name, scores in all_results.items():
|
| row = f"{name:<{name_width}}"
|
| vals = []
|
| for task in tasks:
|
| val = scores.get(task, None)
|
| if val is not None:
|
| row += f"{val:>{col_width}.4f}"
|
| vals.append(val)
|
| else:
|
| row += f"{'N/A':>{col_width}}"
|
| avg = sum(vals) / len(vals) if vals else 0
|
| row += f"{avg:>8.4f}"
|
| print(row)
|
|
|
| print("=" * len(header))
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Benchmark Circuit transformers")
|
| parser.add_argument("--checkpoint", type=str, help="Path to single checkpoint")
|
| parser.add_argument("--compare", action="store_true", help="Compare all known architectures")
|
| parser.add_argument("--tasks", type=str, default=DEFAULT_TASKS, help="Comma-separated task list")
|
| parser.add_argument("--gpu", type=int, default=0, help="GPU index")
|
| parser.add_argument("--limit", type=int, default=None, help="Limit samples per task (for quick testing)")
|
| parser.add_argument("--batch-size", type=int, default=1, help="Batch size")
|
| parser.add_argument("--output", type=str, default=None, help="Save results to JSON")
|
| parser.add_argument("--compile", action="store_true", help="torch.compile models for faster inference")
|
| args = parser.parse_args()
|
|
|
| device = f"cuda:{args.gpu}"
|
| task_list = args.tasks.split(",")
|
|
|
| if args.compare:
|
| all_scores = {}
|
| all_raw = {}
|
|
|
|
|
| available = {k: v for k, v in CHECKPOINTS.items() if Path(v).exists()}
|
| missing = {k: v for k, v in CHECKPOINTS.items() if not Path(v).exists()}
|
| if missing:
|
| print(f"Skipping (not found): {', '.join(missing.keys())}")
|
|
|
| for name, ckpt_path in available.items():
|
| print(f"\n{'='*60}")
|
| print(f"Evaluating: {name}")
|
| print(f"Checkpoint: {ckpt_path}")
|
| print(f"{'='*60}")
|
|
|
| t0 = time.time()
|
| results = run_benchmark(ckpt_path, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| elapsed = time.time() - t0
|
|
|
| scores = extract_scores(results)
|
| all_scores[name] = scores
|
| all_raw[name] = results.get("results", {})
|
| print(f" Completed in {elapsed:.0f}s: {scores}")
|
|
|
| print_comparison(all_scores, task_list)
|
|
|
| if args.output:
|
| with open(args.output, "w") as f:
|
| json.dump({"scores": all_scores, "raw": all_raw}, f, indent=2, default=str)
|
| print(f"\nResults saved to {args.output}")
|
|
|
| elif args.checkpoint:
|
| print(f"Evaluating: {args.checkpoint}")
|
| t0 = time.time()
|
| results = run_benchmark(args.checkpoint, args.tasks, device, args.limit, args.batch_size, args.compile)
|
| elapsed = time.time() - t0
|
|
|
| scores = extract_scores(results)
|
| print(f"\nResults ({elapsed:.0f}s):")
|
| for task, score in scores.items():
|
| print(f" {task}: {score:.4f}")
|
|
|
| if args.output:
|
| with open(args.output, "w") as f:
|
| json.dump(results, f, indent=2, default=str)
|
| print(f"\nResults saved to {args.output}")
|
| else:
|
| parser.print_help()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|