UjjwalPardeshi commited on
Commit ·
eeb6913
1
Parent(s): f4c428c
fix: dashboard, debug logs
Browse files- .coverage +0 -0
- PROJECT_GUIDE.md +691 -0
- server/dashboard.html +166 -15
- server/environment.py +7 -3
- tests/test_episode_lifecycle.py +25 -0
.coverage
CHANGED
|
Binary files a/.coverage and b/.coverage differ
|
|
|
PROJECT_GUIDE.md
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PyTorch Training Run Debugger — Complete Project Guide
|
| 2 |
+
|
| 3 |
+
## What Is This?
|
| 4 |
+
|
| 5 |
+
A game where an AI agent plays detective to fix broken PyTorch training runs. The agent sees a failing training run, investigates clues (gradients, data, code), applies a fix, and submits a diagnosis. Built as an [OpenEnv](https://github.com/openenv) RL environment for the **Meta PyTorch OpenEnv Hackathon**.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## How a Game Works
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
1. Agent receives a broken training run (loss curves, config, error log)
|
| 13 |
+
2. Agent investigates (inspect gradients, data, weights, model modes, code)
|
| 14 |
+
3. Agent applies a fix (reduce LR, patch data, fix code, etc.)
|
| 15 |
+
4. Agent restarts training and confirms recovery
|
| 16 |
+
5. Agent submits diagnosis ("the problem was lr_too_high")
|
| 17 |
+
6. Grader scores the agent 0.0 to 1.0
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## The 7 Tasks
|
| 23 |
+
|
| 24 |
+
| Task | Problem | Difficulty | Root Cause | Key Clue |
|
| 25 |
+
|------|---------|-----------|------------|----------|
|
| 26 |
+
| `task_001` | Gradients explode | Easy | `lr_too_high` | All layers `is_exploding: true` |
|
| 27 |
+
| `task_002` | Gradients vanish | Easy | `vanishing_gradients` | Deep layers `is_vanishing: true` |
|
| 28 |
+
| `task_003` | Test data leaked into training | Medium | `data_leakage` | `class_overlap_score > 0.5` |
|
| 29 |
+
| `task_004` | Model memorizes, doesn't learn | Medium | `overfitting` | Train loss drops, val loss rises |
|
| 30 |
+
| `task_005` | BatchNorm stuck in eval mode | Hard | `batchnorm_eval_mode` | Model modes show "eval" + red herrings |
|
| 31 |
+
| `task_006` | Bug in Python training code | Hard | `code_bug` | Bug visible in code snippet |
|
| 32 |
+
| `task_007` | LR scheduler decays too fast | Medium-Hard | `scheduler_misconfigured` | Early progress then stagnation |
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## Reward System
|
| 37 |
+
|
| 38 |
+
Every action earns or costs points (capped at -1.0 to 1.0):
|
| 39 |
+
|
| 40 |
+
| Event | Reward | When |
|
| 41 |
+
|-------|--------|------|
|
| 42 |
+
| Any step taken | **-0.01** | Always (encourages efficiency) |
|
| 43 |
+
| First-time inspection | **+0.05** | Once per inspection type |
|
| 44 |
+
| Correct diagnosis | **+0.50** | Diagnosis matches root cause |
|
| 45 |
+
| Wrong diagnosis | **-0.30** | Diagnosis doesn't match |
|
| 46 |
+
| Fix works + training recovers | **+0.40** | After fix + restart + convergence |
|
| 47 |
+
| Invalid action | **-0.05** | Action not available |
|
| 48 |
+
| Wrong code fix | **-0.10** | `fix_code` with wrong line/replacement |
|
| 49 |
+
| **Context-gated penalty** | **-0.20** | Inspected gradients, saw they're normal, then added gradient clipping anyway |
|
| 50 |
+
|
| 51 |
+
### The Context-Gated Penalty (Core Innovation)
|
| 52 |
+
|
| 53 |
+
- Agent checks gradients -> finds them **normal** -> adds gradient clipping = **-0.20 penalty** (ignoring evidence)
|
| 54 |
+
- Agent adds gradient clipping **before** checking gradients = **no penalty** (reasonable prior)
|
| 55 |
+
|
| 56 |
+
This teaches: *don't ignore what you've already learned*.
|
| 57 |
+
|
| 58 |
+
---
|
| 59 |
+
|
| 60 |
+
## Architecture
|
| 61 |
+
|
| 62 |
+
```
|
| 63 |
+
ml_training_debugger/ # Core logic
|
| 64 |
+
models.py # All data types (Pydantic)
|
| 65 |
+
scenarios.py # Creates the 7 tasks with random params
|
| 66 |
+
pytorch_engine.py # Real PyTorch model + fault injection
|
| 67 |
+
simulation.py # Loss/accuracy curve generation
|
| 68 |
+
reward_engine.py # Per-step reward calculation
|
| 69 |
+
graders.py # Final 0.0-1.0 scoring per task
|
| 70 |
+
code_templates.py # Buggy code for Task 6
|
| 71 |
+
client.py # Client for connecting to the environment
|
| 72 |
+
|
| 73 |
+
server/ # Web server
|
| 74 |
+
app.py # FastAPI + all endpoints
|
| 75 |
+
environment.py # Game logic (reset, step, state)
|
| 76 |
+
|
| 77 |
+
tests/ # 183 tests, 97% coverage
|
| 78 |
+
baseline_heuristic.py # Rule-based agent (deterministic)
|
| 79 |
+
baseline_inference.py # LLM agent (Llama/GPT-4o)
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## API Endpoints
|
| 85 |
+
|
| 86 |
+
### GET /health
|
| 87 |
+
|
| 88 |
+
Server status check.
|
| 89 |
+
|
| 90 |
+
**Response:**
|
| 91 |
+
```json
|
| 92 |
+
{
|
| 93 |
+
"status": "ready",
|
| 94 |
+
"tasks": 7
|
| 95 |
+
}
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
---
|
| 99 |
+
|
| 100 |
+
### GET /tasks
|
| 101 |
+
|
| 102 |
+
List all available tasks with action schema.
|
| 103 |
+
|
| 104 |
+
**Response:**
|
| 105 |
+
```json
|
| 106 |
+
[
|
| 107 |
+
{
|
| 108 |
+
"id": "task_001",
|
| 109 |
+
"difficulty": "easy",
|
| 110 |
+
"max_steps": 20,
|
| 111 |
+
"action_schema": {
|
| 112 |
+
"title": "MLTrainingAction",
|
| 113 |
+
"type": "object",
|
| 114 |
+
"properties": {
|
| 115 |
+
"action_type": { "type": "string" },
|
| 116 |
+
"target": { "type": ["string", "null"] },
|
| 117 |
+
"value": { "type": ["number", "integer", "string", "null"] },
|
| 118 |
+
"diagnosis": { "type": ["string", "null"] },
|
| 119 |
+
"line": { "type": ["integer", "null"] },
|
| 120 |
+
"replacement": { "type": ["string", "null"] }
|
| 121 |
+
},
|
| 122 |
+
"required": ["action_type"]
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
]
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
---
|
| 129 |
+
|
| 130 |
+
### POST /baseline
|
| 131 |
+
|
| 132 |
+
Run the heuristic baseline agent on all 7 tasks.
|
| 133 |
+
|
| 134 |
+
**Response:**
|
| 135 |
+
```json
|
| 136 |
+
{
|
| 137 |
+
"scores": {
|
| 138 |
+
"task_001": 1.00,
|
| 139 |
+
"task_002": 1.00,
|
| 140 |
+
"task_003": 1.00,
|
| 141 |
+
"task_004": 0.45,
|
| 142 |
+
"task_005": 0.35,
|
| 143 |
+
"task_006": 1.00,
|
| 144 |
+
"task_007": 1.00
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
```
|
| 148 |
+
|
| 149 |
+
Returns `409` if baseline is already running.
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
### POST /grader
|
| 154 |
+
|
| 155 |
+
Get the grader score for the last completed episode.
|
| 156 |
+
|
| 157 |
+
**Query params:** `session_id` (optional)
|
| 158 |
+
|
| 159 |
+
**Response:**
|
| 160 |
+
```json
|
| 161 |
+
{
|
| 162 |
+
"score": 0.85,
|
| 163 |
+
"task_id": "task_001",
|
| 164 |
+
"steps": 5
|
| 165 |
+
}
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
If no episode completed:
|
| 169 |
+
```json
|
| 170 |
+
{
|
| 171 |
+
"score": null,
|
| 172 |
+
"error": "no_completed_episode"
|
| 173 |
+
}
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
---
|
| 177 |
+
|
| 178 |
+
### GET /dashboard
|
| 179 |
+
|
| 180 |
+
Live diagnostic dashboard (HTML page with Plotly.js charts). Open in a browser.
|
| 181 |
+
|
| 182 |
+
**Panels:**
|
| 183 |
+
1. Training metrics (loss/accuracy curves)
|
| 184 |
+
2. Gradient & weight heatmap
|
| 185 |
+
3. Action timeline with rewards
|
| 186 |
+
4. Episode summary with state flags
|
| 187 |
+
|
| 188 |
+
---
|
| 189 |
+
|
| 190 |
+
### GET /validation-report
|
| 191 |
+
|
| 192 |
+
Pre-computed fidelity report comparing parametric curves to real PyTorch training runs.
|
| 193 |
+
|
| 194 |
+
---
|
| 195 |
+
|
| 196 |
+
### GET /curriculum
|
| 197 |
+
|
| 198 |
+
Recommended task order for progressive training (easy to hard, 3 difficulty levels each).
|
| 199 |
+
|
| 200 |
+
**Response:**
|
| 201 |
+
```json
|
| 202 |
+
{
|
| 203 |
+
"curriculum": [
|
| 204 |
+
{ "task_id": "task_001", "difficulty": "easy", "difficulty_level": 1, "max_steps": 20 },
|
| 205 |
+
{ "task_id": "task_001", "difficulty": "easy", "difficulty_level": 3, "max_steps": 20 },
|
| 206 |
+
{ "task_id": "task_001", "difficulty": "easy", "difficulty_level": 5, "max_steps": 20 }
|
| 207 |
+
],
|
| 208 |
+
"total_episodes": 21
|
| 209 |
+
}
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
---
|
| 213 |
+
|
| 214 |
+
### GET /leaderboard
|
| 215 |
+
|
| 216 |
+
Sorted episode scores from baseline runs.
|
| 217 |
+
|
| 218 |
+
**Response:**
|
| 219 |
+
```json
|
| 220 |
+
{
|
| 221 |
+
"entries": [
|
| 222 |
+
{ "score": 1.00, "task_id": "task_001", "steps": 5, "episode_id": "baseline_task_001" }
|
| 223 |
+
],
|
| 224 |
+
"total": 7
|
| 225 |
+
}
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
---
|
| 229 |
+
|
| 230 |
+
### GET /replay/{episode_id}
|
| 231 |
+
|
| 232 |
+
Full action/observation trace for a completed episode.
|
| 233 |
+
|
| 234 |
+
**Response:**
|
| 235 |
+
```json
|
| 236 |
+
{
|
| 237 |
+
"episode_id": "baseline_task_001",
|
| 238 |
+
"score": 1.00,
|
| 239 |
+
"task_id": "task_001",
|
| 240 |
+
"steps": 5
|
| 241 |
+
}
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
---
|
| 245 |
+
|
| 246 |
+
## WebSocket Interface (Primary Agent Interface)
|
| 247 |
+
|
| 248 |
+
**Endpoint:** `ws://localhost:7860/ws`
|
| 249 |
+
|
| 250 |
+
This is the main way agents interact with the environment. HTTP endpoints are stateless — WebSocket maintains session state across a full episode.
|
| 251 |
+
|
| 252 |
+
### Reset (Start New Episode)
|
| 253 |
+
|
| 254 |
+
**Send:**
|
| 255 |
+
```json
|
| 256 |
+
{
|
| 257 |
+
"type": "reset",
|
| 258 |
+
"seed": 42,
|
| 259 |
+
"kwargs": {
|
| 260 |
+
"task_id": "task_003",
|
| 261 |
+
"difficulty_level": 3
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
Without `kwargs`, defaults to `task_001`.
|
| 267 |
+
|
| 268 |
+
**Receive:**
|
| 269 |
+
```json
|
| 270 |
+
{
|
| 271 |
+
"type": "observation",
|
| 272 |
+
"observation": {
|
| 273 |
+
"run_id": "ep_12345",
|
| 274 |
+
"framework": "pytorch",
|
| 275 |
+
"epoch": 20,
|
| 276 |
+
"training_loss_history": [2.3, 2.1, 1.9, ...],
|
| 277 |
+
"val_loss_history": [2.4, 2.2, 2.0, ...],
|
| 278 |
+
"val_accuracy_history": [0.3, 0.35, 0.4, ...],
|
| 279 |
+
"gradient_stats": [],
|
| 280 |
+
"model_weight_stats": null,
|
| 281 |
+
"data_batch_stats": null,
|
| 282 |
+
"model_mode_info": null,
|
| 283 |
+
"code_snippet": null,
|
| 284 |
+
"current_config": {
|
| 285 |
+
"learning_rate": 0.001,
|
| 286 |
+
"weight_decay": 0.0001,
|
| 287 |
+
"batch_size": 64,
|
| 288 |
+
"hidden_dim": 64,
|
| 289 |
+
"num_layers": 3,
|
| 290 |
+
"optimizer": "adam",
|
| 291 |
+
"dropout_rate": 0.0,
|
| 292 |
+
"gradient_clip_norm": null
|
| 293 |
+
},
|
| 294 |
+
"error_log": null,
|
| 295 |
+
"gpu_memory_used_gb": 6.2,
|
| 296 |
+
"gpu_memory_total_gb": 16.0,
|
| 297 |
+
"available_actions": [
|
| 298 |
+
"inspect_gradients",
|
| 299 |
+
"inspect_data_batch",
|
| 300 |
+
"inspect_model_modes",
|
| 301 |
+
"inspect_model_weights",
|
| 302 |
+
"inspect_code",
|
| 303 |
+
"modify_config",
|
| 304 |
+
"add_callback",
|
| 305 |
+
"replace_optimizer",
|
| 306 |
+
"patch_data_loader",
|
| 307 |
+
"fix_model_mode",
|
| 308 |
+
"mark_diagnosed"
|
| 309 |
+
],
|
| 310 |
+
"episode_state": {
|
| 311 |
+
"step_count": 0,
|
| 312 |
+
"gradients_inspected": false,
|
| 313 |
+
"gradients_were_normal": false,
|
| 314 |
+
"data_inspected": false,
|
| 315 |
+
"model_modes_inspected": false,
|
| 316 |
+
"model_weights_inspected": false,
|
| 317 |
+
"code_inspected": false,
|
| 318 |
+
"fix_action_taken": false,
|
| 319 |
+
"restart_after_fix": false,
|
| 320 |
+
"diagnosis_submitted": false,
|
| 321 |
+
"actions_taken": []
|
| 322 |
+
},
|
| 323 |
+
"notes": null,
|
| 324 |
+
"done": false,
|
| 325 |
+
"reward": null,
|
| 326 |
+
"metadata": {}
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
### Step (Take an Action)
|
| 332 |
+
|
| 333 |
+
**Investigation actions** (no extra fields needed):
|
| 334 |
+
```json
|
| 335 |
+
{"type": "step", "action": {"action_type": "inspect_gradients"}}
|
| 336 |
+
{"type": "step", "action": {"action_type": "inspect_data_batch"}}
|
| 337 |
+
{"type": "step", "action": {"action_type": "inspect_model_modes"}}
|
| 338 |
+
{"type": "step", "action": {"action_type": "inspect_model_weights"}}
|
| 339 |
+
{"type": "step", "action": {"action_type": "inspect_code"}}
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
**Fix actions:**
|
| 343 |
+
```json
|
| 344 |
+
{"type": "step", "action": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
|
| 345 |
+
{"type": "step", "action": {"action_type": "add_callback"}}
|
| 346 |
+
{"type": "step", "action": {"action_type": "replace_optimizer"}}
|
| 347 |
+
{"type": "step", "action": {"action_type": "patch_data_loader"}}
|
| 348 |
+
{"type": "step", "action": {"action_type": "fix_model_mode"}}
|
| 349 |
+
{"type": "step", "action": {"action_type": "fix_code", "line": 5, "replacement": "model.train()"}}
|
| 350 |
+
```
|
| 351 |
+
|
| 352 |
+
**Terminal actions:**
|
| 353 |
+
```json
|
| 354 |
+
{"type": "step", "action": {"action_type": "restart_run"}}
|
| 355 |
+
{"type": "step", "action": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
**Receive (after each step):**
|
| 359 |
+
```json
|
| 360 |
+
{
|
| 361 |
+
"type": "observation",
|
| 362 |
+
"observation": {
|
| 363 |
+
"...same structure as reset response...",
|
| 364 |
+
"gradient_stats": [
|
| 365 |
+
{
|
| 366 |
+
"layer_name": "conv1",
|
| 367 |
+
"norm_history": [0.5, 0.6, 0.7],
|
| 368 |
+
"mean_norm": 51.1,
|
| 369 |
+
"max_norm": 98.3,
|
| 370 |
+
"is_exploding": true,
|
| 371 |
+
"is_vanishing": false
|
| 372 |
+
}
|
| 373 |
+
],
|
| 374 |
+
"episode_state": {
|
| 375 |
+
"step_count": 1,
|
| 376 |
+
"gradients_inspected": true,
|
| 377 |
+
"actions_taken": ["inspect_gradients"]
|
| 378 |
+
},
|
| 379 |
+
"done": false,
|
| 380 |
+
"reward": 0.04
|
| 381 |
+
}
|
| 382 |
+
}
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
When `done: true`, the episode is over.
|
| 386 |
+
|
| 387 |
+
---
|
| 388 |
+
|
| 389 |
+
## All 14 Action Types
|
| 390 |
+
|
| 391 |
+
| Action | Required Fields | Description |
|
| 392 |
+
|--------|----------------|-------------|
|
| 393 |
+
| `inspect_gradients` | none | View per-layer gradient stats |
|
| 394 |
+
| `inspect_data_batch` | none | View data batch statistics |
|
| 395 |
+
| `inspect_model_modes` | none | View train/eval mode per layer |
|
| 396 |
+
| `inspect_model_weights` | none | View per-layer weight stats |
|
| 397 |
+
| `inspect_code` | none | View source code (Task 6) |
|
| 398 |
+
| `modify_config` | `target`, `value` | Change a hyperparameter |
|
| 399 |
+
| `add_callback` | none | Add gradient clipping callback |
|
| 400 |
+
| `replace_optimizer` | none | Switch optimizer |
|
| 401 |
+
| `patch_data_loader` | none | Fix data pipeline |
|
| 402 |
+
| `fix_model_mode` | none | Switch model to train mode |
|
| 403 |
+
| `fix_code` | `line`, `replacement` | Fix a line of code |
|
| 404 |
+
| `restart_run` | none | Restart training (requires fix first) |
|
| 405 |
+
| `mark_diagnosed` | `diagnosis` | Submit final diagnosis |
|
| 406 |
+
| `rollback_checkpoint` | none | Rollback to checkpoint |
|
| 407 |
+
|
| 408 |
+
### Valid `target` values for modify_config
|
| 409 |
+
`learning_rate`, `weight_decay`, `batch_size`, `hidden_dim`, `num_layers`, `optimizer`, `dropout_rate`, `gradient_clip_norm`
|
| 410 |
+
|
| 411 |
+
### Valid `diagnosis` values for mark_diagnosed
|
| 412 |
+
`lr_too_high`, `vanishing_gradients`, `data_leakage`, `overfitting`, `batchnorm_eval_mode`, `code_bug`, `scheduler_misconfigured`
|
| 413 |
+
|
| 414 |
+
---
|
| 415 |
+
|
| 416 |
+
## Dynamic Action Availability
|
| 417 |
+
|
| 418 |
+
Actions appear/disappear based on episode state:
|
| 419 |
+
|
| 420 |
+
| Action | Available When |
|
| 421 |
+
|--------|---------------|
|
| 422 |
+
| `fix_code` | Only after `inspect_code` (code_inspected = true) |
|
| 423 |
+
| `restart_run` | Only after a fix action (fix_action_taken = true) |
|
| 424 |
+
| `rollback_checkpoint` | Only after restart (restart_after_fix = true) |
|
| 425 |
+
| `mark_diagnosed` | Only while diagnosis_submitted = false |
|
| 426 |
+
|
| 427 |
+
---
|
| 428 |
+
|
| 429 |
+
## Observation Fields — Progressive Reveal
|
| 430 |
+
|
| 431 |
+
On reset, the agent sees loss curves, config, and error log. Everything else is `null` until inspected:
|
| 432 |
+
|
| 433 |
+
| Field | Starts As | Populated After |
|
| 434 |
+
|-------|-----------|----------------|
|
| 435 |
+
| `training_loss_history` | 20 floats | Always visible |
|
| 436 |
+
| `val_accuracy_history` | 20 floats | Always visible |
|
| 437 |
+
| `val_loss_history` | 20 floats | Always visible |
|
| 438 |
+
| `current_config` | Full config | Always visible |
|
| 439 |
+
| `error_log` | String or null | Always visible |
|
| 440 |
+
| `gradient_stats` | `[]` | `inspect_gradients` |
|
| 441 |
+
| `model_weight_stats` | `null` | `inspect_model_weights` |
|
| 442 |
+
| `data_batch_stats` | `null` | `inspect_data_batch` |
|
| 443 |
+
| `model_mode_info` | `null` | `inspect_model_modes` |
|
| 444 |
+
| `code_snippet` | `null` | `inspect_code` |
|
| 445 |
+
|
| 446 |
+
---
|
| 447 |
+
|
| 448 |
+
## Data Types
|
| 449 |
+
|
| 450 |
+
### GradientStats (per layer)
|
| 451 |
+
```json
|
| 452 |
+
{
|
| 453 |
+
"layer_name": "conv1",
|
| 454 |
+
"norm_history": [0.5, 0.6, 0.7],
|
| 455 |
+
"mean_norm": 12.5,
|
| 456 |
+
"max_norm": 25.3,
|
| 457 |
+
"is_exploding": true,
|
| 458 |
+
"is_vanishing": false
|
| 459 |
+
}
|
| 460 |
+
```
|
| 461 |
+
- Exploding: `mean_norm > 10.0`
|
| 462 |
+
- Vanishing: `mean_norm < 0.000001`
|
| 463 |
+
|
| 464 |
+
### ModelWeightStats (per layer)
|
| 465 |
+
```json
|
| 466 |
+
{
|
| 467 |
+
"layer_name": "conv1",
|
| 468 |
+
"weight_norm": 1.234,
|
| 469 |
+
"weight_mean": 0.001,
|
| 470 |
+
"weight_std": 0.05,
|
| 471 |
+
"weight_min": -0.15,
|
| 472 |
+
"weight_max": 0.16,
|
| 473 |
+
"dead_neuron_pct": 0.0,
|
| 474 |
+
"has_nan": false,
|
| 475 |
+
"has_inf": false
|
| 476 |
+
}
|
| 477 |
+
```
|
| 478 |
+
|
| 479 |
+
### DataBatchStats
|
| 480 |
+
```json
|
| 481 |
+
{
|
| 482 |
+
"label_distribution": {"0": 0.25, "1": 0.25, "2": 0.25, "3": 0.25},
|
| 483 |
+
"feature_mean": 0.5,
|
| 484 |
+
"feature_std": 0.2,
|
| 485 |
+
"null_count": 0,
|
| 486 |
+
"class_overlap_score": 0.15,
|
| 487 |
+
"batch_size": 64,
|
| 488 |
+
"duplicate_ratio": 0.0,
|
| 489 |
+
"confusion_matrix": [[10, 2, 1], [1, 9, 3], [2, 1, 11]]
|
| 490 |
+
}
|
| 491 |
+
```
|
| 492 |
+
|
| 493 |
+
### CodeSnippet (Task 6 only)
|
| 494 |
+
```json
|
| 495 |
+
{
|
| 496 |
+
"code": "import torch\nimport torch.nn as nn\n...",
|
| 497 |
+
"filename": "train.py",
|
| 498 |
+
"line_count": 50,
|
| 499 |
+
"imports": ["torch", "torch.nn", "torch.optim"],
|
| 500 |
+
"hint": "Look for .detach() preventing gradient flow"
|
| 501 |
+
}
|
| 502 |
+
```
|
| 503 |
+
|
| 504 |
+
### EpisodeState
|
| 505 |
+
```json
|
| 506 |
+
{
|
| 507 |
+
"step_count": 0,
|
| 508 |
+
"gradients_inspected": false,
|
| 509 |
+
"gradients_were_normal": false,
|
| 510 |
+
"data_inspected": false,
|
| 511 |
+
"model_modes_inspected": false,
|
| 512 |
+
"model_weights_inspected": false,
|
| 513 |
+
"code_inspected": false,
|
| 514 |
+
"fix_action_taken": false,
|
| 515 |
+
"restart_after_fix": false,
|
| 516 |
+
"diagnosis_submitted": false,
|
| 517 |
+
"actions_taken": []
|
| 518 |
+
}
|
| 519 |
+
```
|
| 520 |
+
|
| 521 |
+
---
|
| 522 |
+
|
| 523 |
+
## Grading Breakdown (per task)
|
| 524 |
+
|
| 525 |
+
Each task has its own grader that scores 0.0 to 1.0 based on what the agent did:
|
| 526 |
+
|
| 527 |
+
### Task 1 — Exploding Gradients
|
| 528 |
+
| Component | Points |
|
| 529 |
+
|-----------|--------|
|
| 530 |
+
| Inspected gradients | +0.05 |
|
| 531 |
+
| Applied config fix | +0.20 |
|
| 532 |
+
| Restarted training | +0.35 |
|
| 533 |
+
| Correct diagnosis (`lr_too_high`) | +0.40 |
|
| 534 |
+
|
| 535 |
+
### Task 2 — Vanishing Gradients
|
| 536 |
+
| Component | Points |
|
| 537 |
+
|-----------|--------|
|
| 538 |
+
| Inspected gradients | +0.05 |
|
| 539 |
+
| Applied config fix | +0.20 |
|
| 540 |
+
| Restarted training | +0.35 |
|
| 541 |
+
| Correct diagnosis (`vanishing_gradients`) | +0.40 |
|
| 542 |
+
|
| 543 |
+
### Task 3 — Data Leakage
|
| 544 |
+
| Component | Points |
|
| 545 |
+
|-----------|--------|
|
| 546 |
+
| Inspected data | +0.05 |
|
| 547 |
+
| Patched data loader | +0.30 |
|
| 548 |
+
| Restarted training | +0.30 |
|
| 549 |
+
| Correct diagnosis (`data_leakage`) | +0.35 |
|
| 550 |
+
|
| 551 |
+
### Task 4 — Overfitting
|
| 552 |
+
| Component | Points |
|
| 553 |
+
|-----------|--------|
|
| 554 |
+
| Inspected data | +0.05 |
|
| 555 |
+
| Applied fix (config or callback) | +0.25 |
|
| 556 |
+
| Restarted training | +0.30 |
|
| 557 |
+
| Correct diagnosis (`overfitting`) | +0.40 |
|
| 558 |
+
|
| 559 |
+
### Task 5 — BatchNorm Eval Mode (with red herrings)
|
| 560 |
+
| Component | Points |
|
| 561 |
+
|-----------|--------|
|
| 562 |
+
| Inspected gradients | +0.05 |
|
| 563 |
+
| Inspected model modes | +0.05 |
|
| 564 |
+
| **Fell for red herring** (add_callback after normal gradients) | **-0.20** |
|
| 565 |
+
| Fixed model mode | +0.25 |
|
| 566 |
+
| Restarted training | +0.30 |
|
| 567 |
+
| Correct diagnosis (`batchnorm_eval_mode`) | +0.40 |
|
| 568 |
+
|
| 569 |
+
### Task 6 — Code Bug
|
| 570 |
+
| Component | Points |
|
| 571 |
+
|-----------|--------|
|
| 572 |
+
| Inspected code | +0.05 |
|
| 573 |
+
| Fixed code correctly | +0.30 |
|
| 574 |
+
| Restarted training | +0.25 |
|
| 575 |
+
| Correct diagnosis (`code_bug`) | +0.40 |
|
| 576 |
+
|
| 577 |
+
### Task 7 — Scheduler Misconfigured
|
| 578 |
+
| Component | Points |
|
| 579 |
+
|-----------|--------|
|
| 580 |
+
| Inspected gradients | +0.05 |
|
| 581 |
+
| Inspected data | +0.05 |
|
| 582 |
+
| Applied config fix | +0.25 |
|
| 583 |
+
| Restarted training | +0.25 |
|
| 584 |
+
| Correct diagnosis (`scheduler_misconfigured`) | +0.40 |
|
| 585 |
+
|
| 586 |
+
---
|
| 587 |
+
|
| 588 |
+
## Baseline Scores
|
| 589 |
+
|
| 590 |
+
| Task | Heuristic | Llama 3.3 70B | Llama 3.1 8B |
|
| 591 |
+
|------|-----------|---------------|--------------|
|
| 592 |
+
| task_001 | **1.00** | 1.00 | 0.60 |
|
| 593 |
+
| task_002 | **1.00** | 1.00 | 0.05 |
|
| 594 |
+
| task_003 | **1.00** | 0.40 | 0.40 |
|
| 595 |
+
| task_004 | 0.45 | 0.45 | **0.60** |
|
| 596 |
+
| task_005 | **1.00** | 1.00 | 1.00 |
|
| 597 |
+
| task_006 | **1.00** | — | 0.60-1.00 |
|
| 598 |
+
| task_007 | **1.00** | — | 0.60 |
|
| 599 |
+
| **Average** | **0.92** | ~0.69 | 0.55 |
|
| 600 |
+
|
| 601 |
+
---
|
| 602 |
+
|
| 603 |
+
## Walkthrough: Solving Task 1 (Exploding Gradients)
|
| 604 |
+
|
| 605 |
+
```
|
| 606 |
+
Step 1: Reset
|
| 607 |
+
Send: {"type": "reset", "kwargs": {"task_id": "task_001"}}
|
| 608 |
+
See: Loss history going to infinity, error_log says "NaN at epoch 12"
|
| 609 |
+
|
| 610 |
+
Step 2: Inspect gradients
|
| 611 |
+
Send: {"type": "step", "action": {"action_type": "inspect_gradients"}}
|
| 612 |
+
See: All layers is_exploding: true, mean_norm > 10.0
|
| 613 |
+
Reward: +0.04 (-0.01 step + 0.05 investigation)
|
| 614 |
+
|
| 615 |
+
Step 3: Reduce learning rate
|
| 616 |
+
Send: {"type": "step", "action": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
|
| 617 |
+
Reward: -0.01 (step penalty)
|
| 618 |
+
|
| 619 |
+
Step 4: Restart training
|
| 620 |
+
Send: {"type": "step", "action": {"action_type": "restart_run"}}
|
| 621 |
+
See: Convergence detected!
|
| 622 |
+
Reward: +0.39 (-0.01 step + 0.40 convergence)
|
| 623 |
+
|
| 624 |
+
Step 5: Submit diagnosis
|
| 625 |
+
Send: {"type": "step", "action": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
|
| 626 |
+
See: done: true
|
| 627 |
+
Reward: +0.49 (-0.01 step + 0.50 correct diagnosis)
|
| 628 |
+
|
| 629 |
+
Grader score: 1.0 (perfect)
|
| 630 |
+
```
|
| 631 |
+
|
| 632 |
+
---
|
| 633 |
+
|
| 634 |
+
## Walkthrough: Task 5 Trap (Red Herring)
|
| 635 |
+
|
| 636 |
+
```
|
| 637 |
+
Step 1: Reset task_005
|
| 638 |
+
Step 2: Inspect gradients
|
| 639 |
+
-> FC layer has a spike (mean_norm=4.2, but is_exploding: false)
|
| 640 |
+
-> gradients_were_normal is set to TRUE (nothing actually exploding)
|
| 641 |
+
|
| 642 |
+
Step 3 (BAD): Add gradient clipping
|
| 643 |
+
-> Reward: -0.21 (-0.01 step - 0.20 context-gated penalty!)
|
| 644 |
+
-> Agent IGNORED the evidence that gradients were normal
|
| 645 |
+
|
| 646 |
+
Step 3 (GOOD): Inspect model modes instead
|
| 647 |
+
-> Sees all layers in "eval" mode — that's the real problem!
|
| 648 |
+
|
| 649 |
+
Step 4: Fix model mode
|
| 650 |
+
Step 5: Restart training
|
| 651 |
+
Step 6: Diagnose batchnorm_eval_mode -> correct!
|
| 652 |
+
```
|
| 653 |
+
|
| 654 |
+
---
|
| 655 |
+
|
| 656 |
+
## Quick Start
|
| 657 |
+
|
| 658 |
+
```bash
|
| 659 |
+
# Setup
|
| 660 |
+
python3 -m venv .venv && source .venv/bin/activate
|
| 661 |
+
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
| 662 |
+
pip install -r requirements.txt
|
| 663 |
+
pip install pytest pytest-cov
|
| 664 |
+
|
| 665 |
+
# Run server
|
| 666 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 667 |
+
|
| 668 |
+
# Test
|
| 669 |
+
pytest tests/ -v --cov=ml_training_debugger
|
| 670 |
+
curl http://localhost:7860/health
|
| 671 |
+
curl http://localhost:7860/tasks | python3 -m json.tool
|
| 672 |
+
curl -X POST http://localhost:7860/baseline | python3 -m json.tool
|
| 673 |
+
|
| 674 |
+
# Docker
|
| 675 |
+
docker build -t pytorch-debugger .
|
| 676 |
+
docker run -p 7860:7860 pytorch-debugger
|
| 677 |
+
```
|
| 678 |
+
|
| 679 |
+
---
|
| 680 |
+
|
| 681 |
+
## Tech Stack
|
| 682 |
+
|
| 683 |
+
| Component | Purpose |
|
| 684 |
+
|-----------|---------|
|
| 685 |
+
| Python 3.12 | Runtime |
|
| 686 |
+
| PyTorch (CPU-only) | Real neural networks, real gradients |
|
| 687 |
+
| FastAPI | Web server |
|
| 688 |
+
| OpenEnv | RL environment framework (step/reset/state API) |
|
| 689 |
+
| Pydantic v2 | Typed data models |
|
| 690 |
+
| Plotly.js | Dashboard charts |
|
| 691 |
+
| Docker | Containerized deployment |
|
server/dashboard.html
CHANGED
|
@@ -17,6 +17,7 @@ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; b
|
|
| 17 |
.panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
|
| 18 |
.panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
|
| 19 |
.panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
|
|
|
|
| 20 |
.placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
|
| 21 |
#controls { display: flex; gap: 8px; align-items: center; }
|
| 22 |
#controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
|
|
@@ -47,6 +48,7 @@ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; b
|
|
| 47 |
<option value="task_004">Task 4 — Overfitting (Medium)</option>
|
| 48 |
<option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
|
| 49 |
<option value="task_006">Task 6 — Code Bug (Hard)</option>
|
|
|
|
| 50 |
</select>
|
| 51 |
<button class="primary" onclick="runBaseline()">Run Baseline</button>
|
| 52 |
</div>
|
|
@@ -220,26 +222,175 @@ function updateSummary(d) {
|
|
| 220 |
document.getElementById('summary').innerHTML = html;
|
| 221 |
}
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
async function runBaseline() {
|
| 224 |
const taskId = document.getElementById('taskSelect').value;
|
| 225 |
actions = []; rewards = []; cumRewards = [];
|
| 226 |
-
if (ws
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 242 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
}
|
| 244 |
|
| 245 |
connect();
|
|
|
|
| 17 |
.panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
|
| 18 |
.panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
|
| 19 |
.panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
|
| 20 |
+
.panel-body > div:first-child { width: 100%; height: 100%; }
|
| 21 |
.placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
|
| 22 |
#controls { display: flex; gap: 8px; align-items: center; }
|
| 23 |
#controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
|
|
|
|
| 48 |
<option value="task_004">Task 4 — Overfitting (Medium)</option>
|
| 49 |
<option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
|
| 50 |
<option value="task_006">Task 6 — Code Bug (Hard)</option>
|
| 51 |
+
<option value="task_007">Task 7 — Scheduler Misconfigured (Med-Hard)</option>
|
| 52 |
</select>
|
| 53 |
<button class="primary" onclick="runBaseline()">Run Baseline</button>
|
| 54 |
</div>
|
|
|
|
| 222 |
document.getElementById('summary').innerHTML = html;
|
| 223 |
}
|
| 224 |
|
| 225 |
+
function sendStep(action) {
|
| 226 |
+
return new Promise(resolve => {
|
| 227 |
+
const handler = (ev) => {
|
| 228 |
+
const msg = JSON.parse(ev.data);
|
| 229 |
+
if (msg.type === 'observation') {
|
| 230 |
+
ws.removeEventListener('message', handler);
|
| 231 |
+
resolve(msg);
|
| 232 |
+
}
|
| 233 |
+
};
|
| 234 |
+
ws.addEventListener('message', handler);
|
| 235 |
+
ws.send(JSON.stringify({ type: 'step', data: action }));
|
| 236 |
+
});
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
function sendReset(taskId) {
|
| 240 |
+
return new Promise(resolve => {
|
| 241 |
+
const handler = (ev) => {
|
| 242 |
+
const msg = JSON.parse(ev.data);
|
| 243 |
+
if (msg.type === 'observation') {
|
| 244 |
+
ws.removeEventListener('message', handler);
|
| 245 |
+
resolve(msg);
|
| 246 |
+
}
|
| 247 |
+
};
|
| 248 |
+
ws.addEventListener('message', handler);
|
| 249 |
+
ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } }));
|
| 250 |
+
});
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
async function runBaseline() {
|
| 254 |
const taskId = document.getElementById('taskSelect').value;
|
| 255 |
actions = []; rewards = []; cumRewards = [];
|
| 256 |
+
if (!ws || ws.readyState !== WebSocket.OPEN) return;
|
| 257 |
+
|
| 258 |
+
const delay = (ms) => new Promise(r => setTimeout(r, ms));
|
| 259 |
+
|
| 260 |
+
// Reset
|
| 261 |
+
await sendReset(taskId);
|
| 262 |
+
await delay(300);
|
| 263 |
+
|
| 264 |
+
// Step 1: Inspect gradients
|
| 265 |
+
await sendStep({ action_type: 'inspect_gradients' });
|
| 266 |
+
await delay(300);
|
| 267 |
+
|
| 268 |
+
const gs = obs && obs.gradient_stats ? obs.gradient_stats : [];
|
| 269 |
+
const anyExploding = gs.some(g => g.is_exploding);
|
| 270 |
+
const anyVanishing = gs.some(g => g.is_vanishing);
|
| 271 |
+
|
| 272 |
+
if (anyExploding) {
|
| 273 |
+
await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.001 });
|
| 274 |
+
await delay(300);
|
| 275 |
+
await sendStep({ action_type: 'restart_run' });
|
| 276 |
+
await delay(300);
|
| 277 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'lr_too_high' });
|
| 278 |
+
return;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
if (anyVanishing) {
|
| 282 |
+
await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.01 });
|
| 283 |
+
await delay(300);
|
| 284 |
+
await sendStep({ action_type: 'restart_run' });
|
| 285 |
+
await delay(300);
|
| 286 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'vanishing_gradients' });
|
| 287 |
+
return;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
// Step 2: Inspect data
|
| 291 |
+
await sendStep({ action_type: 'inspect_data_batch' });
|
| 292 |
+
await delay(300);
|
| 293 |
+
|
| 294 |
+
const dbs = obs && obs.data_batch_stats ? obs.data_batch_stats : {};
|
| 295 |
+
if (dbs.class_overlap_score && dbs.class_overlap_score > 0.5) {
|
| 296 |
+
await sendStep({ action_type: 'patch_data_loader' });
|
| 297 |
+
await delay(300);
|
| 298 |
+
await sendStep({ action_type: 'restart_run' });
|
| 299 |
+
await delay(300);
|
| 300 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'data_leakage' });
|
| 301 |
+
return;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
// Check for overfitting (train loss low, val loss rising)
|
| 305 |
+
const tl = obs && obs.training_loss_history ? obs.training_loss_history : [];
|
| 306 |
+
const vl = obs && obs.val_loss_history ? obs.val_loss_history : [];
|
| 307 |
+
const lastTrainLoss = tl.length > 0 ? tl[tl.length - 1] : 999;
|
| 308 |
+
const lastValLoss = vl.length > 0 ? vl[vl.length - 1] : 0;
|
| 309 |
+
const earlyValLoss = vl.length > 5 ? vl[5] : lastValLoss;
|
| 310 |
+
const isOverfitting = lastTrainLoss < 0.1 && lastValLoss > earlyValLoss;
|
| 311 |
+
|
| 312 |
+
if (isOverfitting) {
|
| 313 |
+
await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 });
|
| 314 |
+
await delay(300);
|
| 315 |
+
await sendStep({ action_type: 'restart_run' });
|
| 316 |
+
await delay(300);
|
| 317 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' });
|
| 318 |
+
return;
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
// Step 3: Inspect model modes
|
| 322 |
+
await sendStep({ action_type: 'inspect_model_modes' });
|
| 323 |
+
await delay(300);
|
| 324 |
+
|
| 325 |
+
const modes = obs && obs.model_mode_info ? obs.model_mode_info : {};
|
| 326 |
+
const anyEval = Object.values(modes).some(m => m === 'eval');
|
| 327 |
+
if (anyEval) {
|
| 328 |
+
await sendStep({ action_type: 'fix_model_mode' });
|
| 329 |
+
await delay(300);
|
| 330 |
+
await sendStep({ action_type: 'restart_run' });
|
| 331 |
+
await delay(300);
|
| 332 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'batchnorm_eval_mode' });
|
| 333 |
+
return;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
// Step 4: Inspect code
|
| 337 |
+
await sendStep({ action_type: 'inspect_code' });
|
| 338 |
+
await delay(300);
|
| 339 |
+
|
| 340 |
+
if (obs && obs.code_snippet && obs.code_snippet.code) {
|
| 341 |
+
const code = obs.code_snippet.code;
|
| 342 |
+
const lines = code.split('\n');
|
| 343 |
+
let fixLine = null, fixReplacement = null;
|
| 344 |
+
for (let i = 0; i < lines.length; i++) {
|
| 345 |
+
const ln = lines[i].trim();
|
| 346 |
+
if (ln.includes('model.eval()')) { fixLine = i + 1; fixReplacement = lines[i].replace('model.eval()', 'model.train()'); break; }
|
| 347 |
+
if (ln.includes('.detach()') && ln.includes('criterion')) { fixLine = i + 1; fixReplacement = lines[i].replace('.detach()', ''); break; }
|
| 348 |
+
if (ln.includes('inplace=True')) { fixLine = i + 1; fixReplacement = lines[i].replace('inplace=True', ''); break; }
|
| 349 |
+
}
|
| 350 |
+
if (fixLine) {
|
| 351 |
+
await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement });
|
| 352 |
+
await delay(300);
|
| 353 |
+
} else {
|
| 354 |
+
// zero_grad_missing — find optimizer.step() and add zero_grad before it
|
| 355 |
+
for (let i = 0; i < lines.length; i++) {
|
| 356 |
+
if (lines[i].trim().includes('optimizer.step()')) {
|
| 357 |
+
fixLine = i + 1;
|
| 358 |
+
fixReplacement = ' optimizer.zero_grad()\n' + lines[i];
|
| 359 |
+
break;
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
if (fixLine) {
|
| 363 |
+
await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement });
|
| 364 |
+
await delay(300);
|
| 365 |
+
}
|
| 366 |
}
|
| 367 |
+
await sendStep({ action_type: 'restart_run' });
|
| 368 |
+
await delay(300);
|
| 369 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'code_bug' });
|
| 370 |
+
return;
|
| 371 |
}
|
| 372 |
+
|
| 373 |
+
// Step 5: Check for scheduler issue
|
| 374 |
+
const va = obs && obs.val_accuracy_history ? obs.val_accuracy_history : [];
|
| 375 |
+
const midAcc = va.length > 10 ? va[9] : 0;
|
| 376 |
+
const endAcc = va.length > 0 ? va[va.length - 1] : 0;
|
| 377 |
+
const stagnated = midAcc > 0.3 && (endAcc - midAcc) < 0.05;
|
| 378 |
+
|
| 379 |
+
if (stagnated) {
|
| 380 |
+
await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.005 });
|
| 381 |
+
await delay(300);
|
| 382 |
+
await sendStep({ action_type: 'restart_run' });
|
| 383 |
+
await delay(300);
|
| 384 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'scheduler_misconfigured' });
|
| 385 |
+
return;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
// Fallback
|
| 389 |
+
await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 });
|
| 390 |
+
await delay(300);
|
| 391 |
+
await sendStep({ action_type: 'restart_run' });
|
| 392 |
+
await delay(300);
|
| 393 |
+
await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' });
|
| 394 |
}
|
| 395 |
|
| 396 |
connect();
|
server/environment.py
CHANGED
|
@@ -294,6 +294,10 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
|
|
| 294 |
is_correct_fix: bool | None = None
|
| 295 |
convergence = False
|
| 296 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
try:
|
| 298 |
is_correct_fix, convergence = self._dispatch_action(action, session)
|
| 299 |
except Exception as exc:
|
|
@@ -306,7 +310,7 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
|
|
| 306 |
},
|
| 307 |
exc_info=True,
|
| 308 |
)
|
| 309 |
-
reward = compute_reward(action,
|
| 310 |
obs = self._build_observation(session, reward=reward)
|
| 311 |
obs.error_log = f"Internal error processing {action_type}: {exc}"
|
| 312 |
return obs
|
|
@@ -317,10 +321,10 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
|
|
| 317 |
else:
|
| 318 |
state.actions_taken.append(action_type)
|
| 319 |
|
| 320 |
-
# Compute reward
|
| 321 |
reward = compute_reward(
|
| 322 |
action,
|
| 323 |
-
|
| 324 |
scenario,
|
| 325 |
is_valid_action=True,
|
| 326 |
is_correct_fix=is_correct_fix,
|
|
|
|
| 294 |
is_correct_fix: bool | None = None
|
| 295 |
convergence = False
|
| 296 |
|
| 297 |
+
# Snapshot state BEFORE dispatch — reward engine needs pre-action state
|
| 298 |
+
# to correctly compute investigation bonuses and context-gated penalties
|
| 299 |
+
state_before = state.model_copy(deep=True)
|
| 300 |
+
|
| 301 |
try:
|
| 302 |
is_correct_fix, convergence = self._dispatch_action(action, session)
|
| 303 |
except Exception as exc:
|
|
|
|
| 310 |
},
|
| 311 |
exc_info=True,
|
| 312 |
)
|
| 313 |
+
reward = compute_reward(action, state_before, scenario, is_valid_action=False)
|
| 314 |
obs = self._build_observation(session, reward=reward)
|
| 315 |
obs.error_log = f"Internal error processing {action_type}: {exc}"
|
| 316 |
return obs
|
|
|
|
| 321 |
else:
|
| 322 |
state.actions_taken.append(action_type)
|
| 323 |
|
| 324 |
+
# Compute reward using pre-action state
|
| 325 |
reward = compute_reward(
|
| 326 |
action,
|
| 327 |
+
state_before,
|
| 328 |
scenario,
|
| 329 |
is_valid_action=True,
|
| 330 |
is_correct_fix=is_correct_fix,
|
tests/test_episode_lifecycle.py
CHANGED
|
@@ -51,6 +51,31 @@ class TestStepInspections:
|
|
| 51 |
assert len(obs.gradient_stats) > 0
|
| 52 |
assert obs.episode_state.gradients_inspected
|
| 53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
def test_inspect_data_batch(self, env):
|
| 55 |
env.reset(seed=42, episode_id="test", task_id="task_003")
|
| 56 |
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
|
|
|
| 51 |
assert len(obs.gradient_stats) > 0
|
| 52 |
assert obs.episode_state.gradients_inspected
|
| 53 |
|
| 54 |
+
def test_inspect_gradients_gives_investigation_bonus(self, env):
|
| 55 |
+
"""First-time inspection must give +0.05 bonus (total +0.04 with step penalty)."""
|
| 56 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 57 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 58 |
+
assert obs.reward == pytest.approx(0.04)
|
| 59 |
+
|
| 60 |
+
def test_inspect_data_batch_gives_investigation_bonus(self, env):
|
| 61 |
+
"""First-time data inspection must give +0.05 bonus."""
|
| 62 |
+
env.reset(seed=42, episode_id="test", task_id="task_003")
|
| 63 |
+
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|
| 64 |
+
assert obs.reward == pytest.approx(0.04)
|
| 65 |
+
|
| 66 |
+
def test_inspect_model_modes_gives_investigation_bonus(self, env):
|
| 67 |
+
"""First-time model modes inspection must give +0.05 bonus."""
|
| 68 |
+
env.reset(seed=42, episode_id="test", task_id="task_005")
|
| 69 |
+
obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
|
| 70 |
+
assert obs.reward == pytest.approx(0.04)
|
| 71 |
+
|
| 72 |
+
def test_repeat_inspection_no_bonus(self, env):
|
| 73 |
+
"""Second inspection of same type must NOT give bonus."""
|
| 74 |
+
env.reset(seed=42, episode_id="test", task_id="task_001")
|
| 75 |
+
env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 76 |
+
obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
|
| 77 |
+
assert obs.reward == pytest.approx(-0.01)
|
| 78 |
+
|
| 79 |
def test_inspect_data_batch(self, env):
|
| 80 |
env.reset(seed=42, episode_id="test", task_id="task_003")
|
| 81 |
obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
|