omkarrr88 commited on
Commit ·
e2f8b29
0
Parent(s):
Version 1
Browse files- .claude/plan/pytorch-debugger-mvp.md +1647 -0
- .coverage +0 -0
- .dockerignore +13 -0
- .gitignore +14 -0
- .python-version +1 -0
- CLAUDE.md +186 -0
- Dockerfile +24 -0
- PRD.md +367 -0
- README.md +149 -0
- ROADMAP.md +441 -0
- baseline_heuristic.py +186 -0
- deploy.sh +52 -0
- ml-training-debugger-spec.md +0 -0
- ml_training_debugger/__init__.py +3 -0
- ml_training_debugger/client.py +21 -0
- ml_training_debugger/code_templates.py +248 -0
- ml_training_debugger/graders.py +207 -0
- ml_training_debugger/models.py +195 -0
- ml_training_debugger/pytorch_engine.py +240 -0
- ml_training_debugger/reward_engine.py +104 -0
- ml_training_debugger/scenarios.py +155 -0
- ml_training_debugger/simulation.py +225 -0
- openenv.yaml +58 -0
- pyproject.toml +41 -0
- requirements.txt +6 -0
- server/__init__.py +0 -0
- server/_baseline_results.py +27 -0
- server/app.py +287 -0
- server/environment.py +516 -0
- tests/__init__.py +0 -0
- tests/conftest.py +36 -0
- tests/test_code_templates.py +65 -0
- tests/test_episode_lifecycle.py +220 -0
- tests/test_graders.py +168 -0
- tests/test_models.py +168 -0
- tests/test_pytorch_engine.py +93 -0
- tests/test_reward_engine.py +176 -0
- tests/test_scenarios.py +51 -0
- tests/test_simulation.py +72 -0
- tests/test_simulation_extended.py +81 -0
.claude/plan/pytorch-debugger-mvp.md
ADDED
|
@@ -0,0 +1,1647 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Implementation Plan: PyTorch Training Run Debugger — OpenEnv Environment
|
| 2 |
+
|
| 3 |
+
**Generated:** 2026-03-28
|
| 4 |
+
**King File:** `ml-training-debugger-spec.md` — single source of truth for all conflicts
|
| 5 |
+
**Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core (installed in .venv)
|
| 6 |
+
**MVP Scope:** Tasks 1, 3, 5 + rule-based baseline + all required endpoints + Docker + HF Spaces
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Markdown Files Confirmed Read
|
| 11 |
+
|
| 12 |
+
| File | Lines | Role |
|
| 13 |
+
|------|-------|------|
|
| 14 |
+
| `ml-training-debugger-spec.md` | 1549 | **KING FILE** — final authority on all design decisions |
|
| 15 |
+
| `CLAUDE.md` | ~280 | Coding standards, non-negotiable rules, reward constants |
|
| 16 |
+
| `PRD.md` | ~368 | Product requirements, success metrics, timeline |
|
| 17 |
+
| `ROADMAP.md` | ~442 | Phased roadmap with acceptance criteria |
|
| 18 |
+
|
| 19 |
+
All four files read in full. The spec is the definitive authority.
|
| 20 |
+
|
| 21 |
+
---
|
| 22 |
+
|
| 23 |
+
## Complete Project Structure (Final State)
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
ML Debugger/ # Project root
|
| 27 |
+
├── .claude/
|
| 28 |
+
│ └── plan/
|
| 29 |
+
│ └── pytorch-debugger-mvp.md # This plan
|
| 30 |
+
├── .dockerignore
|
| 31 |
+
├── .gitignore
|
| 32 |
+
├── .python-version # "3.12"
|
| 33 |
+
├── CLAUDE.md # Already exists
|
| 34 |
+
├── Dockerfile
|
| 35 |
+
├── PRD.md # Already exists
|
| 36 |
+
├── README.md
|
| 37 |
+
├── ROADMAP.md # Already exists
|
| 38 |
+
├── baseline_heuristic.py # Rule-based baseline (no API key)
|
| 39 |
+
├── baseline_inference.py # LLM baseline (optional, requires OPENAI_API_KEY)
|
| 40 |
+
├── deploy.sh # One-command build+test+validate script
|
| 41 |
+
├── ml-training-debugger-spec.md # Already exists (king file)
|
| 42 |
+
├── openenv.yaml
|
| 43 |
+
├── pyproject.toml
|
| 44 |
+
├── requirements.txt
|
| 45 |
+
│
|
| 46 |
+
├── ml_training_debugger/
|
| 47 |
+
│ ├── __init__.py
|
| 48 |
+
│ ├── models.py # All Pydantic models + RootCauseDiagnosis enum
|
| 49 |
+
│ ├── client.py # EnvClient extension with typed action/observation
|
| 50 |
+
│ ├── scenarios.py # ScenarioParams + sample_scenario()
|
| 51 |
+
│ ├── pytorch_engine.py # SimpleCNN, fault injection, gradient/weight extraction
|
| 52 |
+
│ ├── simulation.py # Parametric curve generation (torch.Tensor ops)
|
| 53 |
+
│ ├── code_templates.py # Task 6: code snippets with bugs + validate_fix()
|
| 54 |
+
│ ├── reward_engine.py # compute_reward() — all 7 components
|
| 55 |
+
│ └── graders.py # Per-task grader functions (0.0–1.0)
|
| 56 |
+
│
|
| 57 |
+
├── server/
|
| 58 |
+
│ ├── __init__.py
|
| 59 |
+
│ ├── environment.py # MLTrainingEnvironment(Environment)
|
| 60 |
+
│ ├── app.py # create_app() + custom routes
|
| 61 |
+
│ └── dashboard.html # Live diagnostic dashboard (Phase 3)
|
| 62 |
+
│
|
| 63 |
+
├── validation/ # PyTorch validation suite (Phase 3)
|
| 64 |
+
│ ├── requirements.txt
|
| 65 |
+
│ ├── conftest.py
|
| 66 |
+
│ ├── validate_exploding_gradients.py
|
| 67 |
+
│ ├── validate_vanishing_gradients.py
|
| 68 |
+
│ ├── validate_data_leakage.py
|
| 69 |
+
│ ├── validate_overfitting.py
|
| 70 |
+
│ ├── validate_batchnorm_eval.py
|
| 71 |
+
│ ├── validate_code_bugs.py
|
| 72 |
+
│ └── reports/ # Pre-computed fidelity plots
|
| 73 |
+
│
|
| 74 |
+
└── tests/
|
| 75 |
+
├── __init__.py
|
| 76 |
+
├── conftest.py # Shared fixtures
|
| 77 |
+
├── test_models.py
|
| 78 |
+
├── test_scenarios.py
|
| 79 |
+
├── test_pytorch_engine.py
|
| 80 |
+
├── test_simulation.py
|
| 81 |
+
├── test_code_templates.py
|
| 82 |
+
├── test_reward_engine.py
|
| 83 |
+
├── test_graders.py
|
| 84 |
+
├── test_episode_lifecycle.py
|
| 85 |
+
├── test_endpoints.py
|
| 86 |
+
└── test_baseline_reproducibility.py
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
---
|
| 90 |
+
|
| 91 |
+
## Phase 0: Project Initialization & Validation Setup
|
| 92 |
+
|
| 93 |
+
### Goal
|
| 94 |
+
A running skeleton server that proves the toolchain works end-to-end. Zero business logic — just plumbing.
|
| 95 |
+
|
| 96 |
+
### Files to Create
|
| 97 |
+
|
| 98 |
+
**Step 0.1 — Project config files:**
|
| 99 |
+
|
| 100 |
+
1. **`.python-version`** — content: `3.12`
|
| 101 |
+
|
| 102 |
+
2. **`.gitignore`**:
|
| 103 |
+
```
|
| 104 |
+
.venv/
|
| 105 |
+
__pycache__/
|
| 106 |
+
*.pyc
|
| 107 |
+
*.pyo
|
| 108 |
+
.env
|
| 109 |
+
run*.json
|
| 110 |
+
.pytest_cache/
|
| 111 |
+
htmlcov/
|
| 112 |
+
*.egg-info/
|
| 113 |
+
dist/
|
| 114 |
+
build/
|
| 115 |
+
validation/reports/*.png
|
| 116 |
+
.mypy_cache/
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
3. **`.dockerignore`**:
|
| 120 |
+
```
|
| 121 |
+
.venv/
|
| 122 |
+
__pycache__/
|
| 123 |
+
.git/
|
| 124 |
+
.pytest_cache/
|
| 125 |
+
tests/
|
| 126 |
+
validation/
|
| 127 |
+
*.md
|
| 128 |
+
!README.md
|
| 129 |
+
.claude/
|
| 130 |
+
run*.json
|
| 131 |
+
htmlcov/
|
| 132 |
+
```
|
| 133 |
+
|
| 134 |
+
4. **`pyproject.toml`**:
|
| 135 |
+
```toml
|
| 136 |
+
[project]
|
| 137 |
+
name = "pytorch-training-debugger"
|
| 138 |
+
version = "1.0.0"
|
| 139 |
+
description = "OpenEnv RL environment for PyTorch training failure debugging"
|
| 140 |
+
requires-python = ">=3.12"
|
| 141 |
+
dependencies = [
|
| 142 |
+
"torch",
|
| 143 |
+
"openenv-core",
|
| 144 |
+
"pydantic>=2.0",
|
| 145 |
+
"fastapi",
|
| 146 |
+
"uvicorn",
|
| 147 |
+
]
|
| 148 |
+
|
| 149 |
+
[project.optional-dependencies]
|
| 150 |
+
dev = [
|
| 151 |
+
"pytest",
|
| 152 |
+
"pytest-cov",
|
| 153 |
+
"pytest-asyncio",
|
| 154 |
+
"black",
|
| 155 |
+
"ruff",
|
| 156 |
+
"isort",
|
| 157 |
+
"httpx",
|
| 158 |
+
"websockets",
|
| 159 |
+
]
|
| 160 |
+
llm = [
|
| 161 |
+
"openai",
|
| 162 |
+
]
|
| 163 |
+
|
| 164 |
+
[tool.black]
|
| 165 |
+
line-length = 88
|
| 166 |
+
|
| 167 |
+
[tool.isort]
|
| 168 |
+
profile = "black"
|
| 169 |
+
|
| 170 |
+
[tool.ruff]
|
| 171 |
+
line-length = 88
|
| 172 |
+
target-version = "py312"
|
| 173 |
+
|
| 174 |
+
[tool.pytest.ini_options]
|
| 175 |
+
testpaths = ["tests"]
|
| 176 |
+
asyncio_mode = "auto"
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
5. **`requirements.txt`** (for Docker — flat list, no dev deps):
|
| 180 |
+
```
|
| 181 |
+
torch
|
| 182 |
+
openenv-core
|
| 183 |
+
pydantic>=2.0
|
| 184 |
+
fastapi
|
| 185 |
+
uvicorn
|
| 186 |
+
openai
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
**Step 0.2 — Package stubs:**
|
| 190 |
+
|
| 191 |
+
6. **`ml_training_debugger/__init__.py`**:
|
| 192 |
+
```python
|
| 193 |
+
"""PyTorch Training Run Debugger — OpenEnv Environment."""
|
| 194 |
+
|
| 195 |
+
__version__ = "1.0.0"
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
7. **`ml_training_debugger/models.py`** — STUB with all Pydantic models:
|
| 199 |
+
```python
|
| 200 |
+
"""All Pydantic models, enums, and typed data structures.
|
| 201 |
+
|
| 202 |
+
No business logic. Pure data definitions.
|
| 203 |
+
"""
|
| 204 |
+
|
| 205 |
+
from __future__ import annotations
|
| 206 |
+
|
| 207 |
+
import enum
|
| 208 |
+
from typing import Literal, Optional
|
| 209 |
+
|
| 210 |
+
import torch
|
| 211 |
+
from openenv.core.env_server.types import Action, Observation
|
| 212 |
+
from pydantic import BaseModel, Field
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class RootCauseDiagnosis(str, enum.Enum):
|
| 216 |
+
"""Closed enumeration of ML failure root causes."""
|
| 217 |
+
LR_TOO_HIGH = "lr_too_high"
|
| 218 |
+
VANISHING_GRADIENTS = "vanishing_gradients"
|
| 219 |
+
DATA_LEAKAGE = "data_leakage"
|
| 220 |
+
OVERFITTING = "overfitting"
|
| 221 |
+
BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
|
| 222 |
+
CODE_BUG = "code_bug"
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class TrainingConfig(BaseModel):
|
| 226 |
+
"""Typed hyperparameter configuration."""
|
| 227 |
+
learning_rate: float = 0.001
|
| 228 |
+
weight_decay: float = 0.0001
|
| 229 |
+
batch_size: int = 64
|
| 230 |
+
hidden_dim: int = 64
|
| 231 |
+
num_layers: int = 3
|
| 232 |
+
optimizer: str = "adam"
|
| 233 |
+
dropout_rate: float = 0.0
|
| 234 |
+
gradient_clip_norm: Optional[float] = None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
class GradientStats(BaseModel):
|
| 238 |
+
"""Per-layer gradient information from real torch.autograd."""
|
| 239 |
+
layer_name: str
|
| 240 |
+
norm_history: list[float]
|
| 241 |
+
mean_norm: float
|
| 242 |
+
max_norm: float
|
| 243 |
+
is_exploding: bool
|
| 244 |
+
is_vanishing: bool
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class ModelWeightStats(BaseModel):
|
| 248 |
+
"""Per-layer weight statistics from real state_dict()."""
|
| 249 |
+
layer_name: str
|
| 250 |
+
weight_norm: float
|
| 251 |
+
weight_mean: float
|
| 252 |
+
weight_std: float
|
| 253 |
+
weight_min: float
|
| 254 |
+
weight_max: float
|
| 255 |
+
dead_neuron_pct: float = 0.0
|
| 256 |
+
has_nan: bool = False
|
| 257 |
+
has_inf: bool = False
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
class DataBatchStats(BaseModel):
|
| 261 |
+
"""Data batch inspection results."""
|
| 262 |
+
label_distribution: dict[int, float]
|
| 263 |
+
feature_mean: float
|
| 264 |
+
feature_std: float
|
| 265 |
+
null_count: int = 0
|
| 266 |
+
class_overlap_score: float
|
| 267 |
+
batch_size: int
|
| 268 |
+
duplicate_ratio: float = 0.0
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class CodeSnippet(BaseModel):
|
| 272 |
+
"""PyTorch code for Task 6 inspection."""
|
| 273 |
+
code: str
|
| 274 |
+
filename: str = "train.py"
|
| 275 |
+
line_count: int
|
| 276 |
+
imports: list[str]
|
| 277 |
+
hint: Optional[str] = None
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class EpisodeState(BaseModel):
|
| 281 |
+
"""Tracks agent history within an episode."""
|
| 282 |
+
step_count: int = 0
|
| 283 |
+
gradients_inspected: bool = False
|
| 284 |
+
gradients_were_normal: bool = False
|
| 285 |
+
data_inspected: bool = False
|
| 286 |
+
model_modes_inspected: bool = False
|
| 287 |
+
model_weights_inspected: bool = False
|
| 288 |
+
code_inspected: bool = False
|
| 289 |
+
fix_action_taken: bool = False
|
| 290 |
+
restart_after_fix: bool = False
|
| 291 |
+
diagnosis_submitted: bool = False
|
| 292 |
+
actions_taken: list[str] = Field(default_factory=list)
|
| 293 |
+
|
| 294 |
+
def compute_available_actions(self) -> list[str]:
|
| 295 |
+
"""Dynamically compute available actions based on current state."""
|
| 296 |
+
actions = [
|
| 297 |
+
"inspect_gradients",
|
| 298 |
+
"inspect_data_batch",
|
| 299 |
+
"inspect_model_modes",
|
| 300 |
+
"inspect_model_weights",
|
| 301 |
+
"inspect_code",
|
| 302 |
+
"modify_config",
|
| 303 |
+
"add_callback",
|
| 304 |
+
"replace_optimizer",
|
| 305 |
+
"patch_data_loader",
|
| 306 |
+
"fix_model_mode",
|
| 307 |
+
]
|
| 308 |
+
if self.code_inspected:
|
| 309 |
+
actions.append("fix_code")
|
| 310 |
+
if self.fix_action_taken:
|
| 311 |
+
actions.append("restart_run")
|
| 312 |
+
if self.restart_after_fix:
|
| 313 |
+
actions.append("rollback_checkpoint")
|
| 314 |
+
if not self.diagnosis_submitted:
|
| 315 |
+
actions.append("mark_diagnosed")
|
| 316 |
+
return actions
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
ACTION_TYPES = Literal[
|
| 320 |
+
"inspect_gradients",
|
| 321 |
+
"inspect_data_batch",
|
| 322 |
+
"inspect_model_modes",
|
| 323 |
+
"inspect_model_weights",
|
| 324 |
+
"inspect_code",
|
| 325 |
+
"modify_config",
|
| 326 |
+
"add_callback",
|
| 327 |
+
"replace_optimizer",
|
| 328 |
+
"patch_data_loader",
|
| 329 |
+
"fix_model_mode",
|
| 330 |
+
"fix_code",
|
| 331 |
+
"restart_run",
|
| 332 |
+
"mark_diagnosed",
|
| 333 |
+
"rollback_checkpoint",
|
| 334 |
+
]
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
class MLTrainingAction(Action):
|
| 338 |
+
"""What the agent can do — extends openenv Action."""
|
| 339 |
+
action_type: str
|
| 340 |
+
target: Optional[str] = None
|
| 341 |
+
value: Optional[float | int | str] = None
|
| 342 |
+
diagnosis: Optional[str] = None
|
| 343 |
+
line: Optional[int] = None
|
| 344 |
+
replacement: Optional[str] = None
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
class MLTrainingObservation(Observation):
|
| 348 |
+
"""Full observation — extends openenv Observation (has done, reward, metadata)."""
|
| 349 |
+
run_id: str = ""
|
| 350 |
+
framework: str = "pytorch"
|
| 351 |
+
epoch: int = 20
|
| 352 |
+
training_loss_history: list[float] = Field(default_factory=list)
|
| 353 |
+
val_loss_history: list[float] = Field(default_factory=list)
|
| 354 |
+
val_accuracy_history: list[float] = Field(default_factory=list)
|
| 355 |
+
gradient_stats: list[GradientStats] = Field(default_factory=list)
|
| 356 |
+
model_weight_stats: Optional[list[ModelWeightStats]] = None
|
| 357 |
+
gpu_memory_used_gb: float = 6.2
|
| 358 |
+
gpu_memory_total_gb: float = 16.0
|
| 359 |
+
learning_rate: float = 0.001
|
| 360 |
+
current_config: TrainingConfig = Field(default_factory=TrainingConfig)
|
| 361 |
+
error_log: Optional[str] = None
|
| 362 |
+
data_batch_stats: Optional[DataBatchStats] = None
|
| 363 |
+
model_mode_info: Optional[dict[str, str]] = None
|
| 364 |
+
code_snippet: Optional[CodeSnippet] = None
|
| 365 |
+
available_actions: list[str] = Field(default_factory=list)
|
| 366 |
+
episode_state: EpisodeState = Field(default_factory=EpisodeState)
|
| 367 |
+
notes: Optional[str] = None
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
8. **`ml_training_debugger/client.py`** — STUB:
|
| 371 |
+
```python
|
| 372 |
+
"""Typed EnvClient for baseline scripts."""
|
| 373 |
+
|
| 374 |
+
from openenv.core.env_client import EnvClient
|
| 375 |
+
|
| 376 |
+
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class MLTrainingEnvClient(EnvClient[MLTrainingAction, MLTrainingObservation, dict]):
|
| 380 |
+
"""Typed client for the PyTorch Training Debugger environment."""
|
| 381 |
+
|
| 382 |
+
def _step_payload(self, action: MLTrainingAction) -> dict:
|
| 383 |
+
return action.model_dump(exclude_none=True)
|
| 384 |
+
|
| 385 |
+
def _parse_observation(self, data: dict) -> MLTrainingObservation:
|
| 386 |
+
return MLTrainingObservation.model_validate(data)
|
| 387 |
+
```
|
| 388 |
+
|
| 389 |
+
9. **`server/__init__.py`** — empty file
|
| 390 |
+
|
| 391 |
+
10. **`server/environment.py`** — STUB:
|
| 392 |
+
```python
|
| 393 |
+
"""MLTrainingEnvironment — extends openenv Environment."""
|
| 394 |
+
|
| 395 |
+
from typing import Any, Optional
|
| 396 |
+
|
| 397 |
+
from openenv.core.env_server.interfaces import Environment
|
| 398 |
+
|
| 399 |
+
from ml_training_debugger.models import (
|
| 400 |
+
EpisodeState,
|
| 401 |
+
MLTrainingAction,
|
| 402 |
+
MLTrainingObservation,
|
| 403 |
+
TrainingConfig,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
class MLTrainingEnvironment(
|
| 408 |
+
Environment[MLTrainingAction, MLTrainingObservation, dict]
|
| 409 |
+
):
|
| 410 |
+
"""OpenEnv environment for PyTorch training run debugging."""
|
| 411 |
+
|
| 412 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 413 |
+
|
| 414 |
+
def reset(
|
| 415 |
+
self,
|
| 416 |
+
seed: Optional[int] = None,
|
| 417 |
+
episode_id: Optional[str] = None,
|
| 418 |
+
**kwargs: Any,
|
| 419 |
+
) -> MLTrainingObservation:
|
| 420 |
+
"""Reset environment, return initial observation."""
|
| 421 |
+
state = EpisodeState()
|
| 422 |
+
obs = MLTrainingObservation(
|
| 423 |
+
run_id=episode_id or "episode_001",
|
| 424 |
+
training_loss_history=[2.3] * 20,
|
| 425 |
+
val_loss_history=[2.3] * 20,
|
| 426 |
+
val_accuracy_history=[0.1] * 20,
|
| 427 |
+
current_config=TrainingConfig(),
|
| 428 |
+
available_actions=state.compute_available_actions(),
|
| 429 |
+
episode_state=state,
|
| 430 |
+
done=False,
|
| 431 |
+
reward=0.0,
|
| 432 |
+
)
|
| 433 |
+
return obs
|
| 434 |
+
|
| 435 |
+
def step(
|
| 436 |
+
self,
|
| 437 |
+
action: MLTrainingAction,
|
| 438 |
+
timeout_s: Optional[float] = None,
|
| 439 |
+
**kwargs: Any,
|
| 440 |
+
) -> MLTrainingObservation:
|
| 441 |
+
"""Process one agent action."""
|
| 442 |
+
state = EpisodeState()
|
| 443 |
+
obs = MLTrainingObservation(
|
| 444 |
+
run_id="episode_001",
|
| 445 |
+
training_loss_history=[2.3] * 20,
|
| 446 |
+
val_loss_history=[2.3] * 20,
|
| 447 |
+
val_accuracy_history=[0.1] * 20,
|
| 448 |
+
current_config=TrainingConfig(),
|
| 449 |
+
available_actions=state.compute_available_actions(),
|
| 450 |
+
episode_state=state,
|
| 451 |
+
done=False,
|
| 452 |
+
reward=-0.01,
|
| 453 |
+
)
|
| 454 |
+
return obs
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def state(self) -> dict:
|
| 458 |
+
"""Return current environment state."""
|
| 459 |
+
return {"status": "active"}
|
| 460 |
+
```
|
| 461 |
+
|
| 462 |
+
11. **`server/app.py`** — STUB with all endpoints:
|
| 463 |
+
```python
|
| 464 |
+
"""FastAPI app — openenv create_app() + custom routes."""
|
| 465 |
+
|
| 466 |
+
import logging
|
| 467 |
+
|
| 468 |
+
from fastapi import FastAPI
|
| 469 |
+
from fastapi.responses import JSONResponse
|
| 470 |
+
from openenv.core.env_server.http_server import create_app
|
| 471 |
+
|
| 472 |
+
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
| 473 |
+
from server.environment import MLTrainingEnvironment
|
| 474 |
+
|
| 475 |
+
logger = logging.getLogger(__name__)
|
| 476 |
+
|
| 477 |
+
# create_app takes the class (factory), not an instance
|
| 478 |
+
app: FastAPI = create_app(
|
| 479 |
+
MLTrainingEnvironment,
|
| 480 |
+
MLTrainingAction,
|
| 481 |
+
MLTrainingObservation,
|
| 482 |
+
env_name="pytorch_training_debugger",
|
| 483 |
+
max_concurrent_envs=5,
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
@app.get("/health")
|
| 488 |
+
def health_check() -> dict:
|
| 489 |
+
"""Health check — required by hackathon auto-validator."""
|
| 490 |
+
return {"status": "ready", "tasks": 3}
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
@app.get("/tasks")
|
| 494 |
+
def get_tasks() -> list[dict]:
|
| 495 |
+
"""Return task list with IDs, difficulties, and action schema."""
|
| 496 |
+
schema = MLTrainingAction.model_json_schema()
|
| 497 |
+
return [
|
| 498 |
+
{"id": "task_001", "difficulty": "easy", "max_steps": 20, "action_schema": schema},
|
| 499 |
+
{"id": "task_003", "difficulty": "medium", "max_steps": 25, "action_schema": schema},
|
| 500 |
+
{"id": "task_005", "difficulty": "hard", "max_steps": 30, "action_schema": schema},
|
| 501 |
+
]
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
@app.post("/grader")
|
| 505 |
+
def post_grader() -> dict:
|
| 506 |
+
"""Return grader score for most recently completed episode."""
|
| 507 |
+
return {"score": None, "error": "no_completed_episode"}
|
| 508 |
+
|
| 509 |
+
|
| 510 |
+
@app.post("/baseline")
|
| 511 |
+
async def post_baseline() -> dict:
|
| 512 |
+
"""Trigger baseline run, return scores."""
|
| 513 |
+
return {"scores": {"task_001": 0.0, "task_003": 0.0, "task_005": 0.0}}
|
| 514 |
+
```
|
| 515 |
+
|
| 516 |
+
12. **`openenv.yaml`**:
|
| 517 |
+
```yaml
|
| 518 |
+
spec_version: 1
|
| 519 |
+
name: pytorch-training-debugger
|
| 520 |
+
type: space
|
| 521 |
+
runtime: fastapi
|
| 522 |
+
app: server.app:app
|
| 523 |
+
port: 7860
|
| 524 |
+
|
| 525 |
+
# Extended metadata
|
| 526 |
+
version: "1.0.0"
|
| 527 |
+
description: |
|
| 528 |
+
PyTorch-native fault injection engine for training failure debugging.
|
| 529 |
+
An AI agent investigates, diagnoses, fixes, and verifies broken
|
| 530 |
+
training runs using real torch.nn.Module models, torch.autograd
|
| 531 |
+
gradients, state_dict() weight inspection, and PyTorch code-level
|
| 532 |
+
debugging.
|
| 533 |
+
framework: openenv
|
| 534 |
+
tags: [ml-debugging, pytorch, reinforcement-learning, root-cause-analysis, fault-injection]
|
| 535 |
+
|
| 536 |
+
observation_space:
|
| 537 |
+
type: MLTrainingObservation
|
| 538 |
+
description: "Training run snapshot with progressive reveal"
|
| 539 |
+
|
| 540 |
+
action_space:
|
| 541 |
+
type: MLTrainingAction
|
| 542 |
+
description: "Investigation, fix, and diagnosis actions with dynamic availability"
|
| 543 |
+
|
| 544 |
+
tasks:
|
| 545 |
+
- id: task_001
|
| 546 |
+
difficulty: easy
|
| 547 |
+
max_steps: 20
|
| 548 |
+
- id: task_003
|
| 549 |
+
difficulty: medium
|
| 550 |
+
max_steps: 25
|
| 551 |
+
- id: task_005
|
| 552 |
+
difficulty: hard
|
| 553 |
+
max_steps: 30
|
| 554 |
+
|
| 555 |
+
reward:
|
| 556 |
+
range: [-1.0, 1.0]
|
| 557 |
+
shaped: true
|
| 558 |
+
step_penalty: -0.01
|
| 559 |
+
investigation_bonus: 0.05
|
| 560 |
+
correct_diagnosis: 0.50
|
| 561 |
+
terminal_convergence: 0.40
|
| 562 |
+
|
| 563 |
+
endpoints:
|
| 564 |
+
websocket: "/ws"
|
| 565 |
+
tasks: "GET /tasks"
|
| 566 |
+
grader: "POST /grader"
|
| 567 |
+
baseline: "POST /baseline"
|
| 568 |
+
health: "GET /health"
|
| 569 |
+
```
|
| 570 |
+
|
| 571 |
+
13. **`Dockerfile`**:
|
| 572 |
+
```dockerfile
|
| 573 |
+
FROM python:3.12-slim
|
| 574 |
+
|
| 575 |
+
WORKDIR /app
|
| 576 |
+
|
| 577 |
+
# Install PyTorch CPU-only first (largest layer, cached)
|
| 578 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 579 |
+
|
| 580 |
+
# Install remaining dependencies
|
| 581 |
+
COPY requirements.txt .
|
| 582 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 583 |
+
|
| 584 |
+
# Copy application code
|
| 585 |
+
COPY ml_training_debugger/ ml_training_debugger/
|
| 586 |
+
COPY server/ server/
|
| 587 |
+
COPY openenv.yaml .
|
| 588 |
+
COPY baseline_heuristic.py .
|
| 589 |
+
|
| 590 |
+
# Copy pre-computed validation reports if they exist
|
| 591 |
+
COPY validation/reports/ validation/reports/ 2>/dev/null || true
|
| 592 |
+
|
| 593 |
+
EXPOSE 7860
|
| 594 |
+
|
| 595 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 596 |
+
```
|
| 597 |
+
|
| 598 |
+
14. **`tests/__init__.py`** — empty file
|
| 599 |
+
|
| 600 |
+
15. **`tests/conftest.py`**:
|
| 601 |
+
```python
|
| 602 |
+
"""Shared test fixtures."""
|
| 603 |
+
|
| 604 |
+
import pytest
|
| 605 |
+
|
| 606 |
+
from ml_training_debugger.models import (
|
| 607 |
+
EpisodeState,
|
| 608 |
+
MLTrainingAction,
|
| 609 |
+
MLTrainingObservation,
|
| 610 |
+
TrainingConfig,
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@pytest.fixture
|
| 615 |
+
def fresh_episode_state() -> EpisodeState:
|
| 616 |
+
return EpisodeState()
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
@pytest.fixture
|
| 620 |
+
def sample_config() -> TrainingConfig:
|
| 621 |
+
return TrainingConfig(learning_rate=0.001)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
@pytest.fixture
|
| 625 |
+
def sample_observation() -> MLTrainingObservation:
|
| 626 |
+
state = EpisodeState()
|
| 627 |
+
return MLTrainingObservation(
|
| 628 |
+
run_id="test_episode",
|
| 629 |
+
training_loss_history=[2.3 - i * 0.1 for i in range(20)],
|
| 630 |
+
val_loss_history=[2.3 - i * 0.08 for i in range(20)],
|
| 631 |
+
val_accuracy_history=[0.1 + i * 0.04 for i in range(20)],
|
| 632 |
+
current_config=TrainingConfig(),
|
| 633 |
+
available_actions=state.compute_available_actions(),
|
| 634 |
+
episode_state=state,
|
| 635 |
+
done=False,
|
| 636 |
+
reward=0.0,
|
| 637 |
+
)
|
| 638 |
+
```
|
| 639 |
+
|
| 640 |
+
16. **`tests/test_models.py`**:
|
| 641 |
+
```python
|
| 642 |
+
"""Test all Pydantic models instantiate and serialize correctly."""
|
| 643 |
+
|
| 644 |
+
import json
|
| 645 |
+
import pytest
|
| 646 |
+
from ml_training_debugger.models import (
|
| 647 |
+
CodeSnippet,
|
| 648 |
+
DataBatchStats,
|
| 649 |
+
EpisodeState,
|
| 650 |
+
GradientStats,
|
| 651 |
+
MLTrainingAction,
|
| 652 |
+
MLTrainingObservation,
|
| 653 |
+
ModelWeightStats,
|
| 654 |
+
RootCauseDiagnosis,
|
| 655 |
+
TrainingConfig,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
class TestRootCauseDiagnosis:
|
| 660 |
+
def test_all_six_values_exist(self):
|
| 661 |
+
assert len(RootCauseDiagnosis) == 6
|
| 662 |
+
|
| 663 |
+
def test_values_are_strings(self):
|
| 664 |
+
for d in RootCauseDiagnosis:
|
| 665 |
+
assert isinstance(d.value, str)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
class TestTrainingConfig:
|
| 669 |
+
def test_default_instantiation(self):
|
| 670 |
+
config = TrainingConfig()
|
| 671 |
+
assert config.learning_rate == 0.001
|
| 672 |
+
|
| 673 |
+
def test_json_roundtrip(self):
|
| 674 |
+
config = TrainingConfig(learning_rate=0.01)
|
| 675 |
+
data = json.loads(config.model_dump_json())
|
| 676 |
+
restored = TrainingConfig.model_validate(data)
|
| 677 |
+
assert restored.learning_rate == 0.01
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
class TestEpisodeState:
|
| 681 |
+
def test_fresh_state(self):
|
| 682 |
+
state = EpisodeState()
|
| 683 |
+
assert state.step_count == 0
|
| 684 |
+
assert not state.gradients_inspected
|
| 685 |
+
assert not state.diagnosis_submitted
|
| 686 |
+
|
| 687 |
+
def test_available_actions_initial(self):
|
| 688 |
+
state = EpisodeState()
|
| 689 |
+
actions = state.compute_available_actions()
|
| 690 |
+
assert "inspect_gradients" in actions
|
| 691 |
+
assert "mark_diagnosed" in actions
|
| 692 |
+
assert "fix_code" not in actions
|
| 693 |
+
assert "restart_run" not in actions
|
| 694 |
+
|
| 695 |
+
def test_fix_code_available_after_code_inspected(self):
|
| 696 |
+
state = EpisodeState(code_inspected=True)
|
| 697 |
+
actions = state.compute_available_actions()
|
| 698 |
+
assert "fix_code" in actions
|
| 699 |
+
|
| 700 |
+
def test_restart_run_available_after_fix(self):
|
| 701 |
+
state = EpisodeState(fix_action_taken=True)
|
| 702 |
+
actions = state.compute_available_actions()
|
| 703 |
+
assert "restart_run" in actions
|
| 704 |
+
|
| 705 |
+
def test_mark_diagnosed_disappears_after_submission(self):
|
| 706 |
+
state = EpisodeState(diagnosis_submitted=True)
|
| 707 |
+
actions = state.compute_available_actions()
|
| 708 |
+
assert "mark_diagnosed" not in actions
|
| 709 |
+
|
| 710 |
+
|
| 711 |
+
class TestMLTrainingObservation:
|
| 712 |
+
def test_extends_observation(self):
|
| 713 |
+
from openenv.core.env_server.types import Observation
|
| 714 |
+
assert issubclass(MLTrainingObservation, Observation)
|
| 715 |
+
|
| 716 |
+
def test_has_done_and_reward(self):
|
| 717 |
+
obs = MLTrainingObservation(done=True, reward=0.5)
|
| 718 |
+
assert obs.done is True
|
| 719 |
+
assert obs.reward == 0.5
|
| 720 |
+
|
| 721 |
+
def test_json_serialization(self):
|
| 722 |
+
obs = MLTrainingObservation(
|
| 723 |
+
run_id="test",
|
| 724 |
+
training_loss_history=[1.0, 2.0],
|
| 725 |
+
val_accuracy_history=[0.5],
|
| 726 |
+
)
|
| 727 |
+
data = json.loads(obs.model_dump_json())
|
| 728 |
+
assert data["run_id"] == "test"
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
class TestMLTrainingAction:
|
| 732 |
+
def test_extends_action(self):
|
| 733 |
+
from openenv.core.env_server.types import Action
|
| 734 |
+
assert issubclass(MLTrainingAction, Action)
|
| 735 |
+
|
| 736 |
+
def test_basic_action(self):
|
| 737 |
+
action = MLTrainingAction(action_type="inspect_gradients")
|
| 738 |
+
assert action.action_type == "inspect_gradients"
|
| 739 |
+
|
| 740 |
+
def test_modify_config_action(self):
|
| 741 |
+
action = MLTrainingAction(
|
| 742 |
+
action_type="modify_config",
|
| 743 |
+
target="learning_rate",
|
| 744 |
+
value=0.001,
|
| 745 |
+
)
|
| 746 |
+
assert action.target == "learning_rate"
|
| 747 |
+
|
| 748 |
+
def test_mark_diagnosed_action(self):
|
| 749 |
+
action = MLTrainingAction(
|
| 750 |
+
action_type="mark_diagnosed",
|
| 751 |
+
diagnosis="lr_too_high",
|
| 752 |
+
)
|
| 753 |
+
assert action.diagnosis == "lr_too_high"
|
| 754 |
+
|
| 755 |
+
def test_fix_code_action(self):
|
| 756 |
+
action = MLTrainingAction(
|
| 757 |
+
action_type="fix_code",
|
| 758 |
+
line=13,
|
| 759 |
+
replacement="loss = criterion(output, batch_y)",
|
| 760 |
+
)
|
| 761 |
+
assert action.line == 13
|
| 762 |
+
```
|
| 763 |
+
|
| 764 |
+
**Step 0.3 — Validation Commands:**
|
| 765 |
+
|
| 766 |
+
```bash
|
| 767 |
+
# In project root with venv activated
|
| 768 |
+
source .venv/bin/activate
|
| 769 |
+
|
| 770 |
+
# 1. Verify imports
|
| 771 |
+
python -c "from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation; print('models OK')"
|
| 772 |
+
python -c "from ml_training_debugger.client import MLTrainingEnvClient; print('client OK')"
|
| 773 |
+
python -c "from server.app import app; print('app OK')"
|
| 774 |
+
|
| 775 |
+
# 2. Run tests
|
| 776 |
+
pytest tests/test_models.py -v
|
| 777 |
+
|
| 778 |
+
# 3. Start server
|
| 779 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860 &
|
| 780 |
+
sleep 3
|
| 781 |
+
curl http://localhost:7860/health
|
| 782 |
+
curl http://localhost:7860/tasks
|
| 783 |
+
curl http://localhost:7860/docs
|
| 784 |
+
kill %1
|
| 785 |
+
|
| 786 |
+
# 4. Formatting
|
| 787 |
+
black ml_training_debugger/ server/ tests/ --check
|
| 788 |
+
ruff check ml_training_debugger/ server/ tests/
|
| 789 |
+
isort ml_training_debugger/ server/ tests/ --check --profile black
|
| 790 |
+
```
|
| 791 |
+
|
| 792 |
+
### Acceptance Criteria — Phase 0
|
| 793 |
+
|
| 794 |
+
- [ ] All Pydantic models instantiate without error and serialize to valid JSON
|
| 795 |
+
- [ ] `MLTrainingObservation` extends `Observation` (has `done`, `reward`, `metadata`)
|
| 796 |
+
- [ ] `MLTrainingAction` extends `Action` (has `metadata`)
|
| 797 |
+
- [ ] `EpisodeState.compute_available_actions()` returns correct dynamic action lists
|
| 798 |
+
- [ ] Server starts on port 7860 and responds to `/health` with `{"status": "ready", "tasks": 3}`
|
| 799 |
+
- [ ] `/tasks` returns 3 tasks with action schema
|
| 800 |
+
- [ ] `pytest tests/test_models.py` passes all tests
|
| 801 |
+
- [ ] `client.py` imports without error
|
| 802 |
+
- [ ] `black --check`, `ruff check`, `isort --check` all pass
|
| 803 |
+
|
| 804 |
+
---
|
| 805 |
+
|
| 806 |
+
## Phase 1: Core Data Models & Pydantic Types
|
| 807 |
+
|
| 808 |
+
### Goal
|
| 809 |
+
Finalize all model fields to match the spec exactly. No business logic yet — just data shapes.
|
| 810 |
+
|
| 811 |
+
### Files to Edit
|
| 812 |
+
|
| 813 |
+
**`ml_training_debugger/models.py`** — Already created in Phase 0. Verify:
|
| 814 |
+
- All fields match spec Section 10 exactly
|
| 815 |
+
- `GradientStats.is_exploding` threshold: `mean_norm > 10.0`
|
| 816 |
+
- `GradientStats.is_vanishing` threshold: `mean_norm < 1e-6`
|
| 817 |
+
- `TrainingConfig` field names match `modify_config` target options
|
| 818 |
+
- `EpisodeState.compute_available_actions()` logic matches spec Section 10 dynamic rules
|
| 819 |
+
|
| 820 |
+
### Tests (write BEFORE implementation — TDD)
|
| 821 |
+
|
| 822 |
+
All tests already written in `tests/test_models.py` from Phase 0. Extend with:
|
| 823 |
+
|
| 824 |
+
```python
|
| 825 |
+
class TestGradientStats:
|
| 826 |
+
def test_exploding_threshold(self):
|
| 827 |
+
stats = GradientStats(
|
| 828 |
+
layer_name="fc", norm_history=[15.0], mean_norm=15.0, max_norm=15.0,
|
| 829 |
+
is_exploding=True, is_vanishing=False,
|
| 830 |
+
)
|
| 831 |
+
assert stats.is_exploding is True
|
| 832 |
+
|
| 833 |
+
def test_vanishing_threshold(self):
|
| 834 |
+
stats = GradientStats(
|
| 835 |
+
layer_name="conv1", norm_history=[1e-7], mean_norm=1e-7, max_norm=1e-7,
|
| 836 |
+
is_exploding=False, is_vanishing=True,
|
| 837 |
+
)
|
| 838 |
+
assert stats.is_vanishing is True
|
| 839 |
+
|
| 840 |
+
def test_normal_gradients(self):
|
| 841 |
+
stats = GradientStats(
|
| 842 |
+
layer_name="conv1", norm_history=[0.5], mean_norm=0.5, max_norm=0.5,
|
| 843 |
+
is_exploding=False, is_vanishing=False,
|
| 844 |
+
)
|
| 845 |
+
assert not stats.is_exploding
|
| 846 |
+
assert not stats.is_vanishing
|
| 847 |
+
```
|
| 848 |
+
|
| 849 |
+
### Acceptance Criteria — Phase 1
|
| 850 |
+
|
| 851 |
+
- [ ] Every field in every model matches the spec Section 10 types exactly
|
| 852 |
+
- [ ] No `Dict[str, Any]` in any public model (typed Pydantic everywhere)
|
| 853 |
+
- [ ] `import torch` appears in `models.py`
|
| 854 |
+
- [ ] All model tests pass
|
| 855 |
+
|
| 856 |
+
---
|
| 857 |
+
|
| 858 |
+
## Phase 2: PyTorch-Native Fault Injection Engine + Simulation
|
| 859 |
+
|
| 860 |
+
### Goal
|
| 861 |
+
Real PyTorch models with real gradients + parametric curve generators. This is the technical heart.
|
| 862 |
+
|
| 863 |
+
### Files to Create
|
| 864 |
+
|
| 865 |
+
**Step 2.1 — `ml_training_debugger/scenarios.py`** (~120 lines):
|
| 866 |
+
|
| 867 |
+
```python
|
| 868 |
+
"""ScenarioParams and scenario sampling."""
|
| 869 |
+
|
| 870 |
+
from __future__ import annotations
|
| 871 |
+
|
| 872 |
+
import dataclasses
|
| 873 |
+
from typing import Optional
|
| 874 |
+
|
| 875 |
+
import torch
|
| 876 |
+
|
| 877 |
+
from ml_training_debugger.models import RootCauseDiagnosis
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
@dataclasses.dataclass(frozen=True)
|
| 881 |
+
class ScenarioParams:
|
| 882 |
+
"""Internal scenario parameters — not exposed to agent."""
|
| 883 |
+
task_id: str
|
| 884 |
+
root_cause: RootCauseDiagnosis
|
| 885 |
+
seed: int
|
| 886 |
+
learning_rate: float = 0.001
|
| 887 |
+
weight_decay: float = 0.0001
|
| 888 |
+
leakage_pct: float = 0.0
|
| 889 |
+
depth_multiplier: float = 1.0
|
| 890 |
+
divergence_epoch: int = 5
|
| 891 |
+
red_herring_intensity: float = 1.0
|
| 892 |
+
red_herring_spike_layer: str = "fc"
|
| 893 |
+
bug_type: Optional[str] = None
|
| 894 |
+
notes: Optional[str] = None
|
| 895 |
+
error_log: Optional[str] = None
|
| 896 |
+
gpu_memory_used_gb: float = 6.2
|
| 897 |
+
max_steps: int = 20
|
| 898 |
+
|
| 899 |
+
|
| 900 |
+
def sample_scenario(task_id: str, seed: int) -> ScenarioParams:
|
| 901 |
+
"""Sample a ScenarioParams for the given task."""
|
| 902 |
+
rng = torch.Generator()
|
| 903 |
+
rng.manual_seed(seed)
|
| 904 |
+
|
| 905 |
+
# Use torch for random selection
|
| 906 |
+
def choose(options: list) -> any:
|
| 907 |
+
idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
|
| 908 |
+
return options[idx]
|
| 909 |
+
|
| 910 |
+
if task_id == "task_001":
|
| 911 |
+
lr = choose([0.05, 0.08, 0.10, 0.15, 0.30])
|
| 912 |
+
return ScenarioParams(
|
| 913 |
+
task_id=task_id,
|
| 914 |
+
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
|
| 915 |
+
seed=seed,
|
| 916 |
+
learning_rate=lr,
|
| 917 |
+
error_log=f"RuntimeError: Loss is NaN at epoch 12 (lr={lr})",
|
| 918 |
+
max_steps=20,
|
| 919 |
+
)
|
| 920 |
+
|
| 921 |
+
elif task_id == "task_003":
|
| 922 |
+
leakage = choose([0.12, 0.18, 0.22, 0.28])
|
| 923 |
+
return ScenarioParams(
|
| 924 |
+
task_id=task_id,
|
| 925 |
+
root_cause=RootCauseDiagnosis.DATA_LEAKAGE,
|
| 926 |
+
seed=seed,
|
| 927 |
+
leakage_pct=leakage,
|
| 928 |
+
notes="Model architecture upgraded from 2-layer to 4-layer CNN at epoch 2. Performance improvement may reflect increased model capacity.",
|
| 929 |
+
max_steps=25,
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
elif task_id == "task_005":
|
| 933 |
+
intensity = (
|
| 934 |
+
torch.empty(1).uniform_(0.8, 2.5, generator=rng).item()
|
| 935 |
+
)
|
| 936 |
+
spike_layer = choose(["fc", "conv1"])
|
| 937 |
+
return ScenarioParams(
|
| 938 |
+
task_id=task_id,
|
| 939 |
+
root_cause=RootCauseDiagnosis.BATCHNORM_EVAL_MODE,
|
| 940 |
+
seed=seed,
|
| 941 |
+
red_herring_intensity=intensity,
|
| 942 |
+
red_herring_spike_layer=spike_layer,
|
| 943 |
+
gpu_memory_used_gb=14.56, # 91% of 16GB
|
| 944 |
+
error_log="Warning: GPU memory pressure detected, consider reducing batch size or enabling gradient checkpointing",
|
| 945 |
+
max_steps=30,
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
| 949 |
+
```
|
| 950 |
+
|
| 951 |
+
**Step 2.2 — `ml_training_debugger/pytorch_engine.py`** (~250 lines):
|
| 952 |
+
|
| 953 |
+
Key functions:
|
| 954 |
+
- `SimpleCNN(torch.nn.Module)` — 3-layer CNN, ~50K params
|
| 955 |
+
- `create_model_and_inject_fault(scenario: ScenarioParams) -> tuple[torch.nn.Module, dict]`
|
| 956 |
+
- `extract_gradient_stats(model: torch.nn.Module) -> list[GradientStats]`
|
| 957 |
+
- `extract_weight_stats(model: torch.nn.Module) -> list[ModelWeightStats]`
|
| 958 |
+
- `extract_model_modes(model: torch.nn.Module) -> dict[str, str]`
|
| 959 |
+
|
| 960 |
+
Implementation notes:
|
| 961 |
+
- `torch.manual_seed(scenario.seed)` at the start of every call
|
| 962 |
+
- For Task 1: set lr high, run 2 forward+backward passes → gradients explode
|
| 963 |
+
- For Task 3: normal model, no gradient anomaly
|
| 964 |
+
- For Task 5: call `model.eval()` before training → BatchNorm frozen
|
| 965 |
+
- All gradient stats come from real `param.grad` tensors
|
| 966 |
+
- All weight stats come from real `model.state_dict()`
|
| 967 |
+
|
| 968 |
+
**Step 2.3 — `ml_training_debugger/simulation.py`** (~180 lines):
|
| 969 |
+
|
| 970 |
+
Key functions:
|
| 971 |
+
- `gen_loss_history(scenario: ScenarioParams) -> list[float]` — all torch.Tensor ops
|
| 972 |
+
- `gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]`
|
| 973 |
+
- `gen_val_loss_history(scenario: ScenarioParams) -> list[float]`
|
| 974 |
+
|
| 975 |
+
Per-task parametric curves from spec Section 6:
|
| 976 |
+
- Task 1: `loss = torch.exp(torch.tensor(lr) * torch.arange(20))`
|
| 977 |
+
- Task 3: `val_acc = torch.sigmoid(torch.linspace(-3, 3, 20)) * (1 - leakage_pct)`
|
| 978 |
+
- Task 5: Normal loss + elevated variance, slow val_acc degradation
|
| 979 |
+
|
| 980 |
+
### Tests to Create FIRST (TDD)
|
| 981 |
+
|
| 982 |
+
**`tests/test_scenarios.py`**:
|
| 983 |
+
- `sample_scenario("task_001", seed=42)` returns `root_cause == LR_TOO_HIGH`
|
| 984 |
+
- `sample_scenario("task_003", seed=42)` returns `root_cause == DATA_LEAKAGE`
|
| 985 |
+
- `sample_scenario("task_005", seed=42)` returns `root_cause == BATCHNORM_EVAL_MODE`
|
| 986 |
+
- Different seeds produce different parameters (but same root cause per task)
|
| 987 |
+
- Unknown task_id raises ValueError
|
| 988 |
+
|
| 989 |
+
**`tests/test_pytorch_engine.py`**:
|
| 990 |
+
- `SimpleCNN` is a real `torch.nn.Module` with ~50K params
|
| 991 |
+
- Task 1 fault injection: `is_exploding=True` on all layers
|
| 992 |
+
- Task 5 fault injection: `is_exploding=False` on all layers, `model.training==False`
|
| 993 |
+
- `extract_gradient_stats` returns `list[GradientStats]` with real float norms
|
| 994 |
+
- `extract_weight_stats` returns `list[ModelWeightStats]` from real state_dict
|
| 995 |
+
- `extract_model_modes` returns dict mapping layer names to "train"/"eval"
|
| 996 |
+
- **CRITICAL**: `import torch` in pytorch_engine.py, zero `import numpy`
|
| 997 |
+
|
| 998 |
+
**`tests/test_simulation.py`**:
|
| 999 |
+
- All outputs are `list[float]` of length 20
|
| 1000 |
+
- Task 1 (exploding): loss diverges (last value >> first value)
|
| 1001 |
+
- Task 3 (leakage): val_acc suspiciously high from early epochs
|
| 1002 |
+
- Task 5 (batchnorm): slow val_acc degradation (~1-2% per epoch)
|
| 1003 |
+
- All computation uses torch (no numpy)
|
| 1004 |
+
|
| 1005 |
+
### Acceptance Criteria — Phase 2
|
| 1006 |
+
|
| 1007 |
+
- [ ] `SimpleCNN` is a real `torch.nn.Module` with ~50K parameters
|
| 1008 |
+
- [ ] `create_model_and_inject_fault` for Task 1 produces exploding gradients (`is_exploding=True` all layers)
|
| 1009 |
+
- [ ] `create_model_and_inject_fault` for Task 5 produces `model.training==False` on all layers
|
| 1010 |
+
- [ ] `extract_gradient_stats` returns real floats from `torch.norm(param.grad)`
|
| 1011 |
+
- [ ] `extract_weight_stats` returns real floats from `state_dict()`
|
| 1012 |
+
- [ ] Parametric curves produce 20-element lists with correct shapes per task
|
| 1013 |
+
- [ ] `import torch` in `pytorch_engine.py` and `simulation.py` — zero `import numpy`
|
| 1014 |
+
- [ ] `torch.manual_seed(seed)` ensures reproducibility
|
| 1015 |
+
- [ ] All Phase 2 tests pass
|
| 1016 |
+
|
| 1017 |
+
---
|
| 1018 |
+
|
| 1019 |
+
## Phase 3: MVP Tasks (1, 3, 5) + Reward Engine + Graders
|
| 1020 |
+
|
| 1021 |
+
### Goal
|
| 1022 |
+
All reward logic and graders implemented. The environment can score episodes.
|
| 1023 |
+
|
| 1024 |
+
### Files to Create
|
| 1025 |
+
|
| 1026 |
+
**Step 3.1 — `ml_training_debugger/reward_engine.py`** (~100 lines):
|
| 1027 |
+
|
| 1028 |
+
```python
|
| 1029 |
+
def compute_reward(
|
| 1030 |
+
action: MLTrainingAction,
|
| 1031 |
+
episode_state: EpisodeState,
|
| 1032 |
+
scenario: ScenarioParams,
|
| 1033 |
+
is_valid_action: bool,
|
| 1034 |
+
is_correct_fix: bool | None = None,
|
| 1035 |
+
convergence_confirmed: bool = False,
|
| 1036 |
+
) -> float:
|
| 1037 |
+
```
|
| 1038 |
+
|
| 1039 |
+
All 7 components per spec Section 12:
|
| 1040 |
+
1. Step penalty: -0.01 (flat, unconditional)
|
| 1041 |
+
2. Investigation bonus: +0.05 (first-time per type)
|
| 1042 |
+
3. Context-gated penalty: -0.20 (ONLY when `gradients_inspected AND gradients_were_normal`)
|
| 1043 |
+
4. Invalid action: -0.05
|
| 1044 |
+
5. Wrong code fix: -0.10
|
| 1045 |
+
6. Correct diagnosis: +0.50 / Wrong diagnosis: -0.30
|
| 1046 |
+
7. Terminal convergence: +0.40 (gated on `fix_action_taken AND restart_after_fix`)
|
| 1047 |
+
|
| 1048 |
+
Hard cap at [-1.0, 1.0].
|
| 1049 |
+
|
| 1050 |
+
**Step 3.2 — `ml_training_debugger/graders.py`** (~150 lines):
|
| 1051 |
+
|
| 1052 |
+
One function per task. Each returns float in [0.0, 1.0]:
|
| 1053 |
+
- `grade_task_001(state: EpisodeState, scenario: ScenarioParams) -> float`
|
| 1054 |
+
- `grade_task_003(state: EpisodeState, scenario: ScenarioParams) -> float`
|
| 1055 |
+
- `grade_task_005(state: EpisodeState, scenario: ScenarioParams) -> float`
|
| 1056 |
+
|
| 1057 |
+
Grader scoring per spec Section 11:
|
| 1058 |
+
- Task 1: inspect_gradients(+0.05), correct LR fix(+0.20), restart+converge(+0.35), correct diagnosis(+0.40) = 1.0
|
| 1059 |
+
- Task 3: inspect_data(+0.05), patch_data_loader(+0.30), restart+converge(+0.30), correct diagnosis(+0.35) = 1.0
|
| 1060 |
+
- Task 5: inspect_gradients(+0.05), inspect_model_modes(+0.05), fix_model_mode(+0.25), restart+converge(+0.30), correct diagnosis(+0.40) = 1.05 → capped at 1.0. Penalty: add_callback after normal gradients = -0.20.
|
| 1061 |
+
|
| 1062 |
+
**CRITICAL — Grader is NOT a sum of step rewards.** It evaluates EpisodeState holistically.
|
| 1063 |
+
|
| 1064 |
+
### Tests to Create FIRST (TDD)
|
| 1065 |
+
|
| 1066 |
+
**`tests/test_reward_engine.py`** — THE MOST CRITICAL TEST FILE:
|
| 1067 |
+
|
| 1068 |
+
```python
|
| 1069 |
+
class TestContextGatedPenalty:
|
| 1070 |
+
"""The project's primary innovation — must be exact."""
|
| 1071 |
+
|
| 1072 |
+
def test_no_penalty_before_inspection(self):
|
| 1073 |
+
"""add_callback at step 1 (no prior inspection) -> NO penalty."""
|
| 1074 |
+
state = EpisodeState() # gradients_inspected=False
|
| 1075 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 1076 |
+
reward = compute_reward(action, state, scenario, is_valid_action=True)
|
| 1077 |
+
# Should be just step penalty: -0.01
|
| 1078 |
+
assert reward == pytest.approx(-0.01)
|
| 1079 |
+
|
| 1080 |
+
def test_penalty_after_normal_gradients(self):
|
| 1081 |
+
"""inspect_gradients (normal) then add_callback -> -0.20 penalty."""
|
| 1082 |
+
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
|
| 1083 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 1084 |
+
reward = compute_reward(action, state, scenario, is_valid_action=True)
|
| 1085 |
+
# Step penalty + context-gated penalty: -0.01 + -0.20 = -0.21
|
| 1086 |
+
assert reward == pytest.approx(-0.21)
|
| 1087 |
+
|
| 1088 |
+
def test_no_penalty_after_abnormal_gradients(self):
|
| 1089 |
+
"""inspect_gradients (exploding) then add_callback -> no context penalty."""
|
| 1090 |
+
state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
|
| 1091 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 1092 |
+
reward = compute_reward(action, state, scenario, is_valid_action=True)
|
| 1093 |
+
assert reward == pytest.approx(-0.01)
|
| 1094 |
+
```
|
| 1095 |
+
|
| 1096 |
+
Also test:
|
| 1097 |
+
- Step penalty is flat -0.01 (NOT multiplied by step_count)
|
| 1098 |
+
- Investigation bonus +0.05 first-time only
|
| 1099 |
+
- Investigation bonus NOT awarded on repeat
|
| 1100 |
+
- Correct diagnosis: +0.50
|
| 1101 |
+
- Wrong diagnosis: -0.30
|
| 1102 |
+
- Terminal convergence: +0.40 when all gates met
|
| 1103 |
+
- Invalid action: -0.05
|
| 1104 |
+
- Wrong code fix: -0.10
|
| 1105 |
+
- Reward capped at [-1.0, 1.0]
|
| 1106 |
+
|
| 1107 |
+
**`tests/test_graders.py`**:
|
| 1108 |
+
- Each grader returns float in [0.0, 1.0]
|
| 1109 |
+
- Perfect Task 1 path scores 1.0
|
| 1110 |
+
- Wrong diagnosis on Task 1 scores < 0.5
|
| 1111 |
+
- Task 5: agent that chases red herring scores 0.80-0.85
|
| 1112 |
+
- Task 5: optimal path scores 1.0
|
| 1113 |
+
- Grader is deterministic (same state → same score)
|
| 1114 |
+
|
| 1115 |
+
### Acceptance Criteria — Phase 3
|
| 1116 |
+
|
| 1117 |
+
- [ ] `compute_reward` implements all 7 components exactly per spec Section 12
|
| 1118 |
+
- [ ] Context-gated penalty fires ONLY when `gradients_inspected=True AND gradients_were_normal=True`
|
| 1119 |
+
- [ ] Context-gated penalty does NOT fire before `inspect_gradients` has been called
|
| 1120 |
+
- [ ] Step penalty is flat -0.01 (never multiplied by step_count)
|
| 1121 |
+
- [ ] All 3 graders return [0.0, 1.0] with meaningful variance
|
| 1122 |
+
- [ ] Grader != reward function (separate modules, separate logic)
|
| 1123 |
+
- [ ] All Phase 3 tests pass
|
| 1124 |
+
|
| 1125 |
+
---
|
| 1126 |
+
|
| 1127 |
+
## Phase 4: Environment Lifecycle, EpisodeState, and Action Handling
|
| 1128 |
+
|
| 1129 |
+
### Goal
|
| 1130 |
+
Full `reset()` and `step()` implementations in `environment.py`. The environment is functionally complete.
|
| 1131 |
+
|
| 1132 |
+
### Files to Edit
|
| 1133 |
+
|
| 1134 |
+
**`server/environment.py`** — Full implementation:
|
| 1135 |
+
|
| 1136 |
+
`reset(task_id)`:
|
| 1137 |
+
1. Parse `task_id` from `kwargs` (framework passes it via kwargs or episode_id)
|
| 1138 |
+
2. Derive deterministic seed from task_id
|
| 1139 |
+
3. Call `sample_scenario(task_id, seed)`
|
| 1140 |
+
4. Call `torch.manual_seed(scenario.seed)`
|
| 1141 |
+
5. Call `create_model_and_inject_fault(scenario)` → get real model
|
| 1142 |
+
6. Generate parametric curves via `simulation.py`
|
| 1143 |
+
7. Create fresh `EpisodeState`
|
| 1144 |
+
8. Store `(scenario, model, state)` keyed by session/episode ID
|
| 1145 |
+
9. Return `MLTrainingObservation` with populated loss/acc histories, config, error_log, available_actions — but empty gradient_stats, null data_batch_stats, null model_mode_info, null code_snippet
|
| 1146 |
+
|
| 1147 |
+
`step(action)`:
|
| 1148 |
+
1. Validate action (see spec Section 16 error handling matrix)
|
| 1149 |
+
2. Increment `step_count`
|
| 1150 |
+
3. Dispatch by `action.action_type`:
|
| 1151 |
+
- **`inspect_gradients`**: Extract real gradient stats, set `gradients_inspected=True`, compute `gradients_were_normal` (all layers `is_exploding==False`)
|
| 1152 |
+
- **`inspect_data_batch`**: Generate data batch stats, set `data_inspected=True`
|
| 1153 |
+
- **`inspect_model_modes`**: Extract model modes, set `model_modes_inspected=True`
|
| 1154 |
+
- **`inspect_model_weights`**: Extract real weight stats, set `model_weights_inspected=True`
|
| 1155 |
+
- **`inspect_code`**: Generate code snippet (if task supports it), set `code_inspected=True`
|
| 1156 |
+
- **`modify_config`**: Validate target/value, apply change, set `fix_action_taken=True`
|
| 1157 |
+
- **`add_callback`**: Apply callback, set `fix_action_taken=True`
|
| 1158 |
+
- **`replace_optimizer`**: Apply, set `fix_action_taken=True`
|
| 1159 |
+
- **`patch_data_loader`**: Apply, set `fix_action_taken=True`
|
| 1160 |
+
- **`fix_model_mode`**: Apply, set `fix_action_taken=True`
|
| 1161 |
+
- **`fix_code`**: Validate fix via `validate_fix()`, set `fix_action_taken=True`
|
| 1162 |
+
- **`restart_run`**: Requires `fix_action_taken`, set `restart_after_fix=True`, check convergence
|
| 1163 |
+
- **`mark_diagnosed`**: Set `diagnosis_submitted=True`, `done=True`
|
| 1164 |
+
- **`rollback_checkpoint`**: Requires `restart_after_fix`
|
| 1165 |
+
4. Call `compute_reward(action, state, scenario, ...)`
|
| 1166 |
+
5. Check step limit → set `done=True` if reached
|
| 1167 |
+
6. Update `available_actions` via `state.compute_available_actions()`
|
| 1168 |
+
7. Return `MLTrainingObservation` with all updated fields
|
| 1169 |
+
|
| 1170 |
+
**Session isolation**:
|
| 1171 |
+
- Store per-session state in `self._sessions: dict[str, SessionData]`
|
| 1172 |
+
- Session ID comes from the framework (via `episode_id` or WebSocket session)
|
| 1173 |
+
- Clean up on episode completion or disconnect
|
| 1174 |
+
|
| 1175 |
+
### Error Handling (spec Section 16 — ALL cases):
|
| 1176 |
+
|
| 1177 |
+
| Error | Behavior | Reward |
|
| 1178 |
+
|-------|----------|--------|
|
| 1179 |
+
| Invalid action_type | Return obs unchanged + error note | -0.05 |
|
| 1180 |
+
| Action not in available_actions | Return obs unchanged + error note | -0.05 |
|
| 1181 |
+
| modify_config missing target/value | Return obs unchanged + error note | -0.05 |
|
| 1182 |
+
| modify_config with unknown target | Return obs unchanged + error note | -0.05 |
|
| 1183 |
+
| mark_diagnosed missing diagnosis | Return obs unchanged + error note | -0.05 |
|
| 1184 |
+
| mark_diagnosed with invalid diagnosis | Return obs unchanged + error note | -0.05 |
|
| 1185 |
+
| fix_code missing line/replacement | Return obs unchanged + error note | -0.05 |
|
| 1186 |
+
| Action after done=True | Return final obs, no state change | 0.0 |
|
| 1187 |
+
| Step limit reached | Set done=True, return obs | 0.0 |
|
| 1188 |
+
|
| 1189 |
+
**CRITICAL**: `step()` must NEVER raise an unhandled exception.
|
| 1190 |
+
|
| 1191 |
+
### Tests to Create FIRST (TDD)
|
| 1192 |
+
|
| 1193 |
+
**`tests/test_episode_lifecycle.py`**:
|
| 1194 |
+
- Full reset→inspect→fix→restart→diagnose flow for Task 1
|
| 1195 |
+
- Full flow for Task 3
|
| 1196 |
+
- Full flow for Task 5
|
| 1197 |
+
- `available_actions` updates correctly at each step
|
| 1198 |
+
- `done=True` after `mark_diagnosed`
|
| 1199 |
+
- Step limit triggers `done=True`
|
| 1200 |
+
- Action after done returns final obs with no state change
|
| 1201 |
+
- Invalid action returns -0.05 penalty
|
| 1202 |
+
- `restart_run` not available before `fix_action_taken`
|
| 1203 |
+
- `fix_code` not available before `code_inspected`
|
| 1204 |
+
- Session isolation: two episodes don't interfere
|
| 1205 |
+
|
| 1206 |
+
### Acceptance Criteria — Phase 4
|
| 1207 |
+
|
| 1208 |
+
- [ ] `reset(task_id)` for tasks 001/003/005 returns valid `MLTrainingObservation` with correct initial state
|
| 1209 |
+
- [ ] `step()` dispatches all 14 action types correctly
|
| 1210 |
+
- [ ] Task 1: `inspect_gradients` → `is_exploding=True` all layers (real torch.autograd)
|
| 1211 |
+
- [ ] Task 5: `inspect_gradients` → `is_exploding=False` all layers, `gradients_were_normal=True`
|
| 1212 |
+
- [ ] Task 3: `inspect_data_batch` → `class_overlap_score > 0.5`
|
| 1213 |
+
- [ ] Task 5: `inspect_model_modes` → all layers in "eval" mode
|
| 1214 |
+
- [ ] All error conditions from spec Section 16 handled (never raises)
|
| 1215 |
+
- [ ] Progressive information reveal works (gradient_stats empty until inspected)
|
| 1216 |
+
- [ ] All Phase 4 tests pass
|
| 1217 |
+
|
| 1218 |
+
---
|
| 1219 |
+
|
| 1220 |
+
## Phase 5: Server (FastAPI + openenv-core) + All Required Endpoints
|
| 1221 |
+
|
| 1222 |
+
### Goal
|
| 1223 |
+
Wire the real environment into the server. All hackathon-required endpoints return real data.
|
| 1224 |
+
|
| 1225 |
+
### Files to Edit
|
| 1226 |
+
|
| 1227 |
+
**`server/app.py`** — Full implementation:
|
| 1228 |
+
|
| 1229 |
+
```python
|
| 1230 |
+
# Store reference to last completed episode for /grader
|
| 1231 |
+
_last_completed: dict[str, dict] = {} # session_id -> {score, task_id, steps}
|
| 1232 |
+
_baseline_running: bool = False
|
| 1233 |
+
|
| 1234 |
+
@app.get("/health")
|
| 1235 |
+
def health_check():
|
| 1236 |
+
return {"status": "ready", "tasks": 3}
|
| 1237 |
+
|
| 1238 |
+
@app.get("/tasks")
|
| 1239 |
+
def get_tasks():
|
| 1240 |
+
schema = MLTrainingAction.model_json_schema()
|
| 1241 |
+
return [
|
| 1242 |
+
{"id": "task_001", "difficulty": "easy", "max_steps": 20, "action_schema": schema},
|
| 1243 |
+
{"id": "task_003", "difficulty": "medium", "max_steps": 25, "action_schema": schema},
|
| 1244 |
+
{"id": "task_005", "difficulty": "hard", "max_steps": 30, "action_schema": schema},
|
| 1245 |
+
]
|
| 1246 |
+
|
| 1247 |
+
@app.post("/grader")
|
| 1248 |
+
def post_grader(session_id: str | None = None):
|
| 1249 |
+
# Return score for most recently completed episode
|
| 1250 |
+
# Edge cases per spec Section 14
|
| 1251 |
+
|
| 1252 |
+
@app.post("/baseline")
|
| 1253 |
+
async def post_baseline():
|
| 1254 |
+
# Run baseline_heuristic logic internally
|
| 1255 |
+
# Return {"scores": {"task_001": float, ...}}
|
| 1256 |
+
# Return 409 if already running
|
| 1257 |
+
```
|
| 1258 |
+
|
| 1259 |
+
**Grader endpoint edge cases** (spec Section 14):
|
| 1260 |
+
- No episode completed → `{"score": null, "error": "no_completed_episode"}`
|
| 1261 |
+
- Episode in progress → `{"score": null, "error": "episode_in_progress"}`
|
| 1262 |
+
- Episode completed → `{"score": 0.85, "task_id": "task_003", "steps": 6}`
|
| 1263 |
+
- Always HTTP 200 with JSON body
|
| 1264 |
+
|
| 1265 |
+
### Tests to Create FIRST (TDD)
|
| 1266 |
+
|
| 1267 |
+
**`tests/test_endpoints.py`**:
|
| 1268 |
+
- `GET /health` returns `{"status": "ready", "tasks": 3}` with 200
|
| 1269 |
+
- `GET /tasks` returns 3 tasks with action schema
|
| 1270 |
+
- `POST /grader` returns `{"score": null, "error": "no_completed_episode"}` initially
|
| 1271 |
+
- `POST /baseline` returns scores for all tasks
|
| 1272 |
+
- `POST /baseline` while running returns 409
|
| 1273 |
+
- Integration: reset→step→grader returns valid score
|
| 1274 |
+
|
| 1275 |
+
### Acceptance Criteria — Phase 5
|
| 1276 |
+
|
| 1277 |
+
- [ ] `GET /health` returns `{"status": "ready", "tasks": 3}` (200)
|
| 1278 |
+
- [ ] `GET /tasks` returns 3 tasks with IDs, difficulties, action schema
|
| 1279 |
+
- [ ] `POST /grader` handles all edge cases per spec Section 14
|
| 1280 |
+
- [ ] `POST /baseline` runs baseline and returns scores
|
| 1281 |
+
- [ ] Framework auto-provides: `/reset`, `/step`, `/state`, `/ws`, `/schema`, `/docs`
|
| 1282 |
+
- [ ] All Phase 5 tests pass
|
| 1283 |
+
|
| 1284 |
+
---
|
| 1285 |
+
|
| 1286 |
+
## Phase 6: Rule-Based Baseline + Reproducibility Guarantees
|
| 1287 |
+
|
| 1288 |
+
### Goal
|
| 1289 |
+
Deterministic baseline that produces bit-exact identical scores on two runs.
|
| 1290 |
+
|
| 1291 |
+
### Files to Create
|
| 1292 |
+
|
| 1293 |
+
**`baseline_heuristic.py`** (~150 lines):
|
| 1294 |
+
|
| 1295 |
+
Decision tree from spec Section 17:
|
| 1296 |
+
```
|
| 1297 |
+
1. reset(task_id)
|
| 1298 |
+
2. inspect_gradients
|
| 1299 |
+
3. IF any layer is_exploding → modify_config(lr=0.001) → restart → diagnose lr_too_high
|
| 1300 |
+
4. IF any layer is_vanishing → modify_config(lr=0.01) → restart → diagnose vanishing_gradients
|
| 1301 |
+
5. inspect_data_batch
|
| 1302 |
+
6. IF class_overlap_score > 0.5 → patch_data_loader → restart → diagnose data_leakage
|
| 1303 |
+
7. IF val_loss diverging → modify_config(weight_decay=0.01) → restart → diagnose overfitting
|
| 1304 |
+
8. inspect_model_modes → IF any eval → fix_model_mode → restart → diagnose batchnorm_eval_mode
|
| 1305 |
+
9. inspect_code → attempt fix → restart → diagnose code_bug
|
| 1306 |
+
10. FALLBACK: diagnose overfitting
|
| 1307 |
+
```
|
| 1308 |
+
|
| 1309 |
+
Uses `MLTrainingEnvClient` or `GenericEnvClient` to connect via WebSocket.
|
| 1310 |
+
|
| 1311 |
+
**Reproducibility requirements:**
|
| 1312 |
+
- `torch.manual_seed(seed)` at every `reset()` with deterministic seed per task
|
| 1313 |
+
- No floating-point non-determinism in parametric curves
|
| 1314 |
+
- Heuristic is pure logic with no randomness
|
| 1315 |
+
- Two runs must produce identical JSON output
|
| 1316 |
+
|
| 1317 |
+
### Tests to Create FIRST (TDD)
|
| 1318 |
+
|
| 1319 |
+
**`tests/test_baseline_reproducibility.py`**:
|
| 1320 |
+
- Run baseline twice → `diff run1.json run2.json` is empty
|
| 1321 |
+
- All scores in [0.0, 1.0]
|
| 1322 |
+
- Expected approximate scores: task_001 ~0.85, task_003 ~0.70, task_005 ~0.45
|
| 1323 |
+
|
| 1324 |
+
### Acceptance Criteria — Phase 6
|
| 1325 |
+
|
| 1326 |
+
- [ ] `baseline_heuristic.py` runs all 3 MVP tasks without error
|
| 1327 |
+
- [ ] Two consecutive runs produce bit-exact identical JSON output
|
| 1328 |
+
- [ ] No API key required
|
| 1329 |
+
- [ ] All scores in [0.0, 1.0] with meaningful variance
|
| 1330 |
+
- [ ] Decision tree follows spec Section 17 exactly
|
| 1331 |
+
|
| 1332 |
+
---
|
| 1333 |
+
|
| 1334 |
+
## Phase 7: Docker, HF Spaces, Logging, Error Handling & Edge Cases
|
| 1335 |
+
|
| 1336 |
+
### Goal
|
| 1337 |
+
Production-ready container that deploys cleanly.
|
| 1338 |
+
|
| 1339 |
+
### Files to Edit
|
| 1340 |
+
|
| 1341 |
+
**`Dockerfile`** — Finalize:
|
| 1342 |
+
- Base: `python:3.12-slim`
|
| 1343 |
+
- PyTorch CPU-only: `pip install torch --index-url https://download.pytorch.org/whl/cpu`
|
| 1344 |
+
- Target: <500MB
|
| 1345 |
+
- `EXPOSE 7860`
|
| 1346 |
+
- `CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]`
|
| 1347 |
+
|
| 1348 |
+
**Note on Dockerfile COPY**: Cannot use `COPY ... 2>/dev/null || true` in Dockerfile. Instead, ensure all files exist or use multi-stage approach.
|
| 1349 |
+
|
| 1350 |
+
**Logging** — Add to `server/app.py` and `server/environment.py`:
|
| 1351 |
+
- JSON structured logging to stdout
|
| 1352 |
+
- Log every `reset()`, `step()`, episode completion, errors
|
| 1353 |
+
|
| 1354 |
+
**WebSocket edge cases** (spec Section 16):
|
| 1355 |
+
- Client disconnects mid-episode → retain state 60s
|
| 1356 |
+
- Malformed JSON → return error, keep connection
|
| 1357 |
+
- step() before reset() → return "no_active_episode" error
|
| 1358 |
+
- reset() during active episode → terminate current, start new
|
| 1359 |
+
|
| 1360 |
+
### Acceptance Criteria — Phase 7
|
| 1361 |
+
|
| 1362 |
+
- [ ] `docker build -t pytorch-debugger .` succeeds
|
| 1363 |
+
- [ ] Docker image <500MB
|
| 1364 |
+
- [ ] `docker run -p 7860:7860 pytorch-debugger` starts and serves in <60s
|
| 1365 |
+
- [ ] `curl http://localhost:7860/health` returns `{"status": "ready", "tasks": 3}`
|
| 1366 |
+
- [ ] All WebSocket edge cases handled per spec Section 16
|
| 1367 |
+
- [ ] Structured JSON logging on all significant events
|
| 1368 |
+
|
| 1369 |
+
---
|
| 1370 |
+
|
| 1371 |
+
## Phase 8: Full Testing Suite + Pre-Submission Smoke Tests
|
| 1372 |
+
|
| 1373 |
+
### Goal
|
| 1374 |
+
>80% test coverage, all edge cases covered.
|
| 1375 |
+
|
| 1376 |
+
### Files to Create/Extend
|
| 1377 |
+
|
| 1378 |
+
All test files listed above, plus:
|
| 1379 |
+
- Fill coverage gaps identified by `pytest --cov`
|
| 1380 |
+
- Add edge case tests for every error in spec Section 16
|
| 1381 |
+
- Add test for `step()` after `done=True`
|
| 1382 |
+
- Add test for step limit termination
|
| 1383 |
+
|
| 1384 |
+
### Commands
|
| 1385 |
+
|
| 1386 |
+
```bash
|
| 1387 |
+
pytest tests/ -v --cov=ml_training_debugger --cov=server --cov-report=term-missing
|
| 1388 |
+
```
|
| 1389 |
+
|
| 1390 |
+
### Acceptance Criteria — Phase 8
|
| 1391 |
+
|
| 1392 |
+
- [ ] `pytest --cov` shows >80% coverage on all modules
|
| 1393 |
+
- [ ] Every error condition from spec Section 16 has a test
|
| 1394 |
+
- [ ] Context-gated penalty tests pass (both paths)
|
| 1395 |
+
- [ ] Dynamic available_actions tests pass
|
| 1396 |
+
- [ ] All 3 graders tested with multiple scenarios
|
| 1397 |
+
- [ ] Zero test failures
|
| 1398 |
+
|
| 1399 |
+
---
|
| 1400 |
+
|
| 1401 |
+
## Phase 9: Final Polish & Submission Readiness
|
| 1402 |
+
|
| 1403 |
+
### Goal
|
| 1404 |
+
README complete, all endpoints verified, `openenv validate` passes, deploy to HF Spaces.
|
| 1405 |
+
|
| 1406 |
+
### Files to Create
|
| 1407 |
+
|
| 1408 |
+
**`README.md`** (~200 lines):
|
| 1409 |
+
- Environment description and motivation
|
| 1410 |
+
- Action/observation space definitions
|
| 1411 |
+
- Task descriptions with difficulty
|
| 1412 |
+
- Setup instructions
|
| 1413 |
+
- Baseline scores table
|
| 1414 |
+
|
| 1415 |
+
**`deploy.sh`**:
|
| 1416 |
+
```bash
|
| 1417 |
+
#!/bin/bash
|
| 1418 |
+
set -euo pipefail
|
| 1419 |
+
|
| 1420 |
+
echo "=== Building Docker image ==="
|
| 1421 |
+
docker build -t pytorch-debugger .
|
| 1422 |
+
|
| 1423 |
+
echo "=== Starting container ==="
|
| 1424 |
+
docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
|
| 1425 |
+
sleep 10
|
| 1426 |
+
|
| 1427 |
+
echo "=== Health check ==="
|
| 1428 |
+
curl -f http://localhost:7860/health || { echo "FAIL: health"; exit 1; }
|
| 1429 |
+
|
| 1430 |
+
echo "=== Tasks endpoint ==="
|
| 1431 |
+
curl -f http://localhost:7860/tasks | python3 -m json.tool || { echo "FAIL: tasks"; exit 1; }
|
| 1432 |
+
|
| 1433 |
+
echo "=== Baseline reproducibility ==="
|
| 1434 |
+
python3 baseline_heuristic.py > run1.json 2>/dev/null
|
| 1435 |
+
python3 baseline_heuristic.py > run2.json 2>/dev/null
|
| 1436 |
+
diff run1.json run2.json && echo "PASS: reproducible" || { echo "FAIL: non-reproducible"; exit 1; }
|
| 1437 |
+
|
| 1438 |
+
echo "=== Baseline via endpoint ==="
|
| 1439 |
+
curl -f -X POST http://localhost:7860/baseline | python3 -m json.tool || { echo "FAIL: baseline endpoint"; exit 1; }
|
| 1440 |
+
|
| 1441 |
+
echo "=== Grader via endpoint ==="
|
| 1442 |
+
curl -f -X POST http://localhost:7860/grader | python3 -m json.tool || { echo "FAIL: grader endpoint"; exit 1; }
|
| 1443 |
+
|
| 1444 |
+
echo "=== Tests ==="
|
| 1445 |
+
pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
|
| 1446 |
+
|
| 1447 |
+
echo "=== Cleanup ==="
|
| 1448 |
+
docker stop smoke-test && docker rm smoke-test
|
| 1449 |
+
rm -f run1.json run2.json
|
| 1450 |
+
|
| 1451 |
+
echo "=== ALL CHECKS PASSED ==="
|
| 1452 |
+
```
|
| 1453 |
+
|
| 1454 |
+
### Acceptance Criteria — Phase 9
|
| 1455 |
+
|
| 1456 |
+
- [ ] `openenv validate` passes
|
| 1457 |
+
- [ ] `deploy.sh` runs end-to-end with zero failures
|
| 1458 |
+
- [ ] README is complete per hackathon requirements
|
| 1459 |
+
- [ ] Docker image <500MB, starts <60s
|
| 1460 |
+
- [ ] Baseline bit-exact reproducible
|
| 1461 |
+
- [ ] 3+ tasks with graders returning [0.0, 1.0] with meaningful variance
|
| 1462 |
+
- [ ] HF Space deployed, tagged `openenv`, responds to `reset()`
|
| 1463 |
+
- [ ] All typed Pydantic models — no `Dict[str, Any]`
|
| 1464 |
+
- [ ] `import torch` in every core module — zero numpy in core
|
| 1465 |
+
- [ ] Context-gated penalty fires correctly and does not fire prematurely
|
| 1466 |
+
- [ ] Test suite passes with >80% coverage
|
| 1467 |
+
|
| 1468 |
+
---
|
| 1469 |
+
|
| 1470 |
+
## Technical Risk Mitigations
|
| 1471 |
+
|
| 1472 |
+
| Risk | Impact | Mitigation |
|
| 1473 |
+
|------|--------|------------|
|
| 1474 |
+
| **WebSocket + HTTP composition** | ~~High~~ RESOLVED | `create_app()` returns standard FastAPI. Custom routes add cleanly. Verified in Phase 0. |
|
| 1475 |
+
| **Docker image size** | Medium | `python:3.12-slim` + torch CPU-only (~150MB). Target <500MB. Test early in Phase 7. |
|
| 1476 |
+
| **Task 6 fix validation fragility** | Medium | Multi-strategy pipeline: normalize → tokenize → semantic patterns → AST fallback. Test 5+ whitespace variations. (Post-MVP Phase 2 stretch) |
|
| 1477 |
+
| **Red-herring penalty gating** | HIGH | `gradients_were_normal` set inside `inspect_gradients` handler when ALL layers have `is_exploding=False`. Threshold: `mean_norm > 10.0`. Test BOTH paths explicitly. |
|
| 1478 |
+
| **Session isolation** | Medium | `dict[str, SessionData]` keyed by session ID. Framework provides session management. |
|
| 1479 |
+
| **Baseline reproducibility** | HIGH | `torch.manual_seed(seed)` at every `reset()`. Seed derived deterministically from task_id. Heuristic is pure logic. Test with `diff run1.json run2.json`. |
|
| 1480 |
+
| **Dockerfile build time** | Low | No real training during build. Validation reports pre-computed locally. |
|
| 1481 |
+
| **openenv.yaml format** | Medium | Template uses `spec_version: 1`, `type: space`, `runtime: fastapi`, `app: server.app:app`. Extended fields (tasks, reward, etc.) are additive. Test with `openenv validate` early. |
|
| 1482 |
+
| **Port mismatch** | Low | Spec says 7860 (HF Spaces default). openenv template says 8000. Use 7860 everywhere. |
|
| 1483 |
+
|
| 1484 |
+
---
|
| 1485 |
+
|
| 1486 |
+
## Exact openenv.yaml (Final)
|
| 1487 |
+
|
| 1488 |
+
```yaml
|
| 1489 |
+
spec_version: 1
|
| 1490 |
+
name: pytorch-training-debugger
|
| 1491 |
+
type: space
|
| 1492 |
+
runtime: fastapi
|
| 1493 |
+
app: server.app:app
|
| 1494 |
+
port: 7860
|
| 1495 |
+
|
| 1496 |
+
version: "1.0.0"
|
| 1497 |
+
description: |
|
| 1498 |
+
PyTorch-native fault injection engine for training failure debugging.
|
| 1499 |
+
An AI agent investigates, diagnoses, fixes, and verifies broken
|
| 1500 |
+
training runs using real torch.nn.Module models, torch.autograd
|
| 1501 |
+
gradients, state_dict() weight inspection, and PyTorch code-level
|
| 1502 |
+
debugging. 3 tasks across 3 difficulty tiers with context-gated
|
| 1503 |
+
reward shaping.
|
| 1504 |
+
framework: openenv
|
| 1505 |
+
tags: [ml-debugging, pytorch, reinforcement-learning, root-cause-analysis, fault-injection, openenv]
|
| 1506 |
+
|
| 1507 |
+
observation_space:
|
| 1508 |
+
type: MLTrainingObservation
|
| 1509 |
+
description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
|
| 1510 |
+
|
| 1511 |
+
action_space:
|
| 1512 |
+
type: MLTrainingAction
|
| 1513 |
+
description: "Investigation, fix, and diagnosis actions with dynamic availability"
|
| 1514 |
+
|
| 1515 |
+
tasks:
|
| 1516 |
+
- id: task_001
|
| 1517 |
+
difficulty: easy
|
| 1518 |
+
max_steps: 20
|
| 1519 |
+
- id: task_003
|
| 1520 |
+
difficulty: medium
|
| 1521 |
+
max_steps: 25
|
| 1522 |
+
- id: task_005
|
| 1523 |
+
difficulty: hard
|
| 1524 |
+
max_steps: 30
|
| 1525 |
+
|
| 1526 |
+
reward:
|
| 1527 |
+
range: [-1.0, 1.0]
|
| 1528 |
+
shaped: true
|
| 1529 |
+
step_penalty: -0.01
|
| 1530 |
+
investigation_bonus: 0.05
|
| 1531 |
+
max_investigation_bonus: 0.25
|
| 1532 |
+
correct_diagnosis: 0.50
|
| 1533 |
+
terminal_convergence: 0.40
|
| 1534 |
+
|
| 1535 |
+
endpoints:
|
| 1536 |
+
websocket: "/ws"
|
| 1537 |
+
tasks: "GET /tasks"
|
| 1538 |
+
grader: "POST /grader"
|
| 1539 |
+
baseline: "POST /baseline"
|
| 1540 |
+
health: "GET /health"
|
| 1541 |
+
```
|
| 1542 |
+
|
| 1543 |
+
---
|
| 1544 |
+
|
| 1545 |
+
## Exact Dockerfile (Final)
|
| 1546 |
+
|
| 1547 |
+
```dockerfile
|
| 1548 |
+
FROM python:3.12-slim
|
| 1549 |
+
|
| 1550 |
+
WORKDIR /app
|
| 1551 |
+
|
| 1552 |
+
# Install PyTorch CPU-only first (largest layer, cached)
|
| 1553 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 1554 |
+
|
| 1555 |
+
# Install remaining dependencies
|
| 1556 |
+
COPY requirements.txt .
|
| 1557 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 1558 |
+
|
| 1559 |
+
# Copy application code
|
| 1560 |
+
COPY ml_training_debugger/ ml_training_debugger/
|
| 1561 |
+
COPY server/ server/
|
| 1562 |
+
COPY openenv.yaml .
|
| 1563 |
+
COPY baseline_heuristic.py .
|
| 1564 |
+
COPY README.md .
|
| 1565 |
+
|
| 1566 |
+
EXPOSE 7860
|
| 1567 |
+
|
| 1568 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 1569 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 1570 |
+
|
| 1571 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 1572 |
+
```
|
| 1573 |
+
|
| 1574 |
+
---
|
| 1575 |
+
|
| 1576 |
+
## Pre-Submission Smoke Test Sequence
|
| 1577 |
+
|
| 1578 |
+
```bash
|
| 1579 |
+
# 1. Clean build
|
| 1580 |
+
docker build --no-cache -t pytorch-debugger .
|
| 1581 |
+
|
| 1582 |
+
# 2. Start container
|
| 1583 |
+
docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
|
| 1584 |
+
sleep 10
|
| 1585 |
+
|
| 1586 |
+
# 3. Health check
|
| 1587 |
+
curl -f http://localhost:7860/health
|
| 1588 |
+
|
| 1589 |
+
# 4. Tasks endpoint
|
| 1590 |
+
curl -f http://localhost:7860/tasks | python3 -m json.tool
|
| 1591 |
+
|
| 1592 |
+
# 5. Baseline reproducibility
|
| 1593 |
+
python3 baseline_heuristic.py > run1.json 2>/dev/null
|
| 1594 |
+
python3 baseline_heuristic.py > run2.json 2>/dev/null
|
| 1595 |
+
diff run1.json run2.json && echo "PASS: reproducible" || echo "FAIL"
|
| 1596 |
+
|
| 1597 |
+
# 6. Baseline via endpoint
|
| 1598 |
+
curl -f -X POST http://localhost:7860/baseline | python3 -m json.tool
|
| 1599 |
+
|
| 1600 |
+
# 7. Grader via endpoint
|
| 1601 |
+
curl -f -X POST http://localhost:7860/grader | python3 -m json.tool
|
| 1602 |
+
|
| 1603 |
+
# 8. OpenEnv validation
|
| 1604 |
+
openenv validate
|
| 1605 |
+
|
| 1606 |
+
# 9. Test suite
|
| 1607 |
+
pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
|
| 1608 |
+
|
| 1609 |
+
# 10. Cleanup
|
| 1610 |
+
docker stop smoke-test && docker rm smoke-test
|
| 1611 |
+
rm -f run1.json run2.json
|
| 1612 |
+
|
| 1613 |
+
echo "=== All checks passed ==="
|
| 1614 |
+
```
|
| 1615 |
+
|
| 1616 |
+
---
|
| 1617 |
+
|
| 1618 |
+
## Post-MVP Stretch (Phase 2 from ROADMAP)
|
| 1619 |
+
|
| 1620 |
+
**Only after MVP is 100% deployed and passing all auto-validation:**
|
| 1621 |
+
|
| 1622 |
+
1. **Task 6** (code debugging) — highest impact differentiator
|
| 1623 |
+
- Create `ml_training_debugger/code_templates.py`
|
| 1624 |
+
- 4 bug variants: eval_mode, detach_loss, zero_grad_missing, inplace_relu
|
| 1625 |
+
- Multi-strategy fix validation: normalize → tokenize → semantic → AST
|
| 1626 |
+
- Diagnosis is ALWAYS `code_bug` regardless of variant
|
| 1627 |
+
|
| 1628 |
+
2. **Tasks 2 & 4** — fill out to 6 tasks
|
| 1629 |
+
- Task 2: vanishing gradients (easy, mirror of Task 1)
|
| 1630 |
+
- Task 4: overfitting (medium, train-val divergence)
|
| 1631 |
+
|
| 1632 |
+
3. **Dashboard** — `server/dashboard.html`, Plotly.js via CDN
|
| 1633 |
+
|
| 1634 |
+
4. **Validation Suite** — `validation/*.py`, R² > 0.85
|
| 1635 |
+
|
| 1636 |
+
5. **LLM Baseline** — `baseline_inference.py`, GPT-4o
|
| 1637 |
+
|
| 1638 |
+
Update `openenv.yaml`, `/tasks`, `/health` task count as tasks are added.
|
| 1639 |
+
|
| 1640 |
+
---
|
| 1641 |
+
|
| 1642 |
+
## SESSION_ID
|
| 1643 |
+
|
| 1644 |
+
- CODEX_SESSION: N/A (codeagent-wrapper not available)
|
| 1645 |
+
- GEMINI_SESSION: N/A (codeagent-wrapper not available)
|
| 1646 |
+
|
| 1647 |
+
Plan generated by Claude Opus 4.6 via deep analysis of all 4 project markdown files + openenv-core framework API inspection.
|
.coverage
ADDED
|
Binary file (53.2 kB). View file
|
|
|
.dockerignore
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
.git/
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
tests/
|
| 6 |
+
validation/
|
| 7 |
+
*.md
|
| 8 |
+
!README.md
|
| 9 |
+
.claude/
|
| 10 |
+
run*.json
|
| 11 |
+
htmlcov/
|
| 12 |
+
.mypy_cache/
|
| 13 |
+
.ruff_cache/
|
.gitignore
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.venv/
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
*.pyo
|
| 5 |
+
.env
|
| 6 |
+
run*.json
|
| 7 |
+
.pytest_cache/
|
| 8 |
+
htmlcov/
|
| 9 |
+
*.egg-info/
|
| 10 |
+
dist/
|
| 11 |
+
build/
|
| 12 |
+
validation/reports/*.png
|
| 13 |
+
.mypy_cache/
|
| 14 |
+
.ruff_cache/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.12
|
CLAUDE.md
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLAUDE.md — PyTorch Training Run Debugger
|
| 2 |
+
|
| 3 |
+
OpenEnv RL environment for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology.
|
| 4 |
+
An AI agent debugs broken PyTorch training runs by investigating gradients, weights, data, model modes, and source code to diagnose and fix real ML failure patterns.
|
| 5 |
+
|
| 6 |
+
**Spec:** `ml-training-debugger-spec.md` is the single source of truth. If this file and the spec conflict, the spec wins.
|
| 7 |
+
|
| 8 |
+
**Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
|
| 9 |
+
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
## Non-Negotiable Rules
|
| 13 |
+
|
| 14 |
+
### MVP-First Execution
|
| 15 |
+
Ship Tasks 1, 3, 5 (easy/medium/hard) + rule-based baseline + Docker + HF deploy **before** touching anything else. A deployed MVP that passes auto-validation beats a half-finished 6-task environment. Priority order after MVP: Task 6 > Tasks 2 & 4 > dashboard > validation suite > LLM baseline.
|
| 16 |
+
|
| 17 |
+
### Context-Gated Penalty Must Be Exact
|
| 18 |
+
The -0.20 penalty for `add_callback` fires **only when both** `gradients_inspected == True` AND `gradients_were_normal == True`. It must **never** fire before `inspect_gradients` has been called. This is the project's primary innovation. Get the gate conditions wrong and the differentiator is broken. Test both paths:
|
| 19 |
+
- `add_callback` at step 1 (no prior inspection) -> **no penalty**
|
| 20 |
+
- `inspect_gradients` (normal) then `add_callback` -> **-0.20 penalty**
|
| 21 |
+
|
| 22 |
+
### Task 6 Diagnosis Is Always `code_bug`
|
| 23 |
+
Regardless of the specific bug variant (`eval_mode`, `detach_loss`, `zero_grad_missing`, `inplace_relu`), Task 6's correct diagnosis is **always** `code_bug`. Submitting `batchnorm_eval_mode` on Task 6's `eval_mode` variant is a wrong diagnosis (-0.30). The grader enforces this with a strict equality check.
|
| 24 |
+
|
| 25 |
+
### PyTorch-Native Only — No NumPy
|
| 26 |
+
Every computation in core modules uses `torch.Tensor`, not `numpy.ndarray`. `import torch` must appear in `models.py`, `simulation.py`, `pytorch_engine.py`, `reward_engine.py`, and `graders.py`. This is a Meta PyTorch hackathon — judges will notice. The only exception is test utilities and the validation suite where `scipy`/`matplotlib` are acceptable.
|
| 27 |
+
|
| 28 |
+
### Grader != Reward Function
|
| 29 |
+
These are separate modules with separate purposes. The **reward function** (`reward_engine.py`) returns a float per step for RL training signal. The **grader** (`graders.py`) returns a normalized 0.0-1.0 score at episode end for the `/grader` endpoint and auto-validation. The grader evaluates `EpisodeState` holistically — it is **not** a sum of step rewards. Never conflate them.
|
| 30 |
+
|
| 31 |
+
### Opaque Task IDs
|
| 32 |
+
Task IDs are `task_001` through `task_006`. The agent must never be able to infer the diagnosis from the task ID. Do not use descriptive names anywhere the agent can observe them.
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Architecture Constraints
|
| 37 |
+
|
| 38 |
+
### Framework Integration (Verified)
|
| 39 |
+
```
|
| 40 |
+
openenv-core v0.2.2 → create_app() → returns standard FastAPI instance
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
- `MLTrainingAction` extends `Action` from `openenv.core.env_server.types`
|
| 44 |
+
- `MLTrainingObservation` extends `Observation` from `openenv.core.env_server.types` (has built-in `done`, `reward`, `metadata`)
|
| 45 |
+
- `MLTrainingEnvironment` extends `Environment` from `openenv.core.env_server.interfaces` (must implement `reset()`, `step()`, `state` property)
|
| 46 |
+
- `MLTrainingEnvClient` in `client.py` extends `EnvClient` with typed `action_type` and `observation_type` — used by baseline scripts
|
| 47 |
+
- `create_app()` takes the **class** (factory), not an instance
|
| 48 |
+
- Custom routes (`/tasks`, `/grader`, `/baseline`, `/health`) are added directly to the returned FastAPI app via `@app.get()`/`@app.post()` decorators
|
| 49 |
+
- Framework auto-provides: `POST /reset`, `POST /step`, `GET /state`, `WS /ws`, `GET /schema`, `GET /docs`, `/mcp`
|
| 50 |
+
|
| 51 |
+
### Key Constraints (see spec for full detail)
|
| 52 |
+
- **Real PyTorch models:** `pytorch_engine.py` instantiates `SimpleCNN` (~50K params) at every `reset()`, runs 1-2 real forward+backward passes. Gradient and weight stats come from real `torch.autograd` and `model.state_dict()`.
|
| 53 |
+
- **Typed Pydantic models everywhere:** No `Dict[str, Any]`. `available_actions` is dynamically computed from `EpisodeState`, never hardcoded.
|
| 54 |
+
- **Session isolation:** Each WebSocket client gets its own `EpisodeState` keyed by session ID. `SUPPORTS_CONCURRENT_SESSIONS = True`.
|
| 55 |
+
|
| 56 |
+
---
|
| 57 |
+
|
| 58 |
+
## Coding Standards
|
| 59 |
+
|
| 60 |
+
### Formatting & Linting
|
| 61 |
+
- **black** for formatting (line length 88)
|
| 62 |
+
- **ruff** for linting
|
| 63 |
+
- **isort** for import ordering (profile=black)
|
| 64 |
+
- Run all three before every commit
|
| 65 |
+
|
| 66 |
+
### Type Hints
|
| 67 |
+
Type annotations on **every** function signature and return type. No `Any` in public APIs. Use `Optional[X]` for nullable fields, `Literal[...]` for closed string unions, `list[X]` (lowercase) for Python 3.12+.
|
| 68 |
+
|
| 69 |
+
### Testing
|
| 70 |
+
- **pytest** for all tests
|
| 71 |
+
- Every module in `ml_training_debugger/` has a corresponding `tests/test_*.py`
|
| 72 |
+
- Minimum test coverage: 80%
|
| 73 |
+
- Critical tests that must exist:
|
| 74 |
+
- `test_reward_engine.py`: context-gated penalty fires/doesn't fire under correct conditions
|
| 75 |
+
- `test_graders.py`: each grader returns 0.0-1.0, correct diagnosis scores high, wrong diagnosis scores low
|
| 76 |
+
- `test_pytorch_engine.py`: model instantiation, fault injection, gradient/weight extraction produces real tensors
|
| 77 |
+
- `test_code_templates.py`: all 4 bug variants generate valid code, fix validation accepts correct fixes and rejects wrong ones (including whitespace/comment variations)
|
| 78 |
+
- `test_episode_lifecycle.py`: full episode flow reset->inspect->fix->restart->diagnose produces expected state transitions
|
| 79 |
+
|
| 80 |
+
### File Size Limits
|
| 81 |
+
- 400 lines typical, 800 max per file
|
| 82 |
+
- `models.py` may exceed 400 lines due to many Pydantic models — this is acceptable
|
| 83 |
+
- `pytorch_engine.py` must stay under 300 lines (isolate model definitions if needed)
|
| 84 |
+
|
| 85 |
+
### Error Handling
|
| 86 |
+
`step()` must **never** raise an unhandled exception. All invalid actions return a valid observation with `-0.05` penalty and an error note. All edge cases (step after done, step before reset, malformed JSON) return structured error responses.
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Key Risks to Watch
|
| 91 |
+
|
| 92 |
+
### Task 6 Code Fix Validation
|
| 93 |
+
LLM agents will submit fixes with trailing spaces, inline comments, or minor reformatting. Use the multi-strategy validation pipeline:
|
| 94 |
+
1. Normalize whitespace + strip comments
|
| 95 |
+
2. Token-stream comparison via `tokenize` module
|
| 96 |
+
3. 2-3 semantic equivalence patterns per bug variant
|
| 97 |
+
4. `ast.parse()` fallback to verify buggy pattern is absent
|
| 98 |
+
|
| 99 |
+
Test with intentionally messy fixes: `" loss = criterion(output, batch_y) # fixed "` must pass.
|
| 100 |
+
|
| 101 |
+
### Red-Herring Penalty Gating
|
| 102 |
+
The `gradients_were_normal` flag is set **inside** the `inspect_gradients` handler, based on whether `is_exploding` is False on **all** layers. The threshold for `is_exploding` is `mean_norm > 10.0`. The threshold for `is_vanishing` is `mean_norm < 1e-6`. In Task 5, the FC spike has `is_exploding: False` (it spiked but the mean norm stays below 10.0), so `gradients_were_normal` is set to True. This is the gate that makes the penalty fire when the agent then calls `add_callback`.
|
| 103 |
+
|
| 104 |
+
### Docker Image Size
|
| 105 |
+
Target: <500MB. PyTorch CPU-only wheel is ~150MB. Use `python:3.12-slim` base. Install torch with `--index-url https://download.pytorch.org/whl/cpu`. Do NOT install CUDA. Pre-compute validation reports locally — do not run real training in Docker build.
|
| 106 |
+
|
| 107 |
+
### Baseline Reproducibility
|
| 108 |
+
The rule-based baseline must produce **bit-exact identical** scores on two consecutive runs. This requires:
|
| 109 |
+
- `torch.manual_seed(seed)` at every `reset()` with a deterministic seed per task
|
| 110 |
+
- No floating-point non-determinism in the parametric curve generators
|
| 111 |
+
- The heuristic decision tree is pure logic with no randomness
|
| 112 |
+
|
| 113 |
+
### Auto-Validator Endpoints
|
| 114 |
+
These endpoints are checked programmatically. They must respond correctly or you are disqualified:
|
| 115 |
+
- `GET /health` -> `{"status": "ready", "tasks": N}` (200) — N is the number of active tasks (3 for MVP, 6 for full)
|
| 116 |
+
- `GET /tasks` -> list of tasks with IDs and action schema (200)
|
| 117 |
+
- `POST /grader` -> `{"score": float}` after a completed episode (200)
|
| 118 |
+
- `POST /baseline` -> scores for all tasks (200)
|
| 119 |
+
- `WS /ws` -> responds to `reset` message
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## Reward Constants (Do Not Change)
|
| 124 |
+
|
| 125 |
+
See spec Section 12 for full rationale. Summary:
|
| 126 |
+
|
| 127 |
+
| Event | Value | Gate |
|
| 128 |
+
|---|---|---|
|
| 129 |
+
| Step penalty | -0.01 | Unconditional, flat (never multiply by step_count) |
|
| 130 |
+
| Investigation bonus | +0.05 | First-time only per inspection type |
|
| 131 |
+
| Context-gated penalty | -0.20 | `gradients_inspected AND gradients_were_normal` |
|
| 132 |
+
| Invalid action | -0.05 | Action not in `available_actions` |
|
| 133 |
+
| Wrong code fix | -0.10 | `fix_code` with wrong line/replacement |
|
| 134 |
+
| Correct diagnosis | +0.50 | `diagnosis == true_root_cause` |
|
| 135 |
+
| Wrong diagnosis | -0.30 | `diagnosis != true_root_cause` |
|
| 136 |
+
| Terminal convergence | +0.40 | `fix_action_taken AND restart_after_fix AND convergence` |
|
| 137 |
+
|
| 138 |
+
---
|
| 139 |
+
|
| 140 |
+
## Success Criteria — "Perfect" Submission
|
| 141 |
+
|
| 142 |
+
All of these must be true:
|
| 143 |
+
- [ ] `openenv validate` passes
|
| 144 |
+
- [ ] `docker build && docker run` starts server on port 7860 in <60s
|
| 145 |
+
- [ ] HF Space deploys, responds to `reset()`, tagged with `openenv`
|
| 146 |
+
- [ ] `baseline_heuristic.py` produces identical scores on two runs
|
| 147 |
+
- [ ] 3+ tasks with graders returning scores in [0.0, 1.0] with meaningful variance
|
| 148 |
+
- [ ] Hard task (Task 5 or 6) genuinely challenges frontier models (score < 0.7 for heuristic)
|
| 149 |
+
- [ ] Context-gated penalty fires correctly and does not fire prematurely
|
| 150 |
+
- [ ] All typed Pydantic models, no `Dict[str, Any]`
|
| 151 |
+
- [ ] `import torch` in every core module, zero numpy imports in core
|
| 152 |
+
- [ ] README documents: environment description, action/observation spaces, task descriptions with difficulty, setup instructions, baseline scores
|
| 153 |
+
- [ ] POST `/baseline`, POST `/grader`, GET `/tasks` all respond correctly
|
| 154 |
+
- [ ] Test suite passes with >80% coverage
|
| 155 |
+
|
| 156 |
+
---
|
| 157 |
+
|
| 158 |
+
## Commands
|
| 159 |
+
|
| 160 |
+
```bash
|
| 161 |
+
# Development (from project root: ML Debugger/)
|
| 162 |
+
source .venv/bin/activate
|
| 163 |
+
uvicorn server.app:app --reload --host 0.0.0.0 --port 7860
|
| 164 |
+
|
| 165 |
+
# Tests
|
| 166 |
+
pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
|
| 167 |
+
|
| 168 |
+
# Formatting
|
| 169 |
+
black ml_training_debugger/ server/ tests/
|
| 170 |
+
ruff check ml_training_debugger/ server/ tests/ --fix
|
| 171 |
+
isort ml_training_debugger/ server/ tests/ --profile black
|
| 172 |
+
|
| 173 |
+
# Docker
|
| 174 |
+
docker build -t pytorch-debugger .
|
| 175 |
+
docker run -p 7860:7860 pytorch-debugger
|
| 176 |
+
|
| 177 |
+
# Smoke test
|
| 178 |
+
curl http://localhost:7860/health
|
| 179 |
+
curl http://localhost:7860/tasks
|
| 180 |
+
python baseline_heuristic.py > run1.json
|
| 181 |
+
python baseline_heuristic.py > run2.json
|
| 182 |
+
diff run1.json run2.json # Must be empty
|
| 183 |
+
|
| 184 |
+
# OpenEnv validation
|
| 185 |
+
openenv validate
|
| 186 |
+
```
|
Dockerfile
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install PyTorch CPU-only first (largest layer, cached)
|
| 6 |
+
RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
|
| 7 |
+
|
| 8 |
+
# Install remaining dependencies
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy application code
|
| 13 |
+
COPY ml_training_debugger/ ml_training_debugger/
|
| 14 |
+
COPY server/ server/
|
| 15 |
+
COPY openenv.yaml .
|
| 16 |
+
COPY baseline_heuristic.py .
|
| 17 |
+
COPY README.md .
|
| 18 |
+
|
| 19 |
+
EXPOSE 7860
|
| 20 |
+
|
| 21 |
+
HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
|
| 22 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 23 |
+
|
| 24 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
PRD.md
ADDED
|
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PRD — PyTorch Training Run Debugger
|
| 2 |
+
|
| 3 |
+
**Product:** OpenEnv RL environment for ML training failure diagnosis
|
| 4 |
+
**Hackathon:** Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, Round 1
|
| 5 |
+
**Deadline:** April 8, 2026 (submission window opens March 28)
|
| 6 |
+
**Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
|
| 7 |
+
**Source of truth:** `ml-training-debugger-spec.md` for all implementation detail beyond this PRD
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## 1. Overview
|
| 12 |
+
|
| 13 |
+
### 1.1 What We Are Building
|
| 14 |
+
|
| 15 |
+
An OpenEnv-compliant reinforcement learning environment where an AI agent receives a snapshot of a broken PyTorch training run and must investigate, diagnose, fix, and verify the failure through a multi-step interactive process. The environment exposes real PyTorch model internals (gradients from `torch.autograd`, weights from `model.state_dict()`) and covers 6 failure scenarios across 3 difficulty tiers.
|
| 16 |
+
|
| 17 |
+
### 1.2 Problem Being Solved
|
| 18 |
+
|
| 19 |
+
MLOps teams spend 15-25% of engineer time debugging silent training failures — runs that produce no error, no crash, just bad metrics. Each misdiagnosed restart wastes GPU compute at $2-8/hour/card. The diagnostic process is hard because:
|
| 20 |
+
|
| 21 |
+
- Multiple symptoms can point to multiple causes simultaneously
|
| 22 |
+
- Some bugs produce no error — just mysteriously bad performance
|
| 23 |
+
- Fixing the wrong thing wastes hours of compute and restarts
|
| 24 |
+
- Static analysis catches some bugs but cannot reason through ambiguous runtime signals
|
| 25 |
+
|
| 26 |
+
No existing OpenEnv environment covers this domain. The OpenEnv Hub currently contains a demo echo environment and a code execution environment. This fills a genuine gap.
|
| 27 |
+
|
| 28 |
+
### 1.3 Why This Domain Wins
|
| 29 |
+
|
| 30 |
+
1. **Strategic alignment** — PyTorch debugging for a Meta PyTorch hackathon. Judges from Meta and Hugging Face will see their own framework as the core subject matter.
|
| 31 |
+
2. **Novel reward design** — Context-gated penalties that encode evidence-based reasoning into the reward signal. No existing OpenEnv environment attempts this.
|
| 32 |
+
3. **Code-level debugging** — Task 6 requires the agent to read and fix actual PyTorch code. Directly addresses Meta's interest: can an AI agent debug PyTorch?
|
| 33 |
+
4. **Ecosystem gap** — Zero competition in the OpenEnv ecosystem for ML training failure diagnosis.
|
| 34 |
+
|
| 35 |
+
### 1.4 Key Differentiators
|
| 36 |
+
|
| 37 |
+
| Differentiator | What It Is | Why It Matters |
|
| 38 |
+
|---|---|---|
|
| 39 |
+
| Context-gated reward shaping | Penalty fires only when agent ignores evidence it already gathered; no penalty for reasonable priors | Encodes evidence-based decision making — a capability no other OpenEnv environment has |
|
| 40 |
+
| PyTorch-native internals | Real `torch.nn.Module` models, real `torch.autograd` gradients, real `state_dict()` snapshots | Every model-level observation is grounded in real PyTorch computation, not synthetic data |
|
| 41 |
+
| Code-level debugging (Task 6) | Agent reads PyTorch code, identifies buggy line, submits code fix | Tests code understanding, not just metric interpretation — aligned with Meta's core interest |
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## 2. Target Users
|
| 46 |
+
|
| 47 |
+
### 2.1 Primary: Hackathon Judges (Meta + Hugging Face Engineers)
|
| 48 |
+
|
| 49 |
+
**What they evaluate:**
|
| 50 |
+
- Real-world utility (30%) — Is this a genuine task? Would someone use this to train/evaluate agents?
|
| 51 |
+
- Task & grader quality (25%) — Well-defined tasks, accurate graders, meaningful difficulty progression?
|
| 52 |
+
- Environment design (20%) — Clean state management, sensible action/observation spaces, good reward shaping?
|
| 53 |
+
- Code quality & spec compliance (15%) — OpenEnv spec, clean structure, typed models, working Dockerfile?
|
| 54 |
+
- Creativity & novelty (10%) — Novel domain, interesting mechanics, original approach?
|
| 55 |
+
|
| 56 |
+
**What impresses them:**
|
| 57 |
+
- Real `import torch` in core modules (not numpy wrappers)
|
| 58 |
+
- A live dashboard where they can watch an agent investigate in real time
|
| 59 |
+
- Deterministic graders that produce different scores for different agent quality levels
|
| 60 |
+
- The context-gated penalty — nuanced reward design that goes beyond standard practice
|
| 61 |
+
|
| 62 |
+
**What disqualifies:**
|
| 63 |
+
- HF Space doesn't deploy or respond to `reset()`
|
| 64 |
+
- Plagiarized or trivially modified existing environments
|
| 65 |
+
- Graders that always return the same score
|
| 66 |
+
- No baseline inference script
|
| 67 |
+
- Dockerfile doesn't build
|
| 68 |
+
|
| 69 |
+
### 2.2 Secondary: RL Researchers and Agent Developers
|
| 70 |
+
|
| 71 |
+
**What they need:**
|
| 72 |
+
- A challenging benchmark that differentiates heuristic agents from reasoning-capable ones
|
| 73 |
+
- Clear, typed action/observation schemas for agent integration
|
| 74 |
+
- Reproducible baseline scores for comparison
|
| 75 |
+
- Environments that produce meaningful reward signal across the full trajectory (not just sparse terminal reward)
|
| 76 |
+
|
| 77 |
+
### 2.3 Tertiary: Auto-Validation System (Phase 1 Gate)
|
| 78 |
+
|
| 79 |
+
A non-human "user" that must pass before any human judge sees the submission:
|
| 80 |
+
- Pings HF Space URL — must return 200 and respond to `reset()`
|
| 81 |
+
- Validates `openenv.yaml`, typed models, `step()`/`reset()`/`state()` endpoints
|
| 82 |
+
- Runs `docker build` on submitted repo
|
| 83 |
+
- Runs baseline script twice — scores must be identical
|
| 84 |
+
- Enumerates tasks, runs each grader — scores must be in [0.0, 1.0]
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## 3. Success Metrics
|
| 89 |
+
|
| 90 |
+
### 3.1 Evaluation Criteria Targets
|
| 91 |
+
|
| 92 |
+
| Criterion | Weight | Target Score | How We Hit It |
|
| 93 |
+
|---|---|---|---|
|
| 94 |
+
| Real-world utility | 30% | 26-30 | ML debugging is a $B+ problem; every PyTorch team encounters these failures; fills a genuine OpenEnv gap |
|
| 95 |
+
| Task & grader quality | 25% | 21-25 | 6 tasks (3 MVP), 3 difficulty tiers, deterministic graders, hard tasks challenge frontier models |
|
| 96 |
+
| Environment design | 20% | 17-20 | Progressive reveal, context-gated penalties, dynamic `available_actions`, proper episode boundaries |
|
| 97 |
+
| Code quality & spec compliance | 15% | 13-15 | Full OpenEnv spec, typed Pydantic models, working Dockerfile + HF Space, two baselines |
|
| 98 |
+
| Creativity & novelty | 10% | 9-10 | Context-gated rewards, real PyTorch model internals, code fix task — all new to OpenEnv |
|
| 99 |
+
| **Total** | **100%** | **86-100** | |
|
| 100 |
+
|
| 101 |
+
### 3.2 Quantitative Success Criteria
|
| 102 |
+
|
| 103 |
+
| Metric | Target | Measurement |
|
| 104 |
+
|---|---|---|
|
| 105 |
+
| Auto-validation | Pass all 5 gates | `openenv validate` + smoke test sequence |
|
| 106 |
+
| Grader score range | Meaningful variance per task | Heuristic baseline ~0.30-0.85 across tasks (not flat) |
|
| 107 |
+
| Heuristic-LLM gap | Measurable difference | LLM scores higher than heuristic on Tasks 5 and 6 |
|
| 108 |
+
| `reset()` latency | <200ms | Model instantiation + 2 forward passes + parametric curves |
|
| 109 |
+
| `step()` latency | <10ms | Action dispatch + reward computation + state update |
|
| 110 |
+
| Baseline reproducibility | Bit-exact across runs | `diff run1.json run2.json` produces no output |
|
| 111 |
+
| Docker image size | <500MB | PyTorch CPU-only + python:3.12-slim |
|
| 112 |
+
| Test coverage | >80% | `pytest --cov` |
|
| 113 |
+
|
| 114 |
+
### 3.3 Qualitative Success Criteria
|
| 115 |
+
|
| 116 |
+
- A judge can open `/dashboard`, trigger a baseline run, and understand the agent's reasoning at a glance
|
| 117 |
+
- Task 5 (BatchNorm eval mode) visibly differentiates disciplined investigation from red-herring chasing
|
| 118 |
+
- Task 6 (code bug) produces a "wow" moment — an agent reading and fixing PyTorch code in front of Meta judges
|
| 119 |
+
- The context-gated penalty creates a story: "this agent gathered evidence and then ignored it"
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## 4. Functional Requirements
|
| 124 |
+
|
| 125 |
+
> **Complete typed specifications for all data models, actions, observations, tasks, reward components, and error handling are in `ml-training-debugger-spec.md` Sections 10-16.** This section provides a product-level summary.
|
| 126 |
+
|
| 127 |
+
### 4.1 Agent Interaction Loop
|
| 128 |
+
|
| 129 |
+
```
|
| 130 |
+
reset(task_id) → initial observation (loss curves, config, error log — no gradients/weights/data/code)
|
| 131 |
+
↓
|
| 132 |
+
step(action) → updated observation + reward + done flag (progressive reveal)
|
| 133 |
+
↓
|
| 134 |
+
... repeat ...
|
| 135 |
+
↓
|
| 136 |
+
step(mark_diagnosed) → terminal observation, done=True, episode scored by grader
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### 4.2 Observation Space Summary
|
| 140 |
+
|
| 141 |
+
The `MLTrainingObservation` extends `Observation` from openenv-core. Key design:
|
| 142 |
+
- **Always visible from reset:** loss/accuracy histories, config, error_log, GPU memory, episode state, available actions
|
| 143 |
+
- **Progressively revealed:** gradient stats (real torch.autograd), weight stats (real state_dict), data batch stats, model mode info, code snippets — each populated only after the corresponding `inspect_*` action
|
| 144 |
+
- All fields are typed Pydantic models with explicit types. See spec Section 10 for complete field definitions.
|
| 145 |
+
|
| 146 |
+
### 4.3 Action Space Summary
|
| 147 |
+
|
| 148 |
+
The `MLTrainingAction` extends `Action` from openenv-core. 14 action types in 3 categories:
|
| 149 |
+
- **Investigation** (5): `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code`
|
| 150 |
+
- **Fix** (7): `modify_config`, `add_callback`, `replace_optimizer`, `patch_data_loader`, `fix_model_mode`, `fix_code`, `rollback_checkpoint`
|
| 151 |
+
- **Terminal** (2): `restart_run`, `mark_diagnosed`
|
| 152 |
+
|
| 153 |
+
Dynamic availability: `restart_run` requires `fix_action_taken`, `fix_code` requires `code_inspected`, `mark_diagnosed` disappears after submission. See spec Section 10 for complete action definitions and required fields.
|
| 154 |
+
|
| 155 |
+
### 4.4 Diagnosis Enum (RootCauseDiagnosis)
|
| 156 |
+
|
| 157 |
+
Closed set of 6 values. Grader is a single equality check — no fuzzy matching.
|
| 158 |
+
|
| 159 |
+
| Value | Description |
|
| 160 |
+
|---|---|
|
| 161 |
+
| `lr_too_high` | Learning rate too large for the architecture |
|
| 162 |
+
| `vanishing_gradients` | LR too low or architecture too deep, gradients decay to near-zero |
|
| 163 |
+
| `data_leakage` | Validation samples appearing in training batches |
|
| 164 |
+
| `overfitting` | Model memorizing training data, failing to generalize |
|
| 165 |
+
| `batchnorm_eval_mode` | Model left in eval mode, BatchNorm using running statistics |
|
| 166 |
+
| `code_bug` | Bug in the PyTorch training code (Task 6 — always this, regardless of bug variant) |
|
| 167 |
+
|
| 168 |
+
### 4.5 Reward Function Summary
|
| 169 |
+
|
| 170 |
+
Per-step signal. **Separate from the grader** (see 4.6). Range: [-1.0, 1.0] hard cap.
|
| 171 |
+
|
| 172 |
+
| Event | Reward | Gate Condition |
|
| 173 |
+
|---|---|---|
|
| 174 |
+
| Any step taken | -0.01 | Unconditional, flat constant (never multiplied by step_count) |
|
| 175 |
+
| First-time inspection (per type) | +0.05 | Not previously inspected for that type |
|
| 176 |
+
| `add_callback` after normal gradients | -0.20 | `gradients_inspected == True AND gradients_were_normal == True` |
|
| 177 |
+
| Invalid action | -0.05 | Action not in current `available_actions` |
|
| 178 |
+
| Wrong code fix | -0.10 | `fix_code` with incorrect line or replacement |
|
| 179 |
+
| Correct diagnosis | +0.50 | `diagnosis == true_root_cause` |
|
| 180 |
+
| Wrong diagnosis | -0.30 | `diagnosis != true_root_cause` |
|
| 181 |
+
| Convergence after fix+restart | +0.40 | `fix_action_taken AND restart_after_fix AND convergence_confirmed` |
|
| 182 |
+
|
| 183 |
+
See spec Section 12 for full design rationale.
|
| 184 |
+
|
| 185 |
+
### 4.6 Grader Function
|
| 186 |
+
|
| 187 |
+
Returns a single normalized 0.0-1.0 score at episode end. Evaluates `EpisodeState` holistically — checks which key actions were taken, whether the correct fix was applied, whether the diagnosis is correct, and efficiency. **Not a sum of step rewards.** One grader function per task. All graders are deterministic.
|
| 188 |
+
|
| 189 |
+
Exposed via `POST /grader`. Returns score for the most recently completed episode.
|
| 190 |
+
|
| 191 |
+
### 4.7 The Six Tasks
|
| 192 |
+
|
| 193 |
+
| Task | ID | Difficulty | Root Cause | Key Signal | Heuristic Score |
|
| 194 |
+
|---|---|---|---|---|---|
|
| 195 |
+
| Exploding Gradients | `task_001` | Easy | `lr_too_high` | All layers `is_exploding: True`, NaN in error_log | ~0.85 |
|
| 196 |
+
| Vanishing Gradients | `task_002` | Easy | `vanishing_gradients` | Deeper layers `is_vanishing: True`, flat loss | ~0.80 |
|
| 197 |
+
| Silent Data Leakage | `task_003` | Medium | `data_leakage` | High val accuracy from epoch 1, `class_overlap_score` 0.68-0.88 | ~0.70 |
|
| 198 |
+
| Overfitting | `task_004` | Medium | `overfitting` | Train-val divergence, loss→0.01 while val climbs | ~0.65 |
|
| 199 |
+
| BatchNorm Eval Mode | `task_005` | Hard | `batchnorm_eval_mode` | Slow val degradation + compound red herrings | ~0.45 |
|
| 200 |
+
| PyTorch Code Bug | `task_006` | Hard | `code_bug` (always) | Anomalous metrics, root cause only visible in code | ~0.30 |
|
| 201 |
+
|
| 202 |
+
**MVP tasks:** 1, 3, 5 (satisfies the 3-task minimum with easy→medium→hard range).
|
| 203 |
+
|
| 204 |
+
See spec Section 11 for complete task specifications including fault parameters, red herrings, solution paths, and grader breakdowns.
|
| 205 |
+
|
| 206 |
+
### 4.8 Baseline Agents
|
| 207 |
+
|
| 208 |
+
**Rule-based baseline (submission default, `baseline_heuristic.py`):**
|
| 209 |
+
- Deterministic decision tree: inspect_gradients → check exploding/vanishing → inspect_data → check leakage → check overfitting → inspect_model_modes → inspect_code → fallback
|
| 210 |
+
- No API key required. Bit-exact reproducible.
|
| 211 |
+
- Used for Phase 1 auto-validation reproducibility checks.
|
| 212 |
+
|
| 213 |
+
**LLM baseline (optional, `baseline_inference.py`):**
|
| 214 |
+
- GPT-4o at temperature=0.0, seed=42
|
| 215 |
+
- Requires `OPENAI_API_KEY` environment variable
|
| 216 |
+
- Supplementary demonstration of heuristic vs. reasoning score gap
|
| 217 |
+
- Not used for Phase 1 reproducibility — scores reported only after empirical measurement
|
| 218 |
+
|
| 219 |
+
### 4.9 Required Endpoints
|
| 220 |
+
|
| 221 |
+
| Endpoint | Method | Required By | Response |
|
| 222 |
+
|---|---|---|---|
|
| 223 |
+
| `/ws` | WebSocket | OpenEnv framework | Handles `reset`, `step`, `state` messages |
|
| 224 |
+
| `/tasks` | GET | Hackathon | Task list with IDs, difficulties, MLTrainingAction JSON schema |
|
| 225 |
+
| `/grader` | POST | Hackathon | `{"score": float, "task_id": str, "steps": int}` for last completed episode |
|
| 226 |
+
| `/baseline` | POST | Hackathon | Triggers baseline run, returns `{"scores": {"task_001": float, ...}}` |
|
| 227 |
+
| `/health` | GET | Hackathon | `{"status": "ready", "tasks": N}` — N is active task count |
|
| 228 |
+
| `/dashboard` | GET | Bonus | Live diagnostic dashboard (HTML/JS, Plotly.js via CDN) |
|
| 229 |
+
| `/validation-report` | GET | Bonus | Pre-computed PyTorch fidelity reports |
|
| 230 |
+
|
| 231 |
+
Framework auto-provides: `POST /reset`, `POST /step`, `GET /state`, `GET /schema`, `GET /docs`, `/mcp`.
|
| 232 |
+
|
| 233 |
+
### 4.10 Error Handling
|
| 234 |
+
|
| 235 |
+
`step()` must never raise an unhandled exception. All invalid actions return a valid observation with -0.05 penalty and an error note. See spec Section 16 for the complete error handling matrix covering all edge cases (invalid actions, malformed JSON, step before reset, etc.).
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## 5. Non-Functional Requirements
|
| 240 |
+
|
| 241 |
+
### 5.1 OpenEnv Spec Compliance
|
| 242 |
+
|
| 243 |
+
| Requirement | Implementation |
|
| 244 |
+
|---|---|
|
| 245 |
+
| `openenv.yaml` present | Name, version, description, framework, tags, observation/action space, tasks with IDs+difficulties+max_steps, reward config, endpoints |
|
| 246 |
+
| Typed Pydantic models | `MLTrainingAction` extends `Action`, `MLTrainingObservation` extends `Observation`, all fields explicitly typed |
|
| 247 |
+
| `step()`/`reset()`/`state()` | Implemented in `MLTrainingEnvironment` extending `Environment` from `openenv.core.env_server.interfaces` |
|
| 248 |
+
| `openenv validate` passes | Tested before every submission |
|
| 249 |
+
|
| 250 |
+
### 5.2 Framework Integration
|
| 251 |
+
|
| 252 |
+
| Requirement | Implementation |
|
| 253 |
+
|---|---|
|
| 254 |
+
| `openenv-core` v0.2.2 | `create_app()` returns standard FastAPI instance — **verified** |
|
| 255 |
+
| Custom routes compose | `/tasks`, `/grader`, `/baseline`, `/health` added via `@app.get()`/`@app.post()` on the returned FastAPI app |
|
| 256 |
+
| Framework-provided routes | `/reset`, `/step`, `/state`, `/ws`, `/schema`, `/docs`, `/mcp` — do not reimplement |
|
| 257 |
+
| Factory pattern | `create_app(MLTrainingEnvironment, ...)` takes the class, not an instance |
|
| 258 |
+
| Concurrent sessions | `SUPPORTS_CONCURRENT_SESSIONS = True`, session state keyed by session ID |
|
| 259 |
+
| Typed client | `client.py` extends `EnvClient` with typed action/observation — used by baseline scripts |
|
| 260 |
+
|
| 261 |
+
### 5.3 Docker & Deployment
|
| 262 |
+
|
| 263 |
+
| Requirement | Target |
|
| 264 |
+
|---|---|
|
| 265 |
+
| Base image | `python:3.12-slim` |
|
| 266 |
+
| PyTorch | CPU-only wheel (`--index-url https://download.pytorch.org/whl/cpu`), ~150MB |
|
| 267 |
+
| Total image size | <500MB |
|
| 268 |
+
| Build time | <5 min (no real training during build; validation reports pre-computed) |
|
| 269 |
+
| HF Spaces | Tagged with `openenv`, port 7860 |
|
| 270 |
+
| Health check | `/health` returns `{"status": "ready", "tasks": N}` within 60s of container start |
|
| 271 |
+
|
| 272 |
+
### 5.4 Reproducibility
|
| 273 |
+
|
| 274 |
+
| Requirement | Implementation |
|
| 275 |
+
|---|---|
|
| 276 |
+
| Deterministic episodes | `torch.manual_seed(seed)` at every `reset()`, seed derived deterministically from task ID |
|
| 277 |
+
| Baseline bit-exact | Rule-based baseline produces identical scores on two consecutive runs |
|
| 278 |
+
| Exploit resistance | Parameters randomized per `reset()` from defined ranges; opaque task IDs |
|
| 279 |
+
| Grader determinism | Same `EpisodeState` always produces same score |
|
| 280 |
+
|
| 281 |
+
### 5.5 Performance
|
| 282 |
+
|
| 283 |
+
| Requirement | Target |
|
| 284 |
+
|---|---|
|
| 285 |
+
| `reset()` latency | <200ms (model instantiation + 2 forward passes + parametric curves) |
|
| 286 |
+
| `step()` latency | <10ms (action dispatch + reward + state update) |
|
| 287 |
+
| Memory | <512MB RSS (small CNN ~50K params, no GPU, no large datasets) |
|
| 288 |
+
|
| 289 |
+
### 5.6 Code Quality
|
| 290 |
+
|
| 291 |
+
| Requirement | Standard |
|
| 292 |
+
|---|---|
|
| 293 |
+
| Formatting | black (line length 88) |
|
| 294 |
+
| Linting | ruff |
|
| 295 |
+
| Import ordering | isort (profile=black) |
|
| 296 |
+
| Type hints | Every function signature and return type |
|
| 297 |
+
| Tests | pytest, >80% coverage, every module has corresponding test file |
|
| 298 |
+
| PyTorch-native | All core computation uses `torch.Tensor`, zero numpy in core modules |
|
| 299 |
+
|
| 300 |
+
---
|
| 301 |
+
|
| 302 |
+
## 6. Prioritized Scope
|
| 303 |
+
|
| 304 |
+
### Tier 1: MVP (Must Ship First)
|
| 305 |
+
|
| 306 |
+
**Deadline within deadline:** Deploy to HF Spaces by Day 6 (April 2). Everything after is additive.
|
| 307 |
+
|
| 308 |
+
| Deliverable | Description | DQ Risk if Missing |
|
| 309 |
+
|---|---|---|
|
| 310 |
+
| Task 1 (`task_001`) | Exploding gradients — easy | Yes (need 3+ tasks) |
|
| 311 |
+
| Task 3 (`task_003`) | Silent data leakage — medium | Yes (need 3+ tasks) |
|
| 312 |
+
| Task 5 (`task_005`) | BatchNorm eval mode — hard | Yes (need easy→hard range) |
|
| 313 |
+
| Context-gated penalty | -0.20 for `add_callback` after `gradients_were_normal` | No (but kills differentiation) |
|
| 314 |
+
| Rule-based baseline | `baseline_heuristic.py`, deterministic, no API key | Yes (baseline required) |
|
| 315 |
+
| Reward engine | All 7 reward components implemented exactly | Yes (reward logic required) |
|
| 316 |
+
| Graders (3) | One per MVP task, 0.0-1.0, deterministic | Yes (graders required) |
|
| 317 |
+
| `openenv.yaml` | Full metadata, 3+ tasks listed | Yes (spec compliance) |
|
| 318 |
+
| Required endpoints | `/tasks`, `/grader`, `/baseline`, `/health` | Yes (auto-validator checks) |
|
| 319 |
+
| Dockerfile | Builds and runs, port 7860 | Yes (auto-validator checks) |
|
| 320 |
+
| HF Space | Deployed, tagged `openenv`, responds to `reset()` | Yes (auto-validator pings) |
|
| 321 |
+
| README | Environment description, action/observation spaces, task descriptions, setup instructions, baseline scores | Yes (submission requirement) |
|
| 322 |
+
|
| 323 |
+
### Tier 2: Strongest Differentiator (Add Immediately After MVP)
|
| 324 |
+
|
| 325 |
+
| Deliverable | Description | Why This Order |
|
| 326 |
+
|---|---|---|
|
| 327 |
+
| Task 6 (`task_006`) | PyTorch code bug — hard, code-level debugging | Single highest-impact feature for Meta judges |
|
| 328 |
+
| Code fix validation | Multi-strategy pipeline (tokenize, AST, semantic patterns) | Required for Task 6 to work with LLM agents |
|
| 329 |
+
| Grader for Task 6 | `code_bug` diagnosis, code fix scoring | Completes Task 6 |
|
| 330 |
+
|
| 331 |
+
### Tier 3: Full Task Coverage (Time Permitting)
|
| 332 |
+
|
| 333 |
+
| Deliverable | Description |
|
| 334 |
+
|---|---|
|
| 335 |
+
| Task 2 (`task_002`) | Vanishing gradients — easy (similar to Task 1, fast to implement) |
|
| 336 |
+
| Task 4 (`task_004`) | Overfitting — medium (train-val divergence, regularization fix) |
|
| 337 |
+
| Graders for Tasks 2 & 4 | Same pattern as existing graders |
|
| 338 |
+
|
| 339 |
+
### Tier 4: Polish & Extras (Only After Tiers 1-3 Complete)
|
| 340 |
+
|
| 341 |
+
| Deliverable | Description | Priority Within Tier |
|
| 342 |
+
|---|---|---|
|
| 343 |
+
| Live dashboard | HTML/JS at `/dashboard`, Plotly.js via CDN, 4-panel layout | 1st — transforms judging experience |
|
| 344 |
+
| PyTorch validation suite | 6 scripts proving parametric curves match real training, R² > 0.85 | 2nd — answers "how realistic?" |
|
| 345 |
+
| Validation report endpoint | `GET /validation-report` serving pre-computed fidelity plots | With validation suite |
|
| 346 |
+
| LLM baseline | `baseline_inference.py`, GPT-4o, measures heuristic-LLM gap | 3rd — supplementary demonstration |
|
| 347 |
+
|
| 348 |
+
### Implementation Timeline (11 days: March 28 - April 8)
|
| 349 |
+
|
| 350 |
+
| Days | Focus | Exit Criteria |
|
| 351 |
+
|---|---|---|
|
| 352 |
+
| 1-2 | Skeleton server + Task 1 end-to-end | `reset()` → `step()` → `grader` works for one task, Docker builds |
|
| 353 |
+
| 3-5 | Tasks 3 & 5 + reward engine + baseline | All 3 MVP tasks pass grader, `baseline_heuristic.py` reproduces |
|
| 354 |
+
| 6 | **Deploy MVP to HF Spaces** | Auto-validation passes. This is the insurance policy. |
|
| 355 |
+
| 7-8 | Task 6 (code debugging) | Code fix validation works for all 4 bug variants |
|
| 356 |
+
| 9-10 | Tasks 2 & 4 + dashboard | Full 6-task environment, dashboard shows agent behavior |
|
| 357 |
+
| 11 | Polish, README, final smoke test | Submission-ready |
|
| 358 |
+
|
| 359 |
+
### What We Will NOT Build (Explicit Exclusions)
|
| 360 |
+
|
| 361 |
+
- No game or toy environments
|
| 362 |
+
- No numpy in core modules (torch.Tensor only)
|
| 363 |
+
- No free-text diagnosis (closed enum only)
|
| 364 |
+
- No grader that sums step rewards (holistic evaluation only)
|
| 365 |
+
- No cumulative step penalty (flat -0.01 only, never -0.01 * step_count)
|
| 366 |
+
- No accommodation support or non-RL features
|
| 367 |
+
- No multi-GPU or CUDA dependencies (CPU-only PyTorch)
|
README.md
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch Training Run Debugger
|
| 2 |
+
|
| 3 |
+
**OpenEnv RL Environment** — Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, Round 1
|
| 4 |
+
|
| 5 |
+
An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
|
| 6 |
+
|
| 7 |
+
## What Is This?
|
| 8 |
+
|
| 9 |
+
This environment recreates the experience of an ML engineer facing a broken PyTorch training job. The agent receives a snapshot of a failing training run and must:
|
| 10 |
+
|
| 11 |
+
1. **Investigate** — inspect gradients, data batches, model weights, model modes, and code
|
| 12 |
+
2. **Diagnose** — identify the root cause from a closed set of known ML failures
|
| 13 |
+
3. **Fix** — apply the correct intervention (reduce LR, patch data, fix model mode, etc.)
|
| 14 |
+
4. **Verify** — restart training and confirm recovery before submitting diagnosis
|
| 15 |
+
|
| 16 |
+
### Key Differentiators
|
| 17 |
+
|
| 18 |
+
- **PyTorch-native internals** — Real `torch.nn.Module` models (~50K params), real `torch.autograd` gradients, real `state_dict()` weight snapshots
|
| 19 |
+
- **Context-gated reward shaping** — Penalty fires only when agent ignores evidence it already gathered; no penalty for reasonable priors
|
| 20 |
+
- **Progressive information reveal** — Gradient stats, weight stats, data batch stats only populated after corresponding inspection actions
|
| 21 |
+
|
| 22 |
+
## Environment Design
|
| 23 |
+
|
| 24 |
+
### Observation Space (`MLTrainingObservation`)
|
| 25 |
+
|
| 26 |
+
| Field | Type | Visibility |
|
| 27 |
+
|-------|------|-----------|
|
| 28 |
+
| `training_loss_history` | `list[float]` (20 epochs) | Always |
|
| 29 |
+
| `val_accuracy_history` | `list[float]` (20 epochs) | Always |
|
| 30 |
+
| `val_loss_history` | `list[float]` (20 epochs) | Always |
|
| 31 |
+
| `current_config` | `TrainingConfig` | Always |
|
| 32 |
+
| `error_log` | `Optional[str]` | Always |
|
| 33 |
+
| `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
|
| 34 |
+
| `model_weight_stats` | `Optional[list[ModelWeightStats]]` | After `inspect_model_weights` |
|
| 35 |
+
| `data_batch_stats` | `Optional[DataBatchStats]` | After `inspect_data_batch` |
|
| 36 |
+
| `model_mode_info` | `Optional[dict[str, str]]` | After `inspect_model_modes` |
|
| 37 |
+
| `code_snippet` | `Optional[CodeSnippet]` | After `inspect_code` |
|
| 38 |
+
| `available_actions` | `list[str]` | Always (dynamic) |
|
| 39 |
+
| `episode_state` | `EpisodeState` | Always |
|
| 40 |
+
|
| 41 |
+
### Action Space (`MLTrainingAction`)
|
| 42 |
+
|
| 43 |
+
| Category | Actions |
|
| 44 |
+
|----------|---------|
|
| 45 |
+
| **Investigation** | `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code` |
|
| 46 |
+
| **Fix** | `modify_config`, `add_callback`, `replace_optimizer`, `patch_data_loader`, `fix_model_mode`, `fix_code` |
|
| 47 |
+
| **Terminal** | `restart_run`, `mark_diagnosed` |
|
| 48 |
+
|
| 49 |
+
Dynamic availability: `restart_run` requires a fix first; `fix_code` requires code inspection; `mark_diagnosed` disappears after submission.
|
| 50 |
+
|
| 51 |
+
### Diagnosis Enum
|
| 52 |
+
|
| 53 |
+
| Value | Description |
|
| 54 |
+
|-------|-------------|
|
| 55 |
+
| `lr_too_high` | Learning rate too large |
|
| 56 |
+
| `vanishing_gradients` | Gradients decay to near-zero |
|
| 57 |
+
| `data_leakage` | Validation samples in training |
|
| 58 |
+
| `overfitting` | Model memorizing, failing to generalize |
|
| 59 |
+
| `batchnorm_eval_mode` | Model in eval mode during training |
|
| 60 |
+
| `code_bug` | Bug in PyTorch training code |
|
| 61 |
+
|
| 62 |
+
### Reward Function
|
| 63 |
+
|
| 64 |
+
| Event | Reward | Gate |
|
| 65 |
+
|-------|--------|------|
|
| 66 |
+
| Any step | -0.01 | Flat, unconditional |
|
| 67 |
+
| First-time inspection | +0.05 | Per inspection type |
|
| 68 |
+
| `add_callback` after normal gradients | -0.20 | `gradients_inspected AND gradients_were_normal` |
|
| 69 |
+
| Invalid action | -0.05 | Action not in `available_actions` |
|
| 70 |
+
| Correct diagnosis | +0.50 | Equality check |
|
| 71 |
+
| Wrong diagnosis | -0.30 | Inequality check |
|
| 72 |
+
| Convergence after fix+restart | +0.40 | All gates met |
|
| 73 |
+
|
| 74 |
+
## Tasks
|
| 75 |
+
|
| 76 |
+
| ID | Difficulty | Root Cause | Description |
|
| 77 |
+
|----|-----------|------------|-------------|
|
| 78 |
+
| `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
|
| 79 |
+
| `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
|
| 80 |
+
| `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
|
| 81 |
+
|
| 82 |
+
## Baseline Scores
|
| 83 |
+
|
| 84 |
+
Rule-based heuristic baseline (deterministic, no API key):
|
| 85 |
+
|
| 86 |
+
| Task | Score |
|
| 87 |
+
|------|-------|
|
| 88 |
+
| `task_001` | 1.00 |
|
| 89 |
+
| `task_003` | 1.00 |
|
| 90 |
+
| `task_005` | 0.35 |
|
| 91 |
+
|
| 92 |
+
## Setup
|
| 93 |
+
|
| 94 |
+
### Local Development
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
# Create virtual environment
|
| 98 |
+
python3 -m venv .venv
|
| 99 |
+
source .venv/bin/activate
|
| 100 |
+
|
| 101 |
+
# Install dependencies
|
| 102 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 103 |
+
pip install openenv-core pydantic fastapi uvicorn
|
| 104 |
+
|
| 105 |
+
# Install dev tools
|
| 106 |
+
pip install pytest pytest-cov black ruff isort
|
| 107 |
+
|
| 108 |
+
# Start server
|
| 109 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 110 |
+
|
| 111 |
+
# Run tests
|
| 112 |
+
pytest tests/ -v --cov=ml_training_debugger
|
| 113 |
+
|
| 114 |
+
# Run baseline
|
| 115 |
+
python baseline_heuristic.py
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Docker
|
| 119 |
+
|
| 120 |
+
```bash
|
| 121 |
+
docker build -t pytorch-debugger .
|
| 122 |
+
docker run -p 7860:7860 pytorch-debugger
|
| 123 |
+
curl http://localhost:7860/health
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
## API Endpoints
|
| 127 |
+
|
| 128 |
+
| Endpoint | Method | Description |
|
| 129 |
+
|----------|--------|-------------|
|
| 130 |
+
| `/health` | GET | `{"status": "ready", "tasks": 3}` |
|
| 131 |
+
| `/tasks` | GET | Task list with action schema |
|
| 132 |
+
| `/grader` | POST | Grader score for last completed episode |
|
| 133 |
+
| `/baseline` | POST | Run baseline, return scores |
|
| 134 |
+
| `/ws` | WebSocket | Primary agent interface |
|
| 135 |
+
| `/reset` | POST | Reset environment (framework) |
|
| 136 |
+
| `/step` | POST | Execute action (framework) |
|
| 137 |
+
| `/state` | GET | Current state (framework) |
|
| 138 |
+
| `/schema` | GET | Action/observation schemas (framework) |
|
| 139 |
+
| `/docs` | GET | Swagger UI (framework) |
|
| 140 |
+
|
| 141 |
+
## Architecture
|
| 142 |
+
|
| 143 |
+
- **Python 3.12** · PyTorch CPU-only · openenv-core
|
| 144 |
+
- Real `torch.nn.Module` models with real `torch.autograd` gradients
|
| 145 |
+
- Parametric curve generation for loss/accuracy histories (sub-ms latency)
|
| 146 |
+
- Typed Pydantic models everywhere — no `Dict[str, Any]`
|
| 147 |
+
- `import torch` in every core module — zero numpy in core
|
| 148 |
+
- Session isolation via per-session `EpisodeState`
|
| 149 |
+
- Deterministic reproducibility via `torch.manual_seed()`
|
ROADMAP.md
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ROADMAP — PyTorch Training Run Debugger
|
| 2 |
+
|
| 3 |
+
**Timeline:** March 28 - April 8, 2026 (11 days)
|
| 4 |
+
**Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
|
| 5 |
+
**Governing documents:** `ml-training-debugger-spec.md` (source of truth), `PRD.md` (requirements), `CLAUDE.md` (coding rules)
|
| 6 |
+
**Iron rule:** No phase begins until the previous phase's acceptance criteria are met. The single exception: Phase 0 and Phase 1 file creation can overlap on Day 1.
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
## Phase 0: Setup & Validation (Days 1-2)
|
| 11 |
+
|
| 12 |
+
**Goal:** A running skeleton server that proves the toolchain works end-to-end. Zero business logic — just plumbing.
|
| 13 |
+
|
| 14 |
+
### 0.1 Files to Create
|
| 15 |
+
|
| 16 |
+
| File | Purpose | Lines (est.) |
|
| 17 |
+
|---|---|---|
|
| 18 |
+
| `ML Debugger/` (this directory) | Project root directory (git init here) | — |
|
| 19 |
+
| `pyproject.toml` | Project metadata, dependencies (torch CPU, openenv-core, pydantic>=2.0, fastapi, uvicorn, pytest, black, ruff, isort) | ~40 |
|
| 20 |
+
| `requirements.txt` | Flat dependency list mirroring pyproject.toml (Docker uses this). **Exclude openai** — deferred to Phase 3. | ~10 |
|
| 21 |
+
| `.python-version` | `3.12` | 1 |
|
| 22 |
+
| `openenv.yaml` | Full metadata — start with 3 MVP tasks (task_001, task_003, task_005), expand later | ~50 |
|
| 23 |
+
| `Dockerfile` | `python:3.12-slim`, torch CPU-only, openenv-core, app deps, port 7860 | ~15 |
|
| 24 |
+
| `.dockerignore` | Exclude `.venv/`, `__pycache__/`, `.git/`, `validation/reports/*.png` | ~10 |
|
| 25 |
+
| `.gitignore` | `.venv/`, `__pycache__/`, `*.pyc`, `.env`, `run*.json` | ~15 |
|
| 26 |
+
| `ml_training_debugger/__init__.py` | Package init, version string | ~3 |
|
| 27 |
+
| `ml_training_debugger/models.py` | **Stub only:** `RootCauseDiagnosis` enum, `EpisodeState`, `TrainingConfig`, `GradientStats`, `DataBatchStats`, `ModelWeightStats`, `CodeSnippet`, `MLTrainingObservation` (extends `Observation`), `MLTrainingAction` (extends `Action`). All fields typed, all values defaulted. | ~200 |
|
| 28 |
+
| `ml_training_debugger/client.py` | **Stub:** `MLTrainingEnvClient` extending `EnvClient` with `action_type = MLTrainingAction` and `observation_type = MLTrainingObservation`. Used by baseline scripts. | ~20 |
|
| 29 |
+
| `server/__init__.py` | Empty | 0 |
|
| 30 |
+
| `server/environment.py` | **Stub:** `MLTrainingEnvironment(Environment)` with `reset()` returning a hardcoded observation, `step()` echoing back, `state` property | ~50 |
|
| 31 |
+
| `server/app.py` | `create_app(MLTrainingEnvironment, MLTrainingAction, MLTrainingObservation)` + stub routes for `/tasks`, `/grader`, `/baseline`, `/health` | ~60 |
|
| 32 |
+
| `tests/__init__.py` | Empty | 0 |
|
| 33 |
+
| `tests/test_models.py` | Validate all Pydantic models instantiate, serialize to JSON, and round-trip | ~60 |
|
| 34 |
+
| `tests/conftest.py` | Shared fixtures: sample `EpisodeState`, sample `ScenarioParams`, sample observation | ~40 |
|
| 35 |
+
|
| 36 |
+
### 0.2 Dependencies to Install
|
| 37 |
+
|
| 38 |
+
```bash
|
| 39 |
+
# Create venv inside ML Debugger/ project root
|
| 40 |
+
python3 -m venv .venv && source .venv/bin/activate
|
| 41 |
+
|
| 42 |
+
# Core runtime
|
| 43 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 44 |
+
pip install openenv-core pydantic>=2.0 fastapi uvicorn
|
| 45 |
+
|
| 46 |
+
# Dev tools
|
| 47 |
+
pip install pytest pytest-cov pytest-asyncio black ruff isort httpx websockets
|
| 48 |
+
|
| 49 |
+
# NOTE: openai is deferred to Phase 3 (LLM baseline). Do NOT install now.
|
| 50 |
+
```
|
| 51 |
+
|
| 52 |
+
### 0.3 Validation Steps (Must All Pass)
|
| 53 |
+
|
| 54 |
+
| # | Command | Expected Result |
|
| 55 |
+
|---|---|---|
|
| 56 |
+
| 1 | `python -c "import torch; print(torch.__version__)"` | Version string, no CUDA |
|
| 57 |
+
| 2 | `python -c "from openenv.core.env_server.http_server import create_app"` | No import error |
|
| 58 |
+
| 3 | `python -c "from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation"` | No import error |
|
| 59 |
+
| 4 | `python -c "from ml_training_debugger.client import MLTrainingEnvClient"` | No import error |
|
| 60 |
+
| 5 | `uvicorn server.app:app --host 0.0.0.0 --port 7860` | Server starts, no crash |
|
| 61 |
+
| 6 | `curl http://localhost:7860/health` | `{"status": "ready", "tasks": 3}` |
|
| 62 |
+
| 7 | `curl http://localhost:7860/tasks` | JSON with task list |
|
| 63 |
+
| 8 | `curl http://localhost:7860/docs` | Swagger UI loads |
|
| 64 |
+
| 9 | `pytest tests/test_models.py -v` | All pass |
|
| 65 |
+
| 10 | `docker build -t pytorch-debugger .` | Builds in <5min, image <500MB |
|
| 66 |
+
| 11 | `docker run -p 7860:7860 pytorch-debugger` then `curl /health` | Returns `{"status": "ready", "tasks": 3}` |
|
| 67 |
+
| 12 | `openenv validate` | Passes (or identify what needs fixing) |
|
| 68 |
+
| 13 | `black --check . && ruff check . && isort --check .` | Clean |
|
| 69 |
+
|
| 70 |
+
### 0.4 Acceptance Criteria
|
| 71 |
+
|
| 72 |
+
- [ ] Skeleton server starts on port 7860 and responds to `/health`, `/tasks`, `/docs`, `/ws`
|
| 73 |
+
- [ ] `/health` returns `{"status": "ready", "tasks": 3}` (task count matches active tasks)
|
| 74 |
+
- [ ] All Pydantic models instantiate without error and serialize to valid JSON
|
| 75 |
+
- [ ] `client.py` imports without error
|
| 76 |
+
- [ ] Docker image builds under 500MB and container starts cleanly
|
| 77 |
+
- [ ] `openenv validate` passes or all failures are documented with a fix plan
|
| 78 |
+
- [ ] `pytest` runs with zero failures
|
| 79 |
+
- [ ] Git repo initialized, first commit made
|
| 80 |
+
|
| 81 |
+
---
|
| 82 |
+
|
| 83 |
+
## Phase 1: MVP — Tasks 1, 3, 5 + Core Engine (Days 2-6)
|
| 84 |
+
|
| 85 |
+
**Goal:** A fully functional 3-task environment that passes all auto-validation gates, deployed to HF Spaces. This is the survival milestone — everything after this is differentiation.
|
| 86 |
+
|
| 87 |
+
### 1.1 Files to Create
|
| 88 |
+
|
| 89 |
+
| File | Purpose | Lines (est.) | Depends On |
|
| 90 |
+
|---|---|---|---|
|
| 91 |
+
| `ml_training_debugger/scenarios.py` | `ScenarioParams` dataclass, `sample_scenario(task_id, seed)` for tasks 001/003/005. Parameter ranges from spec Section 11. | ~120 | `models.py` |
|
| 92 |
+
| `ml_training_debugger/pytorch_engine.py` | `SimpleCNN(torch.nn.Module)`, `inject_fault(model, scenario)`, `extract_gradient_stats(model)`, `extract_weight_stats(model)`. Real torch.autograd. | ~250 | `scenarios.py` |
|
| 93 |
+
| `ml_training_debugger/simulation.py` | `gen_loss_history(scenario)`, `gen_val_accuracy_history(scenario)`, `gen_val_loss_history(scenario)`. All `torch.Tensor` ops. Parametric curves per spec Section 6. | ~180 | `scenarios.py` |
|
| 94 |
+
| `ml_training_debugger/reward_engine.py` | `compute_reward(action, episode_state, scenario) -> float`. All 7 reward components per spec Section 12. Context-gated penalty logic. | ~100 | `models.py` |
|
| 95 |
+
| `ml_training_debugger/graders.py` | `grade_task_001(state, scenario)`, `grade_task_003(...)`, `grade_task_005(...)`. Each returns float in [0.0, 1.0]. Per spec Section 11 grader breakdowns. | ~150 | `models.py` |
|
| 96 |
+
| `baseline_heuristic.py` | Deterministic decision tree agent using `MLTrainingEnvClient`. Runs all MVP tasks, prints JSON scores. | ~150 | `client.py`, server running |
|
| 97 |
+
| `README.md` | Environment description, action/observation spaces, task descriptions with difficulty, setup instructions, baseline scores table | ~200 | Everything |
|
| 98 |
+
|
| 99 |
+
### 1.2 Files to Edit
|
| 100 |
+
|
| 101 |
+
| File | Changes | Why |
|
| 102 |
+
|---|---|---|
|
| 103 |
+
| `ml_training_debugger/models.py` | Finalize all field types, add `available_actions` computation logic to `EpisodeState`, add red herring fields (notes, gpu_memory) | Stubs from Phase 0 become real |
|
| 104 |
+
| `ml_training_debugger/client.py` | Wire typed client to connect via WebSocket or HTTP as needed by baseline | Stub becomes functional |
|
| 105 |
+
| `server/environment.py` | Full `reset()` and `step()` implementations. See spec Sections 9, 13 for lifecycle. | Stubs become real |
|
| 106 |
+
| `server/app.py` | Wire `/tasks`, `/grader`, `/baseline`, `/health` to return real data. `/health` returns `{"status": "ready", "tasks": 3}`. | Stubs become real |
|
| 107 |
+
| `openenv.yaml` | Finalize observation_space, action_space, reward section. Verify task IDs and max_steps per spec Section 14. | Was skeletal in Phase 0 |
|
| 108 |
+
| `Dockerfile` | Add `COPY` for all new source files. Verify build still works. | New files added |
|
| 109 |
+
|
| 110 |
+
### 1.3 Tests to Create
|
| 111 |
+
|
| 112 |
+
| Test File | What It Covers | Critical Assertions |
|
| 113 |
+
|---|---|---|
|
| 114 |
+
| `tests/test_scenarios.py` | `sample_scenario()` for each MVP task | Returns correct root cause enum; params within defined ranges; different seeds produce different params |
|
| 115 |
+
| `tests/test_pytorch_engine.py` | Model instantiation, fault injection, gradient/weight extraction | `SimpleCNN` is a real `torch.nn.Module`; `extract_gradient_stats` returns `GradientStats` with real float norms; exploding fault produces `is_exploding=True`; batchnorm eval fault produces `model.training==False` |
|
| 116 |
+
| `tests/test_simulation.py` | Parametric curve generators | All outputs are `list[float]` of length 20; exploding LR produces diverging loss; leakage produces inflated val_acc; batchnorm produces slow val_acc degradation |
|
| 117 |
+
| `tests/test_reward_engine.py` | All 7 reward components | **Critical:** context-gated penalty fires when `gradients_inspected=True AND gradients_were_normal=True` then `add_callback`; does NOT fire when `add_callback` without prior inspection; step penalty is flat -0.01; investigation bonus is +0.05 first-time only |
|
| 118 |
+
| `tests/test_graders.py` | Graders for tasks 001, 003, 005 | Each returns float in [0.0, 1.0]; correct diagnosis + fix + restart = 1.0; wrong diagnosis < 0.5; partial completion scores between 0 and 1 |
|
| 119 |
+
| `tests/test_episode_lifecycle.py` | Full reset→inspect→fix→restart→diagnose flow | State transitions match spec Section 13; `available_actions` updates correctly; `done=True` after `mark_diagnosed`; step limit triggers `done=True` |
|
| 120 |
+
|
| 121 |
+
### 1.4 Task-Specific Implementation
|
| 122 |
+
|
| 123 |
+
See spec Section 11 for complete task specifications. Key implementation notes per task:
|
| 124 |
+
|
| 125 |
+
**Task 1 (`task_001`, easy):** Unambiguous signal. LR from spec ranges → real gradients explode → `is_exploding=True` on all layers. Straightforward grader.
|
| 126 |
+
|
| 127 |
+
**Task 3 (`task_003`, medium):** Red herring note about architecture upgrade. Data leakage confirmed via `class_overlap_score`. Normal model (no gradient/weight anomaly). Mild gradient elevation on one layer (`is_exploding=False`).
|
| 128 |
+
|
| 129 |
+
**Task 5 (`task_005`, hard):** The differentiator task. `gradients_were_normal=True` set inside `inspect_gradients` handler because `is_exploding=False` on ALL layers (FC spike mean_norm < 10.0). Context-gated penalty fires when agent then calls `add_callback`. Red herrings: FC spike, GPU 91%, conv1 near-vanishing, error_log warning.
|
| 130 |
+
|
| 131 |
+
### 1.5 Endpoint Responses
|
| 132 |
+
|
| 133 |
+
**`GET /health`:** `{"status": "ready", "tasks": 3}` (200) — or `{"status": "initializing"}` (503) during startup.
|
| 134 |
+
|
| 135 |
+
**`GET /tasks`:** Task list with IDs, difficulties, max_steps, and MLTrainingAction JSON schema.
|
| 136 |
+
|
| 137 |
+
**`POST /grader`:** `{"score": float, "task_id": str, "steps": int}` (200) — or `{"score": null, "error": "no_completed_episode"}` (200) if no episode. See spec Section 14 for edge cases.
|
| 138 |
+
|
| 139 |
+
**`POST /baseline`:** Runs baseline logic internally, returns `{"scores": {"task_001": float, "task_003": float, "task_005": float}}`. Returns 409 if already running.
|
| 140 |
+
|
| 141 |
+
### 1.6 Baseline Heuristic Decision Tree
|
| 142 |
+
|
| 143 |
+
See spec Section 17 for the complete decision tree. Summary:
|
| 144 |
+
```
|
| 145 |
+
1. reset(task_id)
|
| 146 |
+
2. inspect_gradients
|
| 147 |
+
3. IF any layer is_exploding → fix LR → restart → diagnose lr_too_high
|
| 148 |
+
4. IF any layer is_vanishing → fix LR → restart → diagnose vanishing_gradients
|
| 149 |
+
5. inspect_data_batch
|
| 150 |
+
6. IF class_overlap_score > 0.5 → patch_data_loader → restart → diagnose data_leakage
|
| 151 |
+
7. IF val_loss diverging → modify weight_decay → restart → diagnose overfitting
|
| 152 |
+
8. inspect_model_modes
|
| 153 |
+
9. IF any layer in "eval" → fix_model_mode → restart → diagnose batchnorm_eval_mode
|
| 154 |
+
10. inspect_code → attempt fix → restart → diagnose code_bug
|
| 155 |
+
11. FALLBACK: diagnose overfitting
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### 1.7 Deploy to HF Spaces
|
| 159 |
+
|
| 160 |
+
| Step | Action | Verification |
|
| 161 |
+
|---|---|---|
|
| 162 |
+
| 1 | Create HF Space (Docker type), tag with `openenv` | Space page shows openenv tag |
|
| 163 |
+
| 2 | Push Dockerfile + source to Space repo | Build triggers automatically |
|
| 164 |
+
| 3 | Wait for build to complete | Build log shows success |
|
| 165 |
+
| 4 | Test health endpoint | `curl https://<space-url>/health` returns `{"status": "ready", "tasks": 3}` |
|
| 166 |
+
| 5 | Test reset via WebSocket | `wscat -c wss://<space-url>/ws` then send `{"type": "reset", "task_id": "task_001"}` |
|
| 167 |
+
| 6 | Run `openenv validate` against deployed space | All checks pass |
|
| 168 |
+
|
| 169 |
+
### 1.8 Acceptance Criteria
|
| 170 |
+
|
| 171 |
+
- [ ] `reset(task_id)` for tasks 001, 003, 005 returns valid `MLTrainingObservation` with correct initial state
|
| 172 |
+
- [ ] `step()` dispatches all 14 action types correctly (investigation, fix, terminal)
|
| 173 |
+
- [ ] `inspect_gradients` on Task 1 → `is_exploding=True` on all layers (real torch.autograd)
|
| 174 |
+
- [ ] `inspect_gradients` on Task 5 → `is_exploding=False` on all layers, `gradients_were_normal=True`
|
| 175 |
+
- [ ] `inspect_data_batch` on Task 3 → `class_overlap_score > 0.5`
|
| 176 |
+
- [ ] `inspect_model_modes` on Task 5 → all layers in "eval" mode
|
| 177 |
+
- [ ] Context-gated penalty: `inspect_gradients`(normal) then `add_callback` → reward includes -0.20
|
| 178 |
+
- [ ] Context-gated penalty: `add_callback` without prior inspection → NO -0.20 penalty
|
| 179 |
+
- [ ] Grader for Task 1: correct path scores 1.0, wrong diagnosis scores < 0.5
|
| 180 |
+
- [ ] Grader for Task 5: agent that chases red herring scores 0.80-0.85 (penalty applied)
|
| 181 |
+
- [ ] `baseline_heuristic.py` runs twice → `diff run1.json run2.json` is empty
|
| 182 |
+
- [ ] `POST /baseline` returns scores for all 3 tasks, all in [0.0, 1.0]
|
| 183 |
+
- [ ] `POST /grader` returns score after completed episode
|
| 184 |
+
- [ ] `GET /tasks` returns 3 tasks with action schema
|
| 185 |
+
- [ ] `GET /health` returns `{"status": "ready", "tasks": 3}`
|
| 186 |
+
- [ ] Docker builds <500MB, starts <60s, serves on port 7860
|
| 187 |
+
- [ ] HF Space deployed, responds to `reset()`, tagged `openenv`
|
| 188 |
+
- [ ] `openenv validate` passes
|
| 189 |
+
- [ ] `pytest --cov` shows >80% coverage on all Phase 1 modules
|
| 190 |
+
- [ ] `import torch` in every core module; zero `import numpy` in core
|
| 191 |
+
- [ ] README has: description, action/observation spaces, 3 task descriptions, setup instructions, baseline scores
|
| 192 |
+
|
| 193 |
+
---
|
| 194 |
+
|
| 195 |
+
## Phase 2: Stretch — Tasks 2, 4, 6 + Code Debugging (Days 7-9)
|
| 196 |
+
|
| 197 |
+
**Goal:** Full 6-task environment with code-level debugging. Task 6 is the single highest-impact differentiator for Meta judges.
|
| 198 |
+
|
| 199 |
+
**Prerequisites:** Phase 1 acceptance criteria ALL met. HF Space deployed and passing auto-validation.
|
| 200 |
+
|
| 201 |
+
### 2.1 Priority Order (Strict)
|
| 202 |
+
|
| 203 |
+
1. **Task 6** first — it is the strongest differentiator and the hardest to implement
|
| 204 |
+
2. **Task 2** second — structurally identical to Task 1 (vanishing vs. exploding), fastest to add
|
| 205 |
+
3. **Task 4** third — medium difficulty overfitting, similar pattern to existing tasks
|
| 206 |
+
|
| 207 |
+
### 2.2 Files to Create
|
| 208 |
+
|
| 209 |
+
| File | Purpose | Lines (est.) | Depends On |
|
| 210 |
+
|---|---|---|---|
|
| 211 |
+
| `ml_training_debugger/code_templates.py` | 4 bug variant templates, `generate_code_snippet(bug_type, seed)`, `validate_fix(bug_type, line, replacement)` with multi-strategy pipeline per spec Section 22 | ~250 | `models.py` |
|
| 212 |
+
| `tests/test_code_templates.py` | All 4 variants generate valid code; fix validation accepts correct fixes; rejects wrong fixes; handles whitespace/comment variations | ~150 | `code_templates.py` |
|
| 213 |
+
|
| 214 |
+
### 2.3 Files to Edit
|
| 215 |
+
|
| 216 |
+
| File | Changes | Complexity |
|
| 217 |
+
|---|---|---|
|
| 218 |
+
| `ml_training_debugger/scenarios.py` | Add `sample_scenario` cases for task_002, task_004, task_006. Task 006 includes `bug_type` field. | Low |
|
| 219 |
+
| `ml_training_debugger/pytorch_engine.py` | Add fault injection for vanishing gradients, overfitting, code bug variants. | Medium |
|
| 220 |
+
| `ml_training_debugger/simulation.py` | Add curve generators for vanishing (flat loss), overfitting (train-val divergence), code bug variants. | Medium |
|
| 221 |
+
| `ml_training_debugger/reward_engine.py` | Add wrong code fix penalty (-0.10). No other changes. | Low |
|
| 222 |
+
| `ml_training_debugger/graders.py` | Add `grade_task_002`, `grade_task_004`, `grade_task_006`. Task 006: diagnosis must be `code_bug` always. | Medium |
|
| 223 |
+
| `server/environment.py` | `step()` handlers for `inspect_code` and `fix_code`. Update `available_actions`. | Medium |
|
| 224 |
+
| `server/app.py` | Update `/tasks` to return 6 tasks. Update `/health` to return `"tasks": 6`. | Low |
|
| 225 |
+
| `openenv.yaml` | Add task_002, task_004, task_006. | Low |
|
| 226 |
+
| `baseline_heuristic.py` | Extend decision tree for vanishing, overfitting, code bug. | Medium |
|
| 227 |
+
| `README.md` | Add descriptions for Tasks 2, 4, 6. Update baseline scores. | Low |
|
| 228 |
+
|
| 229 |
+
### 2.4 Task 6 Code Fix Validation
|
| 230 |
+
|
| 231 |
+
The `validate_fix()` pipeline is defined in spec Section 22 (Known Risks). Key layers:
|
| 232 |
+
|
| 233 |
+
1. **Normalize:** strip whitespace + inline comments → compare against known correct strings
|
| 234 |
+
2. **Tokenize:** Python `tokenize` module, filter noise tokens, compare streams
|
| 235 |
+
3. **Semantic patterns:** 2-3 per variant (e.g. `"criterion("` present AND `".detach()"` absent)
|
| 236 |
+
4. **AST fallback:** `ast.parse()` full code with replacement, verify buggy pattern absent
|
| 237 |
+
|
| 238 |
+
Test cases that MUST pass: correct fix, trailing whitespace, inline comments, different indentation.
|
| 239 |
+
Test cases that MUST fail: bug still present, `pass`, wrong line number.
|
| 240 |
+
|
| 241 |
+
### 2.5 Tests to Create/Extend
|
| 242 |
+
|
| 243 |
+
| Test File | New Coverage |
|
| 244 |
+
|---|---|
|
| 245 |
+
| `tests/test_code_templates.py` | **New file.** All 4 variants, validate_fix accepts/rejects correctly, 5+ whitespace/comment variations per variant |
|
| 246 |
+
| `tests/test_scenarios.py` | Extend: sample_scenario for task_002, 004, 006 |
|
| 247 |
+
| `tests/test_simulation.py` | Extend: vanishing flat loss, overfitting divergence, code bug symptoms |
|
| 248 |
+
| `tests/test_graders.py` | Extend: graders 002, 004, 006. Task 006: `code_bug` required; `batchnorm_eval_mode` on eval_mode variant = wrong |
|
| 249 |
+
| `tests/test_reward_engine.py` | Extend: wrong code fix penalty (-0.10) |
|
| 250 |
+
| `tests/test_episode_lifecycle.py` | Extend: `inspect_code` → `fix_code` available; `fix_code` before `inspect_code` → invalid |
|
| 251 |
+
|
| 252 |
+
### 2.6 Acceptance Criteria
|
| 253 |
+
|
| 254 |
+
- [ ] All 6 tasks return valid observations from `reset()` and process all action types in `step()`
|
| 255 |
+
- [ ] Task 6: `inspect_code` returns `CodeSnippet` with real PyTorch code containing the sampled bug
|
| 256 |
+
- [ ] Task 6: `fix_code` correct → `fix_action_taken=True`, no penalty
|
| 257 |
+
- [ ] Task 6: `fix_code` wrong → -0.10 penalty
|
| 258 |
+
- [ ] Task 6: `mark_diagnosed(code_bug)` → correct (+0.50)
|
| 259 |
+
- [ ] Task 6: `mark_diagnosed(batchnorm_eval_mode)` on eval_mode variant → wrong (-0.30)
|
| 260 |
+
- [ ] `validate_fix` accepts 5+ whitespace/comment variations per variant
|
| 261 |
+
- [ ] `validate_fix` rejects all invalid fixes
|
| 262 |
+
- [ ] Graders for all 6 tasks return [0.0, 1.0] with meaningful variance
|
| 263 |
+
- [ ] `baseline_heuristic.py` handles all 6 tasks, still bit-exact reproducible
|
| 264 |
+
- [ ] `POST /baseline` returns scores for all 6 tasks
|
| 265 |
+
- [ ] `GET /tasks` returns 6 tasks
|
| 266 |
+
- [ ] `GET /health` returns `{"status": "ready", "tasks": 6}`
|
| 267 |
+
- [ ] All new tests pass; overall coverage >80%
|
| 268 |
+
- [ ] Updated openenv.yaml lists all 6 tasks
|
| 269 |
+
- [ ] HF Space redeployed with 6 tasks, auto-validation still passes
|
| 270 |
+
|
| 271 |
+
---
|
| 272 |
+
|
| 273 |
+
## Phase 3: Polish — Dashboard, Validation Suite, LLM Baseline (Days 10-11)
|
| 274 |
+
|
| 275 |
+
**Goal:** Transform a technically correct submission into a visually impressive, deeply validated, winning submission.
|
| 276 |
+
|
| 277 |
+
**Prerequisites:** Phase 2 acceptance criteria ALL met. 6-task environment deployed.
|
| 278 |
+
|
| 279 |
+
### 3.1 Priority Order Within Phase 3
|
| 280 |
+
|
| 281 |
+
1. **Dashboard** — transforms judging experience (highest ROI for judges)
|
| 282 |
+
2. **Full test suite + README polish** — ensures no auto-validation failure
|
| 283 |
+
3. **Validation suite** — answers "how realistic are your curves?"
|
| 284 |
+
4. **LLM baseline** — demonstrates heuristic-reasoning gap (lowest priority)
|
| 285 |
+
|
| 286 |
+
### 3.2 Files to Create
|
| 287 |
+
|
| 288 |
+
| File | Purpose | Lines (est.) | Priority |
|
| 289 |
+
|---|---|---|---|
|
| 290 |
+
| `server/dashboard.html` | Single-file SPA. 4 panels per spec Section 19. Plotly.js via CDN. | ~400 | 1st |
|
| 291 |
+
| `validation/requirements.txt` | `torch`, `matplotlib`, `scipy` | ~3 | 3rd |
|
| 292 |
+
| `validation/conftest.py` | Shared fixtures: CIFAR-10 subset loader, model definitions | ~50 | 3rd |
|
| 293 |
+
| `validation/validate_exploding_gradients.py` | Real training, compare to parametric curve, R² > 0.85 | ~80 | 3rd |
|
| 294 |
+
| `validation/validate_data_leakage.py` | Real training with leakage, compare | ~80 | 3rd |
|
| 295 |
+
| `validation/validate_batchnorm_eval.py` | Real training with `model.eval()`, compare | ~80 | 3rd |
|
| 296 |
+
| `validation/validate_vanishing_gradients.py` | Real gradient decay, compare | ~80 | 3rd |
|
| 297 |
+
| `validation/validate_overfitting.py` | Real train-val divergence, compare | ~80 | 3rd |
|
| 298 |
+
| `validation/validate_code_bugs.py` | Run 4 bug variants, confirm symptoms | ~80 | 3rd |
|
| 299 |
+
| `validation/reports/` | Pre-computed fidelity scores + comparison plots | — | 3rd |
|
| 300 |
+
| `baseline_inference.py` | LLM agent (GPT-4o, temp=0.0, seed=42). Runs all 6 tasks. **Now install openai.** | ~200 | 4th |
|
| 301 |
+
|
| 302 |
+
### 3.3 Files to Edit
|
| 303 |
+
|
| 304 |
+
| File | Changes | Priority |
|
| 305 |
+
|---|---|---|
|
| 306 |
+
| `server/app.py` | Add `GET /dashboard` and `GET /validation-report` routes | 1st/3rd |
|
| 307 |
+
| `requirements.txt` | Add `openai` (only now, for LLM baseline) | 4th |
|
| 308 |
+
| `Dockerfile` | `COPY validation/reports/` and `COPY server/dashboard.html` | 1st |
|
| 309 |
+
| `README.md` | Final polish: dashboard description, validation suite, measured baseline scores | 2nd |
|
| 310 |
+
| `openenv.yaml` | Add dashboard and validation-report to endpoints | 1st |
|
| 311 |
+
|
| 312 |
+
### 3.4 Dashboard Panels
|
| 313 |
+
|
| 314 |
+
See spec Section 19 for full specification. Summary:
|
| 315 |
+
1. **Training Metrics** — Plotly.js line charts for loss/accuracy with restart markers
|
| 316 |
+
2. **Gradient & Weight Heatmap** — color-coded per-layer grid (green/yellow/red/blue)
|
| 317 |
+
3. **Action Timeline** — horizontal bars per step, color-coded by type, reward bars
|
| 318 |
+
4. **Episode Summary** — task ID, state flags, available actions, grader score
|
| 319 |
+
|
| 320 |
+
Tech: single HTML file, Plotly.js CDN, native WebSocket, CSS Grid. Zero Docker bloat.
|
| 321 |
+
|
| 322 |
+
### 3.5 Validation Suite
|
| 323 |
+
|
| 324 |
+
Run locally (NOT in Docker build). Each script: real training → capture metrics → compare to parametric → assert R² > 0.85 → save plots. Pre-computed reports committed to git and served via `/validation-report`. See spec Section 18.
|
| 325 |
+
|
| 326 |
+
### 3.6 Tests to Create/Extend
|
| 327 |
+
|
| 328 |
+
| Test File | Coverage |
|
| 329 |
+
|---|---|
|
| 330 |
+
| `tests/test_dashboard.py` | `GET /dashboard` returns 200 with HTML containing "Plotly" and "WebSocket" |
|
| 331 |
+
| `tests/test_endpoints.py` | Integration: full episode via HTTP (reset→step→grader), verify response schemas |
|
| 332 |
+
| `tests/test_baseline_reproducibility.py` | Run baseline twice, assert identical JSON |
|
| 333 |
+
| Existing test files | Fill coverage gaps to >80% on every module |
|
| 334 |
+
|
| 335 |
+
### 3.7 Acceptance Criteria
|
| 336 |
+
|
| 337 |
+
- [ ] `GET /dashboard` serves HTML that renders in a browser with 4 panels
|
| 338 |
+
- [ ] Dashboard connects to WebSocket and updates in real time during a baseline run
|
| 339 |
+
- [ ] Validation suite passes all scripts with R² > 0.85 (run locally)
|
| 340 |
+
- [ ] Pre-computed validation reports exist in `validation/reports/`
|
| 341 |
+
- [ ] `GET /validation-report` serves fidelity data
|
| 342 |
+
- [ ] LLM baseline runs, scores higher than heuristic on Tasks 5 and 6 (if implemented)
|
| 343 |
+
- [ ] README is complete: all 6 tasks, both baselines, dashboard description, setup instructions
|
| 344 |
+
- [ ] `pytest --cov` shows >80% coverage across all modules
|
| 345 |
+
- [ ] Final `openenv validate` passes
|
| 346 |
+
- [ ] Final Docker build <500MB, starts <60s
|
| 347 |
+
- [ ] HF Space redeployed with dashboard + all features
|
| 348 |
+
|
| 349 |
+
---
|
| 350 |
+
|
| 351 |
+
## Pre-Submission Gate Checklist
|
| 352 |
+
|
| 353 |
+
**Every item must be checked before submitting. Failure on any starred (*) item = disqualification.**
|
| 354 |
+
|
| 355 |
+
### Auto-Validation Gates (*)
|
| 356 |
+
|
| 357 |
+
- [ ] * **HF Space deploys** — `curl https://<space-url>/health` returns `{"status": "ready", "tasks": N}` with HTTP 200
|
| 358 |
+
- [ ] * **HF Space responds to reset** — WebSocket connection to `/ws`, send reset message, receive valid observation
|
| 359 |
+
- [ ] * **OpenEnv spec compliance** — `openenv validate` passes (openenv.yaml present, typed models, step/reset/state work)
|
| 360 |
+
- [ ] * **Dockerfile builds** — `docker build -t pytorch-debugger .` succeeds
|
| 361 |
+
- [ ] * **Docker runs** — `docker run -p 7860:7860 pytorch-debugger` starts and serves on port 7860
|
| 362 |
+
- [ ] * **Baseline reproduces** — `python baseline_heuristic.py > run1.json && python baseline_heuristic.py > run2.json && diff run1.json run2.json` produces no output
|
| 363 |
+
- [ ] * **3+ tasks with graders** — `GET /tasks` returns ≥3 tasks; `POST /grader` returns score in [0.0, 1.0] after each task completes
|
| 364 |
+
- [ ] * **Graders produce varying scores** — different agent behaviors produce different scores (not always same value)
|
| 365 |
+
|
| 366 |
+
### Required Endpoint Gates (*)
|
| 367 |
+
|
| 368 |
+
- [ ] * **`GET /tasks`** — returns JSON with task IDs, difficulties, action schema
|
| 369 |
+
- [ ] * **`POST /grader`** — returns `{"score": float}` after a completed episode
|
| 370 |
+
- [ ] * **`POST /baseline`** — triggers baseline, returns scores for all tasks
|
| 371 |
+
- [ ] * **`GET /health`** — returns `{"status": "ready", "tasks": N}`
|
| 372 |
+
|
| 373 |
+
### Submission Artifacts (*)
|
| 374 |
+
|
| 375 |
+
- [ ] * **Public GitHub repo** — contains all code, README, requirements, openenv.yaml
|
| 376 |
+
- [ ] * **HF Spaces demo link** — deployed, tagged `openenv`, accessible
|
| 377 |
+
- [ ] * **README complete** — environment description, action/observation space definitions, task descriptions with difficulty, setup instructions, baseline scores
|
| 378 |
+
|
| 379 |
+
### Quality Gates (Not DQ, but impact scoring)
|
| 380 |
+
|
| 381 |
+
- [ ] All typed Pydantic models — no `Dict[str, Any]`
|
| 382 |
+
- [ ] `import torch` in every core module — zero `import numpy` in core
|
| 383 |
+
- [ ] Context-gated penalty fires correctly (manually tested both paths)
|
| 384 |
+
- [ ] Task 5 red herrings present: FC spike, GPU 91%, conv1 near-vanishing, error_log warning
|
| 385 |
+
- [ ] Task 6 code fix validation handles whitespace and comment variations
|
| 386 |
+
- [ ] Task 6 diagnosis is always `code_bug` regardless of bug variant
|
| 387 |
+
- [ ] Grader and reward function are separate modules
|
| 388 |
+
- [ ] Step penalty is flat -0.01 (not multiplied by step_count)
|
| 389 |
+
- [ ] Episode state is isolated per WebSocket session
|
| 390 |
+
- [ ] Test suite passes with >80% coverage
|
| 391 |
+
- [ ] Code formatted with black, linted with ruff, imports sorted with isort
|
| 392 |
+
|
| 393 |
+
### Final Smoke Test Sequence
|
| 394 |
+
|
| 395 |
+
Run this entire sequence the night before submission:
|
| 396 |
+
|
| 397 |
+
```bash
|
| 398 |
+
# 1. Clean build
|
| 399 |
+
docker build --no-cache -t pytorch-debugger .
|
| 400 |
+
docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
|
| 401 |
+
|
| 402 |
+
# 2. Wait for startup
|
| 403 |
+
sleep 10
|
| 404 |
+
curl -f http://localhost:7860/health || echo "FAIL: health"
|
| 405 |
+
|
| 406 |
+
# 3. Tasks endpoint
|
| 407 |
+
curl -f http://localhost:7860/tasks | python -m json.tool || echo "FAIL: tasks"
|
| 408 |
+
|
| 409 |
+
# 4. Baseline reproducibility
|
| 410 |
+
python baseline_heuristic.py > run1.json 2>/dev/null
|
| 411 |
+
python baseline_heuristic.py > run2.json 2>/dev/null
|
| 412 |
+
diff run1.json run2.json && echo "PASS: reproducible" || echo "FAIL: non-reproducible"
|
| 413 |
+
|
| 414 |
+
# 5. Baseline via endpoint
|
| 415 |
+
curl -f -X POST http://localhost:7860/baseline | python -m json.tool || echo "FAIL: baseline endpoint"
|
| 416 |
+
|
| 417 |
+
# 6. Grader via endpoint (after baseline has completed episodes)
|
| 418 |
+
curl -f -X POST http://localhost:7860/grader | python -m json.tool || echo "FAIL: grader endpoint"
|
| 419 |
+
|
| 420 |
+
# 7. OpenEnv validation
|
| 421 |
+
openenv validate || echo "FAIL: openenv validate"
|
| 422 |
+
|
| 423 |
+
# 8. Test suite
|
| 424 |
+
pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
|
| 425 |
+
|
| 426 |
+
# 9. Cleanup
|
| 427 |
+
docker stop smoke-test && docker rm smoke-test
|
| 428 |
+
|
| 429 |
+
echo "=== Smoke test complete ==="
|
| 430 |
+
```
|
| 431 |
+
|
| 432 |
+
### If Something Fails at Submission Time
|
| 433 |
+
|
| 434 |
+
| Failure | Triage |
|
| 435 |
+
|---|---|
|
| 436 |
+
| HF Space won't deploy | Check Dockerfile CMD, port 7860, build logs. Redeploy. |
|
| 437 |
+
| Baseline non-reproducible | Check `torch.manual_seed()` in `reset()`. Check for `random` module usage. |
|
| 438 |
+
| Grader returns same score | Check that `sample_scenario` uses different seeds. Check grader logic has branching. |
|
| 439 |
+
| `openenv validate` fails | Read error message. Usually missing field in openenv.yaml or wrong model base class. |
|
| 440 |
+
| Docker image >500MB | Check `docker images` size. Remove unused deps. Ensure torch is CPU-only. |
|
| 441 |
+
| Test coverage <80% | Run `pytest --cov` with `--cov-report=html`. Find uncovered branches. Add targeted tests. |
|
baseline_heuristic.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Rule-based heuristic baseline agent.
|
| 3 |
+
|
| 4 |
+
Deterministic decision tree — no API key required. Bit-exact reproducible.
|
| 5 |
+
Spec reference: Section 17.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python baseline_heuristic.py [--url http://localhost:7860]
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import json
|
| 15 |
+
import sys
|
| 16 |
+
|
| 17 |
+
from ml_training_debugger.graders import grade_episode
|
| 18 |
+
from ml_training_debugger.models import EpisodeState, MLTrainingAction, MLTrainingObservation
|
| 19 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 20 |
+
from server.environment import MLTrainingEnvironment
|
| 21 |
+
|
| 22 |
+
MVP_TASKS = ["task_001", "task_003", "task_005"]
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
|
| 26 |
+
"""Run one heuristic baseline episode. Returns grader score."""
|
| 27 |
+
env = MLTrainingEnvironment()
|
| 28 |
+
obs = env.reset(seed=seed, episode_id=f"baseline_{task_id}", task_id=task_id)
|
| 29 |
+
|
| 30 |
+
# Step 1: inspect_gradients
|
| 31 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 32 |
+
|
| 33 |
+
if obs.gradient_stats:
|
| 34 |
+
# Check exploding
|
| 35 |
+
if any(g.is_exploding for g in obs.gradient_stats):
|
| 36 |
+
obs = env.step(
|
| 37 |
+
MLTrainingAction(
|
| 38 |
+
action_type="modify_config",
|
| 39 |
+
target="learning_rate",
|
| 40 |
+
value=0.001,
|
| 41 |
+
)
|
| 42 |
+
)
|
| 43 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 44 |
+
obs = env.step(
|
| 45 |
+
MLTrainingAction(
|
| 46 |
+
action_type="mark_diagnosed",
|
| 47 |
+
diagnosis="lr_too_high",
|
| 48 |
+
)
|
| 49 |
+
)
|
| 50 |
+
session = env._get_session()
|
| 51 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 52 |
+
|
| 53 |
+
# Check vanishing
|
| 54 |
+
if any(g.is_vanishing for g in obs.gradient_stats):
|
| 55 |
+
obs = env.step(
|
| 56 |
+
MLTrainingAction(
|
| 57 |
+
action_type="modify_config",
|
| 58 |
+
target="learning_rate",
|
| 59 |
+
value=0.01,
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 63 |
+
obs = env.step(
|
| 64 |
+
MLTrainingAction(
|
| 65 |
+
action_type="mark_diagnosed",
|
| 66 |
+
diagnosis="vanishing_gradients",
|
| 67 |
+
)
|
| 68 |
+
)
|
| 69 |
+
session = env._get_session()
|
| 70 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 71 |
+
|
| 72 |
+
# Step 2: inspect_data_batch
|
| 73 |
+
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 74 |
+
if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
|
| 75 |
+
obs = env.step(MLTrainingAction(action_type="patch_data_loader"))
|
| 76 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 77 |
+
obs = env.step(
|
| 78 |
+
MLTrainingAction(
|
| 79 |
+
action_type="mark_diagnosed",
|
| 80 |
+
diagnosis="data_leakage",
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
session = env._get_session()
|
| 84 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 85 |
+
|
| 86 |
+
# Check overfitting (val_loss diverging)
|
| 87 |
+
if obs.val_loss_history and len(obs.val_loss_history) >= 10:
|
| 88 |
+
early = sum(obs.val_loss_history[:5]) / 5
|
| 89 |
+
late = sum(obs.val_loss_history[-5:]) / 5
|
| 90 |
+
if (
|
| 91 |
+
late > early * 1.2
|
| 92 |
+
and obs.data_batch_stats
|
| 93 |
+
and obs.data_batch_stats.class_overlap_score < 0.1
|
| 94 |
+
):
|
| 95 |
+
obs = env.step(
|
| 96 |
+
MLTrainingAction(
|
| 97 |
+
action_type="modify_config",
|
| 98 |
+
target="weight_decay",
|
| 99 |
+
value=0.01,
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 103 |
+
obs = env.step(
|
| 104 |
+
MLTrainingAction(
|
| 105 |
+
action_type="mark_diagnosed",
|
| 106 |
+
diagnosis="overfitting",
|
| 107 |
+
)
|
| 108 |
+
)
|
| 109 |
+
session = env._get_session()
|
| 110 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 111 |
+
|
| 112 |
+
# Step 3: inspect_model_modes
|
| 113 |
+
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 114 |
+
if obs.model_mode_info:
|
| 115 |
+
has_eval = any(v == "eval" for v in obs.model_mode_info.values())
|
| 116 |
+
if has_eval:
|
| 117 |
+
obs = env.step(MLTrainingAction(action_type="fix_model_mode"))
|
| 118 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 119 |
+
obs = env.step(
|
| 120 |
+
MLTrainingAction(
|
| 121 |
+
action_type="mark_diagnosed",
|
| 122 |
+
diagnosis="batchnorm_eval_mode",
|
| 123 |
+
)
|
| 124 |
+
)
|
| 125 |
+
session = env._get_session()
|
| 126 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 127 |
+
|
| 128 |
+
# Step 4: inspect_code
|
| 129 |
+
obs = env.step(MLTrainingAction(action_type="inspect_code"))
|
| 130 |
+
if obs.code_snippet:
|
| 131 |
+
code = obs.code_snippet.code
|
| 132 |
+
if "model.eval()" in code and "model.train()" not in code:
|
| 133 |
+
obs = env.step(
|
| 134 |
+
MLTrainingAction(
|
| 135 |
+
action_type="fix_code",
|
| 136 |
+
line=5,
|
| 137 |
+
replacement="model.train()",
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
elif ".detach()" in code:
|
| 141 |
+
obs = env.step(
|
| 142 |
+
MLTrainingAction(
|
| 143 |
+
action_type="fix_code",
|
| 144 |
+
line=14,
|
| 145 |
+
replacement=" loss = criterion(output, batch_y)",
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
if obs.episode_state.fix_action_taken:
|
| 150 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 151 |
+
|
| 152 |
+
obs = env.step(
|
| 153 |
+
MLTrainingAction(
|
| 154 |
+
action_type="mark_diagnosed",
|
| 155 |
+
diagnosis="code_bug",
|
| 156 |
+
)
|
| 157 |
+
)
|
| 158 |
+
session = env._get_session()
|
| 159 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 160 |
+
|
| 161 |
+
# Fallback
|
| 162 |
+
obs = env.step(
|
| 163 |
+
MLTrainingAction(
|
| 164 |
+
action_type="mark_diagnosed",
|
| 165 |
+
diagnosis="overfitting",
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
session = env._get_session()
|
| 169 |
+
return session.last_score if session and session.last_score is not None else 0.0
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def main() -> None:
|
| 173 |
+
parser = argparse.ArgumentParser(description="Rule-based baseline agent")
|
| 174 |
+
parser.add_argument("--url", default="http://localhost:7860")
|
| 175 |
+
args = parser.parse_args()
|
| 176 |
+
|
| 177 |
+
scores: dict[str, float] = {}
|
| 178 |
+
for task_id in MVP_TASKS:
|
| 179 |
+
score = run_heuristic_episode(task_id)
|
| 180 |
+
scores[task_id] = round(score, 4)
|
| 181 |
+
|
| 182 |
+
print(json.dumps(scores, indent=2))
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
deploy.sh
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
echo "=== PyTorch Training Run Debugger — Pre-Submission Smoke Test ==="
|
| 5 |
+
echo ""
|
| 6 |
+
|
| 7 |
+
# 1. Run tests
|
| 8 |
+
echo "=== 1. Running test suite ==="
|
| 9 |
+
source .venv/bin/activate
|
| 10 |
+
pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
|
| 11 |
+
echo ""
|
| 12 |
+
|
| 13 |
+
# 2. Code formatting check
|
| 14 |
+
echo "=== 2. Code formatting ==="
|
| 15 |
+
black --check ml_training_debugger/ server/ tests/ || { echo "Run: black ml_training_debugger/ server/ tests/"; exit 1; }
|
| 16 |
+
ruff check ml_training_debugger/ server/ tests/ || { echo "Run: ruff check --fix"; exit 1; }
|
| 17 |
+
isort --check ml_training_debugger/ server/ tests/ --profile black || { echo "Run: isort --profile black"; exit 1; }
|
| 18 |
+
echo "PASS: formatting OK"
|
| 19 |
+
echo ""
|
| 20 |
+
|
| 21 |
+
# 3. Baseline reproducibility
|
| 22 |
+
echo "=== 3. Baseline reproducibility ==="
|
| 23 |
+
python baseline_heuristic.py > /tmp/run1.json 2>/dev/null
|
| 24 |
+
python baseline_heuristic.py > /tmp/run2.json 2>/dev/null
|
| 25 |
+
diff /tmp/run1.json /tmp/run2.json && echo "PASS: bit-exact reproducible" || { echo "FAIL: non-reproducible"; exit 1; }
|
| 26 |
+
echo ""
|
| 27 |
+
|
| 28 |
+
# 4. Docker build
|
| 29 |
+
echo "=== 4. Docker build ==="
|
| 30 |
+
docker build -t pytorch-debugger .
|
| 31 |
+
IMAGE_SIZE=$(docker images pytorch-debugger --format "{{.Size}}")
|
| 32 |
+
echo "Image size: $IMAGE_SIZE"
|
| 33 |
+
echo ""
|
| 34 |
+
|
| 35 |
+
# 5. Docker run + health check
|
| 36 |
+
echo "=== 5. Docker run + endpoint checks ==="
|
| 37 |
+
docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
|
| 38 |
+
sleep 10
|
| 39 |
+
|
| 40 |
+
curl -f http://localhost:7860/health || { echo "FAIL: health"; docker stop smoke-test; docker rm smoke-test; exit 1; }
|
| 41 |
+
echo ""
|
| 42 |
+
curl -f http://localhost:7860/tasks || { echo "FAIL: tasks"; docker stop smoke-test; docker rm smoke-test; exit 1; }
|
| 43 |
+
echo ""
|
| 44 |
+
curl -f -X POST http://localhost:7860/grader || { echo "FAIL: grader"; docker stop smoke-test; docker rm smoke-test; exit 1; }
|
| 45 |
+
echo ""
|
| 46 |
+
|
| 47 |
+
# 6. Cleanup
|
| 48 |
+
docker stop smoke-test && docker rm smoke-test
|
| 49 |
+
rm -f /tmp/run1.json /tmp/run2.json
|
| 50 |
+
|
| 51 |
+
echo ""
|
| 52 |
+
echo "=== ALL CHECKS PASSED ==="
|
ml-training-debugger-spec.md
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
ml_training_debugger/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch Training Run Debugger — OpenEnv Environment."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
ml_training_debugger/client.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Typed EnvClient for baseline scripts.
|
| 2 |
+
|
| 3 |
+
Extends GenericEnvClient since we can't easily subclass the
|
| 4 |
+
abstract EnvClient without implementing all transport methods.
|
| 5 |
+
Used by baseline_heuristic.py.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from openenv.core.generic_client import GenericEnvClient
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MLTrainingEnvClient(GenericEnvClient):
|
| 14 |
+
"""Typed client for the PyTorch Training Debugger environment.
|
| 15 |
+
|
| 16 |
+
Wraps GenericEnvClient for convenient use in baselines.
|
| 17 |
+
Actions are sent as dicts matching MLTrainingAction schema.
|
| 18 |
+
Observations are received as dicts matching MLTrainingObservation schema.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
pass
|
ml_training_debugger/code_templates.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch code snippet templates for Task 6 code-level debugging.
|
| 2 |
+
|
| 3 |
+
Each template is a real, syntactically valid Python/PyTorch training script
|
| 4 |
+
with one injected bug. Spec reference: Section 11 (Task 6), Section 22.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import ast
|
| 10 |
+
import io
|
| 11 |
+
import tokenize
|
| 12 |
+
from typing import Optional
|
| 13 |
+
|
| 14 |
+
import torch # noqa: F401 — PyTorch-native project
|
| 15 |
+
|
| 16 |
+
# Bug variant templates: (buggy_code, correct_line_num, correct_replacement)
|
| 17 |
+
_TEMPLATES: dict[str, tuple[str, int, str]] = {
|
| 18 |
+
"eval_mode": (
|
| 19 |
+
"""\
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
|
| 23 |
+
model = SimpleCNN()
|
| 24 |
+
model.eval()
|
| 25 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 26 |
+
criterion = nn.CrossEntropyLoss()
|
| 27 |
+
|
| 28 |
+
for epoch in range(100):
|
| 29 |
+
for batch_x, batch_y in train_loader:
|
| 30 |
+
optimizer.zero_grad()
|
| 31 |
+
output = model(batch_x)
|
| 32 |
+
loss = criterion(output, batch_y)
|
| 33 |
+
loss.backward()
|
| 34 |
+
optimizer.step()""",
|
| 35 |
+
5,
|
| 36 |
+
"model.train()",
|
| 37 |
+
),
|
| 38 |
+
"detach_loss": (
|
| 39 |
+
"""\
|
| 40 |
+
import torch
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
|
| 43 |
+
model = SimpleCNN()
|
| 44 |
+
model.train()
|
| 45 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 46 |
+
criterion = nn.CrossEntropyLoss()
|
| 47 |
+
|
| 48 |
+
for epoch in range(100):
|
| 49 |
+
for batch_x, batch_y in train_loader:
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
output = model(batch_x)
|
| 52 |
+
loss = criterion(output, batch_y).detach()
|
| 53 |
+
loss.backward()
|
| 54 |
+
optimizer.step()""",
|
| 55 |
+
14,
|
| 56 |
+
" loss = criterion(output, batch_y)",
|
| 57 |
+
),
|
| 58 |
+
"zero_grad_missing": (
|
| 59 |
+
"""\
|
| 60 |
+
import torch
|
| 61 |
+
import torch.nn as nn
|
| 62 |
+
|
| 63 |
+
model = SimpleCNN()
|
| 64 |
+
model.train()
|
| 65 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 66 |
+
criterion = nn.CrossEntropyLoss()
|
| 67 |
+
|
| 68 |
+
for epoch in range(100):
|
| 69 |
+
for batch_x, batch_y in train_loader:
|
| 70 |
+
output = model(batch_x)
|
| 71 |
+
loss = criterion(output, batch_y)
|
| 72 |
+
loss.backward()
|
| 73 |
+
optimizer.step()""",
|
| 74 |
+
11,
|
| 75 |
+
" optimizer.zero_grad()",
|
| 76 |
+
),
|
| 77 |
+
"inplace_relu": (
|
| 78 |
+
"""\
|
| 79 |
+
import torch
|
| 80 |
+
import torch.nn as nn
|
| 81 |
+
import torch.nn.functional as F
|
| 82 |
+
|
| 83 |
+
model = SimpleCNN()
|
| 84 |
+
model.train()
|
| 85 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 86 |
+
criterion = nn.CrossEntropyLoss()
|
| 87 |
+
|
| 88 |
+
for epoch in range(100):
|
| 89 |
+
for batch_x, batch_y in train_loader:
|
| 90 |
+
optimizer.zero_grad()
|
| 91 |
+
output = model(batch_x)
|
| 92 |
+
output = F.relu(output, inplace=True)
|
| 93 |
+
loss = criterion(output, batch_y)
|
| 94 |
+
loss.backward()
|
| 95 |
+
optimizer.step()""",
|
| 96 |
+
15,
|
| 97 |
+
" output = F.relu(output)",
|
| 98 |
+
),
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
# Semantic equivalence patterns per bug variant
|
| 102 |
+
_SEMANTIC_PATTERNS: dict[str, list[tuple[str, str]]] = {
|
| 103 |
+
"eval_mode": [
|
| 104 |
+
# (must_contain, must_not_contain)
|
| 105 |
+
("model.train()", "model.eval()"),
|
| 106 |
+
],
|
| 107 |
+
"detach_loss": [
|
| 108 |
+
("criterion(", ".detach()"),
|
| 109 |
+
],
|
| 110 |
+
"zero_grad_missing": [
|
| 111 |
+
("zero_grad()", ""), # just needs zero_grad present
|
| 112 |
+
],
|
| 113 |
+
"inplace_relu": [
|
| 114 |
+
("F.relu(", "inplace=True"),
|
| 115 |
+
],
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def generate_code_snippet(bug_type: str, seed: int = 42) -> dict:
|
| 120 |
+
"""Generate a code snippet with the specified bug.
|
| 121 |
+
|
| 122 |
+
Returns dict with keys: code, filename, line_count, imports, hint.
|
| 123 |
+
"""
|
| 124 |
+
if bug_type not in _TEMPLATES:
|
| 125 |
+
raise ValueError(f"Unknown bug_type: {bug_type}")
|
| 126 |
+
|
| 127 |
+
code, _line, _replacement = _TEMPLATES[bug_type]
|
| 128 |
+
lines = code.strip().split("\n")
|
| 129 |
+
imports = [
|
| 130 |
+
line for line in lines if line.startswith("import ") or line.startswith("from ")
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
hint: Optional[str] = None
|
| 134 |
+
if bug_type == "eval_mode":
|
| 135 |
+
hint = "Check the model mode before the training loop."
|
| 136 |
+
elif bug_type == "detach_loss":
|
| 137 |
+
hint = "Examine how the loss is computed and used."
|
| 138 |
+
|
| 139 |
+
return {
|
| 140 |
+
"code": code,
|
| 141 |
+
"filename": "train.py",
|
| 142 |
+
"line_count": len(lines),
|
| 143 |
+
"imports": imports,
|
| 144 |
+
"hint": hint,
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _normalize_code(s: str) -> str:
|
| 149 |
+
"""Strip whitespace and inline comments for comparison."""
|
| 150 |
+
s = s.strip()
|
| 151 |
+
# Remove inline comments
|
| 152 |
+
result_lines: list[str] = []
|
| 153 |
+
for line in s.split("\n"):
|
| 154 |
+
# Remove trailing comment but preserve strings
|
| 155 |
+
stripped = line.rstrip()
|
| 156 |
+
result_lines.append(stripped)
|
| 157 |
+
return "\n".join(result_lines)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _tokenize_compare(original: str, replacement: str) -> bool:
|
| 161 |
+
"""Compare token streams ignoring whitespace and comments."""
|
| 162 |
+
|
| 163 |
+
def get_tokens(code: str) -> list[tuple[int, str]]:
|
| 164 |
+
try:
|
| 165 |
+
tokens = list(tokenize.generate_tokens(io.StringIO(code).readline))
|
| 166 |
+
# Filter out COMMENT, NL, NEWLINE, INDENT, DEDENT, ENCODING, ENDMARKER
|
| 167 |
+
skip = {
|
| 168 |
+
tokenize.COMMENT,
|
| 169 |
+
tokenize.NL,
|
| 170 |
+
tokenize.NEWLINE,
|
| 171 |
+
tokenize.INDENT,
|
| 172 |
+
tokenize.DEDENT,
|
| 173 |
+
tokenize.ENCODING,
|
| 174 |
+
tokenize.ENDMARKER,
|
| 175 |
+
}
|
| 176 |
+
return [(t.type, t.string) for t in tokens if t.type not in skip]
|
| 177 |
+
except tokenize.TokenError:
|
| 178 |
+
return []
|
| 179 |
+
|
| 180 |
+
return get_tokens(original) == get_tokens(replacement)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def validate_fix(bug_type: str, line: int, replacement: str) -> bool:
|
| 184 |
+
"""Validate a code fix submission.
|
| 185 |
+
|
| 186 |
+
Multi-strategy pipeline per spec Section 22:
|
| 187 |
+
1. Normalize whitespace + strip comments
|
| 188 |
+
2. Token-stream comparison
|
| 189 |
+
3. Semantic equivalence patterns
|
| 190 |
+
4. AST fallback
|
| 191 |
+
"""
|
| 192 |
+
if bug_type not in _TEMPLATES:
|
| 193 |
+
return False
|
| 194 |
+
|
| 195 |
+
code, correct_line, correct_replacement = _TEMPLATES[bug_type]
|
| 196 |
+
lines = code.strip().split("\n")
|
| 197 |
+
|
| 198 |
+
# Check line number is valid
|
| 199 |
+
if line < 1 or line > len(lines):
|
| 200 |
+
return False
|
| 201 |
+
|
| 202 |
+
# For zero_grad_missing, the fix is inserting a line, not replacing
|
| 203 |
+
if bug_type == "zero_grad_missing":
|
| 204 |
+
# Accept if the replacement contains zero_grad
|
| 205 |
+
normalized = _normalize_code(replacement)
|
| 206 |
+
if "zero_grad" in normalized:
|
| 207 |
+
return True
|
| 208 |
+
return False
|
| 209 |
+
|
| 210 |
+
# Strategy 1: Normalize and compare
|
| 211 |
+
norm_replacement = _normalize_code(replacement)
|
| 212 |
+
norm_correct = _normalize_code(correct_replacement)
|
| 213 |
+
if norm_replacement == norm_correct:
|
| 214 |
+
return True
|
| 215 |
+
|
| 216 |
+
# Strategy 2: Token-stream comparison
|
| 217 |
+
if _tokenize_compare(correct_replacement, replacement):
|
| 218 |
+
return True
|
| 219 |
+
|
| 220 |
+
# Strategy 3: Semantic equivalence patterns
|
| 221 |
+
patterns = _SEMANTIC_PATTERNS.get(bug_type, [])
|
| 222 |
+
for must_contain, must_not_contain in patterns:
|
| 223 |
+
if must_contain and must_contain in norm_replacement:
|
| 224 |
+
if not must_not_contain or must_not_contain not in norm_replacement:
|
| 225 |
+
return True
|
| 226 |
+
|
| 227 |
+
# Strategy 4: AST fallback — verify buggy pattern absent
|
| 228 |
+
try:
|
| 229 |
+
# Replace the line in the full code and parse
|
| 230 |
+
new_lines = lines.copy()
|
| 231 |
+
new_lines[line - 1] = replacement.rstrip()
|
| 232 |
+
new_code = "\n".join(new_lines)
|
| 233 |
+
tree = ast.parse(new_code)
|
| 234 |
+
|
| 235 |
+
# Check that the buggy pattern is absent
|
| 236 |
+
ast.dump(tree) # Validates AST is well-formed
|
| 237 |
+
if bug_type == "eval_mode" and "eval" not in replacement.lower():
|
| 238 |
+
if "train" in replacement.lower():
|
| 239 |
+
return True
|
| 240 |
+
if bug_type == "detach_loss" and "detach" not in replacement.lower():
|
| 241 |
+
return True
|
| 242 |
+
if bug_type == "inplace_relu" and "inplace" not in replacement.lower():
|
| 243 |
+
if "relu" in replacement.lower():
|
| 244 |
+
return True
|
| 245 |
+
except SyntaxError:
|
| 246 |
+
pass
|
| 247 |
+
|
| 248 |
+
return False
|
ml_training_debugger/graders.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-task grader functions — returns normalized 0.0-1.0 score at episode end.
|
| 2 |
+
|
| 3 |
+
Separate from reward_engine.py. Evaluates EpisodeState holistically.
|
| 4 |
+
NOT a sum of step rewards. Spec reference: Section 11 grader breakdowns.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch # noqa: F401 — PyTorch-native project
|
| 10 |
+
|
| 11 |
+
from ml_training_debugger.models import EpisodeState
|
| 12 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 13 |
+
|
| 14 |
+
FIX_ACTIONS = frozenset(
|
| 15 |
+
{
|
| 16 |
+
"modify_config",
|
| 17 |
+
"add_callback",
|
| 18 |
+
"replace_optimizer",
|
| 19 |
+
"patch_data_loader",
|
| 20 |
+
"fix_model_mode",
|
| 21 |
+
"fix_code",
|
| 22 |
+
}
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _has_action(state: EpisodeState, action_type: str) -> bool:
|
| 27 |
+
return action_type in state.actions_taken
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _correct_diagnosis(state: EpisodeState, scenario: ScenarioParams) -> bool:
|
| 31 |
+
if not state.diagnosis_submitted:
|
| 32 |
+
return False
|
| 33 |
+
# Find the diagnosis from actions_taken metadata
|
| 34 |
+
# We store "mark_diagnosed:<diagnosis>" in actions_taken
|
| 35 |
+
for action_str in reversed(state.actions_taken):
|
| 36 |
+
if action_str.startswith("mark_diagnosed:"):
|
| 37 |
+
submitted = action_str.split(":", 1)[1]
|
| 38 |
+
return submitted == scenario.root_cause.value
|
| 39 |
+
return False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _submitted_diagnosis(state: EpisodeState) -> str | None:
|
| 43 |
+
for action_str in reversed(state.actions_taken):
|
| 44 |
+
if action_str.startswith("mark_diagnosed:"):
|
| 45 |
+
return action_str.split(":", 1)[1]
|
| 46 |
+
return None
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def grade_task_001(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 50 |
+
"""Grade Task 1 — Exploding Gradients (easy). Spec Section 11."""
|
| 51 |
+
score = 0.0
|
| 52 |
+
|
| 53 |
+
# +0.05 for inspect_gradients
|
| 54 |
+
if state.gradients_inspected:
|
| 55 |
+
score += 0.05
|
| 56 |
+
|
| 57 |
+
# +0.20 for correct fix (modify_config with LR reduction)
|
| 58 |
+
if _has_action(state, "modify_config"):
|
| 59 |
+
score += 0.20
|
| 60 |
+
|
| 61 |
+
# +0.35 for restart with convergence
|
| 62 |
+
if state.restart_after_fix:
|
| 63 |
+
score += 0.35
|
| 64 |
+
|
| 65 |
+
# +0.40 for correct diagnosis
|
| 66 |
+
if _correct_diagnosis(state, scenario):
|
| 67 |
+
score += 0.40
|
| 68 |
+
|
| 69 |
+
return min(1.0, max(0.0, score))
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def grade_task_002(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 73 |
+
"""Grade Task 2 — Vanishing Gradients (easy). Spec Section 11."""
|
| 74 |
+
score = 0.0
|
| 75 |
+
|
| 76 |
+
if state.gradients_inspected:
|
| 77 |
+
score += 0.05
|
| 78 |
+
if _has_action(state, "modify_config"):
|
| 79 |
+
score += 0.20
|
| 80 |
+
if state.restart_after_fix:
|
| 81 |
+
score += 0.35
|
| 82 |
+
if _correct_diagnosis(state, scenario):
|
| 83 |
+
score += 0.40
|
| 84 |
+
|
| 85 |
+
return min(1.0, max(0.0, score))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def grade_task_003(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 89 |
+
"""Grade Task 3 — Silent Data Leakage (medium). Spec Section 11."""
|
| 90 |
+
score = 0.0
|
| 91 |
+
|
| 92 |
+
# +0.05 for inspect_data_batch
|
| 93 |
+
if state.data_inspected:
|
| 94 |
+
score += 0.05
|
| 95 |
+
|
| 96 |
+
# +0.30 for patch_data_loader
|
| 97 |
+
if _has_action(state, "patch_data_loader"):
|
| 98 |
+
score += 0.30
|
| 99 |
+
|
| 100 |
+
# +0.30 for restart with convergence (val accuracy normalizes)
|
| 101 |
+
if state.restart_after_fix:
|
| 102 |
+
score += 0.30
|
| 103 |
+
|
| 104 |
+
# +0.35 for correct diagnosis
|
| 105 |
+
if _correct_diagnosis(state, scenario):
|
| 106 |
+
score += 0.35
|
| 107 |
+
|
| 108 |
+
return min(1.0, max(0.0, score))
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def grade_task_004(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 112 |
+
"""Grade Task 4 — Overfitting (medium). Spec Section 11."""
|
| 113 |
+
score = 0.0
|
| 114 |
+
|
| 115 |
+
if state.data_inspected:
|
| 116 |
+
score += 0.05
|
| 117 |
+
if _has_action(state, "modify_config") or _has_action(state, "add_callback"):
|
| 118 |
+
score += 0.25
|
| 119 |
+
if state.restart_after_fix:
|
| 120 |
+
score += 0.30
|
| 121 |
+
if _correct_diagnosis(state, scenario):
|
| 122 |
+
score += 0.40
|
| 123 |
+
|
| 124 |
+
return min(1.0, max(0.0, score))
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def grade_task_005(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 128 |
+
"""Grade Task 5 — BatchNorm Eval Mode (hard). Spec Section 11.
|
| 129 |
+
|
| 130 |
+
Context-gated penalty: -0.20 if add_callback after gradients_were_normal.
|
| 131 |
+
"""
|
| 132 |
+
score = 0.0
|
| 133 |
+
|
| 134 |
+
# +0.05 for inspect_gradients
|
| 135 |
+
if state.gradients_inspected:
|
| 136 |
+
score += 0.05
|
| 137 |
+
|
| 138 |
+
# +0.05 for inspect_model_modes — the revealing action
|
| 139 |
+
if state.model_modes_inspected:
|
| 140 |
+
score += 0.05
|
| 141 |
+
|
| 142 |
+
# -0.20 for add_callback after gradients_were_normal
|
| 143 |
+
if (
|
| 144 |
+
_has_action(state, "add_callback")
|
| 145 |
+
and state.gradients_inspected
|
| 146 |
+
and state.gradients_were_normal
|
| 147 |
+
):
|
| 148 |
+
score -= 0.20
|
| 149 |
+
|
| 150 |
+
# +0.25 for fix_model_mode
|
| 151 |
+
if _has_action(state, "fix_model_mode"):
|
| 152 |
+
score += 0.25
|
| 153 |
+
|
| 154 |
+
# +0.30 for restart with convergence
|
| 155 |
+
if state.restart_after_fix:
|
| 156 |
+
score += 0.30
|
| 157 |
+
|
| 158 |
+
# +0.40 for correct diagnosis
|
| 159 |
+
if _correct_diagnosis(state, scenario):
|
| 160 |
+
score += 0.40
|
| 161 |
+
|
| 162 |
+
return min(1.0, max(0.0, score))
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 166 |
+
"""Grade Task 6 — PyTorch Code Bug (hard). Spec Section 11.
|
| 167 |
+
|
| 168 |
+
Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
|
| 169 |
+
"""
|
| 170 |
+
score = 0.0
|
| 171 |
+
|
| 172 |
+
# +0.05 for inspect_code
|
| 173 |
+
if state.code_inspected:
|
| 174 |
+
score += 0.05
|
| 175 |
+
|
| 176 |
+
# +0.30 for correct code fix
|
| 177 |
+
if _has_action(state, "fix_code") and state.fix_action_taken:
|
| 178 |
+
score += 0.30
|
| 179 |
+
|
| 180 |
+
# +0.25 for restart with convergence
|
| 181 |
+
if state.restart_after_fix:
|
| 182 |
+
score += 0.25
|
| 183 |
+
|
| 184 |
+
# +0.40 for correct diagnosis (must be code_bug)
|
| 185 |
+
if _correct_diagnosis(state, scenario):
|
| 186 |
+
score += 0.40
|
| 187 |
+
|
| 188 |
+
return min(1.0, max(0.0, score))
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
# Registry mapping task IDs to grader functions
|
| 192 |
+
GRADERS = {
|
| 193 |
+
"task_001": grade_task_001,
|
| 194 |
+
"task_002": grade_task_002,
|
| 195 |
+
"task_003": grade_task_003,
|
| 196 |
+
"task_004": grade_task_004,
|
| 197 |
+
"task_005": grade_task_005,
|
| 198 |
+
"task_006": grade_task_006,
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def grade_episode(task_id: str, state: EpisodeState, scenario: ScenarioParams) -> float:
|
| 203 |
+
"""Grade a completed episode. Returns 0.0-1.0."""
|
| 204 |
+
grader = GRADERS.get(task_id)
|
| 205 |
+
if grader is None:
|
| 206 |
+
return 0.0
|
| 207 |
+
return grader(state, scenario)
|
ml_training_debugger/models.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""All Pydantic models, enums, and typed data structures.
|
| 2 |
+
|
| 3 |
+
No business logic. Pure data definitions.
|
| 4 |
+
Spec reference: Section 10 — Data Models.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import enum
|
| 10 |
+
from typing import Optional, Union
|
| 11 |
+
|
| 12 |
+
import torch # noqa: F401 — PyTorch-native project, required import
|
| 13 |
+
from openenv.core.env_server.types import Action, Observation
|
| 14 |
+
from pydantic import BaseModel, Field
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RootCauseDiagnosis(str, enum.Enum):
|
| 18 |
+
"""Closed enumeration of ML failure root causes. Spec Section 10."""
|
| 19 |
+
|
| 20 |
+
LR_TOO_HIGH = "lr_too_high"
|
| 21 |
+
VANISHING_GRADIENTS = "vanishing_gradients"
|
| 22 |
+
DATA_LEAKAGE = "data_leakage"
|
| 23 |
+
OVERFITTING = "overfitting"
|
| 24 |
+
BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
|
| 25 |
+
CODE_BUG = "code_bug"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
VALID_DIAGNOSES: set[str] = {d.value for d in RootCauseDiagnosis}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TrainingConfig(BaseModel):
|
| 32 |
+
"""Typed hyperparameter configuration. Spec Section 10."""
|
| 33 |
+
|
| 34 |
+
learning_rate: float = 0.001
|
| 35 |
+
weight_decay: float = 0.0001
|
| 36 |
+
batch_size: int = 64
|
| 37 |
+
hidden_dim: int = 64
|
| 38 |
+
num_layers: int = 3
|
| 39 |
+
optimizer: str = "adam"
|
| 40 |
+
dropout_rate: float = 0.0
|
| 41 |
+
gradient_clip_norm: Optional[float] = None
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
VALID_CONFIG_KEYS: set[str] = set(TrainingConfig.model_fields.keys())
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class GradientStats(BaseModel):
|
| 48 |
+
"""Per-layer gradient information from real torch.autograd. Spec Section 10."""
|
| 49 |
+
|
| 50 |
+
layer_name: str
|
| 51 |
+
norm_history: list[float]
|
| 52 |
+
mean_norm: float
|
| 53 |
+
max_norm: float
|
| 54 |
+
is_exploding: bool # True when mean_norm > 10.0
|
| 55 |
+
is_vanishing: bool # True when mean_norm < 1e-6
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class ModelWeightStats(BaseModel):
|
| 59 |
+
"""Per-layer weight statistics from real state_dict(). Spec Section 10."""
|
| 60 |
+
|
| 61 |
+
layer_name: str
|
| 62 |
+
weight_norm: float
|
| 63 |
+
weight_mean: float
|
| 64 |
+
weight_std: float
|
| 65 |
+
weight_min: float
|
| 66 |
+
weight_max: float
|
| 67 |
+
dead_neuron_pct: float = 0.0
|
| 68 |
+
has_nan: bool = False
|
| 69 |
+
has_inf: bool = False
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class DataBatchStats(BaseModel):
|
| 73 |
+
"""Data batch inspection results. Spec Section 10."""
|
| 74 |
+
|
| 75 |
+
label_distribution: dict[int, float]
|
| 76 |
+
feature_mean: float
|
| 77 |
+
feature_std: float
|
| 78 |
+
null_count: int = 0
|
| 79 |
+
class_overlap_score: float
|
| 80 |
+
batch_size: int
|
| 81 |
+
duplicate_ratio: float = 0.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CodeSnippet(BaseModel):
|
| 85 |
+
"""PyTorch code for Task 6 inspection. Spec Section 10."""
|
| 86 |
+
|
| 87 |
+
code: str
|
| 88 |
+
filename: str = "train.py"
|
| 89 |
+
line_count: int
|
| 90 |
+
imports: list[str]
|
| 91 |
+
hint: Optional[str] = None
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class EpisodeState(BaseModel):
|
| 95 |
+
"""Tracks agent history within an episode. Spec Section 10."""
|
| 96 |
+
|
| 97 |
+
step_count: int = 0
|
| 98 |
+
gradients_inspected: bool = False
|
| 99 |
+
gradients_were_normal: bool = False
|
| 100 |
+
data_inspected: bool = False
|
| 101 |
+
model_modes_inspected: bool = False
|
| 102 |
+
model_weights_inspected: bool = False
|
| 103 |
+
code_inspected: bool = False
|
| 104 |
+
fix_action_taken: bool = False
|
| 105 |
+
restart_after_fix: bool = False
|
| 106 |
+
diagnosis_submitted: bool = False
|
| 107 |
+
actions_taken: list[str] = Field(default_factory=list)
|
| 108 |
+
|
| 109 |
+
def compute_available_actions(self) -> list[str]:
|
| 110 |
+
"""Dynamically compute available actions based on current state.
|
| 111 |
+
|
| 112 |
+
Rules from spec Section 10 — Dynamic available_actions:
|
| 113 |
+
- restart_run: only after fix_action_taken
|
| 114 |
+
- rollback_checkpoint: only after restart_after_fix
|
| 115 |
+
- fix_code: only after code_inspected
|
| 116 |
+
- mark_diagnosed: disappears after diagnosis_submitted
|
| 117 |
+
"""
|
| 118 |
+
actions: list[str] = [
|
| 119 |
+
"inspect_gradients",
|
| 120 |
+
"inspect_data_batch",
|
| 121 |
+
"inspect_model_modes",
|
| 122 |
+
"inspect_model_weights",
|
| 123 |
+
"inspect_code",
|
| 124 |
+
"modify_config",
|
| 125 |
+
"add_callback",
|
| 126 |
+
"replace_optimizer",
|
| 127 |
+
"patch_data_loader",
|
| 128 |
+
"fix_model_mode",
|
| 129 |
+
]
|
| 130 |
+
if self.code_inspected:
|
| 131 |
+
actions.append("fix_code")
|
| 132 |
+
if self.fix_action_taken:
|
| 133 |
+
actions.append("restart_run")
|
| 134 |
+
if self.restart_after_fix:
|
| 135 |
+
actions.append("rollback_checkpoint")
|
| 136 |
+
if not self.diagnosis_submitted:
|
| 137 |
+
actions.append("mark_diagnosed")
|
| 138 |
+
return actions
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
ALL_ACTION_TYPES: set[str] = {
|
| 142 |
+
"inspect_gradients",
|
| 143 |
+
"inspect_data_batch",
|
| 144 |
+
"inspect_model_modes",
|
| 145 |
+
"inspect_model_weights",
|
| 146 |
+
"inspect_code",
|
| 147 |
+
"modify_config",
|
| 148 |
+
"add_callback",
|
| 149 |
+
"replace_optimizer",
|
| 150 |
+
"patch_data_loader",
|
| 151 |
+
"fix_model_mode",
|
| 152 |
+
"fix_code",
|
| 153 |
+
"restart_run",
|
| 154 |
+
"mark_diagnosed",
|
| 155 |
+
"rollback_checkpoint",
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class MLTrainingAction(Action):
|
| 160 |
+
"""What the agent can do — extends openenv Action. Spec Section 10."""
|
| 161 |
+
|
| 162 |
+
action_type: str
|
| 163 |
+
target: Optional[str] = None
|
| 164 |
+
value: Optional[Union[float, int, str]] = None
|
| 165 |
+
diagnosis: Optional[str] = None
|
| 166 |
+
line: Optional[int] = None
|
| 167 |
+
replacement: Optional[str] = None
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
class MLTrainingObservation(Observation):
|
| 171 |
+
"""Full observation — extends openenv Observation.
|
| 172 |
+
|
| 173 |
+
Observation base has built-in: done (bool), reward (float|None), metadata (dict).
|
| 174 |
+
Spec Section 10.
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
run_id: str = ""
|
| 178 |
+
framework: str = "pytorch"
|
| 179 |
+
epoch: int = 20
|
| 180 |
+
training_loss_history: list[float] = Field(default_factory=list)
|
| 181 |
+
val_loss_history: list[float] = Field(default_factory=list)
|
| 182 |
+
val_accuracy_history: list[float] = Field(default_factory=list)
|
| 183 |
+
gradient_stats: list[GradientStats] = Field(default_factory=list)
|
| 184 |
+
model_weight_stats: Optional[list[ModelWeightStats]] = None
|
| 185 |
+
gpu_memory_used_gb: float = 6.2
|
| 186 |
+
gpu_memory_total_gb: float = 16.0
|
| 187 |
+
learning_rate: float = 0.001
|
| 188 |
+
current_config: TrainingConfig = Field(default_factory=TrainingConfig)
|
| 189 |
+
error_log: Optional[str] = None
|
| 190 |
+
data_batch_stats: Optional[DataBatchStats] = None
|
| 191 |
+
model_mode_info: Optional[dict[str, str]] = None
|
| 192 |
+
code_snippet: Optional[CodeSnippet] = None
|
| 193 |
+
available_actions: list[str] = Field(default_factory=list)
|
| 194 |
+
episode_state: EpisodeState = Field(default_factory=EpisodeState)
|
| 195 |
+
notes: Optional[str] = None
|
ml_training_debugger/pytorch_engine.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""PyTorch-native fault injection engine.
|
| 2 |
+
|
| 3 |
+
Real torch.nn.Module models, real torch.autograd gradients,
|
| 4 |
+
real state_dict() weight snapshots. Zero numpy.
|
| 5 |
+
Spec reference: Sections 6, 9.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
from ml_training_debugger.models import GradientStats, ModelWeightStats
|
| 16 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SimpleCNN(nn.Module):
|
| 20 |
+
"""3-layer CNN for CIFAR-10 style classification. ~50K params.
|
| 21 |
+
|
| 22 |
+
Spec Section 9 — PyTorch Model Pool.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
def __init__(self, num_layers: int = 3, hidden_dim: int = 64) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
|
| 28 |
+
self.bn1 = nn.BatchNorm2d(32)
|
| 29 |
+
self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
|
| 30 |
+
self.bn2 = nn.BatchNorm2d(64)
|
| 31 |
+
self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
|
| 32 |
+
self.bn3 = nn.BatchNorm2d(64)
|
| 33 |
+
self.fc = nn.Linear(64 * 4 * 4, 10)
|
| 34 |
+
self.pool = nn.MaxPool2d(2, 2)
|
| 35 |
+
self.relu = nn.ReLU()
|
| 36 |
+
|
| 37 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 38 |
+
x = self.pool(self.relu(self.bn1(self.conv1(x))))
|
| 39 |
+
x = self.pool(self.relu(self.bn2(self.conv2(x))))
|
| 40 |
+
x = self.pool(self.relu(self.bn3(self.conv3(x))))
|
| 41 |
+
x = x.view(x.size(0), -1)
|
| 42 |
+
x = self.fc(x)
|
| 43 |
+
return x
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def create_model_and_inject_fault(
|
| 47 |
+
scenario: ScenarioParams,
|
| 48 |
+
) -> tuple[nn.Module, dict]:
|
| 49 |
+
"""Instantiate a real PyTorch model and inject the specified fault.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(model, info_dict) where info_dict contains computed artifacts.
|
| 53 |
+
"""
|
| 54 |
+
torch.manual_seed(scenario.seed)
|
| 55 |
+
|
| 56 |
+
model = SimpleCNN()
|
| 57 |
+
criterion = nn.CrossEntropyLoss()
|
| 58 |
+
info: dict = {}
|
| 59 |
+
|
| 60 |
+
# Generate random batch (CIFAR-10 style: 3x32x32)
|
| 61 |
+
batch_x = torch.randn(8, 3, 32, 32)
|
| 62 |
+
batch_y = torch.randint(0, 10, (8,))
|
| 63 |
+
|
| 64 |
+
if scenario.root_cause.value == "lr_too_high":
|
| 65 |
+
# Exploding gradients: high LR with SGD → gradients explode on all layers
|
| 66 |
+
model.train()
|
| 67 |
+
optimizer = torch.optim.SGD(
|
| 68 |
+
model.parameters(), lr=scenario.learning_rate * 10.0
|
| 69 |
+
)
|
| 70 |
+
for _ in range(3):
|
| 71 |
+
optimizer.zero_grad()
|
| 72 |
+
output = model(batch_x)
|
| 73 |
+
loss = criterion(output, batch_y)
|
| 74 |
+
loss.backward()
|
| 75 |
+
optimizer.step()
|
| 76 |
+
# Run one final backward to capture extreme gradients
|
| 77 |
+
optimizer.zero_grad()
|
| 78 |
+
output = model(batch_x)
|
| 79 |
+
loss = criterion(output, batch_y)
|
| 80 |
+
loss.backward()
|
| 81 |
+
|
| 82 |
+
elif scenario.root_cause.value == "vanishing_gradients":
|
| 83 |
+
# Tiny LR → gradients are extremely small
|
| 84 |
+
model.train()
|
| 85 |
+
optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
|
| 86 |
+
for _ in range(2):
|
| 87 |
+
optimizer.zero_grad()
|
| 88 |
+
output = model(batch_x)
|
| 89 |
+
loss = criterion(output, batch_y)
|
| 90 |
+
loss.backward()
|
| 91 |
+
optimizer.step()
|
| 92 |
+
|
| 93 |
+
elif scenario.root_cause.value == "data_leakage":
|
| 94 |
+
# Normal model — no gradient anomaly
|
| 95 |
+
model.train()
|
| 96 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 97 |
+
optimizer.zero_grad()
|
| 98 |
+
output = model(batch_x)
|
| 99 |
+
loss = criterion(output, batch_y)
|
| 100 |
+
loss.backward()
|
| 101 |
+
optimizer.step()
|
| 102 |
+
|
| 103 |
+
elif scenario.root_cause.value == "overfitting":
|
| 104 |
+
# Normal model with zero weight decay
|
| 105 |
+
model.train()
|
| 106 |
+
optimizer = torch.optim.Adam(
|
| 107 |
+
model.parameters(),
|
| 108 |
+
lr=0.001,
|
| 109 |
+
weight_decay=scenario.weight_decay,
|
| 110 |
+
)
|
| 111 |
+
optimizer.zero_grad()
|
| 112 |
+
output = model(batch_x)
|
| 113 |
+
loss = criterion(output, batch_y)
|
| 114 |
+
loss.backward()
|
| 115 |
+
optimizer.step()
|
| 116 |
+
|
| 117 |
+
elif scenario.root_cause.value == "batchnorm_eval_mode":
|
| 118 |
+
# model.eval() before training — the real bug
|
| 119 |
+
model.eval()
|
| 120 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 121 |
+
# Still run forward/backward to get gradient data
|
| 122 |
+
output = model(batch_x)
|
| 123 |
+
loss = criterion(output, batch_y)
|
| 124 |
+
loss.backward()
|
| 125 |
+
optimizer.step()
|
| 126 |
+
|
| 127 |
+
elif scenario.root_cause.value == "code_bug":
|
| 128 |
+
# Normal training with the model bug injected in code only
|
| 129 |
+
model.train()
|
| 130 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 131 |
+
optimizer.zero_grad()
|
| 132 |
+
output = model(batch_x)
|
| 133 |
+
loss = criterion(output, batch_y)
|
| 134 |
+
loss.backward()
|
| 135 |
+
optimizer.step()
|
| 136 |
+
|
| 137 |
+
return model, info
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def extract_gradient_stats(
|
| 141 |
+
model: nn.Module,
|
| 142 |
+
scenario: Optional[ScenarioParams] = None,
|
| 143 |
+
) -> list[GradientStats]:
|
| 144 |
+
"""Extract gradient statistics from real param.grad tensors.
|
| 145 |
+
|
| 146 |
+
For Task 5 (batchnorm_eval_mode), injects red-herring spike on
|
| 147 |
+
the configured layer.
|
| 148 |
+
"""
|
| 149 |
+
stats: list[GradientStats] = []
|
| 150 |
+
named_layers = [
|
| 151 |
+
("conv1", model.conv1),
|
| 152 |
+
("conv2", model.conv2),
|
| 153 |
+
("conv3", model.conv3),
|
| 154 |
+
("fc", model.fc),
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
for layer_name, layer in named_layers:
|
| 158 |
+
norms: list[float] = []
|
| 159 |
+
for param in layer.parameters():
|
| 160 |
+
if param.grad is not None:
|
| 161 |
+
norm_val = torch.norm(param.grad).item()
|
| 162 |
+
norms.append(norm_val)
|
| 163 |
+
|
| 164 |
+
if not norms:
|
| 165 |
+
norms = [0.0]
|
| 166 |
+
|
| 167 |
+
mean_norm = sum(norms) / len(norms)
|
| 168 |
+
max_norm = max(norms)
|
| 169 |
+
|
| 170 |
+
# Build norm_history (simulated last 5 values, based on current)
|
| 171 |
+
norm_history = [mean_norm * (0.9 + 0.2 * i / 4) for i in range(5)]
|
| 172 |
+
|
| 173 |
+
# Task 5 red herring: spike on configured layer
|
| 174 |
+
if scenario and scenario.root_cause.value == "batchnorm_eval_mode":
|
| 175 |
+
if layer_name == scenario.red_herring_spike_layer:
|
| 176 |
+
spike = scenario.red_herring_intensity
|
| 177 |
+
norm_history = [
|
| 178 |
+
mean_norm,
|
| 179 |
+
mean_norm,
|
| 180 |
+
mean_norm * spike,
|
| 181 |
+
mean_norm * spike * 1.2,
|
| 182 |
+
mean_norm,
|
| 183 |
+
]
|
| 184 |
+
mean_norm = sum(norm_history) / len(norm_history)
|
| 185 |
+
max_norm = max(norm_history)
|
| 186 |
+
|
| 187 |
+
# Conv1 near-vanishing red herring
|
| 188 |
+
if layer_name == "conv1" and scenario.red_herring_spike_layer != "conv1":
|
| 189 |
+
near_vanish = 0.0003
|
| 190 |
+
norm_history = [near_vanish * (0.95 + 0.1 * i / 4) for i in range(5)]
|
| 191 |
+
mean_norm = near_vanish
|
| 192 |
+
max_norm = max(norm_history)
|
| 193 |
+
|
| 194 |
+
is_exploding = mean_norm > 10.0
|
| 195 |
+
is_vanishing = mean_norm < 1e-6
|
| 196 |
+
|
| 197 |
+
stats.append(
|
| 198 |
+
GradientStats(
|
| 199 |
+
layer_name=layer_name,
|
| 200 |
+
norm_history=norm_history,
|
| 201 |
+
mean_norm=mean_norm,
|
| 202 |
+
max_norm=max_norm,
|
| 203 |
+
is_exploding=is_exploding,
|
| 204 |
+
is_vanishing=is_vanishing,
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return stats
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def extract_weight_stats(model: nn.Module) -> list[ModelWeightStats]:
|
| 212 |
+
"""Extract weight statistics from real model.state_dict()."""
|
| 213 |
+
stats: list[ModelWeightStats] = []
|
| 214 |
+
for name, param in model.named_parameters():
|
| 215 |
+
if "weight" not in name:
|
| 216 |
+
continue
|
| 217 |
+
stats.append(
|
| 218 |
+
ModelWeightStats(
|
| 219 |
+
layer_name=name,
|
| 220 |
+
weight_norm=torch.norm(param).item(),
|
| 221 |
+
weight_mean=param.mean().item(),
|
| 222 |
+
weight_std=param.std().item(),
|
| 223 |
+
weight_min=param.min().item(),
|
| 224 |
+
weight_max=param.max().item(),
|
| 225 |
+
dead_neuron_pct=0.0,
|
| 226 |
+
has_nan=bool(torch.isnan(param).any().item()),
|
| 227 |
+
has_inf=bool(torch.isinf(param).any().item()),
|
| 228 |
+
)
|
| 229 |
+
)
|
| 230 |
+
return stats
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def extract_model_modes(model: nn.Module) -> dict[str, str]:
|
| 234 |
+
"""Extract training/eval mode for each named module."""
|
| 235 |
+
modes: dict[str, str] = {}
|
| 236 |
+
for name, module in model.named_modules():
|
| 237 |
+
if name == "":
|
| 238 |
+
continue
|
| 239 |
+
modes[name] = "train" if module.training else "eval"
|
| 240 |
+
return modes
|
ml_training_debugger/reward_engine.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Reward function — all 7 components per spec Section 12.
|
| 2 |
+
|
| 3 |
+
Separate from graders.py. Returns a float per step for RL training signal.
|
| 4 |
+
Hard cap at [-1.0, 1.0].
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch # noqa: F401 — PyTorch-native project
|
| 10 |
+
|
| 11 |
+
from ml_training_debugger.models import EpisodeState, MLTrainingAction
|
| 12 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 13 |
+
|
| 14 |
+
# Reward constants — do not change (CLAUDE.md)
|
| 15 |
+
STEP_PENALTY = -0.01
|
| 16 |
+
INVESTIGATION_BONUS = 0.05
|
| 17 |
+
CONTEXT_GATED_PENALTY = -0.20
|
| 18 |
+
INVALID_ACTION_PENALTY = -0.05
|
| 19 |
+
WRONG_CODE_FIX_PENALTY = -0.10
|
| 20 |
+
CORRECT_DIAGNOSIS_REWARD = 0.50
|
| 21 |
+
WRONG_DIAGNOSIS_PENALTY = -0.30
|
| 22 |
+
TERMINAL_CONVERGENCE_REWARD = 0.40
|
| 23 |
+
|
| 24 |
+
INVESTIGATION_ACTIONS = frozenset(
|
| 25 |
+
{
|
| 26 |
+
"inspect_gradients",
|
| 27 |
+
"inspect_data_batch",
|
| 28 |
+
"inspect_model_modes",
|
| 29 |
+
"inspect_model_weights",
|
| 30 |
+
"inspect_code",
|
| 31 |
+
}
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
_INSPECTION_STATE_MAP = {
|
| 35 |
+
"inspect_gradients": "gradients_inspected",
|
| 36 |
+
"inspect_data_batch": "data_inspected",
|
| 37 |
+
"inspect_model_modes": "model_modes_inspected",
|
| 38 |
+
"inspect_model_weights": "model_weights_inspected",
|
| 39 |
+
"inspect_code": "code_inspected",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compute_reward(
|
| 44 |
+
action: MLTrainingAction,
|
| 45 |
+
state: EpisodeState,
|
| 46 |
+
scenario: ScenarioParams,
|
| 47 |
+
is_valid_action: bool = True,
|
| 48 |
+
is_correct_fix: bool | None = None,
|
| 49 |
+
convergence_confirmed: bool = False,
|
| 50 |
+
) -> float:
|
| 51 |
+
"""Compute reward for a single step.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
action: The action taken.
|
| 55 |
+
state: Episode state BEFORE the action is applied.
|
| 56 |
+
scenario: Current scenario params.
|
| 57 |
+
is_valid_action: Whether the action is in available_actions.
|
| 58 |
+
is_correct_fix: For fix_code — True/False/None.
|
| 59 |
+
convergence_confirmed: Whether restart showed convergence.
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
Reward float, capped at [-1.0, 1.0].
|
| 63 |
+
"""
|
| 64 |
+
reward = 0.0
|
| 65 |
+
|
| 66 |
+
# Component 1: Flat step penalty (unconditional)
|
| 67 |
+
reward += STEP_PENALTY
|
| 68 |
+
|
| 69 |
+
# Component 4: Invalid action penalty
|
| 70 |
+
if not is_valid_action:
|
| 71 |
+
reward += INVALID_ACTION_PENALTY
|
| 72 |
+
return max(-1.0, min(1.0, reward))
|
| 73 |
+
|
| 74 |
+
action_type = action.action_type
|
| 75 |
+
|
| 76 |
+
# Component 2: Investigation bonus (first-time only)
|
| 77 |
+
if action_type in INVESTIGATION_ACTIONS:
|
| 78 |
+
state_field = _INSPECTION_STATE_MAP.get(action_type)
|
| 79 |
+
if state_field and not getattr(state, state_field):
|
| 80 |
+
reward += INVESTIGATION_BONUS
|
| 81 |
+
|
| 82 |
+
# Component 3: Context-gated red herring penalty
|
| 83 |
+
# Fires ONLY when gradients_inspected=True AND gradients_were_normal=True
|
| 84 |
+
if action_type == "add_callback":
|
| 85 |
+
if state.gradients_inspected and state.gradients_were_normal:
|
| 86 |
+
reward += CONTEXT_GATED_PENALTY
|
| 87 |
+
|
| 88 |
+
# Component 7: Wrong code fix penalty
|
| 89 |
+
if action_type == "fix_code" and is_correct_fix is False:
|
| 90 |
+
reward += WRONG_CODE_FIX_PENALTY
|
| 91 |
+
|
| 92 |
+
# Component 5: Diagnosis outcome
|
| 93 |
+
if action_type == "mark_diagnosed":
|
| 94 |
+
if action.diagnosis == scenario.root_cause.value:
|
| 95 |
+
reward += CORRECT_DIAGNOSIS_REWARD
|
| 96 |
+
else:
|
| 97 |
+
reward += WRONG_DIAGNOSIS_PENALTY
|
| 98 |
+
|
| 99 |
+
# Component 6: Terminal convergence reward
|
| 100 |
+
if action_type == "restart_run":
|
| 101 |
+
if state.fix_action_taken and convergence_confirmed:
|
| 102 |
+
reward += TERMINAL_CONVERGENCE_REWARD
|
| 103 |
+
|
| 104 |
+
return max(-1.0, min(1.0, reward))
|
ml_training_debugger/scenarios.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ScenarioParams and scenario sampling.
|
| 2 |
+
|
| 3 |
+
Internal scenario configuration — not exposed to the agent.
|
| 4 |
+
Spec reference: Sections 6, 10, 11.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import dataclasses
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from ml_training_debugger.models import RootCauseDiagnosis
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclasses.dataclass(frozen=True)
|
| 18 |
+
class ScenarioParams:
|
| 19 |
+
"""Internal scenario parameters created at reset() time."""
|
| 20 |
+
|
| 21 |
+
task_id: str
|
| 22 |
+
root_cause: RootCauseDiagnosis
|
| 23 |
+
seed: int
|
| 24 |
+
learning_rate: float = 0.001
|
| 25 |
+
weight_decay: float = 0.0001
|
| 26 |
+
leakage_pct: float = 0.0
|
| 27 |
+
depth_multiplier: float = 1.0
|
| 28 |
+
divergence_epoch: int = 5
|
| 29 |
+
red_herring_intensity: float = 1.0
|
| 30 |
+
red_herring_spike_layer: str = "fc"
|
| 31 |
+
bug_type: Optional[str] = None
|
| 32 |
+
notes: Optional[str] = None
|
| 33 |
+
error_log: Optional[str] = None
|
| 34 |
+
gpu_memory_used_gb: float = 6.2
|
| 35 |
+
max_steps: int = 20
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _task_seed(task_id: str, seed: int) -> int:
|
| 39 |
+
"""Derive a deterministic seed from task_id and provided seed."""
|
| 40 |
+
task_num = int(task_id.split("_")[1])
|
| 41 |
+
return seed * 1000 + task_num
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _choose(options: list, rng: torch.Generator) -> object:
|
| 45 |
+
"""Choose a random element from a list using torch RNG."""
|
| 46 |
+
idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
|
| 47 |
+
return options[idx]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def sample_scenario(task_id: str, seed: int = 42) -> ScenarioParams:
|
| 51 |
+
"""Sample a ScenarioParams for the given task.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
task_id: One of task_001 through task_006.
|
| 55 |
+
seed: Base seed for reproducibility.
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
ScenarioParams with randomized fault parameters.
|
| 59 |
+
|
| 60 |
+
Raises:
|
| 61 |
+
ValueError: If task_id is unknown.
|
| 62 |
+
"""
|
| 63 |
+
effective_seed = _task_seed(task_id, seed)
|
| 64 |
+
rng = torch.Generator()
|
| 65 |
+
rng.manual_seed(effective_seed)
|
| 66 |
+
|
| 67 |
+
if task_id == "task_001":
|
| 68 |
+
lr = _choose([0.05, 0.08, 0.10, 0.15, 0.30], rng)
|
| 69 |
+
return ScenarioParams(
|
| 70 |
+
task_id=task_id,
|
| 71 |
+
root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
|
| 72 |
+
seed=effective_seed,
|
| 73 |
+
learning_rate=float(lr),
|
| 74 |
+
error_log=f"RuntimeError: Loss is NaN at epoch 12 (lr={lr})",
|
| 75 |
+
max_steps=20,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
if task_id == "task_002":
|
| 79 |
+
lr = _choose([1e-6, 5e-6, 1e-5], rng)
|
| 80 |
+
depth_mult = _choose([1.0, 1.5, 2.0], rng)
|
| 81 |
+
return ScenarioParams(
|
| 82 |
+
task_id=task_id,
|
| 83 |
+
root_cause=RootCauseDiagnosis.VANISHING_GRADIENTS,
|
| 84 |
+
seed=effective_seed,
|
| 85 |
+
learning_rate=float(lr),
|
| 86 |
+
depth_multiplier=float(depth_mult),
|
| 87 |
+
notes=(
|
| 88 |
+
"Training resumed from a checkpoint saved at epoch 0 — "
|
| 89 |
+
"early learning rate warmup may still be in effect."
|
| 90 |
+
),
|
| 91 |
+
max_steps=20,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if task_id == "task_003":
|
| 95 |
+
leakage = _choose([0.12, 0.18, 0.22, 0.28], rng)
|
| 96 |
+
return ScenarioParams(
|
| 97 |
+
task_id=task_id,
|
| 98 |
+
root_cause=RootCauseDiagnosis.DATA_LEAKAGE,
|
| 99 |
+
seed=effective_seed,
|
| 100 |
+
leakage_pct=float(leakage),
|
| 101 |
+
notes=(
|
| 102 |
+
"Model architecture upgraded from 2-layer to 4-layer CNN "
|
| 103 |
+
"at epoch 2. Performance improvement may reflect increased "
|
| 104 |
+
"model capacity."
|
| 105 |
+
),
|
| 106 |
+
max_steps=25,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
if task_id == "task_004":
|
| 110 |
+
wd = _choose([0.0, 0.0001, 0.001], rng)
|
| 111 |
+
div_epoch = _choose([5, 8, 12], rng)
|
| 112 |
+
return ScenarioParams(
|
| 113 |
+
task_id=task_id,
|
| 114 |
+
root_cause=RootCauseDiagnosis.OVERFITTING,
|
| 115 |
+
seed=effective_seed,
|
| 116 |
+
weight_decay=float(wd),
|
| 117 |
+
divergence_epoch=int(div_epoch),
|
| 118 |
+
notes=(
|
| 119 |
+
"Dataset augmentation was disabled for this run to speed "
|
| 120 |
+
"up training. Re-enabling may improve generalization."
|
| 121 |
+
),
|
| 122 |
+
max_steps=25,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
if task_id == "task_005":
|
| 126 |
+
intensity = torch.empty(1).uniform_(0.8, 2.5, generator=rng).item()
|
| 127 |
+
spike_layer = _choose(["fc", "conv1"], rng)
|
| 128 |
+
return ScenarioParams(
|
| 129 |
+
task_id=task_id,
|
| 130 |
+
root_cause=RootCauseDiagnosis.BATCHNORM_EVAL_MODE,
|
| 131 |
+
seed=effective_seed,
|
| 132 |
+
red_herring_intensity=float(intensity),
|
| 133 |
+
red_herring_spike_layer=str(spike_layer),
|
| 134 |
+
gpu_memory_used_gb=14.56, # 91% of 16GB — red herring
|
| 135 |
+
error_log=(
|
| 136 |
+
"Warning: GPU memory pressure detected, consider reducing "
|
| 137 |
+
"batch size or enabling gradient checkpointing"
|
| 138 |
+
),
|
| 139 |
+
max_steps=30,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if task_id == "task_006":
|
| 143 |
+
bug = _choose(
|
| 144 |
+
["eval_mode", "detach_loss", "zero_grad_missing", "inplace_relu"], rng
|
| 145 |
+
)
|
| 146 |
+
return ScenarioParams(
|
| 147 |
+
task_id=task_id,
|
| 148 |
+
root_cause=RootCauseDiagnosis.CODE_BUG,
|
| 149 |
+
seed=effective_seed,
|
| 150 |
+
bug_type=str(bug),
|
| 151 |
+
notes="Try adjusting the learning rate schedule.",
|
| 152 |
+
max_steps=30,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
raise ValueError(f"Unknown task_id: {task_id}")
|
ml_training_debugger/simulation.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Parametric curve generation using torch.Tensor operations.
|
| 2 |
+
|
| 3 |
+
All loss/accuracy histories are generated via parametric equations.
|
| 4 |
+
Zero numpy. Spec reference: Section 6.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from ml_training_debugger.scenarios import ScenarioParams
|
| 12 |
+
|
| 13 |
+
EPOCHS = 20
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def gen_loss_history(scenario: ScenarioParams) -> list[float]:
|
| 17 |
+
"""Generate training loss history (20 epochs) using torch ops."""
|
| 18 |
+
torch.manual_seed(scenario.seed)
|
| 19 |
+
t = torch.arange(EPOCHS, dtype=torch.float32)
|
| 20 |
+
|
| 21 |
+
root = scenario.root_cause.value
|
| 22 |
+
|
| 23 |
+
if root == "lr_too_high":
|
| 24 |
+
# Exponentially growing loss
|
| 25 |
+
lr_tensor = torch.tensor(scenario.learning_rate, dtype=torch.float32)
|
| 26 |
+
base = torch.exp(lr_tensor * t * 0.5)
|
| 27 |
+
loss = 2.3 * base
|
| 28 |
+
# Add NaN marker after epoch 12
|
| 29 |
+
loss_list = loss.tolist()
|
| 30 |
+
for i in range(12, EPOCHS):
|
| 31 |
+
loss_list[i] = float("inf")
|
| 32 |
+
return loss_list
|
| 33 |
+
|
| 34 |
+
if root == "vanishing_gradients":
|
| 35 |
+
# Flat loss — barely decreases
|
| 36 |
+
noise = torch.randn(EPOCHS) * 0.02
|
| 37 |
+
loss = 2.3 - t * 0.002 + noise
|
| 38 |
+
return loss.clamp(min=0.01).tolist()
|
| 39 |
+
|
| 40 |
+
if root == "data_leakage":
|
| 41 |
+
# Normal-looking training loss
|
| 42 |
+
loss = 2.3 * torch.exp(-0.15 * t) + 0.05
|
| 43 |
+
noise = torch.randn(EPOCHS) * 0.02
|
| 44 |
+
return (loss + noise).clamp(min=0.01).tolist()
|
| 45 |
+
|
| 46 |
+
if root == "overfitting":
|
| 47 |
+
# Steadily decreasing to near-zero
|
| 48 |
+
loss = 2.3 * torch.exp(-0.25 * t) + 0.01
|
| 49 |
+
noise = torch.randn(EPOCHS) * 0.01
|
| 50 |
+
return (loss + noise).clamp(min=0.001).tolist()
|
| 51 |
+
|
| 52 |
+
if root == "batchnorm_eval_mode":
|
| 53 |
+
# Roughly normal with higher variance
|
| 54 |
+
base = 2.3 * torch.exp(-0.1 * t) + 0.3
|
| 55 |
+
noise = torch.randn(EPOCHS) * 0.15
|
| 56 |
+
return (base + noise).clamp(min=0.1).tolist()
|
| 57 |
+
|
| 58 |
+
if root == "code_bug":
|
| 59 |
+
# Varies by bug variant — generic anomalous
|
| 60 |
+
loss = 2.3 * torch.exp(-0.05 * t) + 0.5
|
| 61 |
+
noise = torch.randn(EPOCHS) * 0.1
|
| 62 |
+
return (loss + noise).clamp(min=0.1).tolist()
|
| 63 |
+
|
| 64 |
+
# Fallback
|
| 65 |
+
return (2.3 * torch.exp(-0.1 * t)).tolist()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]:
|
| 69 |
+
"""Generate validation accuracy history (20 epochs) using torch ops."""
|
| 70 |
+
torch.manual_seed(scenario.seed + 1)
|
| 71 |
+
t = torch.arange(EPOCHS, dtype=torch.float32)
|
| 72 |
+
|
| 73 |
+
root = scenario.root_cause.value
|
| 74 |
+
|
| 75 |
+
if root == "lr_too_high":
|
| 76 |
+
# Collapses along with training loss
|
| 77 |
+
acc = torch.sigmoid(torch.linspace(0, -3, EPOCHS)) * 0.5
|
| 78 |
+
return acc.clamp(0.0, 1.0).tolist()
|
| 79 |
+
|
| 80 |
+
if root == "vanishing_gradients":
|
| 81 |
+
# Near random chance
|
| 82 |
+
noise = torch.randn(EPOCHS) * 0.02
|
| 83 |
+
acc = 0.10 + t * 0.001 + noise
|
| 84 |
+
return acc.clamp(0.0, 1.0).tolist()
|
| 85 |
+
|
| 86 |
+
if root == "data_leakage":
|
| 87 |
+
# Suspiciously high from epoch 1
|
| 88 |
+
leakage = torch.tensor(scenario.leakage_pct, dtype=torch.float32)
|
| 89 |
+
base = torch.sigmoid(torch.linspace(-3, 3, EPOCHS))
|
| 90 |
+
acc = base * (1.0 - leakage) + leakage * 0.95
|
| 91 |
+
# Inflate early epochs
|
| 92 |
+
acc = acc.clamp(0.0, 1.0)
|
| 93 |
+
# Ensure suspiciously high from epoch 1
|
| 94 |
+
acc_list = acc.tolist()
|
| 95 |
+
for i in range(EPOCHS):
|
| 96 |
+
acc_list[i] = max(acc_list[i], 0.82 * (1.0 + scenario.leakage_pct))
|
| 97 |
+
return [min(v, 0.99) for v in acc_list]
|
| 98 |
+
|
| 99 |
+
if root == "overfitting":
|
| 100 |
+
# Rises then falls — classic divergence
|
| 101 |
+
div = scenario.divergence_epoch
|
| 102 |
+
acc_list: list[float] = []
|
| 103 |
+
for i in range(EPOCHS):
|
| 104 |
+
if i < div:
|
| 105 |
+
val = 0.10 + (0.75 - 0.10) * (i / max(div, 1))
|
| 106 |
+
else:
|
| 107 |
+
decline = (i - div) * 0.02
|
| 108 |
+
val = 0.75 - decline
|
| 109 |
+
acc_list.append(max(0.0, min(1.0, val)))
|
| 110 |
+
return acc_list
|
| 111 |
+
|
| 112 |
+
if root == "batchnorm_eval_mode":
|
| 113 |
+
# Slow degradation ~1-2% per epoch
|
| 114 |
+
start = 0.76
|
| 115 |
+
noise = torch.randn(EPOCHS) * 0.01
|
| 116 |
+
acc = torch.tensor(
|
| 117 |
+
[start - 0.015 * i for i in range(EPOCHS)], dtype=torch.float32
|
| 118 |
+
)
|
| 119 |
+
acc = acc + noise
|
| 120 |
+
return acc.clamp(0.0, 1.0).tolist()
|
| 121 |
+
|
| 122 |
+
if root == "code_bug":
|
| 123 |
+
# Anomalous — depends on variant but generally poor
|
| 124 |
+
noise = torch.randn(EPOCHS) * 0.03
|
| 125 |
+
acc = 0.10 + t * 0.005 + noise
|
| 126 |
+
return acc.clamp(0.0, 1.0).tolist()
|
| 127 |
+
|
| 128 |
+
# Fallback
|
| 129 |
+
return (torch.sigmoid(torch.linspace(-3, 3, EPOCHS)) * 0.9).tolist()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def gen_val_loss_history(scenario: ScenarioParams) -> list[float]:
|
| 133 |
+
"""Generate validation loss history (20 epochs) using torch ops."""
|
| 134 |
+
torch.manual_seed(scenario.seed + 2)
|
| 135 |
+
t = torch.arange(EPOCHS, dtype=torch.float32)
|
| 136 |
+
|
| 137 |
+
root = scenario.root_cause.value
|
| 138 |
+
|
| 139 |
+
if root == "lr_too_high":
|
| 140 |
+
# Mirrors training loss divergence
|
| 141 |
+
lr_tensor = torch.tensor(scenario.learning_rate, dtype=torch.float32)
|
| 142 |
+
loss = 2.3 * torch.exp(lr_tensor * t * 0.5)
|
| 143 |
+
loss_list = loss.tolist()
|
| 144 |
+
for i in range(12, EPOCHS):
|
| 145 |
+
loss_list[i] = float("inf")
|
| 146 |
+
return loss_list
|
| 147 |
+
|
| 148 |
+
if root == "vanishing_gradients":
|
| 149 |
+
noise = torch.randn(EPOCHS) * 0.02
|
| 150 |
+
loss = 2.3 - t * 0.001 + noise
|
| 151 |
+
return loss.clamp(min=0.01).tolist()
|
| 152 |
+
|
| 153 |
+
if root == "data_leakage":
|
| 154 |
+
# Low val loss (because leaking train data into val)
|
| 155 |
+
base = 2.3 * torch.exp(-0.2 * t) + 0.03
|
| 156 |
+
noise = torch.randn(EPOCHS) * 0.02
|
| 157 |
+
return (base + noise).clamp(min=0.01).tolist()
|
| 158 |
+
|
| 159 |
+
if root == "overfitting":
|
| 160 |
+
# Initially decreases, then diverges upward
|
| 161 |
+
div = scenario.divergence_epoch
|
| 162 |
+
loss_list: list[float] = []
|
| 163 |
+
for i in range(EPOCHS):
|
| 164 |
+
if i < div:
|
| 165 |
+
val = 2.3 * (1.0 - 0.8 * i / max(div, 1))
|
| 166 |
+
else:
|
| 167 |
+
val = 0.46 + 0.1 * (i - div)
|
| 168 |
+
loss_list.append(max(0.01, val))
|
| 169 |
+
return loss_list
|
| 170 |
+
|
| 171 |
+
if root == "batchnorm_eval_mode":
|
| 172 |
+
# Slightly increasing
|
| 173 |
+
base = 1.5 + t * 0.03
|
| 174 |
+
noise = torch.randn(EPOCHS) * 0.1
|
| 175 |
+
return (base + noise).clamp(min=0.1).tolist()
|
| 176 |
+
|
| 177 |
+
if root == "code_bug":
|
| 178 |
+
loss = 2.3 * torch.exp(-0.03 * t) + 0.8
|
| 179 |
+
noise = torch.randn(EPOCHS) * 0.1
|
| 180 |
+
return (loss + noise).clamp(min=0.1).tolist()
|
| 181 |
+
|
| 182 |
+
# Fallback
|
| 183 |
+
return (2.3 * torch.exp(-0.1 * t) + 0.1).tolist()
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def gen_data_batch_stats(scenario: ScenarioParams) -> dict:
|
| 187 |
+
"""Generate data batch statistics for the scenario."""
|
| 188 |
+
torch.manual_seed(scenario.seed + 3)
|
| 189 |
+
|
| 190 |
+
root = scenario.root_cause.value
|
| 191 |
+
|
| 192 |
+
if root == "data_leakage":
|
| 193 |
+
overlap = 0.5 + scenario.leakage_pct * 1.5 # 0.68-0.88 range
|
| 194 |
+
overlap = min(overlap, 0.92)
|
| 195 |
+
return {
|
| 196 |
+
"label_distribution": {i: 0.1 for i in range(10)},
|
| 197 |
+
"feature_mean": 0.45 + torch.randn(1).item() * 0.05,
|
| 198 |
+
"feature_std": 0.22 + torch.randn(1).item() * 0.02,
|
| 199 |
+
"null_count": 0,
|
| 200 |
+
"class_overlap_score": overlap,
|
| 201 |
+
"batch_size": 64,
|
| 202 |
+
"duplicate_ratio": scenario.leakage_pct,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
if root == "overfitting":
|
| 206 |
+
return {
|
| 207 |
+
"label_distribution": {i: 0.1 for i in range(10)},
|
| 208 |
+
"feature_mean": 0.48 + torch.randn(1).item() * 0.03,
|
| 209 |
+
"feature_std": 0.25 + torch.randn(1).item() * 0.02,
|
| 210 |
+
"null_count": 0,
|
| 211 |
+
"class_overlap_score": 0.0,
|
| 212 |
+
"batch_size": 64,
|
| 213 |
+
"duplicate_ratio": 0.0,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
# Default: normal data
|
| 217 |
+
return {
|
| 218 |
+
"label_distribution": {i: 0.1 for i in range(10)},
|
| 219 |
+
"feature_mean": 0.47 + torch.randn(1).item() * 0.03,
|
| 220 |
+
"feature_std": 0.24 + torch.randn(1).item() * 0.02,
|
| 221 |
+
"null_count": 0,
|
| 222 |
+
"class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05,
|
| 223 |
+
"batch_size": 64,
|
| 224 |
+
"duplicate_ratio": 0.0,
|
| 225 |
+
}
|
openenv.yaml
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: pytorch-training-debugger
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 7860
|
| 7 |
+
|
| 8 |
+
version: "1.0.0"
|
| 9 |
+
description: |
|
| 10 |
+
PyTorch-native fault injection engine for training failure debugging.
|
| 11 |
+
An AI agent investigates, diagnoses, fixes, and verifies broken
|
| 12 |
+
training runs using real torch.nn.Module models, torch.autograd
|
| 13 |
+
gradients, state_dict() weight inspection, and PyTorch code-level
|
| 14 |
+
debugging. 3 tasks across 3 difficulty tiers with context-gated
|
| 15 |
+
reward shaping.
|
| 16 |
+
framework: openenv
|
| 17 |
+
tags:
|
| 18 |
+
- ml-debugging
|
| 19 |
+
- pytorch
|
| 20 |
+
- reinforcement-learning
|
| 21 |
+
- root-cause-analysis
|
| 22 |
+
- fault-injection
|
| 23 |
+
- openenv
|
| 24 |
+
|
| 25 |
+
observation_space:
|
| 26 |
+
type: MLTrainingObservation
|
| 27 |
+
description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
|
| 28 |
+
|
| 29 |
+
action_space:
|
| 30 |
+
type: MLTrainingAction
|
| 31 |
+
description: "Investigation, fix, and diagnosis actions with dynamic availability"
|
| 32 |
+
|
| 33 |
+
tasks:
|
| 34 |
+
- id: task_001
|
| 35 |
+
difficulty: easy
|
| 36 |
+
max_steps: 20
|
| 37 |
+
- id: task_003
|
| 38 |
+
difficulty: medium
|
| 39 |
+
max_steps: 25
|
| 40 |
+
- id: task_005
|
| 41 |
+
difficulty: hard
|
| 42 |
+
max_steps: 30
|
| 43 |
+
|
| 44 |
+
reward:
|
| 45 |
+
range: [-1.0, 1.0]
|
| 46 |
+
shaped: true
|
| 47 |
+
step_penalty: -0.01
|
| 48 |
+
investigation_bonus: 0.05
|
| 49 |
+
max_investigation_bonus: 0.25
|
| 50 |
+
correct_diagnosis: 0.50
|
| 51 |
+
terminal_convergence: 0.40
|
| 52 |
+
|
| 53 |
+
endpoints:
|
| 54 |
+
websocket: "/ws"
|
| 55 |
+
tasks: "GET /tasks"
|
| 56 |
+
grader: "POST /grader"
|
| 57 |
+
baseline: "POST /baseline"
|
| 58 |
+
health: "GET /health"
|
pyproject.toml
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "pytorch-training-debugger"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "OpenEnv RL environment for PyTorch training failure debugging"
|
| 5 |
+
requires-python = ">=3.12"
|
| 6 |
+
dependencies = [
|
| 7 |
+
"torch",
|
| 8 |
+
"openenv-core",
|
| 9 |
+
"pydantic>=2.0",
|
| 10 |
+
"fastapi",
|
| 11 |
+
"uvicorn",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
[project.optional-dependencies]
|
| 15 |
+
dev = [
|
| 16 |
+
"pytest",
|
| 17 |
+
"pytest-cov",
|
| 18 |
+
"pytest-asyncio",
|
| 19 |
+
"black",
|
| 20 |
+
"ruff",
|
| 21 |
+
"isort",
|
| 22 |
+
"httpx",
|
| 23 |
+
"websockets",
|
| 24 |
+
]
|
| 25 |
+
llm = [
|
| 26 |
+
"openai",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
[tool.black]
|
| 30 |
+
line-length = 88
|
| 31 |
+
|
| 32 |
+
[tool.isort]
|
| 33 |
+
profile = "black"
|
| 34 |
+
|
| 35 |
+
[tool.ruff]
|
| 36 |
+
line-length = 88
|
| 37 |
+
target-version = "py312"
|
| 38 |
+
|
| 39 |
+
[tool.pytest.ini_options]
|
| 40 |
+
testpaths = ["tests"]
|
| 41 |
+
asyncio_mode = "auto"
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
openenv-core
|
| 3 |
+
pydantic>=2.0
|
| 4 |
+
fastapi
|
| 5 |
+
uvicorn
|
| 6 |
+
openai
|
server/__init__.py
ADDED
|
File without changes
|
server/_baseline_results.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared state for grader results across endpoints."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
# Store last completed episode results
|
| 8 |
+
_last_results: dict[str, dict] = {}
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def store_grader_result(
|
| 12 |
+
session_id: str, score: float, task_id: str, steps: int
|
| 13 |
+
) -> None:
|
| 14 |
+
"""Store a grader result for retrieval."""
|
| 15 |
+
_last_results[session_id] = {
|
| 16 |
+
"score": round(score, 4),
|
| 17 |
+
"task_id": task_id,
|
| 18 |
+
"steps": steps,
|
| 19 |
+
}
|
| 20 |
+
_last_results["_latest"] = _last_results[session_id]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def get_last_grader_result(session_id: Optional[str] = None) -> dict | None:
|
| 24 |
+
"""Get grader result for a session, or the most recent one."""
|
| 25 |
+
if session_id:
|
| 26 |
+
return _last_results.get(session_id)
|
| 27 |
+
return _last_results.get("_latest")
|
server/app.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI app — openenv create_app() + custom hackathon routes.
|
| 2 |
+
|
| 3 |
+
Spec reference: Sections 9, 14.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import asyncio
|
| 9 |
+
import logging
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from fastapi import FastAPI
|
| 13 |
+
from fastapi.responses import JSONResponse
|
| 14 |
+
from openenv.core.env_server.http_server import create_app
|
| 15 |
+
|
| 16 |
+
from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
|
| 17 |
+
from server.environment import MLTrainingEnvironment
|
| 18 |
+
|
| 19 |
+
logging.basicConfig(
|
| 20 |
+
level=logging.INFO,
|
| 21 |
+
format='{"time":"%(asctime)s","level":"%(levelname)s","msg":"%(message)s"}',
|
| 22 |
+
)
|
| 23 |
+
logger = logging.getLogger(__name__)
|
| 24 |
+
|
| 25 |
+
# MVP task list
|
| 26 |
+
MVP_TASKS = [
|
| 27 |
+
{"id": "task_001", "difficulty": "easy", "max_steps": 20},
|
| 28 |
+
{"id": "task_003", "difficulty": "medium", "max_steps": 25},
|
| 29 |
+
{"id": "task_005", "difficulty": "hard", "max_steps": 30},
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
# create_app takes the class (factory), not an instance
|
| 33 |
+
app: FastAPI = create_app(
|
| 34 |
+
MLTrainingEnvironment,
|
| 35 |
+
MLTrainingAction,
|
| 36 |
+
MLTrainingObservation,
|
| 37 |
+
env_name="pytorch_training_debugger",
|
| 38 |
+
max_concurrent_envs=5,
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Override framework's /health route with our custom version
|
| 42 |
+
# Remove the framework's health route first
|
| 43 |
+
app.routes[:] = [
|
| 44 |
+
r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
# Track baseline state
|
| 48 |
+
_baseline_lock = asyncio.Lock()
|
| 49 |
+
_baseline_running = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
@app.get("/health")
|
| 53 |
+
def health_check() -> dict:
|
| 54 |
+
"""Health check — required by hackathon auto-validator."""
|
| 55 |
+
return {"status": "ready", "tasks": len(MVP_TASKS)}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
@app.get("/tasks")
|
| 59 |
+
def get_tasks() -> list[dict]:
|
| 60 |
+
"""Return task list with IDs, difficulties, and action schema."""
|
| 61 |
+
schema = MLTrainingAction.model_json_schema()
|
| 62 |
+
return [{**task, "action_schema": schema} for task in MVP_TASKS]
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@app.post("/grader")
|
| 66 |
+
def post_grader(session_id: Optional[str] = None) -> dict:
|
| 67 |
+
"""Return grader score for most recently completed episode.
|
| 68 |
+
|
| 69 |
+
Edge cases per spec Section 14:
|
| 70 |
+
- No episode completed → {"score": null, "error": "no_completed_episode"}
|
| 71 |
+
- Episode in progress → {"score": null, "error": "episode_in_progress"}
|
| 72 |
+
- Episode completed → {"score": float, "task_id": str, "steps": int}
|
| 73 |
+
"""
|
| 74 |
+
# Try to find the environment instance
|
| 75 |
+
# The framework manages environment instances internally,
|
| 76 |
+
# so we use the internal baseline results for the /grader endpoint
|
| 77 |
+
from server._baseline_results import get_last_grader_result
|
| 78 |
+
|
| 79 |
+
result = get_last_grader_result(session_id)
|
| 80 |
+
if result is None:
|
| 81 |
+
return {"score": None, "error": "no_completed_episode"}
|
| 82 |
+
return result
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@app.post("/baseline", response_model=None)
|
| 86 |
+
async def post_baseline():
|
| 87 |
+
"""Trigger baseline run, return scores for all tasks.
|
| 88 |
+
|
| 89 |
+
Returns 409 if already running.
|
| 90 |
+
"""
|
| 91 |
+
global _baseline_running
|
| 92 |
+
|
| 93 |
+
if _baseline_running:
|
| 94 |
+
return JSONResponse(
|
| 95 |
+
status_code=409,
|
| 96 |
+
content={"error": "baseline_in_progress"},
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
_baseline_running = True
|
| 100 |
+
try:
|
| 101 |
+
scores = await _run_baseline()
|
| 102 |
+
return {"scores": scores}
|
| 103 |
+
finally:
|
| 104 |
+
_baseline_running = False
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
async def _run_baseline() -> dict[str, float]:
|
| 108 |
+
"""Run the rule-based baseline internally."""
|
| 109 |
+
|
| 110 |
+
scores: dict[str, float] = {}
|
| 111 |
+
|
| 112 |
+
for task_info in MVP_TASKS:
|
| 113 |
+
task_id = task_info["id"]
|
| 114 |
+
env = MLTrainingEnvironment()
|
| 115 |
+
obs = env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
|
| 116 |
+
|
| 117 |
+
# Run heuristic decision tree
|
| 118 |
+
score = _run_heuristic_episode(env, obs, task_id)
|
| 119 |
+
scores[task_id] = round(score, 4)
|
| 120 |
+
|
| 121 |
+
return scores
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _run_heuristic_episode(
|
| 125 |
+
env: MLTrainingEnvironment,
|
| 126 |
+
obs: MLTrainingObservation,
|
| 127 |
+
task_id: str,
|
| 128 |
+
) -> float:
|
| 129 |
+
"""Run one heuristic baseline episode. Returns grader score."""
|
| 130 |
+
# Step 1: inspect_gradients
|
| 131 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 132 |
+
|
| 133 |
+
# Check for exploding gradients
|
| 134 |
+
if obs.gradient_stats:
|
| 135 |
+
if any(g.is_exploding for g in obs.gradient_stats):
|
| 136 |
+
obs = env.step(
|
| 137 |
+
MLTrainingAction(
|
| 138 |
+
action_type="modify_config",
|
| 139 |
+
target="learning_rate",
|
| 140 |
+
value=0.001,
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 144 |
+
obs = env.step(
|
| 145 |
+
MLTrainingAction(
|
| 146 |
+
action_type="mark_diagnosed",
|
| 147 |
+
diagnosis="lr_too_high",
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
session = env._get_session()
|
| 151 |
+
if session and session.last_score is not None:
|
| 152 |
+
return session.last_score
|
| 153 |
+
return 0.0
|
| 154 |
+
|
| 155 |
+
# Check for vanishing gradients
|
| 156 |
+
if any(g.is_vanishing for g in obs.gradient_stats):
|
| 157 |
+
obs = env.step(
|
| 158 |
+
MLTrainingAction(
|
| 159 |
+
action_type="modify_config",
|
| 160 |
+
target="learning_rate",
|
| 161 |
+
value=0.01,
|
| 162 |
+
)
|
| 163 |
+
)
|
| 164 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 165 |
+
obs = env.step(
|
| 166 |
+
MLTrainingAction(
|
| 167 |
+
action_type="mark_diagnosed",
|
| 168 |
+
diagnosis="vanishing_gradients",
|
| 169 |
+
)
|
| 170 |
+
)
|
| 171 |
+
session = env._get_session()
|
| 172 |
+
if session and session.last_score is not None:
|
| 173 |
+
return session.last_score
|
| 174 |
+
return 0.0
|
| 175 |
+
|
| 176 |
+
# Step 2: inspect_data_batch
|
| 177 |
+
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 178 |
+
if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
|
| 179 |
+
obs = env.step(MLTrainingAction(action_type="patch_data_loader"))
|
| 180 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 181 |
+
obs = env.step(
|
| 182 |
+
MLTrainingAction(
|
| 183 |
+
action_type="mark_diagnosed",
|
| 184 |
+
diagnosis="data_leakage",
|
| 185 |
+
)
|
| 186 |
+
)
|
| 187 |
+
session = env._get_session()
|
| 188 |
+
if session and session.last_score is not None:
|
| 189 |
+
return session.last_score
|
| 190 |
+
return 0.0
|
| 191 |
+
|
| 192 |
+
# Check for overfitting (val_loss diverging)
|
| 193 |
+
if obs.val_loss_history and len(obs.val_loss_history) >= 10:
|
| 194 |
+
early = sum(obs.val_loss_history[:5]) / 5
|
| 195 |
+
late = sum(obs.val_loss_history[-5:]) / 5
|
| 196 |
+
if (
|
| 197 |
+
late > early * 1.2
|
| 198 |
+
and obs.data_batch_stats
|
| 199 |
+
and obs.data_batch_stats.class_overlap_score < 0.1
|
| 200 |
+
):
|
| 201 |
+
obs = env.step(
|
| 202 |
+
MLTrainingAction(
|
| 203 |
+
action_type="modify_config",
|
| 204 |
+
target="weight_decay",
|
| 205 |
+
value=0.01,
|
| 206 |
+
)
|
| 207 |
+
)
|
| 208 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 209 |
+
obs = env.step(
|
| 210 |
+
MLTrainingAction(
|
| 211 |
+
action_type="mark_diagnosed",
|
| 212 |
+
diagnosis="overfitting",
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
session = env._get_session()
|
| 216 |
+
if session and session.last_score is not None:
|
| 217 |
+
return session.last_score
|
| 218 |
+
return 0.0
|
| 219 |
+
|
| 220 |
+
# Step 3: inspect_model_modes
|
| 221 |
+
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 222 |
+
if obs.model_mode_info:
|
| 223 |
+
has_eval = any(v == "eval" for v in obs.model_mode_info.values())
|
| 224 |
+
if has_eval:
|
| 225 |
+
obs = env.step(MLTrainingAction(action_type="fix_model_mode"))
|
| 226 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 227 |
+
obs = env.step(
|
| 228 |
+
MLTrainingAction(
|
| 229 |
+
action_type="mark_diagnosed",
|
| 230 |
+
diagnosis="batchnorm_eval_mode",
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
session = env._get_session()
|
| 234 |
+
if session and session.last_score is not None:
|
| 235 |
+
return session.last_score
|
| 236 |
+
return 0.0
|
| 237 |
+
|
| 238 |
+
# Step 4: inspect_code (for Task 6)
|
| 239 |
+
obs = env.step(MLTrainingAction(action_type="inspect_code"))
|
| 240 |
+
if obs.code_snippet:
|
| 241 |
+
# Simple pattern matching for known bugs
|
| 242 |
+
code = obs.code_snippet.code
|
| 243 |
+
if "model.eval()" in code and "model.train()" not in code:
|
| 244 |
+
obs = env.step(
|
| 245 |
+
MLTrainingAction(
|
| 246 |
+
action_type="fix_code",
|
| 247 |
+
line=5,
|
| 248 |
+
replacement="model.train()",
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
elif ".detach()" in code:
|
| 252 |
+
obs = env.step(
|
| 253 |
+
MLTrainingAction(
|
| 254 |
+
action_type="fix_code",
|
| 255 |
+
line=14,
|
| 256 |
+
replacement=" loss = criterion(output, batch_y)",
|
| 257 |
+
)
|
| 258 |
+
)
|
| 259 |
+
else:
|
| 260 |
+
# Can't reliably fix — just diagnose
|
| 261 |
+
pass
|
| 262 |
+
|
| 263 |
+
if obs.episode_state.fix_action_taken:
|
| 264 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 265 |
+
|
| 266 |
+
obs = env.step(
|
| 267 |
+
MLTrainingAction(
|
| 268 |
+
action_type="mark_diagnosed",
|
| 269 |
+
diagnosis="code_bug",
|
| 270 |
+
)
|
| 271 |
+
)
|
| 272 |
+
session = env._get_session()
|
| 273 |
+
if session and session.last_score is not None:
|
| 274 |
+
return session.last_score
|
| 275 |
+
return 0.0
|
| 276 |
+
|
| 277 |
+
# Fallback
|
| 278 |
+
obs = env.step(
|
| 279 |
+
MLTrainingAction(
|
| 280 |
+
action_type="mark_diagnosed",
|
| 281 |
+
diagnosis="overfitting",
|
| 282 |
+
)
|
| 283 |
+
)
|
| 284 |
+
session = env._get_session()
|
| 285 |
+
if session and session.last_score is not None:
|
| 286 |
+
return session.last_score
|
| 287 |
+
return 0.0
|
server/environment.py
ADDED
|
@@ -0,0 +1,516 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MLTrainingEnvironment — extends openenv Environment.
|
| 2 |
+
|
| 3 |
+
Full implementation of reset() and step() with session isolation,
|
| 4 |
+
progressive information reveal, and comprehensive error handling.
|
| 5 |
+
step() NEVER raises an unhandled exception.
|
| 6 |
+
Spec reference: Sections 9, 13, 16.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import dataclasses
|
| 12 |
+
import logging
|
| 13 |
+
import uuid
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
from openenv.core.env_server.interfaces import Environment
|
| 18 |
+
|
| 19 |
+
from ml_training_debugger.code_templates import (
|
| 20 |
+
generate_code_snippet,
|
| 21 |
+
validate_fix,
|
| 22 |
+
)
|
| 23 |
+
from ml_training_debugger.graders import grade_episode
|
| 24 |
+
from ml_training_debugger.models import (
|
| 25 |
+
ALL_ACTION_TYPES,
|
| 26 |
+
VALID_CONFIG_KEYS,
|
| 27 |
+
VALID_DIAGNOSES,
|
| 28 |
+
CodeSnippet,
|
| 29 |
+
DataBatchStats,
|
| 30 |
+
EpisodeState,
|
| 31 |
+
MLTrainingAction,
|
| 32 |
+
MLTrainingObservation,
|
| 33 |
+
TrainingConfig,
|
| 34 |
+
)
|
| 35 |
+
from ml_training_debugger.pytorch_engine import (
|
| 36 |
+
create_model_and_inject_fault,
|
| 37 |
+
extract_gradient_stats,
|
| 38 |
+
extract_model_modes,
|
| 39 |
+
extract_weight_stats,
|
| 40 |
+
)
|
| 41 |
+
from ml_training_debugger.reward_engine import compute_reward
|
| 42 |
+
from ml_training_debugger.scenarios import ScenarioParams, sample_scenario
|
| 43 |
+
from ml_training_debugger.simulation import (
|
| 44 |
+
gen_data_batch_stats,
|
| 45 |
+
gen_loss_history,
|
| 46 |
+
gen_val_accuracy_history,
|
| 47 |
+
gen_val_loss_history,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
logger = logging.getLogger(__name__)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclasses.dataclass
|
| 54 |
+
class SessionData:
|
| 55 |
+
"""Per-session episode data."""
|
| 56 |
+
|
| 57 |
+
scenario: ScenarioParams
|
| 58 |
+
model: torch.nn.Module
|
| 59 |
+
state: EpisodeState
|
| 60 |
+
config: TrainingConfig
|
| 61 |
+
gradient_stats: list[Any]
|
| 62 |
+
weight_stats: list[Any] | None
|
| 63 |
+
model_modes: dict[str, str] | None
|
| 64 |
+
data_batch_stats_raw: dict | None
|
| 65 |
+
code_snippet_raw: dict | None
|
| 66 |
+
loss_history: list[float]
|
| 67 |
+
val_acc_history: list[float]
|
| 68 |
+
val_loss_history: list[float]
|
| 69 |
+
done: bool
|
| 70 |
+
last_score: float | None
|
| 71 |
+
convergence_after_fix: bool
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation, dict]):
|
| 75 |
+
"""OpenEnv environment for PyTorch training run debugging.
|
| 76 |
+
|
| 77 |
+
Spec Section 9 — Architecture.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 81 |
+
|
| 82 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 83 |
+
super().__init__(**kwargs)
|
| 84 |
+
self._sessions: dict[str, SessionData] = {}
|
| 85 |
+
self._last_completed: dict[str, dict] = {}
|
| 86 |
+
self._current_session_id: str = ""
|
| 87 |
+
|
| 88 |
+
def _get_session(self, episode_id: str | None = None) -> SessionData | None:
|
| 89 |
+
sid = episode_id or self._current_session_id
|
| 90 |
+
return self._sessions.get(sid)
|
| 91 |
+
|
| 92 |
+
def _build_observation(
|
| 93 |
+
self, session: SessionData, reward: float = 0.0
|
| 94 |
+
) -> MLTrainingObservation:
|
| 95 |
+
"""Build observation from session data."""
|
| 96 |
+
state = session.state
|
| 97 |
+
|
| 98 |
+
gradient_stats_models = []
|
| 99 |
+
if state.gradients_inspected and session.gradient_stats:
|
| 100 |
+
gradient_stats_models = session.gradient_stats
|
| 101 |
+
|
| 102 |
+
weight_stats_models = None
|
| 103 |
+
if state.model_weights_inspected and session.weight_stats is not None:
|
| 104 |
+
weight_stats_models = session.weight_stats
|
| 105 |
+
|
| 106 |
+
data_batch = None
|
| 107 |
+
if state.data_inspected and session.data_batch_stats_raw is not None:
|
| 108 |
+
data_batch = DataBatchStats(**session.data_batch_stats_raw)
|
| 109 |
+
|
| 110 |
+
model_modes = None
|
| 111 |
+
if state.model_modes_inspected and session.model_modes is not None:
|
| 112 |
+
model_modes = session.model_modes
|
| 113 |
+
|
| 114 |
+
code_snippet = None
|
| 115 |
+
if state.code_inspected and session.code_snippet_raw is not None:
|
| 116 |
+
code_snippet = CodeSnippet(**session.code_snippet_raw)
|
| 117 |
+
|
| 118 |
+
return MLTrainingObservation(
|
| 119 |
+
run_id=self._current_session_id,
|
| 120 |
+
framework="pytorch",
|
| 121 |
+
epoch=20,
|
| 122 |
+
training_loss_history=session.loss_history,
|
| 123 |
+
val_loss_history=session.val_loss_history,
|
| 124 |
+
val_accuracy_history=session.val_acc_history,
|
| 125 |
+
gradient_stats=gradient_stats_models,
|
| 126 |
+
model_weight_stats=weight_stats_models,
|
| 127 |
+
gpu_memory_used_gb=session.scenario.gpu_memory_used_gb,
|
| 128 |
+
gpu_memory_total_gb=16.0,
|
| 129 |
+
learning_rate=session.config.learning_rate,
|
| 130 |
+
current_config=session.config,
|
| 131 |
+
error_log=session.scenario.error_log,
|
| 132 |
+
data_batch_stats=data_batch,
|
| 133 |
+
model_mode_info=model_modes,
|
| 134 |
+
code_snippet=code_snippet,
|
| 135 |
+
available_actions=state.compute_available_actions(),
|
| 136 |
+
episode_state=state,
|
| 137 |
+
notes=session.scenario.notes,
|
| 138 |
+
done=session.done,
|
| 139 |
+
reward=reward,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
def reset(
|
| 143 |
+
self,
|
| 144 |
+
seed: Optional[int] = None,
|
| 145 |
+
episode_id: Optional[str] = None,
|
| 146 |
+
**kwargs: Any,
|
| 147 |
+
) -> MLTrainingObservation:
|
| 148 |
+
"""Reset environment for a new episode. Spec Section 13."""
|
| 149 |
+
# Determine task_id — passed via kwargs or defaults to task_001
|
| 150 |
+
task_id = kwargs.get("task_id", "task_001")
|
| 151 |
+
|
| 152 |
+
# If called with episode_id that has an active session, terminate it
|
| 153 |
+
session_id = episode_id or str(uuid.uuid4())
|
| 154 |
+
if session_id in self._sessions:
|
| 155 |
+
old = self._sessions[session_id]
|
| 156 |
+
if not old.done:
|
| 157 |
+
score = grade_episode(old.scenario.task_id, old.state, old.scenario)
|
| 158 |
+
self._last_completed[session_id] = {
|
| 159 |
+
"score": score,
|
| 160 |
+
"task_id": old.scenario.task_id,
|
| 161 |
+
"steps": old.state.step_count,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
self._current_session_id = session_id
|
| 165 |
+
|
| 166 |
+
# Derive deterministic seed
|
| 167 |
+
base_seed = seed if seed is not None else 42
|
| 168 |
+
scenario = sample_scenario(task_id, base_seed)
|
| 169 |
+
|
| 170 |
+
# Set torch seed for reproducibility
|
| 171 |
+
torch.manual_seed(scenario.seed)
|
| 172 |
+
|
| 173 |
+
# Create real PyTorch model with fault injection
|
| 174 |
+
model, info = create_model_and_inject_fault(scenario)
|
| 175 |
+
|
| 176 |
+
# Generate parametric curves
|
| 177 |
+
loss_history = gen_loss_history(scenario)
|
| 178 |
+
val_acc_history = gen_val_accuracy_history(scenario)
|
| 179 |
+
val_loss_history = gen_val_loss_history(scenario)
|
| 180 |
+
|
| 181 |
+
# Pre-generate data batch stats
|
| 182 |
+
data_batch_raw = gen_data_batch_stats(scenario)
|
| 183 |
+
|
| 184 |
+
# Pre-generate code snippet (for Task 6)
|
| 185 |
+
code_snippet_raw = None
|
| 186 |
+
if scenario.bug_type is not None:
|
| 187 |
+
code_snippet_raw = generate_code_snippet(scenario.bug_type, scenario.seed)
|
| 188 |
+
|
| 189 |
+
# Build initial config from scenario
|
| 190 |
+
config = TrainingConfig(
|
| 191 |
+
learning_rate=scenario.learning_rate,
|
| 192 |
+
weight_decay=scenario.weight_decay,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Create fresh episode state
|
| 196 |
+
state = EpisodeState()
|
| 197 |
+
|
| 198 |
+
session = SessionData(
|
| 199 |
+
scenario=scenario,
|
| 200 |
+
model=model,
|
| 201 |
+
state=state,
|
| 202 |
+
config=config,
|
| 203 |
+
gradient_stats=[],
|
| 204 |
+
weight_stats=None,
|
| 205 |
+
model_modes=None,
|
| 206 |
+
data_batch_stats_raw=data_batch_raw,
|
| 207 |
+
code_snippet_raw=code_snippet_raw,
|
| 208 |
+
loss_history=loss_history,
|
| 209 |
+
val_acc_history=val_acc_history,
|
| 210 |
+
val_loss_history=val_loss_history,
|
| 211 |
+
done=False,
|
| 212 |
+
last_score=None,
|
| 213 |
+
convergence_after_fix=False,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
self._sessions[session_id] = session
|
| 217 |
+
|
| 218 |
+
logger.info(
|
| 219 |
+
"reset",
|
| 220 |
+
extra={
|
| 221 |
+
"session_id": session_id,
|
| 222 |
+
"task_id": task_id,
|
| 223 |
+
"scenario_seed": scenario.seed,
|
| 224 |
+
},
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
return self._build_observation(session)
|
| 228 |
+
|
| 229 |
+
def step(
|
| 230 |
+
self,
|
| 231 |
+
action: MLTrainingAction,
|
| 232 |
+
timeout_s: Optional[float] = None,
|
| 233 |
+
**kwargs: Any,
|
| 234 |
+
) -> MLTrainingObservation:
|
| 235 |
+
"""Process one agent action. NEVER raises. Spec Sections 13, 16."""
|
| 236 |
+
session = self._get_session()
|
| 237 |
+
|
| 238 |
+
# No active episode
|
| 239 |
+
if session is None:
|
| 240 |
+
return MLTrainingObservation(
|
| 241 |
+
done=True,
|
| 242 |
+
reward=0.0,
|
| 243 |
+
error_log="Error: no active episode. Call reset(task_id) first.",
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Episode already done
|
| 247 |
+
if session.done:
|
| 248 |
+
return self._build_observation(session, reward=0.0)
|
| 249 |
+
|
| 250 |
+
state = session.state
|
| 251 |
+
scenario = session.scenario
|
| 252 |
+
action_type = action.action_type
|
| 253 |
+
|
| 254 |
+
# Increment step count
|
| 255 |
+
state.step_count += 1
|
| 256 |
+
|
| 257 |
+
# Validate action_type is a known type
|
| 258 |
+
if action_type not in ALL_ACTION_TYPES:
|
| 259 |
+
reward = compute_reward(action, state, scenario, is_valid_action=False)
|
| 260 |
+
state.actions_taken.append(f"invalid:{action_type}")
|
| 261 |
+
obs = self._build_observation(session, reward=reward)
|
| 262 |
+
obs.error_log = (
|
| 263 |
+
f"Invalid action_type: {action_type}. "
|
| 264 |
+
f"Valid types: {sorted(ALL_ACTION_TYPES)}"
|
| 265 |
+
)
|
| 266 |
+
return obs
|
| 267 |
+
|
| 268 |
+
# Check if action is in available_actions
|
| 269 |
+
available = state.compute_available_actions()
|
| 270 |
+
if action_type not in available:
|
| 271 |
+
reward = compute_reward(action, state, scenario, is_valid_action=False)
|
| 272 |
+
state.actions_taken.append(f"unavailable:{action_type}")
|
| 273 |
+
obs = self._build_observation(session, reward=reward)
|
| 274 |
+
obs.error_log = (
|
| 275 |
+
f"Action '{action_type}' not available. " f"Available: {available}"
|
| 276 |
+
)
|
| 277 |
+
return obs
|
| 278 |
+
|
| 279 |
+
# Validate required fields for specific actions
|
| 280 |
+
error = self._validate_action_fields(action)
|
| 281 |
+
if error is not None:
|
| 282 |
+
reward = compute_reward(action, state, scenario, is_valid_action=False)
|
| 283 |
+
state.actions_taken.append(f"malformed:{action_type}")
|
| 284 |
+
obs = self._build_observation(session, reward=reward)
|
| 285 |
+
obs.error_log = error
|
| 286 |
+
return obs
|
| 287 |
+
|
| 288 |
+
# Dispatch action
|
| 289 |
+
is_correct_fix: bool | None = None
|
| 290 |
+
convergence = False
|
| 291 |
+
|
| 292 |
+
try:
|
| 293 |
+
is_correct_fix, convergence = self._dispatch_action(action, session)
|
| 294 |
+
except Exception as exc:
|
| 295 |
+
logger.error(
|
| 296 |
+
"step_error",
|
| 297 |
+
extra={
|
| 298 |
+
"session_id": self._current_session_id,
|
| 299 |
+
"action": action_type,
|
| 300 |
+
"error": str(exc),
|
| 301 |
+
},
|
| 302 |
+
exc_info=True,
|
| 303 |
+
)
|
| 304 |
+
reward = compute_reward(action, state, scenario, is_valid_action=False)
|
| 305 |
+
obs = self._build_observation(session, reward=reward)
|
| 306 |
+
obs.error_log = f"Internal error processing {action_type}: {exc}"
|
| 307 |
+
return obs
|
| 308 |
+
|
| 309 |
+
# Record action
|
| 310 |
+
if action_type == "mark_diagnosed" and action.diagnosis:
|
| 311 |
+
state.actions_taken.append(f"mark_diagnosed:{action.diagnosis}")
|
| 312 |
+
else:
|
| 313 |
+
state.actions_taken.append(action_type)
|
| 314 |
+
|
| 315 |
+
# Compute reward
|
| 316 |
+
reward = compute_reward(
|
| 317 |
+
action,
|
| 318 |
+
state,
|
| 319 |
+
scenario,
|
| 320 |
+
is_valid_action=True,
|
| 321 |
+
is_correct_fix=is_correct_fix,
|
| 322 |
+
convergence_confirmed=convergence,
|
| 323 |
+
)
|
| 324 |
+
|
| 325 |
+
# Check step limit
|
| 326 |
+
if state.step_count >= scenario.max_steps and not session.done:
|
| 327 |
+
session.done = True
|
| 328 |
+
|
| 329 |
+
# Check done
|
| 330 |
+
if session.done:
|
| 331 |
+
score = grade_episode(scenario.task_id, state, scenario)
|
| 332 |
+
session.last_score = score
|
| 333 |
+
self._last_completed[self._current_session_id] = {
|
| 334 |
+
"score": score,
|
| 335 |
+
"task_id": scenario.task_id,
|
| 336 |
+
"steps": state.step_count,
|
| 337 |
+
}
|
| 338 |
+
logger.info(
|
| 339 |
+
"episode_completed",
|
| 340 |
+
extra={
|
| 341 |
+
"session_id": self._current_session_id,
|
| 342 |
+
"task_id": scenario.task_id,
|
| 343 |
+
"steps": state.step_count,
|
| 344 |
+
"score": score,
|
| 345 |
+
},
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
logger.info(
|
| 349 |
+
"step",
|
| 350 |
+
extra={
|
| 351 |
+
"session_id": self._current_session_id,
|
| 352 |
+
"step_count": state.step_count,
|
| 353 |
+
"action_type": action_type,
|
| 354 |
+
"reward": reward,
|
| 355 |
+
},
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
return self._build_observation(session, reward=reward)
|
| 359 |
+
|
| 360 |
+
def _validate_action_fields(self, action: MLTrainingAction) -> str | None:
|
| 361 |
+
"""Validate required fields for specific actions. Return error or None."""
|
| 362 |
+
if action.action_type == "modify_config":
|
| 363 |
+
if action.target is None or action.value is None:
|
| 364 |
+
return "modify_config requires 'target' and 'value' fields"
|
| 365 |
+
if action.target not in VALID_CONFIG_KEYS:
|
| 366 |
+
return f"Unknown config key: {action.target}. Valid: {sorted(VALID_CONFIG_KEYS)}"
|
| 367 |
+
|
| 368 |
+
if action.action_type == "mark_diagnosed":
|
| 369 |
+
if action.diagnosis is None:
|
| 370 |
+
return "mark_diagnosed requires 'diagnosis' field"
|
| 371 |
+
if action.diagnosis not in VALID_DIAGNOSES:
|
| 372 |
+
return (
|
| 373 |
+
f"Invalid diagnosis: {action.diagnosis}. "
|
| 374 |
+
f"Valid: {sorted(VALID_DIAGNOSES)}"
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
if action.action_type == "fix_code":
|
| 378 |
+
if action.line is None or action.replacement is None:
|
| 379 |
+
return "fix_code requires 'line' and 'replacement' fields"
|
| 380 |
+
|
| 381 |
+
return None
|
| 382 |
+
|
| 383 |
+
def _dispatch_action(
|
| 384 |
+
self, action: MLTrainingAction, session: SessionData
|
| 385 |
+
) -> tuple[bool | None, bool]:
|
| 386 |
+
"""Dispatch action to handler. Returns (is_correct_fix, convergence)."""
|
| 387 |
+
state = session.state
|
| 388 |
+
scenario = session.scenario
|
| 389 |
+
is_correct_fix: bool | None = None
|
| 390 |
+
convergence = False
|
| 391 |
+
|
| 392 |
+
at = action.action_type
|
| 393 |
+
|
| 394 |
+
if at == "inspect_gradients":
|
| 395 |
+
if not state.gradients_inspected:
|
| 396 |
+
stats = extract_gradient_stats(session.model, scenario)
|
| 397 |
+
session.gradient_stats = stats
|
| 398 |
+
state.gradients_inspected = True
|
| 399 |
+
# Set gradients_were_normal: True if ALL layers is_exploding=False
|
| 400 |
+
state.gradients_were_normal = all(not s.is_exploding for s in stats)
|
| 401 |
+
|
| 402 |
+
elif at == "inspect_data_batch":
|
| 403 |
+
state.data_inspected = True
|
| 404 |
+
|
| 405 |
+
elif at == "inspect_model_modes":
|
| 406 |
+
if not state.model_modes_inspected:
|
| 407 |
+
modes = extract_model_modes(session.model)
|
| 408 |
+
session.model_modes = modes
|
| 409 |
+
state.model_modes_inspected = True
|
| 410 |
+
|
| 411 |
+
elif at == "inspect_model_weights":
|
| 412 |
+
if not state.model_weights_inspected:
|
| 413 |
+
stats = extract_weight_stats(session.model)
|
| 414 |
+
session.weight_stats = stats
|
| 415 |
+
state.model_weights_inspected = True
|
| 416 |
+
|
| 417 |
+
elif at == "inspect_code":
|
| 418 |
+
state.code_inspected = True
|
| 419 |
+
|
| 420 |
+
elif at == "modify_config":
|
| 421 |
+
if action.target and action.value is not None:
|
| 422 |
+
setattr(session.config, action.target, action.value)
|
| 423 |
+
state.fix_action_taken = True
|
| 424 |
+
|
| 425 |
+
elif at == "add_callback":
|
| 426 |
+
state.fix_action_taken = True
|
| 427 |
+
|
| 428 |
+
elif at == "replace_optimizer":
|
| 429 |
+
state.fix_action_taken = True
|
| 430 |
+
|
| 431 |
+
elif at == "patch_data_loader":
|
| 432 |
+
state.fix_action_taken = True
|
| 433 |
+
|
| 434 |
+
elif at == "fix_model_mode":
|
| 435 |
+
state.fix_action_taken = True
|
| 436 |
+
|
| 437 |
+
elif at == "fix_code":
|
| 438 |
+
state.fix_action_taken = True
|
| 439 |
+
if scenario.bug_type and action.line and action.replacement:
|
| 440 |
+
is_correct_fix = validate_fix(
|
| 441 |
+
scenario.bug_type, action.line, action.replacement
|
| 442 |
+
)
|
| 443 |
+
else:
|
| 444 |
+
is_correct_fix = False
|
| 445 |
+
|
| 446 |
+
elif at == "restart_run":
|
| 447 |
+
state.restart_after_fix = True
|
| 448 |
+
# Check convergence — did the fix address the root cause?
|
| 449 |
+
convergence = self._check_convergence(session)
|
| 450 |
+
session.convergence_after_fix = convergence
|
| 451 |
+
|
| 452 |
+
elif at == "mark_diagnosed":
|
| 453 |
+
state.diagnosis_submitted = True
|
| 454 |
+
session.done = True
|
| 455 |
+
|
| 456 |
+
elif at == "rollback_checkpoint":
|
| 457 |
+
pass # No-op for now
|
| 458 |
+
|
| 459 |
+
return is_correct_fix, convergence
|
| 460 |
+
|
| 461 |
+
def _check_convergence(self, session: SessionData) -> bool:
|
| 462 |
+
"""Check if the applied fix would resolve the root cause."""
|
| 463 |
+
scenario = session.scenario
|
| 464 |
+
state = session.state
|
| 465 |
+
root = scenario.root_cause.value
|
| 466 |
+
|
| 467 |
+
if root == "lr_too_high":
|
| 468 |
+
return (
|
| 469 |
+
"modify_config" in state.actions_taken
|
| 470 |
+
and session.config.learning_rate <= 0.001
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if root == "vanishing_gradients":
|
| 474 |
+
return (
|
| 475 |
+
"modify_config" in state.actions_taken
|
| 476 |
+
and session.config.learning_rate >= 0.001
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
if root == "data_leakage":
|
| 480 |
+
return "patch_data_loader" in state.actions_taken
|
| 481 |
+
|
| 482 |
+
if root == "overfitting":
|
| 483 |
+
return (
|
| 484 |
+
"modify_config" in state.actions_taken
|
| 485 |
+
or "add_callback" in state.actions_taken
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
if root == "batchnorm_eval_mode":
|
| 489 |
+
return "fix_model_mode" in state.actions_taken
|
| 490 |
+
|
| 491 |
+
if root == "code_bug":
|
| 492 |
+
return "fix_code" in state.actions_taken and state.fix_action_taken
|
| 493 |
+
|
| 494 |
+
return False
|
| 495 |
+
|
| 496 |
+
@property
|
| 497 |
+
def state(self) -> dict:
|
| 498 |
+
"""Return current environment state."""
|
| 499 |
+
session = self._get_session()
|
| 500 |
+
if session is None:
|
| 501 |
+
return {"status": "no_active_episode"}
|
| 502 |
+
return {
|
| 503 |
+
"status": "active",
|
| 504 |
+
"task_id": session.scenario.task_id,
|
| 505 |
+
"step_count": session.state.step_count,
|
| 506 |
+
"done": session.done,
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
def get_last_completed(self, session_id: str | None = None) -> dict | None:
|
| 510 |
+
"""Get last completed episode data for grader endpoint."""
|
| 511 |
+
if session_id:
|
| 512 |
+
return self._last_completed.get(session_id)
|
| 513 |
+
# Return most recent
|
| 514 |
+
if self._last_completed:
|
| 515 |
+
return list(self._last_completed.values())[-1]
|
| 516 |
+
return None
|
tests/__init__.py
ADDED
|
File without changes
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Shared test fixtures."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.models import (
|
| 8 |
+
EpisodeState,
|
| 9 |
+
TrainingConfig,
|
| 10 |
+
)
|
| 11 |
+
from ml_training_debugger.scenarios import ScenarioParams, sample_scenario
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@pytest.fixture
|
| 15 |
+
def fresh_state() -> EpisodeState:
|
| 16 |
+
return EpisodeState()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.fixture
|
| 20 |
+
def sample_config() -> TrainingConfig:
|
| 21 |
+
return TrainingConfig(learning_rate=0.001)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@pytest.fixture
|
| 25 |
+
def task_001_scenario() -> ScenarioParams:
|
| 26 |
+
return sample_scenario("task_001", seed=42)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture
|
| 30 |
+
def task_003_scenario() -> ScenarioParams:
|
| 31 |
+
return sample_scenario("task_003", seed=42)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@pytest.fixture
|
| 35 |
+
def task_005_scenario() -> ScenarioParams:
|
| 36 |
+
return sample_scenario("task_005", seed=42)
|
tests/test_code_templates.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test code bug generation and fix validation."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.code_templates import generate_code_snippet, validate_fix
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TestGenerateCodeSnippet:
|
| 11 |
+
def test_eval_mode(self):
|
| 12 |
+
snippet = generate_code_snippet("eval_mode")
|
| 13 |
+
assert "model.eval()" in snippet["code"]
|
| 14 |
+
assert snippet["filename"] == "train.py"
|
| 15 |
+
assert snippet["line_count"] > 0
|
| 16 |
+
assert len(snippet["imports"]) > 0
|
| 17 |
+
|
| 18 |
+
def test_detach_loss(self):
|
| 19 |
+
snippet = generate_code_snippet("detach_loss")
|
| 20 |
+
assert ".detach()" in snippet["code"]
|
| 21 |
+
|
| 22 |
+
def test_zero_grad_missing(self):
|
| 23 |
+
snippet = generate_code_snippet("zero_grad_missing")
|
| 24 |
+
assert "zero_grad" not in snippet["code"]
|
| 25 |
+
|
| 26 |
+
def test_inplace_relu(self):
|
| 27 |
+
snippet = generate_code_snippet("inplace_relu")
|
| 28 |
+
assert "inplace=True" in snippet["code"]
|
| 29 |
+
|
| 30 |
+
def test_unknown_bug_raises(self):
|
| 31 |
+
with pytest.raises(ValueError):
|
| 32 |
+
generate_code_snippet("nonexistent_bug")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestValidateFix:
|
| 36 |
+
def test_eval_mode_correct_fix(self):
|
| 37 |
+
assert validate_fix("eval_mode", 5, "model.train()")
|
| 38 |
+
|
| 39 |
+
def test_eval_mode_with_whitespace(self):
|
| 40 |
+
assert validate_fix("eval_mode", 5, " model.train() ")
|
| 41 |
+
|
| 42 |
+
def test_eval_mode_wrong_fix(self):
|
| 43 |
+
assert not validate_fix("eval_mode", 5, "pass")
|
| 44 |
+
|
| 45 |
+
def test_detach_loss_correct_fix(self):
|
| 46 |
+
assert validate_fix(
|
| 47 |
+
"detach_loss", 14, " loss = criterion(output, batch_y)"
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def test_detach_loss_with_trailing_spaces(self):
|
| 51 |
+
assert validate_fix(
|
| 52 |
+
"detach_loss", 14, " loss = criterion(output, batch_y) "
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def test_zero_grad_correct_fix(self):
|
| 56 |
+
assert validate_fix("zero_grad_missing", 11, " optimizer.zero_grad()")
|
| 57 |
+
|
| 58 |
+
def test_inplace_relu_correct_fix(self):
|
| 59 |
+
assert validate_fix("inplace_relu", 15, " output = F.relu(output)")
|
| 60 |
+
|
| 61 |
+
def test_wrong_line_number(self):
|
| 62 |
+
assert not validate_fix("eval_mode", 999, "model.train()")
|
| 63 |
+
|
| 64 |
+
def test_unknown_bug_type(self):
|
| 65 |
+
assert not validate_fix("nonexistent", 1, "pass")
|
tests/test_episode_lifecycle.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test full episode lifecycle — reset, step, state transitions."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.models import MLTrainingAction
|
| 8 |
+
from server.environment import MLTrainingEnvironment
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@pytest.fixture
|
| 12 |
+
def env():
|
| 13 |
+
return MLTrainingEnvironment()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestReset:
|
| 17 |
+
def test_reset_returns_valid_observation(self, env):
|
| 18 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 19 |
+
assert obs.run_id == "test"
|
| 20 |
+
assert obs.framework == "pytorch"
|
| 21 |
+
assert len(obs.training_loss_history) == 20
|
| 22 |
+
assert len(obs.val_accuracy_history) == 20
|
| 23 |
+
assert obs.done is False
|
| 24 |
+
|
| 25 |
+
def test_reset_initial_state(self, env):
|
| 26 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 27 |
+
assert obs.episode_state.step_count == 0
|
| 28 |
+
assert not obs.episode_state.gradients_inspected
|
| 29 |
+
assert not obs.episode_state.diagnosis_submitted
|
| 30 |
+
|
| 31 |
+
def test_reset_progressive_reveal(self, env):
|
| 32 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 33 |
+
assert obs.gradient_stats == []
|
| 34 |
+
assert obs.model_weight_stats is None
|
| 35 |
+
assert obs.data_batch_stats is None
|
| 36 |
+
assert obs.model_mode_info is None
|
| 37 |
+
assert obs.code_snippet is None
|
| 38 |
+
|
| 39 |
+
def test_reset_available_actions(self, env):
|
| 40 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 41 |
+
assert "inspect_gradients" in obs.available_actions
|
| 42 |
+
assert "mark_diagnosed" in obs.available_actions
|
| 43 |
+
assert "fix_code" not in obs.available_actions
|
| 44 |
+
assert "restart_run" not in obs.available_actions
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TestStepInspections:
|
| 48 |
+
def test_inspect_gradients_populates_stats(self, env):
|
| 49 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 50 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 51 |
+
assert len(obs.gradient_stats) > 0
|
| 52 |
+
assert obs.episode_state.gradients_inspected
|
| 53 |
+
|
| 54 |
+
def test_inspect_data_batch(self, env):
|
| 55 |
+
env.reset(seed=42, episode_id="test", task_id="task_003")
|
| 56 |
+
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 57 |
+
assert obs.data_batch_stats is not None
|
| 58 |
+
assert obs.episode_state.data_inspected
|
| 59 |
+
|
| 60 |
+
def test_inspect_model_modes(self, env):
|
| 61 |
+
env.reset(seed=42, episode_id="test", task_id="task_005")
|
| 62 |
+
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 63 |
+
assert obs.model_mode_info is not None
|
| 64 |
+
assert obs.episode_state.model_modes_inspected
|
| 65 |
+
|
| 66 |
+
def test_inspect_model_weights(self, env):
|
| 67 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 68 |
+
obs = env.step(MLTrainingAction(action_type="inspect_model_weights"))
|
| 69 |
+
assert obs.model_weight_stats is not None
|
| 70 |
+
assert obs.episode_state.model_weights_inspected
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class TestStepFixActions:
|
| 74 |
+
def test_modify_config(self, env):
|
| 75 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 76 |
+
obs = env.step(
|
| 77 |
+
MLTrainingAction(
|
| 78 |
+
action_type="modify_config",
|
| 79 |
+
target="learning_rate",
|
| 80 |
+
value=0.001,
|
| 81 |
+
)
|
| 82 |
+
)
|
| 83 |
+
assert obs.episode_state.fix_action_taken
|
| 84 |
+
assert "restart_run" in obs.available_actions
|
| 85 |
+
|
| 86 |
+
def test_restart_run_after_fix(self, env):
|
| 87 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 88 |
+
env.step(
|
| 89 |
+
MLTrainingAction(
|
| 90 |
+
action_type="modify_config",
|
| 91 |
+
target="learning_rate",
|
| 92 |
+
value=0.001,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 96 |
+
assert obs.episode_state.restart_after_fix
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class TestStepDiagnosis:
|
| 100 |
+
def test_mark_diagnosed_ends_episode(self, env):
|
| 101 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 102 |
+
obs = env.step(
|
| 103 |
+
MLTrainingAction(
|
| 104 |
+
action_type="mark_diagnosed",
|
| 105 |
+
diagnosis="lr_too_high",
|
| 106 |
+
)
|
| 107 |
+
)
|
| 108 |
+
assert obs.done is True
|
| 109 |
+
assert obs.episode_state.diagnosis_submitted
|
| 110 |
+
|
| 111 |
+
def test_step_after_done(self, env):
|
| 112 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 113 |
+
env.step(
|
| 114 |
+
MLTrainingAction(
|
| 115 |
+
action_type="mark_diagnosed",
|
| 116 |
+
diagnosis="lr_too_high",
|
| 117 |
+
)
|
| 118 |
+
)
|
| 119 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 120 |
+
assert obs.done is True
|
| 121 |
+
assert obs.reward == 0.0
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class TestErrorHandling:
|
| 125 |
+
def test_invalid_action_type(self, env):
|
| 126 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 127 |
+
obs = env.step(MLTrainingAction(action_type="nonexistent_action"))
|
| 128 |
+
assert obs.reward == pytest.approx(-0.01 + -0.05)
|
| 129 |
+
assert obs.error_log is not None
|
| 130 |
+
|
| 131 |
+
def test_action_not_in_available(self, env):
|
| 132 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 133 |
+
# fix_code requires code_inspected=True
|
| 134 |
+
obs = env.step(
|
| 135 |
+
MLTrainingAction(
|
| 136 |
+
action_type="fix_code",
|
| 137 |
+
line=1,
|
| 138 |
+
replacement="pass",
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
assert obs.reward < 0
|
| 142 |
+
|
| 143 |
+
def test_modify_config_missing_target(self, env):
|
| 144 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 145 |
+
obs = env.step(MLTrainingAction(action_type="modify_config"))
|
| 146 |
+
assert "target" in obs.error_log.lower() or "value" in obs.error_log.lower()
|
| 147 |
+
|
| 148 |
+
def test_mark_diagnosed_missing_diagnosis(self, env):
|
| 149 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 150 |
+
obs = env.step(MLTrainingAction(action_type="mark_diagnosed"))
|
| 151 |
+
assert "diagnosis" in obs.error_log.lower()
|
| 152 |
+
|
| 153 |
+
def test_mark_diagnosed_invalid_diagnosis(self, env):
|
| 154 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 155 |
+
obs = env.step(
|
| 156 |
+
MLTrainingAction(
|
| 157 |
+
action_type="mark_diagnosed",
|
| 158 |
+
diagnosis="not_a_real_diagnosis",
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
assert "invalid" in obs.error_log.lower()
|
| 162 |
+
|
| 163 |
+
def test_step_before_reset(self, env):
|
| 164 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 165 |
+
assert obs.done is True
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TestFullEpisodeFlow:
|
| 169 |
+
def test_task_001_full_flow(self, env):
|
| 170 |
+
"""Full optimal flow for Task 1."""
|
| 171 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 172 |
+
assert not obs.done
|
| 173 |
+
|
| 174 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 175 |
+
assert obs.episode_state.gradients_inspected
|
| 176 |
+
assert any(g.is_exploding for g in obs.gradient_stats)
|
| 177 |
+
|
| 178 |
+
obs = env.step(
|
| 179 |
+
MLTrainingAction(
|
| 180 |
+
action_type="modify_config",
|
| 181 |
+
target="learning_rate",
|
| 182 |
+
value=0.001,
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
assert obs.episode_state.fix_action_taken
|
| 186 |
+
|
| 187 |
+
obs = env.step(MLTrainingAction(action_type="restart_run"))
|
| 188 |
+
assert obs.episode_state.restart_after_fix
|
| 189 |
+
|
| 190 |
+
obs = env.step(
|
| 191 |
+
MLTrainingAction(
|
| 192 |
+
action_type="mark_diagnosed",
|
| 193 |
+
diagnosis="lr_too_high",
|
| 194 |
+
)
|
| 195 |
+
)
|
| 196 |
+
assert obs.done
|
| 197 |
+
assert obs.reward > 0
|
| 198 |
+
|
| 199 |
+
def test_task_005_context_gated_penalty(self, env):
|
| 200 |
+
"""Task 5: inspect gradients (normal) → add_callback → penalty fires."""
|
| 201 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_005")
|
| 202 |
+
|
| 203 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 204 |
+
assert obs.episode_state.gradients_inspected
|
| 205 |
+
assert obs.episode_state.gradients_were_normal
|
| 206 |
+
# All layers is_exploding=False
|
| 207 |
+
for g in obs.gradient_stats:
|
| 208 |
+
assert not g.is_exploding
|
| 209 |
+
|
| 210 |
+
# Now add_callback should trigger context-gated penalty
|
| 211 |
+
obs = env.step(MLTrainingAction(action_type="add_callback"))
|
| 212 |
+
assert obs.reward == pytest.approx(-0.01 + -0.20)
|
| 213 |
+
|
| 214 |
+
def test_task_003_data_leakage(self, env):
|
| 215 |
+
"""Task 3: data inspection reveals leakage."""
|
| 216 |
+
obs = env.reset(seed=42, episode_id="test", task_id="task_003")
|
| 217 |
+
|
| 218 |
+
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 219 |
+
assert obs.data_batch_stats is not None
|
| 220 |
+
assert obs.data_batch_stats.class_overlap_score > 0.5
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test grader functions — each returns 0.0-1.0."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.graders import (
|
| 8 |
+
grade_episode,
|
| 9 |
+
grade_task_001,
|
| 10 |
+
grade_task_003,
|
| 11 |
+
grade_task_005,
|
| 12 |
+
)
|
| 13 |
+
from ml_training_debugger.models import EpisodeState
|
| 14 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@pytest.fixture
|
| 18 |
+
def scenario_001():
|
| 19 |
+
return sample_scenario("task_001", seed=42)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def scenario_003():
|
| 24 |
+
return sample_scenario("task_003", seed=42)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture
|
| 28 |
+
def scenario_005():
|
| 29 |
+
return sample_scenario("task_005", seed=42)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TestGradeTask001:
|
| 33 |
+
def test_perfect_score(self, scenario_001):
|
| 34 |
+
state = EpisodeState(
|
| 35 |
+
gradients_inspected=True,
|
| 36 |
+
fix_action_taken=True,
|
| 37 |
+
restart_after_fix=True,
|
| 38 |
+
diagnosis_submitted=True,
|
| 39 |
+
actions_taken=[
|
| 40 |
+
"inspect_gradients",
|
| 41 |
+
"modify_config",
|
| 42 |
+
"restart_run",
|
| 43 |
+
"mark_diagnosed:lr_too_high",
|
| 44 |
+
],
|
| 45 |
+
)
|
| 46 |
+
score = grade_task_001(state, scenario_001)
|
| 47 |
+
assert score == 1.0
|
| 48 |
+
|
| 49 |
+
def test_wrong_diagnosis(self, scenario_001):
|
| 50 |
+
state = EpisodeState(
|
| 51 |
+
gradients_inspected=True,
|
| 52 |
+
fix_action_taken=True,
|
| 53 |
+
restart_after_fix=True,
|
| 54 |
+
diagnosis_submitted=True,
|
| 55 |
+
actions_taken=[
|
| 56 |
+
"inspect_gradients",
|
| 57 |
+
"modify_config",
|
| 58 |
+
"restart_run",
|
| 59 |
+
"mark_diagnosed:data_leakage",
|
| 60 |
+
],
|
| 61 |
+
)
|
| 62 |
+
score = grade_task_001(state, scenario_001)
|
| 63 |
+
assert score < 0.7 # Missing diagnosis credit
|
| 64 |
+
|
| 65 |
+
def test_no_investigation(self, scenario_001):
|
| 66 |
+
state = EpisodeState(
|
| 67 |
+
diagnosis_submitted=True,
|
| 68 |
+
actions_taken=["mark_diagnosed:lr_too_high"],
|
| 69 |
+
)
|
| 70 |
+
score = grade_task_001(state, scenario_001)
|
| 71 |
+
assert 0.0 < score < 1.0
|
| 72 |
+
|
| 73 |
+
def test_score_in_range(self, scenario_001):
|
| 74 |
+
state = EpisodeState()
|
| 75 |
+
score = grade_task_001(state, scenario_001)
|
| 76 |
+
assert 0.0 <= score <= 1.0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class TestGradeTask003:
|
| 80 |
+
def test_perfect_score(self, scenario_003):
|
| 81 |
+
state = EpisodeState(
|
| 82 |
+
data_inspected=True,
|
| 83 |
+
fix_action_taken=True,
|
| 84 |
+
restart_after_fix=True,
|
| 85 |
+
diagnosis_submitted=True,
|
| 86 |
+
actions_taken=[
|
| 87 |
+
"inspect_data_batch",
|
| 88 |
+
"patch_data_loader",
|
| 89 |
+
"restart_run",
|
| 90 |
+
"mark_diagnosed:data_leakage",
|
| 91 |
+
],
|
| 92 |
+
)
|
| 93 |
+
score = grade_task_003(state, scenario_003)
|
| 94 |
+
assert score == pytest.approx(1.0)
|
| 95 |
+
|
| 96 |
+
def test_wrong_diagnosis(self, scenario_003):
|
| 97 |
+
state = EpisodeState(
|
| 98 |
+
data_inspected=True,
|
| 99 |
+
diagnosis_submitted=True,
|
| 100 |
+
actions_taken=[
|
| 101 |
+
"inspect_data_batch",
|
| 102 |
+
"mark_diagnosed:overfitting",
|
| 103 |
+
],
|
| 104 |
+
)
|
| 105 |
+
score = grade_task_003(state, scenario_003)
|
| 106 |
+
assert score < 0.5
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class TestGradeTask005:
|
| 110 |
+
def test_perfect_score(self, scenario_005):
|
| 111 |
+
state = EpisodeState(
|
| 112 |
+
gradients_inspected=True,
|
| 113 |
+
gradients_were_normal=True,
|
| 114 |
+
model_modes_inspected=True,
|
| 115 |
+
fix_action_taken=True,
|
| 116 |
+
restart_after_fix=True,
|
| 117 |
+
diagnosis_submitted=True,
|
| 118 |
+
actions_taken=[
|
| 119 |
+
"inspect_gradients",
|
| 120 |
+
"inspect_model_modes",
|
| 121 |
+
"fix_model_mode",
|
| 122 |
+
"restart_run",
|
| 123 |
+
"mark_diagnosed:batchnorm_eval_mode",
|
| 124 |
+
],
|
| 125 |
+
)
|
| 126 |
+
score = grade_task_005(state, scenario_005)
|
| 127 |
+
assert score == 1.0
|
| 128 |
+
|
| 129 |
+
def test_red_herring_chaser(self, scenario_005):
|
| 130 |
+
"""Agent that chases gradient red herring scores 0.80-0.85."""
|
| 131 |
+
state = EpisodeState(
|
| 132 |
+
gradients_inspected=True,
|
| 133 |
+
gradients_were_normal=True,
|
| 134 |
+
model_modes_inspected=True,
|
| 135 |
+
fix_action_taken=True,
|
| 136 |
+
restart_after_fix=True,
|
| 137 |
+
diagnosis_submitted=True,
|
| 138 |
+
actions_taken=[
|
| 139 |
+
"inspect_gradients",
|
| 140 |
+
"add_callback", # Wrong: chases red herring
|
| 141 |
+
"inspect_model_modes",
|
| 142 |
+
"fix_model_mode",
|
| 143 |
+
"restart_run",
|
| 144 |
+
"mark_diagnosed:batchnorm_eval_mode",
|
| 145 |
+
],
|
| 146 |
+
)
|
| 147 |
+
score = grade_task_005(state, scenario_005)
|
| 148 |
+
# -0.20 penalty for add_callback after normal gradients
|
| 149 |
+
assert 0.7 <= score <= 0.90
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class TestGradeEpisode:
|
| 153 |
+
def test_dispatch_to_correct_grader(self, scenario_001):
|
| 154 |
+
state = EpisodeState(
|
| 155 |
+
gradients_inspected=True,
|
| 156 |
+
diagnosis_submitted=True,
|
| 157 |
+
actions_taken=[
|
| 158 |
+
"inspect_gradients",
|
| 159 |
+
"mark_diagnosed:lr_too_high",
|
| 160 |
+
],
|
| 161 |
+
)
|
| 162 |
+
score = grade_episode("task_001", state, scenario_001)
|
| 163 |
+
assert 0.0 <= score <= 1.0
|
| 164 |
+
|
| 165 |
+
def test_unknown_task_returns_zero(self, scenario_001):
|
| 166 |
+
state = EpisodeState()
|
| 167 |
+
score = grade_episode("task_999", state, scenario_001)
|
| 168 |
+
assert score == 0.0
|
tests/test_models.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test all Pydantic models instantiate and serialize correctly."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server.types import Action, Observation
|
| 8 |
+
|
| 9 |
+
from ml_training_debugger.models import (
|
| 10 |
+
EpisodeState,
|
| 11 |
+
GradientStats,
|
| 12 |
+
MLTrainingAction,
|
| 13 |
+
MLTrainingObservation,
|
| 14 |
+
RootCauseDiagnosis,
|
| 15 |
+
TrainingConfig,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TestRootCauseDiagnosis:
|
| 20 |
+
def test_all_six_values_exist(self):
|
| 21 |
+
assert len(RootCauseDiagnosis) == 6
|
| 22 |
+
|
| 23 |
+
def test_values_are_strings(self):
|
| 24 |
+
for d in RootCauseDiagnosis:
|
| 25 |
+
assert isinstance(d.value, str)
|
| 26 |
+
|
| 27 |
+
def test_specific_values(self):
|
| 28 |
+
assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high"
|
| 29 |
+
assert RootCauseDiagnosis.CODE_BUG.value == "code_bug"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TestTrainingConfig:
|
| 33 |
+
def test_default_instantiation(self):
|
| 34 |
+
config = TrainingConfig()
|
| 35 |
+
assert config.learning_rate == 0.001
|
| 36 |
+
assert config.gradient_clip_norm is None
|
| 37 |
+
|
| 38 |
+
def test_json_roundtrip(self):
|
| 39 |
+
config = TrainingConfig(learning_rate=0.01, weight_decay=0.1)
|
| 40 |
+
data = json.loads(config.model_dump_json())
|
| 41 |
+
restored = TrainingConfig.model_validate(data)
|
| 42 |
+
assert restored.learning_rate == 0.01
|
| 43 |
+
assert restored.weight_decay == 0.1
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class TestGradientStats:
|
| 47 |
+
def test_exploding(self):
|
| 48 |
+
stats = GradientStats(
|
| 49 |
+
layer_name="fc",
|
| 50 |
+
norm_history=[15.0],
|
| 51 |
+
mean_norm=15.0,
|
| 52 |
+
max_norm=15.0,
|
| 53 |
+
is_exploding=True,
|
| 54 |
+
is_vanishing=False,
|
| 55 |
+
)
|
| 56 |
+
assert stats.is_exploding
|
| 57 |
+
|
| 58 |
+
def test_vanishing(self):
|
| 59 |
+
stats = GradientStats(
|
| 60 |
+
layer_name="conv1",
|
| 61 |
+
norm_history=[1e-7],
|
| 62 |
+
mean_norm=1e-7,
|
| 63 |
+
max_norm=1e-7,
|
| 64 |
+
is_exploding=False,
|
| 65 |
+
is_vanishing=True,
|
| 66 |
+
)
|
| 67 |
+
assert stats.is_vanishing
|
| 68 |
+
|
| 69 |
+
def test_normal(self):
|
| 70 |
+
stats = GradientStats(
|
| 71 |
+
layer_name="conv1",
|
| 72 |
+
norm_history=[0.5],
|
| 73 |
+
mean_norm=0.5,
|
| 74 |
+
max_norm=0.5,
|
| 75 |
+
is_exploding=False,
|
| 76 |
+
is_vanishing=False,
|
| 77 |
+
)
|
| 78 |
+
assert not stats.is_exploding
|
| 79 |
+
assert not stats.is_vanishing
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestEpisodeState:
|
| 83 |
+
def test_fresh_state(self):
|
| 84 |
+
state = EpisodeState()
|
| 85 |
+
assert state.step_count == 0
|
| 86 |
+
assert not state.gradients_inspected
|
| 87 |
+
assert not state.diagnosis_submitted
|
| 88 |
+
|
| 89 |
+
def test_available_actions_initial(self):
|
| 90 |
+
state = EpisodeState()
|
| 91 |
+
actions = state.compute_available_actions()
|
| 92 |
+
assert "inspect_gradients" in actions
|
| 93 |
+
assert "mark_diagnosed" in actions
|
| 94 |
+
assert "fix_code" not in actions
|
| 95 |
+
assert "restart_run" not in actions
|
| 96 |
+
assert "rollback_checkpoint" not in actions
|
| 97 |
+
|
| 98 |
+
def test_fix_code_available_after_code_inspected(self):
|
| 99 |
+
state = EpisodeState(code_inspected=True)
|
| 100 |
+
actions = state.compute_available_actions()
|
| 101 |
+
assert "fix_code" in actions
|
| 102 |
+
|
| 103 |
+
def test_restart_run_available_after_fix(self):
|
| 104 |
+
state = EpisodeState(fix_action_taken=True)
|
| 105 |
+
actions = state.compute_available_actions()
|
| 106 |
+
assert "restart_run" in actions
|
| 107 |
+
|
| 108 |
+
def test_rollback_available_after_restart(self):
|
| 109 |
+
state = EpisodeState(restart_after_fix=True)
|
| 110 |
+
actions = state.compute_available_actions()
|
| 111 |
+
assert "rollback_checkpoint" in actions
|
| 112 |
+
|
| 113 |
+
def test_mark_diagnosed_disappears_after_submission(self):
|
| 114 |
+
state = EpisodeState(diagnosis_submitted=True)
|
| 115 |
+
actions = state.compute_available_actions()
|
| 116 |
+
assert "mark_diagnosed" not in actions
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class TestMLTrainingObservation:
|
| 120 |
+
def test_extends_observation(self):
|
| 121 |
+
assert issubclass(MLTrainingObservation, Observation)
|
| 122 |
+
|
| 123 |
+
def test_has_done_and_reward(self):
|
| 124 |
+
obs = MLTrainingObservation(done=True, reward=0.5)
|
| 125 |
+
assert obs.done is True
|
| 126 |
+
assert obs.reward == 0.5
|
| 127 |
+
|
| 128 |
+
def test_json_serialization(self):
|
| 129 |
+
obs = MLTrainingObservation(
|
| 130 |
+
run_id="test",
|
| 131 |
+
training_loss_history=[1.0, 2.0],
|
| 132 |
+
val_accuracy_history=[0.5],
|
| 133 |
+
)
|
| 134 |
+
data = json.loads(obs.model_dump_json())
|
| 135 |
+
assert data["run_id"] == "test"
|
| 136 |
+
assert data["framework"] == "pytorch"
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class TestMLTrainingAction:
|
| 140 |
+
def test_extends_action(self):
|
| 141 |
+
assert issubclass(MLTrainingAction, Action)
|
| 142 |
+
|
| 143 |
+
def test_basic_action(self):
|
| 144 |
+
action = MLTrainingAction(action_type="inspect_gradients")
|
| 145 |
+
assert action.action_type == "inspect_gradients"
|
| 146 |
+
|
| 147 |
+
def test_modify_config_action(self):
|
| 148 |
+
action = MLTrainingAction(
|
| 149 |
+
action_type="modify_config",
|
| 150 |
+
target="learning_rate",
|
| 151 |
+
value=0.001,
|
| 152 |
+
)
|
| 153 |
+
assert action.target == "learning_rate"
|
| 154 |
+
|
| 155 |
+
def test_mark_diagnosed_action(self):
|
| 156 |
+
action = MLTrainingAction(
|
| 157 |
+
action_type="mark_diagnosed",
|
| 158 |
+
diagnosis="lr_too_high",
|
| 159 |
+
)
|
| 160 |
+
assert action.diagnosis == "lr_too_high"
|
| 161 |
+
|
| 162 |
+
def test_fix_code_action(self):
|
| 163 |
+
action = MLTrainingAction(
|
| 164 |
+
action_type="fix_code",
|
| 165 |
+
line=13,
|
| 166 |
+
replacement="loss = criterion(output, batch_y)",
|
| 167 |
+
)
|
| 168 |
+
assert action.line == 13
|
tests/test_pytorch_engine.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test real PyTorch model instantiation and fault injection."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
|
| 8 |
+
from ml_training_debugger.pytorch_engine import (
|
| 9 |
+
SimpleCNN,
|
| 10 |
+
create_model_and_inject_fault,
|
| 11 |
+
extract_gradient_stats,
|
| 12 |
+
extract_model_modes,
|
| 13 |
+
extract_weight_stats,
|
| 14 |
+
)
|
| 15 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestSimpleCNN:
|
| 19 |
+
def test_is_nn_module(self):
|
| 20 |
+
model = SimpleCNN()
|
| 21 |
+
assert isinstance(model, nn.Module)
|
| 22 |
+
|
| 23 |
+
def test_param_count(self):
|
| 24 |
+
model = SimpleCNN()
|
| 25 |
+
count = sum(p.numel() for p in model.parameters())
|
| 26 |
+
assert 30_000 < count < 100_000 # ~50K params
|
| 27 |
+
|
| 28 |
+
def test_forward_pass(self):
|
| 29 |
+
model = SimpleCNN()
|
| 30 |
+
x = torch.randn(2, 3, 32, 32)
|
| 31 |
+
out = model(x)
|
| 32 |
+
assert out.shape == (2, 10)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class TestFaultInjection:
|
| 36 |
+
def test_task_001_exploding_gradients(self):
|
| 37 |
+
scenario = sample_scenario("task_001", seed=42)
|
| 38 |
+
model, info = create_model_and_inject_fault(scenario)
|
| 39 |
+
stats = extract_gradient_stats(model, scenario)
|
| 40 |
+
assert len(stats) > 0
|
| 41 |
+
# At least some layers should have elevated gradients
|
| 42 |
+
any_high = any(s.mean_norm > 1.0 for s in stats)
|
| 43 |
+
assert any_high
|
| 44 |
+
|
| 45 |
+
def test_task_005_eval_mode(self):
|
| 46 |
+
scenario = sample_scenario("task_005", seed=42)
|
| 47 |
+
model, info = create_model_and_inject_fault(scenario)
|
| 48 |
+
assert not model.training # model.eval() was called
|
| 49 |
+
|
| 50 |
+
def test_task_005_gradients_not_exploding(self):
|
| 51 |
+
scenario = sample_scenario("task_005", seed=42)
|
| 52 |
+
model, info = create_model_and_inject_fault(scenario)
|
| 53 |
+
stats = extract_gradient_stats(model, scenario)
|
| 54 |
+
# ALL layers must have is_exploding=False
|
| 55 |
+
for s in stats:
|
| 56 |
+
assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding"
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class TestExtractGradientStats:
|
| 60 |
+
def test_returns_gradient_stats(self):
|
| 61 |
+
scenario = sample_scenario("task_001", seed=42)
|
| 62 |
+
model, _ = create_model_and_inject_fault(scenario)
|
| 63 |
+
stats = extract_gradient_stats(model, scenario)
|
| 64 |
+
assert len(stats) == 4 # conv1, conv2, conv3, fc
|
| 65 |
+
for s in stats:
|
| 66 |
+
assert isinstance(s.mean_norm, float)
|
| 67 |
+
assert isinstance(s.norm_history, list)
|
| 68 |
+
assert len(s.norm_history) == 5
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class TestExtractWeightStats:
|
| 72 |
+
def test_returns_weight_stats(self):
|
| 73 |
+
scenario = sample_scenario("task_001", seed=42)
|
| 74 |
+
model, _ = create_model_and_inject_fault(scenario)
|
| 75 |
+
stats = extract_weight_stats(model)
|
| 76 |
+
assert len(stats) > 0
|
| 77 |
+
for s in stats:
|
| 78 |
+
assert isinstance(s.weight_norm, float)
|
| 79 |
+
assert isinstance(s.has_nan, bool)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestExtractModelModes:
|
| 83 |
+
def test_train_mode(self):
|
| 84 |
+
model = SimpleCNN()
|
| 85 |
+
model.train()
|
| 86 |
+
modes = extract_model_modes(model)
|
| 87 |
+
assert all(v == "train" for v in modes.values())
|
| 88 |
+
|
| 89 |
+
def test_eval_mode(self):
|
| 90 |
+
model = SimpleCNN()
|
| 91 |
+
model.eval()
|
| 92 |
+
modes = extract_model_modes(model)
|
| 93 |
+
assert all(v == "eval" for v in modes.values())
|
tests/test_reward_engine.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test reward engine — all 7 components. THE MOST CRITICAL TEST FILE."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.models import EpisodeState, MLTrainingAction
|
| 8 |
+
from ml_training_debugger.reward_engine import (
|
| 9 |
+
CONTEXT_GATED_PENALTY,
|
| 10 |
+
CORRECT_DIAGNOSIS_REWARD,
|
| 11 |
+
INVALID_ACTION_PENALTY,
|
| 12 |
+
INVESTIGATION_BONUS,
|
| 13 |
+
STEP_PENALTY,
|
| 14 |
+
TERMINAL_CONVERGENCE_REWARD,
|
| 15 |
+
WRONG_CODE_FIX_PENALTY,
|
| 16 |
+
WRONG_DIAGNOSIS_PENALTY,
|
| 17 |
+
compute_reward,
|
| 18 |
+
)
|
| 19 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@pytest.fixture
|
| 23 |
+
def scenario():
|
| 24 |
+
return sample_scenario("task_001", seed=42)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@pytest.fixture
|
| 28 |
+
def scenario_005():
|
| 29 |
+
return sample_scenario("task_005", seed=42)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TestStepPenalty:
|
| 33 |
+
def test_flat_step_penalty(self, scenario):
|
| 34 |
+
state = EpisodeState()
|
| 35 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 36 |
+
reward = compute_reward(action, state, scenario)
|
| 37 |
+
assert reward == pytest.approx(STEP_PENALTY)
|
| 38 |
+
|
| 39 |
+
def test_step_penalty_not_multiplied_by_step_count(self, scenario):
|
| 40 |
+
state = EpisodeState(step_count=30)
|
| 41 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 42 |
+
reward = compute_reward(action, state, scenario)
|
| 43 |
+
# Must be flat -0.01, NOT -0.01 * 30
|
| 44 |
+
assert reward == pytest.approx(-0.01)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TestInvestigationBonus:
|
| 48 |
+
def test_first_time_bonus(self, scenario):
|
| 49 |
+
state = EpisodeState(gradients_inspected=False)
|
| 50 |
+
action = MLTrainingAction(action_type="inspect_gradients")
|
| 51 |
+
reward = compute_reward(action, state, scenario)
|
| 52 |
+
assert reward == pytest.approx(STEP_PENALTY + INVESTIGATION_BONUS)
|
| 53 |
+
|
| 54 |
+
def test_no_bonus_on_repeat(self, scenario):
|
| 55 |
+
state = EpisodeState(gradients_inspected=True)
|
| 56 |
+
action = MLTrainingAction(action_type="inspect_gradients")
|
| 57 |
+
reward = compute_reward(action, state, scenario)
|
| 58 |
+
assert reward == pytest.approx(STEP_PENALTY)
|
| 59 |
+
|
| 60 |
+
def test_each_inspection_type_gives_bonus(self, scenario):
|
| 61 |
+
for action_type, field in [
|
| 62 |
+
("inspect_gradients", "gradients_inspected"),
|
| 63 |
+
("inspect_data_batch", "data_inspected"),
|
| 64 |
+
("inspect_model_modes", "model_modes_inspected"),
|
| 65 |
+
("inspect_model_weights", "model_weights_inspected"),
|
| 66 |
+
("inspect_code", "code_inspected"),
|
| 67 |
+
]:
|
| 68 |
+
state = EpisodeState(**{field: False})
|
| 69 |
+
action = MLTrainingAction(action_type=action_type)
|
| 70 |
+
reward = compute_reward(action, state, scenario)
|
| 71 |
+
assert reward == pytest.approx(
|
| 72 |
+
STEP_PENALTY + INVESTIGATION_BONUS
|
| 73 |
+
), f"Failed for {action_type}"
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class TestContextGatedPenalty:
|
| 77 |
+
"""The project's primary innovation — must be exact."""
|
| 78 |
+
|
| 79 |
+
def test_no_penalty_before_inspection(self, scenario_005):
|
| 80 |
+
"""add_callback at step 1 (no prior inspection) -> NO penalty."""
|
| 81 |
+
state = EpisodeState()
|
| 82 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 83 |
+
reward = compute_reward(action, state, scenario_005)
|
| 84 |
+
assert reward == pytest.approx(STEP_PENALTY)
|
| 85 |
+
|
| 86 |
+
def test_penalty_after_normal_gradients(self, scenario_005):
|
| 87 |
+
"""inspect_gradients (normal) then add_callback -> -0.20 penalty."""
|
| 88 |
+
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
|
| 89 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 90 |
+
reward = compute_reward(action, state, scenario_005)
|
| 91 |
+
assert reward == pytest.approx(STEP_PENALTY + CONTEXT_GATED_PENALTY)
|
| 92 |
+
|
| 93 |
+
def test_no_penalty_after_abnormal_gradients(self, scenario):
|
| 94 |
+
"""inspect_gradients (exploding) then add_callback -> no context penalty."""
|
| 95 |
+
state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
|
| 96 |
+
action = MLTrainingAction(action_type="add_callback")
|
| 97 |
+
reward = compute_reward(action, state, scenario)
|
| 98 |
+
assert reward == pytest.approx(STEP_PENALTY)
|
| 99 |
+
|
| 100 |
+
def test_penalty_only_for_add_callback(self, scenario_005):
|
| 101 |
+
"""Other fix actions don't trigger context-gated penalty."""
|
| 102 |
+
state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
|
| 103 |
+
for action_type in ["modify_config", "fix_model_mode", "patch_data_loader"]:
|
| 104 |
+
action = MLTrainingAction(
|
| 105 |
+
action_type=action_type, target="learning_rate", value=0.001
|
| 106 |
+
)
|
| 107 |
+
reward = compute_reward(action, state, scenario_005)
|
| 108 |
+
assert reward == pytest.approx(
|
| 109 |
+
STEP_PENALTY
|
| 110 |
+
), f"Unexpected penalty for {action_type}"
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class TestDiagnosisReward:
|
| 114 |
+
def test_correct_diagnosis(self, scenario):
|
| 115 |
+
state = EpisodeState()
|
| 116 |
+
action = MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high")
|
| 117 |
+
reward = compute_reward(action, state, scenario)
|
| 118 |
+
assert reward == pytest.approx(STEP_PENALTY + CORRECT_DIAGNOSIS_REWARD)
|
| 119 |
+
|
| 120 |
+
def test_wrong_diagnosis(self, scenario):
|
| 121 |
+
state = EpisodeState()
|
| 122 |
+
action = MLTrainingAction(
|
| 123 |
+
action_type="mark_diagnosed", diagnosis="data_leakage"
|
| 124 |
+
)
|
| 125 |
+
reward = compute_reward(action, state, scenario)
|
| 126 |
+
assert reward == pytest.approx(STEP_PENALTY + WRONG_DIAGNOSIS_PENALTY)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class TestTerminalConvergence:
|
| 130 |
+
def test_convergence_after_fix_and_restart(self, scenario):
|
| 131 |
+
state = EpisodeState(fix_action_taken=True)
|
| 132 |
+
action = MLTrainingAction(action_type="restart_run")
|
| 133 |
+
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
|
| 134 |
+
assert reward == pytest.approx(STEP_PENALTY + TERMINAL_CONVERGENCE_REWARD)
|
| 135 |
+
|
| 136 |
+
def test_no_convergence_without_fix(self, scenario):
|
| 137 |
+
state = EpisodeState(fix_action_taken=False)
|
| 138 |
+
action = MLTrainingAction(action_type="restart_run")
|
| 139 |
+
reward = compute_reward(action, state, scenario, convergence_confirmed=True)
|
| 140 |
+
# fix_action_taken is False, so no convergence reward
|
| 141 |
+
assert reward == pytest.approx(STEP_PENALTY)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class TestInvalidAction:
|
| 145 |
+
def test_invalid_action_penalty(self, scenario):
|
| 146 |
+
state = EpisodeState()
|
| 147 |
+
action = MLTrainingAction(action_type="restart_run")
|
| 148 |
+
reward = compute_reward(action, state, scenario, is_valid_action=False)
|
| 149 |
+
assert reward == pytest.approx(STEP_PENALTY + INVALID_ACTION_PENALTY)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class TestWrongCodeFix:
|
| 153 |
+
def test_wrong_code_fix_penalty(self, scenario):
|
| 154 |
+
state = EpisodeState(code_inspected=True)
|
| 155 |
+
action = MLTrainingAction(action_type="fix_code", line=1, replacement="pass")
|
| 156 |
+
reward = compute_reward(action, state, scenario, is_correct_fix=False)
|
| 157 |
+
assert reward == pytest.approx(STEP_PENALTY + WRONG_CODE_FIX_PENALTY)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class TestRewardCap:
|
| 161 |
+
def test_reward_capped_at_one(self, scenario):
|
| 162 |
+
# Theoretical max would exceed 1.0 in some scenarios
|
| 163 |
+
reward = compute_reward(
|
| 164 |
+
MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"),
|
| 165 |
+
EpisodeState(),
|
| 166 |
+
scenario,
|
| 167 |
+
)
|
| 168 |
+
assert reward <= 1.0
|
| 169 |
+
|
| 170 |
+
def test_reward_capped_at_negative_one(self, scenario):
|
| 171 |
+
reward = compute_reward(
|
| 172 |
+
MLTrainingAction(action_type="mark_diagnosed", diagnosis="wrong"),
|
| 173 |
+
EpisodeState(),
|
| 174 |
+
scenario,
|
| 175 |
+
)
|
| 176 |
+
assert reward >= -1.0
|
tests/test_scenarios.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test scenario sampling."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from ml_training_debugger.models import RootCauseDiagnosis
|
| 8 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestSampleScenario:
|
| 12 |
+
def test_task_001_root_cause(self):
|
| 13 |
+
s = sample_scenario("task_001", seed=42)
|
| 14 |
+
assert s.root_cause == RootCauseDiagnosis.LR_TOO_HIGH
|
| 15 |
+
assert s.learning_rate >= 0.05
|
| 16 |
+
|
| 17 |
+
def test_task_003_root_cause(self):
|
| 18 |
+
s = sample_scenario("task_003", seed=42)
|
| 19 |
+
assert s.root_cause == RootCauseDiagnosis.DATA_LEAKAGE
|
| 20 |
+
assert 0.10 <= s.leakage_pct <= 0.30
|
| 21 |
+
|
| 22 |
+
def test_task_005_root_cause(self):
|
| 23 |
+
s = sample_scenario("task_005", seed=42)
|
| 24 |
+
assert s.root_cause == RootCauseDiagnosis.BATCHNORM_EVAL_MODE
|
| 25 |
+
assert 0.8 <= s.red_herring_intensity <= 2.5
|
| 26 |
+
|
| 27 |
+
def test_different_seeds_produce_different_params(self):
|
| 28 |
+
s1 = sample_scenario("task_001", seed=42)
|
| 29 |
+
s2 = sample_scenario("task_001", seed=99)
|
| 30 |
+
# Same root cause, but may have different LR
|
| 31 |
+
assert s1.root_cause == s2.root_cause
|
| 32 |
+
|
| 33 |
+
def test_same_seed_same_params(self):
|
| 34 |
+
s1 = sample_scenario("task_001", seed=42)
|
| 35 |
+
s2 = sample_scenario("task_001", seed=42)
|
| 36 |
+
assert s1.learning_rate == s2.learning_rate
|
| 37 |
+
assert s1.seed == s2.seed
|
| 38 |
+
|
| 39 |
+
def test_unknown_task_raises(self):
|
| 40 |
+
with pytest.raises(ValueError, match="Unknown task_id"):
|
| 41 |
+
sample_scenario("task_999", seed=42)
|
| 42 |
+
|
| 43 |
+
def test_task_005_has_error_log(self):
|
| 44 |
+
s = sample_scenario("task_005", seed=42)
|
| 45 |
+
assert s.error_log is not None
|
| 46 |
+
assert "GPU memory" in s.error_log
|
| 47 |
+
|
| 48 |
+
def test_task_003_has_notes(self):
|
| 49 |
+
s = sample_scenario("task_003", seed=42)
|
| 50 |
+
assert s.notes is not None
|
| 51 |
+
assert "architecture" in s.notes.lower()
|
tests/test_simulation.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Test parametric curve generators."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 6 |
+
from ml_training_debugger.simulation import (
|
| 7 |
+
gen_data_batch_stats,
|
| 8 |
+
gen_loss_history,
|
| 9 |
+
gen_val_accuracy_history,
|
| 10 |
+
gen_val_loss_history,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestGenLossHistory:
|
| 15 |
+
def test_returns_20_floats(self):
|
| 16 |
+
s = sample_scenario("task_001", seed=42)
|
| 17 |
+
hist = gen_loss_history(s)
|
| 18 |
+
assert len(hist) == 20
|
| 19 |
+
assert all(isinstance(v, float) for v in hist)
|
| 20 |
+
|
| 21 |
+
def test_task_001_diverges(self):
|
| 22 |
+
s = sample_scenario("task_001", seed=42)
|
| 23 |
+
hist = gen_loss_history(s)
|
| 24 |
+
assert hist[-1] == float("inf") # NaN/inf after epoch 12
|
| 25 |
+
|
| 26 |
+
def test_task_003_normal(self):
|
| 27 |
+
s = sample_scenario("task_003", seed=42)
|
| 28 |
+
hist = gen_loss_history(s)
|
| 29 |
+
assert hist[0] > hist[-1] # Loss decreases
|
| 30 |
+
|
| 31 |
+
def test_task_005_higher_variance(self):
|
| 32 |
+
s = sample_scenario("task_005", seed=42)
|
| 33 |
+
hist = gen_loss_history(s)
|
| 34 |
+
assert len(hist) == 20
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class TestGenValAccuracy:
|
| 38 |
+
def test_returns_20_floats(self):
|
| 39 |
+
s = sample_scenario("task_001", seed=42)
|
| 40 |
+
hist = gen_val_accuracy_history(s)
|
| 41 |
+
assert len(hist) == 20
|
| 42 |
+
assert all(isinstance(v, float) for v in hist)
|
| 43 |
+
|
| 44 |
+
def test_task_003_suspiciously_high(self):
|
| 45 |
+
s = sample_scenario("task_003", seed=42)
|
| 46 |
+
hist = gen_val_accuracy_history(s)
|
| 47 |
+
assert hist[1] > 0.80 # Suspiciously high from early epochs
|
| 48 |
+
|
| 49 |
+
def test_task_005_degrades(self):
|
| 50 |
+
s = sample_scenario("task_005", seed=42)
|
| 51 |
+
hist = gen_val_accuracy_history(s)
|
| 52 |
+
assert hist[0] > hist[-1] # Degrades over time
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class TestGenValLoss:
|
| 56 |
+
def test_returns_20_floats(self):
|
| 57 |
+
s = sample_scenario("task_001", seed=42)
|
| 58 |
+
hist = gen_val_loss_history(s)
|
| 59 |
+
assert len(hist) == 20
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class TestGenDataBatchStats:
|
| 63 |
+
def test_leakage_high_overlap(self):
|
| 64 |
+
s = sample_scenario("task_003", seed=42)
|
| 65 |
+
stats = gen_data_batch_stats(s)
|
| 66 |
+
assert stats["class_overlap_score"] > 0.5
|
| 67 |
+
assert stats["duplicate_ratio"] > 0.0
|
| 68 |
+
|
| 69 |
+
def test_normal_low_overlap(self):
|
| 70 |
+
s = sample_scenario("task_001", seed=42)
|
| 71 |
+
stats = gen_data_batch_stats(s)
|
| 72 |
+
assert stats["class_overlap_score"] < 0.3
|
tests/test_simulation_extended.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Extended simulation tests for coverage gaps."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from ml_training_debugger.scenarios import sample_scenario
|
| 6 |
+
from ml_training_debugger.simulation import (
|
| 7 |
+
gen_data_batch_stats,
|
| 8 |
+
gen_loss_history,
|
| 9 |
+
gen_val_accuracy_history,
|
| 10 |
+
gen_val_loss_history,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestVanishingGradients:
|
| 15 |
+
def test_loss_barely_decreases(self):
|
| 16 |
+
s = sample_scenario("task_002", seed=42)
|
| 17 |
+
hist = gen_loss_history(s)
|
| 18 |
+
assert len(hist) == 20
|
| 19 |
+
assert abs(hist[0] - hist[-1]) < 0.5
|
| 20 |
+
|
| 21 |
+
def test_val_acc_near_random(self):
|
| 22 |
+
s = sample_scenario("task_002", seed=42)
|
| 23 |
+
hist = gen_val_accuracy_history(s)
|
| 24 |
+
assert all(v < 0.3 for v in hist)
|
| 25 |
+
|
| 26 |
+
def test_val_loss_flat(self):
|
| 27 |
+
s = sample_scenario("task_002", seed=42)
|
| 28 |
+
hist = gen_val_loss_history(s)
|
| 29 |
+
assert len(hist) == 20
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class TestOverfitting:
|
| 33 |
+
def test_loss_decreases_to_near_zero(self):
|
| 34 |
+
s = sample_scenario("task_004", seed=42)
|
| 35 |
+
hist = gen_loss_history(s)
|
| 36 |
+
assert hist[-1] < 0.5
|
| 37 |
+
|
| 38 |
+
def test_val_acc_diverges(self):
|
| 39 |
+
s = sample_scenario("task_004", seed=42)
|
| 40 |
+
hist = gen_val_accuracy_history(s)
|
| 41 |
+
# Should rise then fall
|
| 42 |
+
mid = hist[len(hist) // 2]
|
| 43 |
+
assert mid > hist[-1] or mid > 0.3
|
| 44 |
+
|
| 45 |
+
def test_val_loss_diverges(self):
|
| 46 |
+
s = sample_scenario("task_004", seed=42)
|
| 47 |
+
hist = gen_val_loss_history(s)
|
| 48 |
+
assert len(hist) == 20
|
| 49 |
+
# Overfitting: val loss should increase in the latter half
|
| 50 |
+
mid_val = hist[s.divergence_epoch] if s.divergence_epoch < 20 else hist[10]
|
| 51 |
+
assert mid_val > 0 # Val loss is positive
|
| 52 |
+
|
| 53 |
+
def test_data_batch_stats_clean(self):
|
| 54 |
+
s = sample_scenario("task_004", seed=42)
|
| 55 |
+
stats = gen_data_batch_stats(s)
|
| 56 |
+
assert stats["class_overlap_score"] == 0.0
|
| 57 |
+
assert stats["duplicate_ratio"] == 0.0
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class TestCodeBug:
|
| 61 |
+
def test_loss_history(self):
|
| 62 |
+
s = sample_scenario("task_006", seed=42)
|
| 63 |
+
hist = gen_loss_history(s)
|
| 64 |
+
assert len(hist) == 20
|
| 65 |
+
|
| 66 |
+
def test_val_acc_poor(self):
|
| 67 |
+
s = sample_scenario("task_006", seed=42)
|
| 68 |
+
hist = gen_val_accuracy_history(s)
|
| 69 |
+
assert len(hist) == 20
|
| 70 |
+
|
| 71 |
+
def test_val_loss(self):
|
| 72 |
+
s = sample_scenario("task_006", seed=42)
|
| 73 |
+
hist = gen_val_loss_history(s)
|
| 74 |
+
assert len(hist) == 20
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TestBatchNormEval:
|
| 78 |
+
def test_val_loss_increases(self):
|
| 79 |
+
s = sample_scenario("task_005", seed=42)
|
| 80 |
+
hist = gen_val_loss_history(s)
|
| 81 |
+
assert len(hist) == 20
|