UjjwalPardeshi commited on
Commit
eeb6913
·
1 Parent(s): f4c428c

fix: dashboard, debug logs

Browse files
.coverage CHANGED
Binary files a/.coverage and b/.coverage differ
 
PROJECT_GUIDE.md ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Training Run Debugger — Complete Project Guide
2
+
3
+ ## What Is This?
4
+
5
+ A game where an AI agent plays detective to fix broken PyTorch training runs. The agent sees a failing training run, investigates clues (gradients, data, code), applies a fix, and submits a diagnosis. Built as an [OpenEnv](https://github.com/openenv) RL environment for the **Meta PyTorch OpenEnv Hackathon**.
6
+
7
+ ---
8
+
9
+ ## How a Game Works
10
+
11
+ ```
12
+ 1. Agent receives a broken training run (loss curves, config, error log)
13
+ 2. Agent investigates (inspect gradients, data, weights, model modes, code)
14
+ 3. Agent applies a fix (reduce LR, patch data, fix code, etc.)
15
+ 4. Agent restarts training and confirms recovery
16
+ 5. Agent submits diagnosis ("the problem was lr_too_high")
17
+ 6. Grader scores the agent 0.0 to 1.0
18
+ ```
19
+
20
+ ---
21
+
22
+ ## The 7 Tasks
23
+
24
+ | Task | Problem | Difficulty | Root Cause | Key Clue |
25
+ |------|---------|-----------|------------|----------|
26
+ | `task_001` | Gradients explode | Easy | `lr_too_high` | All layers `is_exploding: true` |
27
+ | `task_002` | Gradients vanish | Easy | `vanishing_gradients` | Deep layers `is_vanishing: true` |
28
+ | `task_003` | Test data leaked into training | Medium | `data_leakage` | `class_overlap_score > 0.5` |
29
+ | `task_004` | Model memorizes, doesn't learn | Medium | `overfitting` | Train loss drops, val loss rises |
30
+ | `task_005` | BatchNorm stuck in eval mode | Hard | `batchnorm_eval_mode` | Model modes show "eval" + red herrings |
31
+ | `task_006` | Bug in Python training code | Hard | `code_bug` | Bug visible in code snippet |
32
+ | `task_007` | LR scheduler decays too fast | Medium-Hard | `scheduler_misconfigured` | Early progress then stagnation |
33
+
34
+ ---
35
+
36
+ ## Reward System
37
+
38
+ Every action earns or costs points (capped at -1.0 to 1.0):
39
+
40
+ | Event | Reward | When |
41
+ |-------|--------|------|
42
+ | Any step taken | **-0.01** | Always (encourages efficiency) |
43
+ | First-time inspection | **+0.05** | Once per inspection type |
44
+ | Correct diagnosis | **+0.50** | Diagnosis matches root cause |
45
+ | Wrong diagnosis | **-0.30** | Diagnosis doesn't match |
46
+ | Fix works + training recovers | **+0.40** | After fix + restart + convergence |
47
+ | Invalid action | **-0.05** | Action not available |
48
+ | Wrong code fix | **-0.10** | `fix_code` with wrong line/replacement |
49
+ | **Context-gated penalty** | **-0.20** | Inspected gradients, saw they're normal, then added gradient clipping anyway |
50
+
51
+ ### The Context-Gated Penalty (Core Innovation)
52
+
53
+ - Agent checks gradients -> finds them **normal** -> adds gradient clipping = **-0.20 penalty** (ignoring evidence)
54
+ - Agent adds gradient clipping **before** checking gradients = **no penalty** (reasonable prior)
55
+
56
+ This teaches: *don't ignore what you've already learned*.
57
+
58
+ ---
59
+
60
+ ## Architecture
61
+
62
+ ```
63
+ ml_training_debugger/ # Core logic
64
+ models.py # All data types (Pydantic)
65
+ scenarios.py # Creates the 7 tasks with random params
66
+ pytorch_engine.py # Real PyTorch model + fault injection
67
+ simulation.py # Loss/accuracy curve generation
68
+ reward_engine.py # Per-step reward calculation
69
+ graders.py # Final 0.0-1.0 scoring per task
70
+ code_templates.py # Buggy code for Task 6
71
+ client.py # Client for connecting to the environment
72
+
73
+ server/ # Web server
74
+ app.py # FastAPI + all endpoints
75
+ environment.py # Game logic (reset, step, state)
76
+
77
+ tests/ # 183 tests, 97% coverage
78
+ baseline_heuristic.py # Rule-based agent (deterministic)
79
+ baseline_inference.py # LLM agent (Llama/GPT-4o)
80
+ ```
81
+
82
+ ---
83
+
84
+ ## API Endpoints
85
+
86
+ ### GET /health
87
+
88
+ Server status check.
89
+
90
+ **Response:**
91
+ ```json
92
+ {
93
+ "status": "ready",
94
+ "tasks": 7
95
+ }
96
+ ```
97
+
98
+ ---
99
+
100
+ ### GET /tasks
101
+
102
+ List all available tasks with action schema.
103
+
104
+ **Response:**
105
+ ```json
106
+ [
107
+ {
108
+ "id": "task_001",
109
+ "difficulty": "easy",
110
+ "max_steps": 20,
111
+ "action_schema": {
112
+ "title": "MLTrainingAction",
113
+ "type": "object",
114
+ "properties": {
115
+ "action_type": { "type": "string" },
116
+ "target": { "type": ["string", "null"] },
117
+ "value": { "type": ["number", "integer", "string", "null"] },
118
+ "diagnosis": { "type": ["string", "null"] },
119
+ "line": { "type": ["integer", "null"] },
120
+ "replacement": { "type": ["string", "null"] }
121
+ },
122
+ "required": ["action_type"]
123
+ }
124
+ }
125
+ ]
126
+ ```
127
+
128
+ ---
129
+
130
+ ### POST /baseline
131
+
132
+ Run the heuristic baseline agent on all 7 tasks.
133
+
134
+ **Response:**
135
+ ```json
136
+ {
137
+ "scores": {
138
+ "task_001": 1.00,
139
+ "task_002": 1.00,
140
+ "task_003": 1.00,
141
+ "task_004": 0.45,
142
+ "task_005": 0.35,
143
+ "task_006": 1.00,
144
+ "task_007": 1.00
145
+ }
146
+ }
147
+ ```
148
+
149
+ Returns `409` if baseline is already running.
150
+
151
+ ---
152
+
153
+ ### POST /grader
154
+
155
+ Get the grader score for the last completed episode.
156
+
157
+ **Query params:** `session_id` (optional)
158
+
159
+ **Response:**
160
+ ```json
161
+ {
162
+ "score": 0.85,
163
+ "task_id": "task_001",
164
+ "steps": 5
165
+ }
166
+ ```
167
+
168
+ If no episode completed:
169
+ ```json
170
+ {
171
+ "score": null,
172
+ "error": "no_completed_episode"
173
+ }
174
+ ```
175
+
176
+ ---
177
+
178
+ ### GET /dashboard
179
+
180
+ Live diagnostic dashboard (HTML page with Plotly.js charts). Open in a browser.
181
+
182
+ **Panels:**
183
+ 1. Training metrics (loss/accuracy curves)
184
+ 2. Gradient & weight heatmap
185
+ 3. Action timeline with rewards
186
+ 4. Episode summary with state flags
187
+
188
+ ---
189
+
190
+ ### GET /validation-report
191
+
192
+ Pre-computed fidelity report comparing parametric curves to real PyTorch training runs.
193
+
194
+ ---
195
+
196
+ ### GET /curriculum
197
+
198
+ Recommended task order for progressive training (easy to hard, 3 difficulty levels each).
199
+
200
+ **Response:**
201
+ ```json
202
+ {
203
+ "curriculum": [
204
+ { "task_id": "task_001", "difficulty": "easy", "difficulty_level": 1, "max_steps": 20 },
205
+ { "task_id": "task_001", "difficulty": "easy", "difficulty_level": 3, "max_steps": 20 },
206
+ { "task_id": "task_001", "difficulty": "easy", "difficulty_level": 5, "max_steps": 20 }
207
+ ],
208
+ "total_episodes": 21
209
+ }
210
+ ```
211
+
212
+ ---
213
+
214
+ ### GET /leaderboard
215
+
216
+ Sorted episode scores from baseline runs.
217
+
218
+ **Response:**
219
+ ```json
220
+ {
221
+ "entries": [
222
+ { "score": 1.00, "task_id": "task_001", "steps": 5, "episode_id": "baseline_task_001" }
223
+ ],
224
+ "total": 7
225
+ }
226
+ ```
227
+
228
+ ---
229
+
230
+ ### GET /replay/{episode_id}
231
+
232
+ Full action/observation trace for a completed episode.
233
+
234
+ **Response:**
235
+ ```json
236
+ {
237
+ "episode_id": "baseline_task_001",
238
+ "score": 1.00,
239
+ "task_id": "task_001",
240
+ "steps": 5
241
+ }
242
+ ```
243
+
244
+ ---
245
+
246
+ ## WebSocket Interface (Primary Agent Interface)
247
+
248
+ **Endpoint:** `ws://localhost:7860/ws`
249
+
250
+ This is the main way agents interact with the environment. HTTP endpoints are stateless — WebSocket maintains session state across a full episode.
251
+
252
+ ### Reset (Start New Episode)
253
+
254
+ **Send:**
255
+ ```json
256
+ {
257
+ "type": "reset",
258
+ "seed": 42,
259
+ "kwargs": {
260
+ "task_id": "task_003",
261
+ "difficulty_level": 3
262
+ }
263
+ }
264
+ ```
265
+
266
+ Without `kwargs`, defaults to `task_001`.
267
+
268
+ **Receive:**
269
+ ```json
270
+ {
271
+ "type": "observation",
272
+ "observation": {
273
+ "run_id": "ep_12345",
274
+ "framework": "pytorch",
275
+ "epoch": 20,
276
+ "training_loss_history": [2.3, 2.1, 1.9, ...],
277
+ "val_loss_history": [2.4, 2.2, 2.0, ...],
278
+ "val_accuracy_history": [0.3, 0.35, 0.4, ...],
279
+ "gradient_stats": [],
280
+ "model_weight_stats": null,
281
+ "data_batch_stats": null,
282
+ "model_mode_info": null,
283
+ "code_snippet": null,
284
+ "current_config": {
285
+ "learning_rate": 0.001,
286
+ "weight_decay": 0.0001,
287
+ "batch_size": 64,
288
+ "hidden_dim": 64,
289
+ "num_layers": 3,
290
+ "optimizer": "adam",
291
+ "dropout_rate": 0.0,
292
+ "gradient_clip_norm": null
293
+ },
294
+ "error_log": null,
295
+ "gpu_memory_used_gb": 6.2,
296
+ "gpu_memory_total_gb": 16.0,
297
+ "available_actions": [
298
+ "inspect_gradients",
299
+ "inspect_data_batch",
300
+ "inspect_model_modes",
301
+ "inspect_model_weights",
302
+ "inspect_code",
303
+ "modify_config",
304
+ "add_callback",
305
+ "replace_optimizer",
306
+ "patch_data_loader",
307
+ "fix_model_mode",
308
+ "mark_diagnosed"
309
+ ],
310
+ "episode_state": {
311
+ "step_count": 0,
312
+ "gradients_inspected": false,
313
+ "gradients_were_normal": false,
314
+ "data_inspected": false,
315
+ "model_modes_inspected": false,
316
+ "model_weights_inspected": false,
317
+ "code_inspected": false,
318
+ "fix_action_taken": false,
319
+ "restart_after_fix": false,
320
+ "diagnosis_submitted": false,
321
+ "actions_taken": []
322
+ },
323
+ "notes": null,
324
+ "done": false,
325
+ "reward": null,
326
+ "metadata": {}
327
+ }
328
+ }
329
+ ```
330
+
331
+ ### Step (Take an Action)
332
+
333
+ **Investigation actions** (no extra fields needed):
334
+ ```json
335
+ {"type": "step", "action": {"action_type": "inspect_gradients"}}
336
+ {"type": "step", "action": {"action_type": "inspect_data_batch"}}
337
+ {"type": "step", "action": {"action_type": "inspect_model_modes"}}
338
+ {"type": "step", "action": {"action_type": "inspect_model_weights"}}
339
+ {"type": "step", "action": {"action_type": "inspect_code"}}
340
+ ```
341
+
342
+ **Fix actions:**
343
+ ```json
344
+ {"type": "step", "action": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
345
+ {"type": "step", "action": {"action_type": "add_callback"}}
346
+ {"type": "step", "action": {"action_type": "replace_optimizer"}}
347
+ {"type": "step", "action": {"action_type": "patch_data_loader"}}
348
+ {"type": "step", "action": {"action_type": "fix_model_mode"}}
349
+ {"type": "step", "action": {"action_type": "fix_code", "line": 5, "replacement": "model.train()"}}
350
+ ```
351
+
352
+ **Terminal actions:**
353
+ ```json
354
+ {"type": "step", "action": {"action_type": "restart_run"}}
355
+ {"type": "step", "action": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
356
+ ```
357
+
358
+ **Receive (after each step):**
359
+ ```json
360
+ {
361
+ "type": "observation",
362
+ "observation": {
363
+ "...same structure as reset response...",
364
+ "gradient_stats": [
365
+ {
366
+ "layer_name": "conv1",
367
+ "norm_history": [0.5, 0.6, 0.7],
368
+ "mean_norm": 51.1,
369
+ "max_norm": 98.3,
370
+ "is_exploding": true,
371
+ "is_vanishing": false
372
+ }
373
+ ],
374
+ "episode_state": {
375
+ "step_count": 1,
376
+ "gradients_inspected": true,
377
+ "actions_taken": ["inspect_gradients"]
378
+ },
379
+ "done": false,
380
+ "reward": 0.04
381
+ }
382
+ }
383
+ ```
384
+
385
+ When `done: true`, the episode is over.
386
+
387
+ ---
388
+
389
+ ## All 14 Action Types
390
+
391
+ | Action | Required Fields | Description |
392
+ |--------|----------------|-------------|
393
+ | `inspect_gradients` | none | View per-layer gradient stats |
394
+ | `inspect_data_batch` | none | View data batch statistics |
395
+ | `inspect_model_modes` | none | View train/eval mode per layer |
396
+ | `inspect_model_weights` | none | View per-layer weight stats |
397
+ | `inspect_code` | none | View source code (Task 6) |
398
+ | `modify_config` | `target`, `value` | Change a hyperparameter |
399
+ | `add_callback` | none | Add gradient clipping callback |
400
+ | `replace_optimizer` | none | Switch optimizer |
401
+ | `patch_data_loader` | none | Fix data pipeline |
402
+ | `fix_model_mode` | none | Switch model to train mode |
403
+ | `fix_code` | `line`, `replacement` | Fix a line of code |
404
+ | `restart_run` | none | Restart training (requires fix first) |
405
+ | `mark_diagnosed` | `diagnosis` | Submit final diagnosis |
406
+ | `rollback_checkpoint` | none | Rollback to checkpoint |
407
+
408
+ ### Valid `target` values for modify_config
409
+ `learning_rate`, `weight_decay`, `batch_size`, `hidden_dim`, `num_layers`, `optimizer`, `dropout_rate`, `gradient_clip_norm`
410
+
411
+ ### Valid `diagnosis` values for mark_diagnosed
412
+ `lr_too_high`, `vanishing_gradients`, `data_leakage`, `overfitting`, `batchnorm_eval_mode`, `code_bug`, `scheduler_misconfigured`
413
+
414
+ ---
415
+
416
+ ## Dynamic Action Availability
417
+
418
+ Actions appear/disappear based on episode state:
419
+
420
+ | Action | Available When |
421
+ |--------|---------------|
422
+ | `fix_code` | Only after `inspect_code` (code_inspected = true) |
423
+ | `restart_run` | Only after a fix action (fix_action_taken = true) |
424
+ | `rollback_checkpoint` | Only after restart (restart_after_fix = true) |
425
+ | `mark_diagnosed` | Only while diagnosis_submitted = false |
426
+
427
+ ---
428
+
429
+ ## Observation Fields — Progressive Reveal
430
+
431
+ On reset, the agent sees loss curves, config, and error log. Everything else is `null` until inspected:
432
+
433
+ | Field | Starts As | Populated After |
434
+ |-------|-----------|----------------|
435
+ | `training_loss_history` | 20 floats | Always visible |
436
+ | `val_accuracy_history` | 20 floats | Always visible |
437
+ | `val_loss_history` | 20 floats | Always visible |
438
+ | `current_config` | Full config | Always visible |
439
+ | `error_log` | String or null | Always visible |
440
+ | `gradient_stats` | `[]` | `inspect_gradients` |
441
+ | `model_weight_stats` | `null` | `inspect_model_weights` |
442
+ | `data_batch_stats` | `null` | `inspect_data_batch` |
443
+ | `model_mode_info` | `null` | `inspect_model_modes` |
444
+ | `code_snippet` | `null` | `inspect_code` |
445
+
446
+ ---
447
+
448
+ ## Data Types
449
+
450
+ ### GradientStats (per layer)
451
+ ```json
452
+ {
453
+ "layer_name": "conv1",
454
+ "norm_history": [0.5, 0.6, 0.7],
455
+ "mean_norm": 12.5,
456
+ "max_norm": 25.3,
457
+ "is_exploding": true,
458
+ "is_vanishing": false
459
+ }
460
+ ```
461
+ - Exploding: `mean_norm > 10.0`
462
+ - Vanishing: `mean_norm < 0.000001`
463
+
464
+ ### ModelWeightStats (per layer)
465
+ ```json
466
+ {
467
+ "layer_name": "conv1",
468
+ "weight_norm": 1.234,
469
+ "weight_mean": 0.001,
470
+ "weight_std": 0.05,
471
+ "weight_min": -0.15,
472
+ "weight_max": 0.16,
473
+ "dead_neuron_pct": 0.0,
474
+ "has_nan": false,
475
+ "has_inf": false
476
+ }
477
+ ```
478
+
479
+ ### DataBatchStats
480
+ ```json
481
+ {
482
+ "label_distribution": {"0": 0.25, "1": 0.25, "2": 0.25, "3": 0.25},
483
+ "feature_mean": 0.5,
484
+ "feature_std": 0.2,
485
+ "null_count": 0,
486
+ "class_overlap_score": 0.15,
487
+ "batch_size": 64,
488
+ "duplicate_ratio": 0.0,
489
+ "confusion_matrix": [[10, 2, 1], [1, 9, 3], [2, 1, 11]]
490
+ }
491
+ ```
492
+
493
+ ### CodeSnippet (Task 6 only)
494
+ ```json
495
+ {
496
+ "code": "import torch\nimport torch.nn as nn\n...",
497
+ "filename": "train.py",
498
+ "line_count": 50,
499
+ "imports": ["torch", "torch.nn", "torch.optim"],
500
+ "hint": "Look for .detach() preventing gradient flow"
501
+ }
502
+ ```
503
+
504
+ ### EpisodeState
505
+ ```json
506
+ {
507
+ "step_count": 0,
508
+ "gradients_inspected": false,
509
+ "gradients_were_normal": false,
510
+ "data_inspected": false,
511
+ "model_modes_inspected": false,
512
+ "model_weights_inspected": false,
513
+ "code_inspected": false,
514
+ "fix_action_taken": false,
515
+ "restart_after_fix": false,
516
+ "diagnosis_submitted": false,
517
+ "actions_taken": []
518
+ }
519
+ ```
520
+
521
+ ---
522
+
523
+ ## Grading Breakdown (per task)
524
+
525
+ Each task has its own grader that scores 0.0 to 1.0 based on what the agent did:
526
+
527
+ ### Task 1 — Exploding Gradients
528
+ | Component | Points |
529
+ |-----------|--------|
530
+ | Inspected gradients | +0.05 |
531
+ | Applied config fix | +0.20 |
532
+ | Restarted training | +0.35 |
533
+ | Correct diagnosis (`lr_too_high`) | +0.40 |
534
+
535
+ ### Task 2 — Vanishing Gradients
536
+ | Component | Points |
537
+ |-----------|--------|
538
+ | Inspected gradients | +0.05 |
539
+ | Applied config fix | +0.20 |
540
+ | Restarted training | +0.35 |
541
+ | Correct diagnosis (`vanishing_gradients`) | +0.40 |
542
+
543
+ ### Task 3 — Data Leakage
544
+ | Component | Points |
545
+ |-----------|--------|
546
+ | Inspected data | +0.05 |
547
+ | Patched data loader | +0.30 |
548
+ | Restarted training | +0.30 |
549
+ | Correct diagnosis (`data_leakage`) | +0.35 |
550
+
551
+ ### Task 4 — Overfitting
552
+ | Component | Points |
553
+ |-----------|--------|
554
+ | Inspected data | +0.05 |
555
+ | Applied fix (config or callback) | +0.25 |
556
+ | Restarted training | +0.30 |
557
+ | Correct diagnosis (`overfitting`) | +0.40 |
558
+
559
+ ### Task 5 — BatchNorm Eval Mode (with red herrings)
560
+ | Component | Points |
561
+ |-----------|--------|
562
+ | Inspected gradients | +0.05 |
563
+ | Inspected model modes | +0.05 |
564
+ | **Fell for red herring** (add_callback after normal gradients) | **-0.20** |
565
+ | Fixed model mode | +0.25 |
566
+ | Restarted training | +0.30 |
567
+ | Correct diagnosis (`batchnorm_eval_mode`) | +0.40 |
568
+
569
+ ### Task 6 — Code Bug
570
+ | Component | Points |
571
+ |-----------|--------|
572
+ | Inspected code | +0.05 |
573
+ | Fixed code correctly | +0.30 |
574
+ | Restarted training | +0.25 |
575
+ | Correct diagnosis (`code_bug`) | +0.40 |
576
+
577
+ ### Task 7 — Scheduler Misconfigured
578
+ | Component | Points |
579
+ |-----------|--------|
580
+ | Inspected gradients | +0.05 |
581
+ | Inspected data | +0.05 |
582
+ | Applied config fix | +0.25 |
583
+ | Restarted training | +0.25 |
584
+ | Correct diagnosis (`scheduler_misconfigured`) | +0.40 |
585
+
586
+ ---
587
+
588
+ ## Baseline Scores
589
+
590
+ | Task | Heuristic | Llama 3.3 70B | Llama 3.1 8B |
591
+ |------|-----------|---------------|--------------|
592
+ | task_001 | **1.00** | 1.00 | 0.60 |
593
+ | task_002 | **1.00** | 1.00 | 0.05 |
594
+ | task_003 | **1.00** | 0.40 | 0.40 |
595
+ | task_004 | 0.45 | 0.45 | **0.60** |
596
+ | task_005 | **1.00** | 1.00 | 1.00 |
597
+ | task_006 | **1.00** | — | 0.60-1.00 |
598
+ | task_007 | **1.00** | — | 0.60 |
599
+ | **Average** | **0.92** | ~0.69 | 0.55 |
600
+
601
+ ---
602
+
603
+ ## Walkthrough: Solving Task 1 (Exploding Gradients)
604
+
605
+ ```
606
+ Step 1: Reset
607
+ Send: {"type": "reset", "kwargs": {"task_id": "task_001"}}
608
+ See: Loss history going to infinity, error_log says "NaN at epoch 12"
609
+
610
+ Step 2: Inspect gradients
611
+ Send: {"type": "step", "action": {"action_type": "inspect_gradients"}}
612
+ See: All layers is_exploding: true, mean_norm > 10.0
613
+ Reward: +0.04 (-0.01 step + 0.05 investigation)
614
+
615
+ Step 3: Reduce learning rate
616
+ Send: {"type": "step", "action": {"action_type": "modify_config", "target": "learning_rate", "value": 0.001}}
617
+ Reward: -0.01 (step penalty)
618
+
619
+ Step 4: Restart training
620
+ Send: {"type": "step", "action": {"action_type": "restart_run"}}
621
+ See: Convergence detected!
622
+ Reward: +0.39 (-0.01 step + 0.40 convergence)
623
+
624
+ Step 5: Submit diagnosis
625
+ Send: {"type": "step", "action": {"action_type": "mark_diagnosed", "diagnosis": "lr_too_high"}}
626
+ See: done: true
627
+ Reward: +0.49 (-0.01 step + 0.50 correct diagnosis)
628
+
629
+ Grader score: 1.0 (perfect)
630
+ ```
631
+
632
+ ---
633
+
634
+ ## Walkthrough: Task 5 Trap (Red Herring)
635
+
636
+ ```
637
+ Step 1: Reset task_005
638
+ Step 2: Inspect gradients
639
+ -> FC layer has a spike (mean_norm=4.2, but is_exploding: false)
640
+ -> gradients_were_normal is set to TRUE (nothing actually exploding)
641
+
642
+ Step 3 (BAD): Add gradient clipping
643
+ -> Reward: -0.21 (-0.01 step - 0.20 context-gated penalty!)
644
+ -> Agent IGNORED the evidence that gradients were normal
645
+
646
+ Step 3 (GOOD): Inspect model modes instead
647
+ -> Sees all layers in "eval" mode — that's the real problem!
648
+
649
+ Step 4: Fix model mode
650
+ Step 5: Restart training
651
+ Step 6: Diagnose batchnorm_eval_mode -> correct!
652
+ ```
653
+
654
+ ---
655
+
656
+ ## Quick Start
657
+
658
+ ```bash
659
+ # Setup
660
+ python3 -m venv .venv && source .venv/bin/activate
661
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
662
+ pip install -r requirements.txt
663
+ pip install pytest pytest-cov
664
+
665
+ # Run server
666
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
667
+
668
+ # Test
669
+ pytest tests/ -v --cov=ml_training_debugger
670
+ curl http://localhost:7860/health
671
+ curl http://localhost:7860/tasks | python3 -m json.tool
672
+ curl -X POST http://localhost:7860/baseline | python3 -m json.tool
673
+
674
+ # Docker
675
+ docker build -t pytorch-debugger .
676
+ docker run -p 7860:7860 pytorch-debugger
677
+ ```
678
+
679
+ ---
680
+
681
+ ## Tech Stack
682
+
683
+ | Component | Purpose |
684
+ |-----------|---------|
685
+ | Python 3.12 | Runtime |
686
+ | PyTorch (CPU-only) | Real neural networks, real gradients |
687
+ | FastAPI | Web server |
688
+ | OpenEnv | RL environment framework (step/reset/state API) |
689
+ | Pydantic v2 | Typed data models |
690
+ | Plotly.js | Dashboard charts |
691
+ | Docker | Containerized deployment |
server/dashboard.html CHANGED
@@ -17,6 +17,7 @@ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; b
17
  .panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
18
  .panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
19
  .panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
 
20
  .placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
21
  #controls { display: flex; gap: 8px; align-items: center; }
22
  #controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
@@ -47,6 +48,7 @@ body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif; b
47
  <option value="task_004">Task 4 — Overfitting (Medium)</option>
48
  <option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
49
  <option value="task_006">Task 6 — Code Bug (Hard)</option>
 
50
  </select>
51
  <button class="primary" onclick="runBaseline()">Run Baseline</button>
52
  </div>
@@ -220,26 +222,175 @@ function updateSummary(d) {
220
  document.getElementById('summary').innerHTML = html;
221
  }
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  async function runBaseline() {
224
  const taskId = document.getElementById('taskSelect').value;
225
  actions = []; rewards = []; cumRewards = [];
226
- if (ws && ws.readyState === WebSocket.OPEN) {
227
- ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } }));
228
- await new Promise(r => setTimeout(r, 500));
229
- // Run the heuristic steps
230
- const steps = [
231
- { action_type: 'inspect_gradients' },
232
- { action_type: 'inspect_data_batch' },
233
- { action_type: 'inspect_model_modes' },
234
- { action_type: 'inspect_model_weights' },
235
- { action_type: 'inspect_code' },
236
- ];
237
- for (const step of steps) {
238
- ws.send(JSON.stringify({ type: 'step', data: step }));
239
- await new Promise(r => setTimeout(r, 300));
240
- if (obs && obs.done) break;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  }
 
 
 
 
242
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  }
244
 
245
  connect();
 
17
  .panel { background: #161b22; border: 1px solid #30363d; border-radius: 8px; overflow: hidden; display: flex; flex-direction: column; }
18
  .panel-title { padding: 10px 16px; font-size: 14px; font-weight: 600; color: #58a6ff; border-bottom: 1px solid #30363d; background: #0d1117; }
19
  .panel-body { flex: 1; padding: 8px; position: relative; min-height: 0; }
20
+ .panel-body > div:first-child { width: 100%; height: 100%; }
21
  .placeholder { display: flex; align-items: center; justify-content: center; height: 100%; color: #484f58; font-style: italic; }
22
  #controls { display: flex; gap: 8px; align-items: center; }
23
  #controls select, #controls button { background: #21262d; color: #c9d1d9; border: 1px solid #30363d; padding: 6px 12px; border-radius: 6px; cursor: pointer; font-size: 13px; }
 
48
  <option value="task_004">Task 4 — Overfitting (Medium)</option>
49
  <option value="task_005">Task 5 — BatchNorm Eval (Hard)</option>
50
  <option value="task_006">Task 6 — Code Bug (Hard)</option>
51
+ <option value="task_007">Task 7 — Scheduler Misconfigured (Med-Hard)</option>
52
  </select>
53
  <button class="primary" onclick="runBaseline()">Run Baseline</button>
54
  </div>
 
222
  document.getElementById('summary').innerHTML = html;
223
  }
224
 
225
+ function sendStep(action) {
226
+ return new Promise(resolve => {
227
+ const handler = (ev) => {
228
+ const msg = JSON.parse(ev.data);
229
+ if (msg.type === 'observation') {
230
+ ws.removeEventListener('message', handler);
231
+ resolve(msg);
232
+ }
233
+ };
234
+ ws.addEventListener('message', handler);
235
+ ws.send(JSON.stringify({ type: 'step', data: action }));
236
+ });
237
+ }
238
+
239
+ function sendReset(taskId) {
240
+ return new Promise(resolve => {
241
+ const handler = (ev) => {
242
+ const msg = JSON.parse(ev.data);
243
+ if (msg.type === 'observation') {
244
+ ws.removeEventListener('message', handler);
245
+ resolve(msg);
246
+ }
247
+ };
248
+ ws.addEventListener('message', handler);
249
+ ws.send(JSON.stringify({ type: 'reset', data: { task_id: taskId, seed: 42 } }));
250
+ });
251
+ }
252
+
253
  async function runBaseline() {
254
  const taskId = document.getElementById('taskSelect').value;
255
  actions = []; rewards = []; cumRewards = [];
256
+ if (!ws || ws.readyState !== WebSocket.OPEN) return;
257
+
258
+ const delay = (ms) => new Promise(r => setTimeout(r, ms));
259
+
260
+ // Reset
261
+ await sendReset(taskId);
262
+ await delay(300);
263
+
264
+ // Step 1: Inspect gradients
265
+ await sendStep({ action_type: 'inspect_gradients' });
266
+ await delay(300);
267
+
268
+ const gs = obs && obs.gradient_stats ? obs.gradient_stats : [];
269
+ const anyExploding = gs.some(g => g.is_exploding);
270
+ const anyVanishing = gs.some(g => g.is_vanishing);
271
+
272
+ if (anyExploding) {
273
+ await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.001 });
274
+ await delay(300);
275
+ await sendStep({ action_type: 'restart_run' });
276
+ await delay(300);
277
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'lr_too_high' });
278
+ return;
279
+ }
280
+
281
+ if (anyVanishing) {
282
+ await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.01 });
283
+ await delay(300);
284
+ await sendStep({ action_type: 'restart_run' });
285
+ await delay(300);
286
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'vanishing_gradients' });
287
+ return;
288
+ }
289
+
290
+ // Step 2: Inspect data
291
+ await sendStep({ action_type: 'inspect_data_batch' });
292
+ await delay(300);
293
+
294
+ const dbs = obs && obs.data_batch_stats ? obs.data_batch_stats : {};
295
+ if (dbs.class_overlap_score && dbs.class_overlap_score > 0.5) {
296
+ await sendStep({ action_type: 'patch_data_loader' });
297
+ await delay(300);
298
+ await sendStep({ action_type: 'restart_run' });
299
+ await delay(300);
300
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'data_leakage' });
301
+ return;
302
+ }
303
+
304
+ // Check for overfitting (train loss low, val loss rising)
305
+ const tl = obs && obs.training_loss_history ? obs.training_loss_history : [];
306
+ const vl = obs && obs.val_loss_history ? obs.val_loss_history : [];
307
+ const lastTrainLoss = tl.length > 0 ? tl[tl.length - 1] : 999;
308
+ const lastValLoss = vl.length > 0 ? vl[vl.length - 1] : 0;
309
+ const earlyValLoss = vl.length > 5 ? vl[5] : lastValLoss;
310
+ const isOverfitting = lastTrainLoss < 0.1 && lastValLoss > earlyValLoss;
311
+
312
+ if (isOverfitting) {
313
+ await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 });
314
+ await delay(300);
315
+ await sendStep({ action_type: 'restart_run' });
316
+ await delay(300);
317
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' });
318
+ return;
319
+ }
320
+
321
+ // Step 3: Inspect model modes
322
+ await sendStep({ action_type: 'inspect_model_modes' });
323
+ await delay(300);
324
+
325
+ const modes = obs && obs.model_mode_info ? obs.model_mode_info : {};
326
+ const anyEval = Object.values(modes).some(m => m === 'eval');
327
+ if (anyEval) {
328
+ await sendStep({ action_type: 'fix_model_mode' });
329
+ await delay(300);
330
+ await sendStep({ action_type: 'restart_run' });
331
+ await delay(300);
332
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'batchnorm_eval_mode' });
333
+ return;
334
+ }
335
+
336
+ // Step 4: Inspect code
337
+ await sendStep({ action_type: 'inspect_code' });
338
+ await delay(300);
339
+
340
+ if (obs && obs.code_snippet && obs.code_snippet.code) {
341
+ const code = obs.code_snippet.code;
342
+ const lines = code.split('\n');
343
+ let fixLine = null, fixReplacement = null;
344
+ for (let i = 0; i < lines.length; i++) {
345
+ const ln = lines[i].trim();
346
+ if (ln.includes('model.eval()')) { fixLine = i + 1; fixReplacement = lines[i].replace('model.eval()', 'model.train()'); break; }
347
+ if (ln.includes('.detach()') && ln.includes('criterion')) { fixLine = i + 1; fixReplacement = lines[i].replace('.detach()', ''); break; }
348
+ if (ln.includes('inplace=True')) { fixLine = i + 1; fixReplacement = lines[i].replace('inplace=True', ''); break; }
349
+ }
350
+ if (fixLine) {
351
+ await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement });
352
+ await delay(300);
353
+ } else {
354
+ // zero_grad_missing — find optimizer.step() and add zero_grad before it
355
+ for (let i = 0; i < lines.length; i++) {
356
+ if (lines[i].trim().includes('optimizer.step()')) {
357
+ fixLine = i + 1;
358
+ fixReplacement = ' optimizer.zero_grad()\n' + lines[i];
359
+ break;
360
+ }
361
+ }
362
+ if (fixLine) {
363
+ await sendStep({ action_type: 'fix_code', line: fixLine, replacement: fixReplacement });
364
+ await delay(300);
365
+ }
366
  }
367
+ await sendStep({ action_type: 'restart_run' });
368
+ await delay(300);
369
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'code_bug' });
370
+ return;
371
  }
372
+
373
+ // Step 5: Check for scheduler issue
374
+ const va = obs && obs.val_accuracy_history ? obs.val_accuracy_history : [];
375
+ const midAcc = va.length > 10 ? va[9] : 0;
376
+ const endAcc = va.length > 0 ? va[va.length - 1] : 0;
377
+ const stagnated = midAcc > 0.3 && (endAcc - midAcc) < 0.05;
378
+
379
+ if (stagnated) {
380
+ await sendStep({ action_type: 'modify_config', target: 'learning_rate', value: 0.005 });
381
+ await delay(300);
382
+ await sendStep({ action_type: 'restart_run' });
383
+ await delay(300);
384
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'scheduler_misconfigured' });
385
+ return;
386
+ }
387
+
388
+ // Fallback
389
+ await sendStep({ action_type: 'modify_config', target: 'weight_decay', value: 0.01 });
390
+ await delay(300);
391
+ await sendStep({ action_type: 'restart_run' });
392
+ await delay(300);
393
+ await sendStep({ action_type: 'mark_diagnosed', diagnosis: 'overfitting' });
394
  }
395
 
396
  connect();
server/environment.py CHANGED
@@ -294,6 +294,10 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
294
  is_correct_fix: bool | None = None
295
  convergence = False
296
 
 
 
 
 
297
  try:
298
  is_correct_fix, convergence = self._dispatch_action(action, session)
299
  except Exception as exc:
@@ -306,7 +310,7 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
306
  },
307
  exc_info=True,
308
  )
309
- reward = compute_reward(action, state, scenario, is_valid_action=False)
310
  obs = self._build_observation(session, reward=reward)
311
  obs.error_log = f"Internal error processing {action_type}: {exc}"
312
  return obs
@@ -317,10 +321,10 @@ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation,
317
  else:
318
  state.actions_taken.append(action_type)
319
 
320
- # Compute reward
321
  reward = compute_reward(
322
  action,
323
- state,
324
  scenario,
325
  is_valid_action=True,
326
  is_correct_fix=is_correct_fix,
 
294
  is_correct_fix: bool | None = None
295
  convergence = False
296
 
297
+ # Snapshot state BEFORE dispatch — reward engine needs pre-action state
298
+ # to correctly compute investigation bonuses and context-gated penalties
299
+ state_before = state.model_copy(deep=True)
300
+
301
  try:
302
  is_correct_fix, convergence = self._dispatch_action(action, session)
303
  except Exception as exc:
 
310
  },
311
  exc_info=True,
312
  )
313
+ reward = compute_reward(action, state_before, scenario, is_valid_action=False)
314
  obs = self._build_observation(session, reward=reward)
315
  obs.error_log = f"Internal error processing {action_type}: {exc}"
316
  return obs
 
321
  else:
322
  state.actions_taken.append(action_type)
323
 
324
+ # Compute reward using pre-action state
325
  reward = compute_reward(
326
  action,
327
+ state_before,
328
  scenario,
329
  is_valid_action=True,
330
  is_correct_fix=is_correct_fix,
tests/test_episode_lifecycle.py CHANGED
@@ -51,6 +51,31 @@ class TestStepInspections:
51
  assert len(obs.gradient_stats) > 0
52
  assert obs.episode_state.gradients_inspected
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  def test_inspect_data_batch(self, env):
55
  env.reset(seed=42, episode_id="test", task_id="task_003")
56
  obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
 
51
  assert len(obs.gradient_stats) > 0
52
  assert obs.episode_state.gradients_inspected
53
 
54
+ def test_inspect_gradients_gives_investigation_bonus(self, env):
55
+ """First-time inspection must give +0.05 bonus (total +0.04 with step penalty)."""
56
+ env.reset(seed=42, episode_id="test", task_id="task_001")
57
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
58
+ assert obs.reward == pytest.approx(0.04)
59
+
60
+ def test_inspect_data_batch_gives_investigation_bonus(self, env):
61
+ """First-time data inspection must give +0.05 bonus."""
62
+ env.reset(seed=42, episode_id="test", task_id="task_003")
63
+ obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
64
+ assert obs.reward == pytest.approx(0.04)
65
+
66
+ def test_inspect_model_modes_gives_investigation_bonus(self, env):
67
+ """First-time model modes inspection must give +0.05 bonus."""
68
+ env.reset(seed=42, episode_id="test", task_id="task_005")
69
+ obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
70
+ assert obs.reward == pytest.approx(0.04)
71
+
72
+ def test_repeat_inspection_no_bonus(self, env):
73
+ """Second inspection of same type must NOT give bonus."""
74
+ env.reset(seed=42, episode_id="test", task_id="task_001")
75
+ env.step(MLTrainingAction(action_type="inspect_gradients"))
76
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
77
+ assert obs.reward == pytest.approx(-0.01)
78
+
79
  def test_inspect_data_batch(self, env):
80
  env.reset(seed=42, episode_id="test", task_id="task_003")
81
  obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))