Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| """Diagnostic script for torch.compile deadlock after ~500 steps. | |
| F17 investigation: validates that the _compiled_core / forward split | |
| fixes the deadlock by running forward+backward loops with compile on. | |
| Usage: | |
| LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \ | |
| HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \ | |
| HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \ | |
| .venv/bin/python -u scripts/compile_debug.py [mode] | |
| Modes: | |
| eager - no compile (baseline) | |
| model_only - compile model _compiled_core only | |
| muon_only - compile muon step only | |
| both - compile both (default) | |
| """ | |
| import gc | |
| import os | |
| import signal | |
| import sys | |
| import threading | |
| import time | |
| # Set CUDA env before torch import | |
| os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") | |
| os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # ------------------------------------------------------------------------- | |
| # Config | |
| # ------------------------------------------------------------------------- | |
| MAX_STEPS = 800 | |
| WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds | |
| BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8")) | |
| SEQ_LEN = 2048 | |
| VOCAB_SIZE = 8192 | |
| # ------------------------------------------------------------------------- | |
| # Watchdog thread: kills process if no progress | |
| # ------------------------------------------------------------------------- | |
| _last_progress = time.time() | |
| _watchdog_armed = True | |
| def _watchdog_fn(): | |
| global _last_progress, _watchdog_armed | |
| while _watchdog_armed: | |
| time.sleep(1.0) | |
| elapsed = time.time() - _last_progress | |
| if elapsed > WATCHDOG_TIMEOUT_S: | |
| print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s — DEADLOCK DETECTED ***", | |
| flush=True) | |
| _dump_diagnostics() | |
| os.kill(os.getpid(), signal.SIGTERM) | |
| return | |
| def _dump_diagnostics(): | |
| """Dump CUDA/dynamo state at deadlock time.""" | |
| try: | |
| stats = torch.cuda.memory_stats() | |
| print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}") | |
| print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB") | |
| print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB") | |
| print(f" num_ooms: {stats.get('num_ooms', 0)}") | |
| except Exception as e: | |
| print(f" (memory_stats failed: {e})") | |
| try: | |
| import torch._dynamo.utils as du | |
| print(f" dynamo counters: {dict(du.counters)}") | |
| except Exception as e: | |
| print(f" (dynamo counters failed: {e})") | |
| def tick(): | |
| global _last_progress | |
| _last_progress = time.time() | |
| # ------------------------------------------------------------------------- | |
| # Test | |
| # ------------------------------------------------------------------------- | |
| def run_test(mode: str) -> dict: | |
| """Run forward+backward loop with specified compile config.""" | |
| print(f"\n{'='*70}") | |
| print(f"TEST MODE: {mode}") | |
| print(f"{'='*70}", flush=True) | |
| compile_model = mode in ("model_only", "both") | |
| compile_muon = mode in ("muon_only", "both") | |
| os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0" | |
| os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0" | |
| os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0" | |
| os.environ["HYDRA_HESTIA_INTERVAL"] = "9999" | |
| os.environ["HYDRA_HTM_LEARN_EVERY"] = "4" | |
| # Clear cached modules for fresh env var reads | |
| for mod_name in list(sys.modules.keys()): | |
| if mod_name.startswith("hydra."): | |
| del sys.modules[mod_name] | |
| torch._dynamo.reset() | |
| torch.cuda.empty_cache() | |
| torch.cuda.reset_peak_memory_stats() | |
| gc.collect() | |
| from hydra.model import PostSemClawModel | |
| from hydra.config import PostSemClawConfig | |
| device = torch.device("cuda") | |
| config = PostSemClawConfig( | |
| d_model=256, n_layer=4, d_state=64, headdim=32, expand=2, | |
| vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN, | |
| ) | |
| with torch.device("meta"): | |
| model = PostSemClawModel(config) | |
| model.to_empty(device=device) | |
| model.init_weights() | |
| optimizer = model.setup_optimizer() | |
| autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) | |
| result = {"mode": mode, "max_step": 0, "tps_samples": []} | |
| alloc_retries_prev = 0 | |
| tick() | |
| for step in range(MAX_STEPS): | |
| t0 = time.time() | |
| x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device) | |
| y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device) | |
| with autocast_ctx: | |
| loss = model(x, y) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| model.zero_grad(set_to_none=True) | |
| torch.cuda.synchronize() | |
| dt = time.time() - t0 | |
| tps = int(BATCH_SIZE * SEQ_LEN / dt) | |
| tick() | |
| stats = torch.cuda.memory_stats() | |
| retries = stats.get("num_alloc_retries", 0) | |
| retry_delta = retries - alloc_retries_prev | |
| alloc_retries_prev = retries | |
| result["max_step"] = step | |
| if step % 50 == 0 or retry_delta > 0 or step < 3: | |
| alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6 | |
| print( | |
| f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms " | |
| f"alloc={alloc_mb:.0f}MB retries={retries}", | |
| flush=True, | |
| ) | |
| result["tps_samples"].append((step, tps)) | |
| result["completed"] = True | |
| print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True) | |
| return result | |
| def main(): | |
| print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}") | |
| print(f"GPU: {torch.cuda.get_device_name()}") | |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") | |
| print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s") | |
| wd = threading.Thread(target=_watchdog_fn, daemon=True) | |
| wd.start() | |
| modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"] | |
| results = [] | |
| for mode in modes: | |
| try: | |
| r = run_test(mode) | |
| except SystemExit: | |
| print(f"\n DEADLOCK/KILLED mode={mode}", flush=True) | |
| r = {"mode": mode, "completed": False, "max_step": "?"} | |
| except Exception as e: | |
| print(f"\n ERROR mode={mode}: {e}", flush=True) | |
| r = {"mode": mode, "completed": False, "error": str(e)} | |
| results.append(r) | |
| print(f"\n{'='*70}") | |
| print("SUMMARY") | |
| print(f"{'='*70}") | |
| for r in results: | |
| status = "PASS" if r.get("completed") else "FAIL" | |
| print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})") | |
| global _watchdog_armed | |
| _watchdog_armed = False | |
| if __name__ == "__main__": | |
| main() | |