Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| """Hyena stack benchmark — measure TPS under the four knob combinations. | |
| Produces the table requested in Task 4: | |
| | Config | TPS | BPB@500 | VRAM | | |
| |----------------------------|------|---------|------| | |
| | B=8, no flash, no cache | ... | ... | ... | <-- baseline | |
| | B=16, no flash, no cache | ... | |
| | B=16, no flash, cache on | ... | |
| | B=16, flash on, cache on | ... | ... | ... | <-- best | |
| Run ONE config by invoking with command-line args, then collate externally. | |
| Each invocation runs train.py for the specified wall-clock time with the | |
| given env overrides, tails run.log, and emits a single summary line. | |
| Invocation: | |
| cd /home/mikeb/work/feather | |
| # On the RTX 3060 (local validation only — these numbers will NOT hit | |
| # the 200k tps production floor): | |
| .venv/bin/python scripts/benchmark_hyena_stack.py --config baseline --time 300 | |
| .venv/bin/python scripts/benchmark_hyena_stack.py --config b16 --time 300 | |
| .venv/bin/python scripts/benchmark_hyena_stack.py --config cache --time 300 | |
| # "kernel" config requires flashfftconv built — see kernels/cuda/flashfftconv/README.md | |
| .venv/bin/python scripts/benchmark_hyena_stack.py --config kernel --time 300 | |
| # On A100/A10G (production cloud hardware), use time=900 (15 min) for | |
| # stable steady-state numbers. | |
| After each run the script prints: | |
| BENCHMARK config=<name> tps_steady=<avg> bpb_at_500=<val> vram_peak=<MiB> | |
| Collate those lines into the matrix table manually, then pick the winner | |
| for the 6-hour production run (HYDRA_TIME_BUDGET=21600). | |
| """ | |
| import argparse | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| REPO = Path(__file__).resolve().parents[1] | |
| CONFIGS = { | |
| # Baseline: B=8, no flash, no train-cache. Current reference point. | |
| "baseline": { | |
| "HYDRA_BATCH_SIZE": "8", | |
| "HYDRA_HYENA_LAYERS": "3,7", | |
| "HYDRA_HYENA_FLASH_FFT": "0", | |
| "HYDRA_HYENA_TRAIN_CACHE": "0", | |
| "HYDRA_HYENA_FILTER_CACHE": "0", | |
| }, | |
| "b16": { | |
| "HYDRA_BATCH_SIZE": "16", | |
| "HYDRA_HYENA_LAYERS": "3,7", | |
| "HYDRA_HYENA_FLASH_FFT": "0", | |
| "HYDRA_HYENA_TRAIN_CACHE": "0", | |
| "HYDRA_HYENA_FILTER_CACHE": "0", | |
| }, | |
| "cache": { | |
| "HYDRA_BATCH_SIZE": "16", | |
| "HYDRA_HYENA_LAYERS": "3,7", | |
| "HYDRA_HYENA_FLASH_FFT": "0", | |
| "HYDRA_HYENA_TRAIN_CACHE": "1", | |
| "HYDRA_HYENA_FILTER_CACHE": "1", | |
| }, | |
| "kernel": { | |
| "HYDRA_BATCH_SIZE": "16", | |
| "HYDRA_HYENA_LAYERS": "3,7", | |
| "HYDRA_HYENA_FLASH_FFT": "1", | |
| "HYDRA_HYENA_TRAIN_CACHE": "1", | |
| "HYDRA_HYENA_FILTER_CACHE": "1", | |
| # Task 4 note: also bump HYDRA_HTM_SUBSAMPLE to 128 (from 64) in the | |
| # best config to get more aggressive reclamation. | |
| "HYDRA_HTM_SUBSAMPLE": "128", | |
| }, | |
| } | |
| def build_env(cfg_overrides: dict) -> dict: | |
| """Compose a full env dict from the inherited env + config overrides.""" | |
| env = os.environ.copy() | |
| # Ensure the Hyena layer selection is always present (defaults to off). | |
| env.setdefault("HYDRA_HYENA_LAYERS", "") | |
| for k, v in cfg_overrides.items(): | |
| env[k] = v | |
| return env | |
| def parse_step_line(line: str) -> dict | None: | |
| """Parse a single step=... line into a dict of metrics, or None.""" | |
| if not line.startswith("step="): | |
| return None | |
| parts = re.findall(r"(\w+)=([0-9.eE+\-]+)", line) | |
| try: | |
| return {k: float(v) for k, v in parts} | |
| except ValueError: | |
| return None | |
| def summarize(log_path: Path, warmup_steps: int = 50) -> dict: | |
| """Tail log_path, compute steady-state TPS / BPB@500 / VRAM peak. | |
| Skips the first `warmup_steps` to discard CUDA graph capture / autotune | |
| spikes; takes the median of the rest. | |
| """ | |
| tps_vals = [] | |
| bpbs = [] | |
| vram_peak = 0.0 | |
| bpb_at_500 = None | |
| with log_path.open() as f: | |
| for line in f: | |
| d = parse_step_line(line.strip()) | |
| if d is None: | |
| continue | |
| step = int(d.get("step", -1)) | |
| if step < warmup_steps: | |
| continue | |
| tps = d.get("tps") | |
| if tps is not None: | |
| tps_vals.append(tps) | |
| bpb = d.get("bpb") | |
| if bpb is not None: | |
| bpbs.append(bpb) | |
| if step == 500 and bpb_at_500 is None: | |
| bpb_at_500 = bpb | |
| vram = d.get("vram") | |
| if vram is not None and vram > vram_peak: | |
| vram_peak = vram | |
| if not tps_vals: | |
| return {"tps_steady": 0.0, "bpb_at_500": 0.0, "vram_peak": 0.0, "steps": 0} | |
| tps_sorted = sorted(tps_vals) | |
| tps_steady = tps_sorted[len(tps_sorted) // 2] # median | |
| return { | |
| "tps_steady": tps_steady, | |
| "bpb_at_500": bpb_at_500 or (bpbs[-1] if bpbs else 0.0), | |
| "vram_peak": vram_peak, | |
| "steps": len(tps_vals) + warmup_steps, | |
| } | |
| def main() -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--config", required=True, choices=list(CONFIGS)) | |
| ap.add_argument("--time", type=int, default=300, help="training seconds") | |
| ap.add_argument("--log", default=None, help="output log path (default: run_bench_<cfg>.log)") | |
| args = ap.parse_args() | |
| cfg = CONFIGS[args.config] | |
| log_path = Path(args.log or (REPO / f"run_bench_{args.config}.log")) | |
| env = build_env(cfg) | |
| env["HYDRA_TIME_BUDGET"] = str(args.time) | |
| # Make the config visible up-front so failed runs are debuggable. | |
| print(f"BENCH start config={args.config} time={args.time}s log={log_path}", flush=True) | |
| print(f" overrides: {cfg}", flush=True) | |
| with log_path.open("w") as logf: | |
| proc = subprocess.Popen( | |
| ["python", "-u", str(REPO / "train.py")], | |
| env=env, | |
| cwd=str(REPO), | |
| stdout=logf, | |
| stderr=subprocess.STDOUT, | |
| ) | |
| proc.wait() | |
| print(f"BENCH wait_done exit={proc.returncode}", flush=True) | |
| if proc.returncode != 0: | |
| print(f"BENCH FAIL config={args.config}", flush=True) | |
| return proc.returncode | |
| summary = summarize(log_path) | |
| print( | |
| f"BENCHMARK config={args.config} " | |
| f"tps_steady={summary['tps_steady']:.0f} " | |
| f"bpb_at_500={summary['bpb_at_500']:.4f} " | |
| f"vram_peak={summary['vram_peak']:.0f}MiB " | |
| f"steps={summary['steps']}", | |
| flush=True, | |
| ) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |