omkarrr88 commited on
Commit
e2f8b29
·
0 Parent(s):

Version 1

Browse files
.claude/plan/pytorch-debugger-mvp.md ADDED
@@ -0,0 +1,1647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation Plan: PyTorch Training Run Debugger — OpenEnv Environment
2
+
3
+ **Generated:** 2026-03-28
4
+ **King File:** `ml-training-debugger-spec.md` — single source of truth for all conflicts
5
+ **Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core (installed in .venv)
6
+ **MVP Scope:** Tasks 1, 3, 5 + rule-based baseline + all required endpoints + Docker + HF Spaces
7
+
8
+ ---
9
+
10
+ ## Markdown Files Confirmed Read
11
+
12
+ | File | Lines | Role |
13
+ |------|-------|------|
14
+ | `ml-training-debugger-spec.md` | 1549 | **KING FILE** — final authority on all design decisions |
15
+ | `CLAUDE.md` | ~280 | Coding standards, non-negotiable rules, reward constants |
16
+ | `PRD.md` | ~368 | Product requirements, success metrics, timeline |
17
+ | `ROADMAP.md` | ~442 | Phased roadmap with acceptance criteria |
18
+
19
+ All four files read in full. The spec is the definitive authority.
20
+
21
+ ---
22
+
23
+ ## Complete Project Structure (Final State)
24
+
25
+ ```
26
+ ML Debugger/ # Project root
27
+ ├── .claude/
28
+ │ └── plan/
29
+ │ └── pytorch-debugger-mvp.md # This plan
30
+ ├── .dockerignore
31
+ ├── .gitignore
32
+ ├── .python-version # "3.12"
33
+ ├── CLAUDE.md # Already exists
34
+ ├── Dockerfile
35
+ ├── PRD.md # Already exists
36
+ ├── README.md
37
+ ├── ROADMAP.md # Already exists
38
+ ├── baseline_heuristic.py # Rule-based baseline (no API key)
39
+ ├── baseline_inference.py # LLM baseline (optional, requires OPENAI_API_KEY)
40
+ ├── deploy.sh # One-command build+test+validate script
41
+ ├── ml-training-debugger-spec.md # Already exists (king file)
42
+ ├── openenv.yaml
43
+ ├── pyproject.toml
44
+ ├── requirements.txt
45
+
46
+ ├── ml_training_debugger/
47
+ │ ├── __init__.py
48
+ │ ├── models.py # All Pydantic models + RootCauseDiagnosis enum
49
+ │ ├── client.py # EnvClient extension with typed action/observation
50
+ │ ├── scenarios.py # ScenarioParams + sample_scenario()
51
+ │ ├── pytorch_engine.py # SimpleCNN, fault injection, gradient/weight extraction
52
+ │ ├── simulation.py # Parametric curve generation (torch.Tensor ops)
53
+ │ ├── code_templates.py # Task 6: code snippets with bugs + validate_fix()
54
+ │ ├── reward_engine.py # compute_reward() — all 7 components
55
+ │ └── graders.py # Per-task grader functions (0.0–1.0)
56
+
57
+ ├── server/
58
+ │ ├── __init__.py
59
+ │ ├── environment.py # MLTrainingEnvironment(Environment)
60
+ │ ├── app.py # create_app() + custom routes
61
+ │ └── dashboard.html # Live diagnostic dashboard (Phase 3)
62
+
63
+ ├── validation/ # PyTorch validation suite (Phase 3)
64
+ │ ├── requirements.txt
65
+ │ ├── conftest.py
66
+ │ ├── validate_exploding_gradients.py
67
+ │ ├── validate_vanishing_gradients.py
68
+ │ ├── validate_data_leakage.py
69
+ │ ├── validate_overfitting.py
70
+ │ ├── validate_batchnorm_eval.py
71
+ │ ├── validate_code_bugs.py
72
+ │ └── reports/ # Pre-computed fidelity plots
73
+
74
+ └── tests/
75
+ ├── __init__.py
76
+ ├── conftest.py # Shared fixtures
77
+ ├── test_models.py
78
+ ├── test_scenarios.py
79
+ ├── test_pytorch_engine.py
80
+ ├── test_simulation.py
81
+ ├── test_code_templates.py
82
+ ├── test_reward_engine.py
83
+ ├── test_graders.py
84
+ ├── test_episode_lifecycle.py
85
+ ├── test_endpoints.py
86
+ └── test_baseline_reproducibility.py
87
+ ```
88
+
89
+ ---
90
+
91
+ ## Phase 0: Project Initialization & Validation Setup
92
+
93
+ ### Goal
94
+ A running skeleton server that proves the toolchain works end-to-end. Zero business logic — just plumbing.
95
+
96
+ ### Files to Create
97
+
98
+ **Step 0.1 — Project config files:**
99
+
100
+ 1. **`.python-version`** — content: `3.12`
101
+
102
+ 2. **`.gitignore`**:
103
+ ```
104
+ .venv/
105
+ __pycache__/
106
+ *.pyc
107
+ *.pyo
108
+ .env
109
+ run*.json
110
+ .pytest_cache/
111
+ htmlcov/
112
+ *.egg-info/
113
+ dist/
114
+ build/
115
+ validation/reports/*.png
116
+ .mypy_cache/
117
+ ```
118
+
119
+ 3. **`.dockerignore`**:
120
+ ```
121
+ .venv/
122
+ __pycache__/
123
+ .git/
124
+ .pytest_cache/
125
+ tests/
126
+ validation/
127
+ *.md
128
+ !README.md
129
+ .claude/
130
+ run*.json
131
+ htmlcov/
132
+ ```
133
+
134
+ 4. **`pyproject.toml`**:
135
+ ```toml
136
+ [project]
137
+ name = "pytorch-training-debugger"
138
+ version = "1.0.0"
139
+ description = "OpenEnv RL environment for PyTorch training failure debugging"
140
+ requires-python = ">=3.12"
141
+ dependencies = [
142
+ "torch",
143
+ "openenv-core",
144
+ "pydantic>=2.0",
145
+ "fastapi",
146
+ "uvicorn",
147
+ ]
148
+
149
+ [project.optional-dependencies]
150
+ dev = [
151
+ "pytest",
152
+ "pytest-cov",
153
+ "pytest-asyncio",
154
+ "black",
155
+ "ruff",
156
+ "isort",
157
+ "httpx",
158
+ "websockets",
159
+ ]
160
+ llm = [
161
+ "openai",
162
+ ]
163
+
164
+ [tool.black]
165
+ line-length = 88
166
+
167
+ [tool.isort]
168
+ profile = "black"
169
+
170
+ [tool.ruff]
171
+ line-length = 88
172
+ target-version = "py312"
173
+
174
+ [tool.pytest.ini_options]
175
+ testpaths = ["tests"]
176
+ asyncio_mode = "auto"
177
+ ```
178
+
179
+ 5. **`requirements.txt`** (for Docker — flat list, no dev deps):
180
+ ```
181
+ torch
182
+ openenv-core
183
+ pydantic>=2.0
184
+ fastapi
185
+ uvicorn
186
+ openai
187
+ ```
188
+
189
+ **Step 0.2 — Package stubs:**
190
+
191
+ 6. **`ml_training_debugger/__init__.py`**:
192
+ ```python
193
+ """PyTorch Training Run Debugger — OpenEnv Environment."""
194
+
195
+ __version__ = "1.0.0"
196
+ ```
197
+
198
+ 7. **`ml_training_debugger/models.py`** — STUB with all Pydantic models:
199
+ ```python
200
+ """All Pydantic models, enums, and typed data structures.
201
+
202
+ No business logic. Pure data definitions.
203
+ """
204
+
205
+ from __future__ import annotations
206
+
207
+ import enum
208
+ from typing import Literal, Optional
209
+
210
+ import torch
211
+ from openenv.core.env_server.types import Action, Observation
212
+ from pydantic import BaseModel, Field
213
+
214
+
215
+ class RootCauseDiagnosis(str, enum.Enum):
216
+ """Closed enumeration of ML failure root causes."""
217
+ LR_TOO_HIGH = "lr_too_high"
218
+ VANISHING_GRADIENTS = "vanishing_gradients"
219
+ DATA_LEAKAGE = "data_leakage"
220
+ OVERFITTING = "overfitting"
221
+ BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
222
+ CODE_BUG = "code_bug"
223
+
224
+
225
+ class TrainingConfig(BaseModel):
226
+ """Typed hyperparameter configuration."""
227
+ learning_rate: float = 0.001
228
+ weight_decay: float = 0.0001
229
+ batch_size: int = 64
230
+ hidden_dim: int = 64
231
+ num_layers: int = 3
232
+ optimizer: str = "adam"
233
+ dropout_rate: float = 0.0
234
+ gradient_clip_norm: Optional[float] = None
235
+
236
+
237
+ class GradientStats(BaseModel):
238
+ """Per-layer gradient information from real torch.autograd."""
239
+ layer_name: str
240
+ norm_history: list[float]
241
+ mean_norm: float
242
+ max_norm: float
243
+ is_exploding: bool
244
+ is_vanishing: bool
245
+
246
+
247
+ class ModelWeightStats(BaseModel):
248
+ """Per-layer weight statistics from real state_dict()."""
249
+ layer_name: str
250
+ weight_norm: float
251
+ weight_mean: float
252
+ weight_std: float
253
+ weight_min: float
254
+ weight_max: float
255
+ dead_neuron_pct: float = 0.0
256
+ has_nan: bool = False
257
+ has_inf: bool = False
258
+
259
+
260
+ class DataBatchStats(BaseModel):
261
+ """Data batch inspection results."""
262
+ label_distribution: dict[int, float]
263
+ feature_mean: float
264
+ feature_std: float
265
+ null_count: int = 0
266
+ class_overlap_score: float
267
+ batch_size: int
268
+ duplicate_ratio: float = 0.0
269
+
270
+
271
+ class CodeSnippet(BaseModel):
272
+ """PyTorch code for Task 6 inspection."""
273
+ code: str
274
+ filename: str = "train.py"
275
+ line_count: int
276
+ imports: list[str]
277
+ hint: Optional[str] = None
278
+
279
+
280
+ class EpisodeState(BaseModel):
281
+ """Tracks agent history within an episode."""
282
+ step_count: int = 0
283
+ gradients_inspected: bool = False
284
+ gradients_were_normal: bool = False
285
+ data_inspected: bool = False
286
+ model_modes_inspected: bool = False
287
+ model_weights_inspected: bool = False
288
+ code_inspected: bool = False
289
+ fix_action_taken: bool = False
290
+ restart_after_fix: bool = False
291
+ diagnosis_submitted: bool = False
292
+ actions_taken: list[str] = Field(default_factory=list)
293
+
294
+ def compute_available_actions(self) -> list[str]:
295
+ """Dynamically compute available actions based on current state."""
296
+ actions = [
297
+ "inspect_gradients",
298
+ "inspect_data_batch",
299
+ "inspect_model_modes",
300
+ "inspect_model_weights",
301
+ "inspect_code",
302
+ "modify_config",
303
+ "add_callback",
304
+ "replace_optimizer",
305
+ "patch_data_loader",
306
+ "fix_model_mode",
307
+ ]
308
+ if self.code_inspected:
309
+ actions.append("fix_code")
310
+ if self.fix_action_taken:
311
+ actions.append("restart_run")
312
+ if self.restart_after_fix:
313
+ actions.append("rollback_checkpoint")
314
+ if not self.diagnosis_submitted:
315
+ actions.append("mark_diagnosed")
316
+ return actions
317
+
318
+
319
+ ACTION_TYPES = Literal[
320
+ "inspect_gradients",
321
+ "inspect_data_batch",
322
+ "inspect_model_modes",
323
+ "inspect_model_weights",
324
+ "inspect_code",
325
+ "modify_config",
326
+ "add_callback",
327
+ "replace_optimizer",
328
+ "patch_data_loader",
329
+ "fix_model_mode",
330
+ "fix_code",
331
+ "restart_run",
332
+ "mark_diagnosed",
333
+ "rollback_checkpoint",
334
+ ]
335
+
336
+
337
+ class MLTrainingAction(Action):
338
+ """What the agent can do — extends openenv Action."""
339
+ action_type: str
340
+ target: Optional[str] = None
341
+ value: Optional[float | int | str] = None
342
+ diagnosis: Optional[str] = None
343
+ line: Optional[int] = None
344
+ replacement: Optional[str] = None
345
+
346
+
347
+ class MLTrainingObservation(Observation):
348
+ """Full observation — extends openenv Observation (has done, reward, metadata)."""
349
+ run_id: str = ""
350
+ framework: str = "pytorch"
351
+ epoch: int = 20
352
+ training_loss_history: list[float] = Field(default_factory=list)
353
+ val_loss_history: list[float] = Field(default_factory=list)
354
+ val_accuracy_history: list[float] = Field(default_factory=list)
355
+ gradient_stats: list[GradientStats] = Field(default_factory=list)
356
+ model_weight_stats: Optional[list[ModelWeightStats]] = None
357
+ gpu_memory_used_gb: float = 6.2
358
+ gpu_memory_total_gb: float = 16.0
359
+ learning_rate: float = 0.001
360
+ current_config: TrainingConfig = Field(default_factory=TrainingConfig)
361
+ error_log: Optional[str] = None
362
+ data_batch_stats: Optional[DataBatchStats] = None
363
+ model_mode_info: Optional[dict[str, str]] = None
364
+ code_snippet: Optional[CodeSnippet] = None
365
+ available_actions: list[str] = Field(default_factory=list)
366
+ episode_state: EpisodeState = Field(default_factory=EpisodeState)
367
+ notes: Optional[str] = None
368
+ ```
369
+
370
+ 8. **`ml_training_debugger/client.py`** — STUB:
371
+ ```python
372
+ """Typed EnvClient for baseline scripts."""
373
+
374
+ from openenv.core.env_client import EnvClient
375
+
376
+ from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
377
+
378
+
379
+ class MLTrainingEnvClient(EnvClient[MLTrainingAction, MLTrainingObservation, dict]):
380
+ """Typed client for the PyTorch Training Debugger environment."""
381
+
382
+ def _step_payload(self, action: MLTrainingAction) -> dict:
383
+ return action.model_dump(exclude_none=True)
384
+
385
+ def _parse_observation(self, data: dict) -> MLTrainingObservation:
386
+ return MLTrainingObservation.model_validate(data)
387
+ ```
388
+
389
+ 9. **`server/__init__.py`** — empty file
390
+
391
+ 10. **`server/environment.py`** — STUB:
392
+ ```python
393
+ """MLTrainingEnvironment — extends openenv Environment."""
394
+
395
+ from typing import Any, Optional
396
+
397
+ from openenv.core.env_server.interfaces import Environment
398
+
399
+ from ml_training_debugger.models import (
400
+ EpisodeState,
401
+ MLTrainingAction,
402
+ MLTrainingObservation,
403
+ TrainingConfig,
404
+ )
405
+
406
+
407
+ class MLTrainingEnvironment(
408
+ Environment[MLTrainingAction, MLTrainingObservation, dict]
409
+ ):
410
+ """OpenEnv environment for PyTorch training run debugging."""
411
+
412
+ SUPPORTS_CONCURRENT_SESSIONS = True
413
+
414
+ def reset(
415
+ self,
416
+ seed: Optional[int] = None,
417
+ episode_id: Optional[str] = None,
418
+ **kwargs: Any,
419
+ ) -> MLTrainingObservation:
420
+ """Reset environment, return initial observation."""
421
+ state = EpisodeState()
422
+ obs = MLTrainingObservation(
423
+ run_id=episode_id or "episode_001",
424
+ training_loss_history=[2.3] * 20,
425
+ val_loss_history=[2.3] * 20,
426
+ val_accuracy_history=[0.1] * 20,
427
+ current_config=TrainingConfig(),
428
+ available_actions=state.compute_available_actions(),
429
+ episode_state=state,
430
+ done=False,
431
+ reward=0.0,
432
+ )
433
+ return obs
434
+
435
+ def step(
436
+ self,
437
+ action: MLTrainingAction,
438
+ timeout_s: Optional[float] = None,
439
+ **kwargs: Any,
440
+ ) -> MLTrainingObservation:
441
+ """Process one agent action."""
442
+ state = EpisodeState()
443
+ obs = MLTrainingObservation(
444
+ run_id="episode_001",
445
+ training_loss_history=[2.3] * 20,
446
+ val_loss_history=[2.3] * 20,
447
+ val_accuracy_history=[0.1] * 20,
448
+ current_config=TrainingConfig(),
449
+ available_actions=state.compute_available_actions(),
450
+ episode_state=state,
451
+ done=False,
452
+ reward=-0.01,
453
+ )
454
+ return obs
455
+
456
+ @property
457
+ def state(self) -> dict:
458
+ """Return current environment state."""
459
+ return {"status": "active"}
460
+ ```
461
+
462
+ 11. **`server/app.py`** — STUB with all endpoints:
463
+ ```python
464
+ """FastAPI app — openenv create_app() + custom routes."""
465
+
466
+ import logging
467
+
468
+ from fastapi import FastAPI
469
+ from fastapi.responses import JSONResponse
470
+ from openenv.core.env_server.http_server import create_app
471
+
472
+ from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
473
+ from server.environment import MLTrainingEnvironment
474
+
475
+ logger = logging.getLogger(__name__)
476
+
477
+ # create_app takes the class (factory), not an instance
478
+ app: FastAPI = create_app(
479
+ MLTrainingEnvironment,
480
+ MLTrainingAction,
481
+ MLTrainingObservation,
482
+ env_name="pytorch_training_debugger",
483
+ max_concurrent_envs=5,
484
+ )
485
+
486
+
487
+ @app.get("/health")
488
+ def health_check() -> dict:
489
+ """Health check — required by hackathon auto-validator."""
490
+ return {"status": "ready", "tasks": 3}
491
+
492
+
493
+ @app.get("/tasks")
494
+ def get_tasks() -> list[dict]:
495
+ """Return task list with IDs, difficulties, and action schema."""
496
+ schema = MLTrainingAction.model_json_schema()
497
+ return [
498
+ {"id": "task_001", "difficulty": "easy", "max_steps": 20, "action_schema": schema},
499
+ {"id": "task_003", "difficulty": "medium", "max_steps": 25, "action_schema": schema},
500
+ {"id": "task_005", "difficulty": "hard", "max_steps": 30, "action_schema": schema},
501
+ ]
502
+
503
+
504
+ @app.post("/grader")
505
+ def post_grader() -> dict:
506
+ """Return grader score for most recently completed episode."""
507
+ return {"score": None, "error": "no_completed_episode"}
508
+
509
+
510
+ @app.post("/baseline")
511
+ async def post_baseline() -> dict:
512
+ """Trigger baseline run, return scores."""
513
+ return {"scores": {"task_001": 0.0, "task_003": 0.0, "task_005": 0.0}}
514
+ ```
515
+
516
+ 12. **`openenv.yaml`**:
517
+ ```yaml
518
+ spec_version: 1
519
+ name: pytorch-training-debugger
520
+ type: space
521
+ runtime: fastapi
522
+ app: server.app:app
523
+ port: 7860
524
+
525
+ # Extended metadata
526
+ version: "1.0.0"
527
+ description: |
528
+ PyTorch-native fault injection engine for training failure debugging.
529
+ An AI agent investigates, diagnoses, fixes, and verifies broken
530
+ training runs using real torch.nn.Module models, torch.autograd
531
+ gradients, state_dict() weight inspection, and PyTorch code-level
532
+ debugging.
533
+ framework: openenv
534
+ tags: [ml-debugging, pytorch, reinforcement-learning, root-cause-analysis, fault-injection]
535
+
536
+ observation_space:
537
+ type: MLTrainingObservation
538
+ description: "Training run snapshot with progressive reveal"
539
+
540
+ action_space:
541
+ type: MLTrainingAction
542
+ description: "Investigation, fix, and diagnosis actions with dynamic availability"
543
+
544
+ tasks:
545
+ - id: task_001
546
+ difficulty: easy
547
+ max_steps: 20
548
+ - id: task_003
549
+ difficulty: medium
550
+ max_steps: 25
551
+ - id: task_005
552
+ difficulty: hard
553
+ max_steps: 30
554
+
555
+ reward:
556
+ range: [-1.0, 1.0]
557
+ shaped: true
558
+ step_penalty: -0.01
559
+ investigation_bonus: 0.05
560
+ correct_diagnosis: 0.50
561
+ terminal_convergence: 0.40
562
+
563
+ endpoints:
564
+ websocket: "/ws"
565
+ tasks: "GET /tasks"
566
+ grader: "POST /grader"
567
+ baseline: "POST /baseline"
568
+ health: "GET /health"
569
+ ```
570
+
571
+ 13. **`Dockerfile`**:
572
+ ```dockerfile
573
+ FROM python:3.12-slim
574
+
575
+ WORKDIR /app
576
+
577
+ # Install PyTorch CPU-only first (largest layer, cached)
578
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
579
+
580
+ # Install remaining dependencies
581
+ COPY requirements.txt .
582
+ RUN pip install --no-cache-dir -r requirements.txt
583
+
584
+ # Copy application code
585
+ COPY ml_training_debugger/ ml_training_debugger/
586
+ COPY server/ server/
587
+ COPY openenv.yaml .
588
+ COPY baseline_heuristic.py .
589
+
590
+ # Copy pre-computed validation reports if they exist
591
+ COPY validation/reports/ validation/reports/ 2>/dev/null || true
592
+
593
+ EXPOSE 7860
594
+
595
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
596
+ ```
597
+
598
+ 14. **`tests/__init__.py`** — empty file
599
+
600
+ 15. **`tests/conftest.py`**:
601
+ ```python
602
+ """Shared test fixtures."""
603
+
604
+ import pytest
605
+
606
+ from ml_training_debugger.models import (
607
+ EpisodeState,
608
+ MLTrainingAction,
609
+ MLTrainingObservation,
610
+ TrainingConfig,
611
+ )
612
+
613
+
614
+ @pytest.fixture
615
+ def fresh_episode_state() -> EpisodeState:
616
+ return EpisodeState()
617
+
618
+
619
+ @pytest.fixture
620
+ def sample_config() -> TrainingConfig:
621
+ return TrainingConfig(learning_rate=0.001)
622
+
623
+
624
+ @pytest.fixture
625
+ def sample_observation() -> MLTrainingObservation:
626
+ state = EpisodeState()
627
+ return MLTrainingObservation(
628
+ run_id="test_episode",
629
+ training_loss_history=[2.3 - i * 0.1 for i in range(20)],
630
+ val_loss_history=[2.3 - i * 0.08 for i in range(20)],
631
+ val_accuracy_history=[0.1 + i * 0.04 for i in range(20)],
632
+ current_config=TrainingConfig(),
633
+ available_actions=state.compute_available_actions(),
634
+ episode_state=state,
635
+ done=False,
636
+ reward=0.0,
637
+ )
638
+ ```
639
+
640
+ 16. **`tests/test_models.py`**:
641
+ ```python
642
+ """Test all Pydantic models instantiate and serialize correctly."""
643
+
644
+ import json
645
+ import pytest
646
+ from ml_training_debugger.models import (
647
+ CodeSnippet,
648
+ DataBatchStats,
649
+ EpisodeState,
650
+ GradientStats,
651
+ MLTrainingAction,
652
+ MLTrainingObservation,
653
+ ModelWeightStats,
654
+ RootCauseDiagnosis,
655
+ TrainingConfig,
656
+ )
657
+
658
+
659
+ class TestRootCauseDiagnosis:
660
+ def test_all_six_values_exist(self):
661
+ assert len(RootCauseDiagnosis) == 6
662
+
663
+ def test_values_are_strings(self):
664
+ for d in RootCauseDiagnosis:
665
+ assert isinstance(d.value, str)
666
+
667
+
668
+ class TestTrainingConfig:
669
+ def test_default_instantiation(self):
670
+ config = TrainingConfig()
671
+ assert config.learning_rate == 0.001
672
+
673
+ def test_json_roundtrip(self):
674
+ config = TrainingConfig(learning_rate=0.01)
675
+ data = json.loads(config.model_dump_json())
676
+ restored = TrainingConfig.model_validate(data)
677
+ assert restored.learning_rate == 0.01
678
+
679
+
680
+ class TestEpisodeState:
681
+ def test_fresh_state(self):
682
+ state = EpisodeState()
683
+ assert state.step_count == 0
684
+ assert not state.gradients_inspected
685
+ assert not state.diagnosis_submitted
686
+
687
+ def test_available_actions_initial(self):
688
+ state = EpisodeState()
689
+ actions = state.compute_available_actions()
690
+ assert "inspect_gradients" in actions
691
+ assert "mark_diagnosed" in actions
692
+ assert "fix_code" not in actions
693
+ assert "restart_run" not in actions
694
+
695
+ def test_fix_code_available_after_code_inspected(self):
696
+ state = EpisodeState(code_inspected=True)
697
+ actions = state.compute_available_actions()
698
+ assert "fix_code" in actions
699
+
700
+ def test_restart_run_available_after_fix(self):
701
+ state = EpisodeState(fix_action_taken=True)
702
+ actions = state.compute_available_actions()
703
+ assert "restart_run" in actions
704
+
705
+ def test_mark_diagnosed_disappears_after_submission(self):
706
+ state = EpisodeState(diagnosis_submitted=True)
707
+ actions = state.compute_available_actions()
708
+ assert "mark_diagnosed" not in actions
709
+
710
+
711
+ class TestMLTrainingObservation:
712
+ def test_extends_observation(self):
713
+ from openenv.core.env_server.types import Observation
714
+ assert issubclass(MLTrainingObservation, Observation)
715
+
716
+ def test_has_done_and_reward(self):
717
+ obs = MLTrainingObservation(done=True, reward=0.5)
718
+ assert obs.done is True
719
+ assert obs.reward == 0.5
720
+
721
+ def test_json_serialization(self):
722
+ obs = MLTrainingObservation(
723
+ run_id="test",
724
+ training_loss_history=[1.0, 2.0],
725
+ val_accuracy_history=[0.5],
726
+ )
727
+ data = json.loads(obs.model_dump_json())
728
+ assert data["run_id"] == "test"
729
+
730
+
731
+ class TestMLTrainingAction:
732
+ def test_extends_action(self):
733
+ from openenv.core.env_server.types import Action
734
+ assert issubclass(MLTrainingAction, Action)
735
+
736
+ def test_basic_action(self):
737
+ action = MLTrainingAction(action_type="inspect_gradients")
738
+ assert action.action_type == "inspect_gradients"
739
+
740
+ def test_modify_config_action(self):
741
+ action = MLTrainingAction(
742
+ action_type="modify_config",
743
+ target="learning_rate",
744
+ value=0.001,
745
+ )
746
+ assert action.target == "learning_rate"
747
+
748
+ def test_mark_diagnosed_action(self):
749
+ action = MLTrainingAction(
750
+ action_type="mark_diagnosed",
751
+ diagnosis="lr_too_high",
752
+ )
753
+ assert action.diagnosis == "lr_too_high"
754
+
755
+ def test_fix_code_action(self):
756
+ action = MLTrainingAction(
757
+ action_type="fix_code",
758
+ line=13,
759
+ replacement="loss = criterion(output, batch_y)",
760
+ )
761
+ assert action.line == 13
762
+ ```
763
+
764
+ **Step 0.3 — Validation Commands:**
765
+
766
+ ```bash
767
+ # In project root with venv activated
768
+ source .venv/bin/activate
769
+
770
+ # 1. Verify imports
771
+ python -c "from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation; print('models OK')"
772
+ python -c "from ml_training_debugger.client import MLTrainingEnvClient; print('client OK')"
773
+ python -c "from server.app import app; print('app OK')"
774
+
775
+ # 2. Run tests
776
+ pytest tests/test_models.py -v
777
+
778
+ # 3. Start server
779
+ uvicorn server.app:app --host 0.0.0.0 --port 7860 &
780
+ sleep 3
781
+ curl http://localhost:7860/health
782
+ curl http://localhost:7860/tasks
783
+ curl http://localhost:7860/docs
784
+ kill %1
785
+
786
+ # 4. Formatting
787
+ black ml_training_debugger/ server/ tests/ --check
788
+ ruff check ml_training_debugger/ server/ tests/
789
+ isort ml_training_debugger/ server/ tests/ --check --profile black
790
+ ```
791
+
792
+ ### Acceptance Criteria — Phase 0
793
+
794
+ - [ ] All Pydantic models instantiate without error and serialize to valid JSON
795
+ - [ ] `MLTrainingObservation` extends `Observation` (has `done`, `reward`, `metadata`)
796
+ - [ ] `MLTrainingAction` extends `Action` (has `metadata`)
797
+ - [ ] `EpisodeState.compute_available_actions()` returns correct dynamic action lists
798
+ - [ ] Server starts on port 7860 and responds to `/health` with `{"status": "ready", "tasks": 3}`
799
+ - [ ] `/tasks` returns 3 tasks with action schema
800
+ - [ ] `pytest tests/test_models.py` passes all tests
801
+ - [ ] `client.py` imports without error
802
+ - [ ] `black --check`, `ruff check`, `isort --check` all pass
803
+
804
+ ---
805
+
806
+ ## Phase 1: Core Data Models & Pydantic Types
807
+
808
+ ### Goal
809
+ Finalize all model fields to match the spec exactly. No business logic yet — just data shapes.
810
+
811
+ ### Files to Edit
812
+
813
+ **`ml_training_debugger/models.py`** — Already created in Phase 0. Verify:
814
+ - All fields match spec Section 10 exactly
815
+ - `GradientStats.is_exploding` threshold: `mean_norm > 10.0`
816
+ - `GradientStats.is_vanishing` threshold: `mean_norm < 1e-6`
817
+ - `TrainingConfig` field names match `modify_config` target options
818
+ - `EpisodeState.compute_available_actions()` logic matches spec Section 10 dynamic rules
819
+
820
+ ### Tests (write BEFORE implementation — TDD)
821
+
822
+ All tests already written in `tests/test_models.py` from Phase 0. Extend with:
823
+
824
+ ```python
825
+ class TestGradientStats:
826
+ def test_exploding_threshold(self):
827
+ stats = GradientStats(
828
+ layer_name="fc", norm_history=[15.0], mean_norm=15.0, max_norm=15.0,
829
+ is_exploding=True, is_vanishing=False,
830
+ )
831
+ assert stats.is_exploding is True
832
+
833
+ def test_vanishing_threshold(self):
834
+ stats = GradientStats(
835
+ layer_name="conv1", norm_history=[1e-7], mean_norm=1e-7, max_norm=1e-7,
836
+ is_exploding=False, is_vanishing=True,
837
+ )
838
+ assert stats.is_vanishing is True
839
+
840
+ def test_normal_gradients(self):
841
+ stats = GradientStats(
842
+ layer_name="conv1", norm_history=[0.5], mean_norm=0.5, max_norm=0.5,
843
+ is_exploding=False, is_vanishing=False,
844
+ )
845
+ assert not stats.is_exploding
846
+ assert not stats.is_vanishing
847
+ ```
848
+
849
+ ### Acceptance Criteria — Phase 1
850
+
851
+ - [ ] Every field in every model matches the spec Section 10 types exactly
852
+ - [ ] No `Dict[str, Any]` in any public model (typed Pydantic everywhere)
853
+ - [ ] `import torch` appears in `models.py`
854
+ - [ ] All model tests pass
855
+
856
+ ---
857
+
858
+ ## Phase 2: PyTorch-Native Fault Injection Engine + Simulation
859
+
860
+ ### Goal
861
+ Real PyTorch models with real gradients + parametric curve generators. This is the technical heart.
862
+
863
+ ### Files to Create
864
+
865
+ **Step 2.1 — `ml_training_debugger/scenarios.py`** (~120 lines):
866
+
867
+ ```python
868
+ """ScenarioParams and scenario sampling."""
869
+
870
+ from __future__ import annotations
871
+
872
+ import dataclasses
873
+ from typing import Optional
874
+
875
+ import torch
876
+
877
+ from ml_training_debugger.models import RootCauseDiagnosis
878
+
879
+
880
+ @dataclasses.dataclass(frozen=True)
881
+ class ScenarioParams:
882
+ """Internal scenario parameters — not exposed to agent."""
883
+ task_id: str
884
+ root_cause: RootCauseDiagnosis
885
+ seed: int
886
+ learning_rate: float = 0.001
887
+ weight_decay: float = 0.0001
888
+ leakage_pct: float = 0.0
889
+ depth_multiplier: float = 1.0
890
+ divergence_epoch: int = 5
891
+ red_herring_intensity: float = 1.0
892
+ red_herring_spike_layer: str = "fc"
893
+ bug_type: Optional[str] = None
894
+ notes: Optional[str] = None
895
+ error_log: Optional[str] = None
896
+ gpu_memory_used_gb: float = 6.2
897
+ max_steps: int = 20
898
+
899
+
900
+ def sample_scenario(task_id: str, seed: int) -> ScenarioParams:
901
+ """Sample a ScenarioParams for the given task."""
902
+ rng = torch.Generator()
903
+ rng.manual_seed(seed)
904
+
905
+ # Use torch for random selection
906
+ def choose(options: list) -> any:
907
+ idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
908
+ return options[idx]
909
+
910
+ if task_id == "task_001":
911
+ lr = choose([0.05, 0.08, 0.10, 0.15, 0.30])
912
+ return ScenarioParams(
913
+ task_id=task_id,
914
+ root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
915
+ seed=seed,
916
+ learning_rate=lr,
917
+ error_log=f"RuntimeError: Loss is NaN at epoch 12 (lr={lr})",
918
+ max_steps=20,
919
+ )
920
+
921
+ elif task_id == "task_003":
922
+ leakage = choose([0.12, 0.18, 0.22, 0.28])
923
+ return ScenarioParams(
924
+ task_id=task_id,
925
+ root_cause=RootCauseDiagnosis.DATA_LEAKAGE,
926
+ seed=seed,
927
+ leakage_pct=leakage,
928
+ notes="Model architecture upgraded from 2-layer to 4-layer CNN at epoch 2. Performance improvement may reflect increased model capacity.",
929
+ max_steps=25,
930
+ )
931
+
932
+ elif task_id == "task_005":
933
+ intensity = (
934
+ torch.empty(1).uniform_(0.8, 2.5, generator=rng).item()
935
+ )
936
+ spike_layer = choose(["fc", "conv1"])
937
+ return ScenarioParams(
938
+ task_id=task_id,
939
+ root_cause=RootCauseDiagnosis.BATCHNORM_EVAL_MODE,
940
+ seed=seed,
941
+ red_herring_intensity=intensity,
942
+ red_herring_spike_layer=spike_layer,
943
+ gpu_memory_used_gb=14.56, # 91% of 16GB
944
+ error_log="Warning: GPU memory pressure detected, consider reducing batch size or enabling gradient checkpointing",
945
+ max_steps=30,
946
+ )
947
+
948
+ raise ValueError(f"Unknown task_id: {task_id}")
949
+ ```
950
+
951
+ **Step 2.2 — `ml_training_debugger/pytorch_engine.py`** (~250 lines):
952
+
953
+ Key functions:
954
+ - `SimpleCNN(torch.nn.Module)` — 3-layer CNN, ~50K params
955
+ - `create_model_and_inject_fault(scenario: ScenarioParams) -> tuple[torch.nn.Module, dict]`
956
+ - `extract_gradient_stats(model: torch.nn.Module) -> list[GradientStats]`
957
+ - `extract_weight_stats(model: torch.nn.Module) -> list[ModelWeightStats]`
958
+ - `extract_model_modes(model: torch.nn.Module) -> dict[str, str]`
959
+
960
+ Implementation notes:
961
+ - `torch.manual_seed(scenario.seed)` at the start of every call
962
+ - For Task 1: set lr high, run 2 forward+backward passes → gradients explode
963
+ - For Task 3: normal model, no gradient anomaly
964
+ - For Task 5: call `model.eval()` before training → BatchNorm frozen
965
+ - All gradient stats come from real `param.grad` tensors
966
+ - All weight stats come from real `model.state_dict()`
967
+
968
+ **Step 2.3 — `ml_training_debugger/simulation.py`** (~180 lines):
969
+
970
+ Key functions:
971
+ - `gen_loss_history(scenario: ScenarioParams) -> list[float]` — all torch.Tensor ops
972
+ - `gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]`
973
+ - `gen_val_loss_history(scenario: ScenarioParams) -> list[float]`
974
+
975
+ Per-task parametric curves from spec Section 6:
976
+ - Task 1: `loss = torch.exp(torch.tensor(lr) * torch.arange(20))`
977
+ - Task 3: `val_acc = torch.sigmoid(torch.linspace(-3, 3, 20)) * (1 - leakage_pct)`
978
+ - Task 5: Normal loss + elevated variance, slow val_acc degradation
979
+
980
+ ### Tests to Create FIRST (TDD)
981
+
982
+ **`tests/test_scenarios.py`**:
983
+ - `sample_scenario("task_001", seed=42)` returns `root_cause == LR_TOO_HIGH`
984
+ - `sample_scenario("task_003", seed=42)` returns `root_cause == DATA_LEAKAGE`
985
+ - `sample_scenario("task_005", seed=42)` returns `root_cause == BATCHNORM_EVAL_MODE`
986
+ - Different seeds produce different parameters (but same root cause per task)
987
+ - Unknown task_id raises ValueError
988
+
989
+ **`tests/test_pytorch_engine.py`**:
990
+ - `SimpleCNN` is a real `torch.nn.Module` with ~50K params
991
+ - Task 1 fault injection: `is_exploding=True` on all layers
992
+ - Task 5 fault injection: `is_exploding=False` on all layers, `model.training==False`
993
+ - `extract_gradient_stats` returns `list[GradientStats]` with real float norms
994
+ - `extract_weight_stats` returns `list[ModelWeightStats]` from real state_dict
995
+ - `extract_model_modes` returns dict mapping layer names to "train"/"eval"
996
+ - **CRITICAL**: `import torch` in pytorch_engine.py, zero `import numpy`
997
+
998
+ **`tests/test_simulation.py`**:
999
+ - All outputs are `list[float]` of length 20
1000
+ - Task 1 (exploding): loss diverges (last value >> first value)
1001
+ - Task 3 (leakage): val_acc suspiciously high from early epochs
1002
+ - Task 5 (batchnorm): slow val_acc degradation (~1-2% per epoch)
1003
+ - All computation uses torch (no numpy)
1004
+
1005
+ ### Acceptance Criteria — Phase 2
1006
+
1007
+ - [ ] `SimpleCNN` is a real `torch.nn.Module` with ~50K parameters
1008
+ - [ ] `create_model_and_inject_fault` for Task 1 produces exploding gradients (`is_exploding=True` all layers)
1009
+ - [ ] `create_model_and_inject_fault` for Task 5 produces `model.training==False` on all layers
1010
+ - [ ] `extract_gradient_stats` returns real floats from `torch.norm(param.grad)`
1011
+ - [ ] `extract_weight_stats` returns real floats from `state_dict()`
1012
+ - [ ] Parametric curves produce 20-element lists with correct shapes per task
1013
+ - [ ] `import torch` in `pytorch_engine.py` and `simulation.py` — zero `import numpy`
1014
+ - [ ] `torch.manual_seed(seed)` ensures reproducibility
1015
+ - [ ] All Phase 2 tests pass
1016
+
1017
+ ---
1018
+
1019
+ ## Phase 3: MVP Tasks (1, 3, 5) + Reward Engine + Graders
1020
+
1021
+ ### Goal
1022
+ All reward logic and graders implemented. The environment can score episodes.
1023
+
1024
+ ### Files to Create
1025
+
1026
+ **Step 3.1 — `ml_training_debugger/reward_engine.py`** (~100 lines):
1027
+
1028
+ ```python
1029
+ def compute_reward(
1030
+ action: MLTrainingAction,
1031
+ episode_state: EpisodeState,
1032
+ scenario: ScenarioParams,
1033
+ is_valid_action: bool,
1034
+ is_correct_fix: bool | None = None,
1035
+ convergence_confirmed: bool = False,
1036
+ ) -> float:
1037
+ ```
1038
+
1039
+ All 7 components per spec Section 12:
1040
+ 1. Step penalty: -0.01 (flat, unconditional)
1041
+ 2. Investigation bonus: +0.05 (first-time per type)
1042
+ 3. Context-gated penalty: -0.20 (ONLY when `gradients_inspected AND gradients_were_normal`)
1043
+ 4. Invalid action: -0.05
1044
+ 5. Wrong code fix: -0.10
1045
+ 6. Correct diagnosis: +0.50 / Wrong diagnosis: -0.30
1046
+ 7. Terminal convergence: +0.40 (gated on `fix_action_taken AND restart_after_fix`)
1047
+
1048
+ Hard cap at [-1.0, 1.0].
1049
+
1050
+ **Step 3.2 — `ml_training_debugger/graders.py`** (~150 lines):
1051
+
1052
+ One function per task. Each returns float in [0.0, 1.0]:
1053
+ - `grade_task_001(state: EpisodeState, scenario: ScenarioParams) -> float`
1054
+ - `grade_task_003(state: EpisodeState, scenario: ScenarioParams) -> float`
1055
+ - `grade_task_005(state: EpisodeState, scenario: ScenarioParams) -> float`
1056
+
1057
+ Grader scoring per spec Section 11:
1058
+ - Task 1: inspect_gradients(+0.05), correct LR fix(+0.20), restart+converge(+0.35), correct diagnosis(+0.40) = 1.0
1059
+ - Task 3: inspect_data(+0.05), patch_data_loader(+0.30), restart+converge(+0.30), correct diagnosis(+0.35) = 1.0
1060
+ - Task 5: inspect_gradients(+0.05), inspect_model_modes(+0.05), fix_model_mode(+0.25), restart+converge(+0.30), correct diagnosis(+0.40) = 1.05 → capped at 1.0. Penalty: add_callback after normal gradients = -0.20.
1061
+
1062
+ **CRITICAL — Grader is NOT a sum of step rewards.** It evaluates EpisodeState holistically.
1063
+
1064
+ ### Tests to Create FIRST (TDD)
1065
+
1066
+ **`tests/test_reward_engine.py`** — THE MOST CRITICAL TEST FILE:
1067
+
1068
+ ```python
1069
+ class TestContextGatedPenalty:
1070
+ """The project's primary innovation — must be exact."""
1071
+
1072
+ def test_no_penalty_before_inspection(self):
1073
+ """add_callback at step 1 (no prior inspection) -> NO penalty."""
1074
+ state = EpisodeState() # gradients_inspected=False
1075
+ action = MLTrainingAction(action_type="add_callback")
1076
+ reward = compute_reward(action, state, scenario, is_valid_action=True)
1077
+ # Should be just step penalty: -0.01
1078
+ assert reward == pytest.approx(-0.01)
1079
+
1080
+ def test_penalty_after_normal_gradients(self):
1081
+ """inspect_gradients (normal) then add_callback -> -0.20 penalty."""
1082
+ state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
1083
+ action = MLTrainingAction(action_type="add_callback")
1084
+ reward = compute_reward(action, state, scenario, is_valid_action=True)
1085
+ # Step penalty + context-gated penalty: -0.01 + -0.20 = -0.21
1086
+ assert reward == pytest.approx(-0.21)
1087
+
1088
+ def test_no_penalty_after_abnormal_gradients(self):
1089
+ """inspect_gradients (exploding) then add_callback -> no context penalty."""
1090
+ state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
1091
+ action = MLTrainingAction(action_type="add_callback")
1092
+ reward = compute_reward(action, state, scenario, is_valid_action=True)
1093
+ assert reward == pytest.approx(-0.01)
1094
+ ```
1095
+
1096
+ Also test:
1097
+ - Step penalty is flat -0.01 (NOT multiplied by step_count)
1098
+ - Investigation bonus +0.05 first-time only
1099
+ - Investigation bonus NOT awarded on repeat
1100
+ - Correct diagnosis: +0.50
1101
+ - Wrong diagnosis: -0.30
1102
+ - Terminal convergence: +0.40 when all gates met
1103
+ - Invalid action: -0.05
1104
+ - Wrong code fix: -0.10
1105
+ - Reward capped at [-1.0, 1.0]
1106
+
1107
+ **`tests/test_graders.py`**:
1108
+ - Each grader returns float in [0.0, 1.0]
1109
+ - Perfect Task 1 path scores 1.0
1110
+ - Wrong diagnosis on Task 1 scores < 0.5
1111
+ - Task 5: agent that chases red herring scores 0.80-0.85
1112
+ - Task 5: optimal path scores 1.0
1113
+ - Grader is deterministic (same state → same score)
1114
+
1115
+ ### Acceptance Criteria — Phase 3
1116
+
1117
+ - [ ] `compute_reward` implements all 7 components exactly per spec Section 12
1118
+ - [ ] Context-gated penalty fires ONLY when `gradients_inspected=True AND gradients_were_normal=True`
1119
+ - [ ] Context-gated penalty does NOT fire before `inspect_gradients` has been called
1120
+ - [ ] Step penalty is flat -0.01 (never multiplied by step_count)
1121
+ - [ ] All 3 graders return [0.0, 1.0] with meaningful variance
1122
+ - [ ] Grader != reward function (separate modules, separate logic)
1123
+ - [ ] All Phase 3 tests pass
1124
+
1125
+ ---
1126
+
1127
+ ## Phase 4: Environment Lifecycle, EpisodeState, and Action Handling
1128
+
1129
+ ### Goal
1130
+ Full `reset()` and `step()` implementations in `environment.py`. The environment is functionally complete.
1131
+
1132
+ ### Files to Edit
1133
+
1134
+ **`server/environment.py`** — Full implementation:
1135
+
1136
+ `reset(task_id)`:
1137
+ 1. Parse `task_id` from `kwargs` (framework passes it via kwargs or episode_id)
1138
+ 2. Derive deterministic seed from task_id
1139
+ 3. Call `sample_scenario(task_id, seed)`
1140
+ 4. Call `torch.manual_seed(scenario.seed)`
1141
+ 5. Call `create_model_and_inject_fault(scenario)` → get real model
1142
+ 6. Generate parametric curves via `simulation.py`
1143
+ 7. Create fresh `EpisodeState`
1144
+ 8. Store `(scenario, model, state)` keyed by session/episode ID
1145
+ 9. Return `MLTrainingObservation` with populated loss/acc histories, config, error_log, available_actions — but empty gradient_stats, null data_batch_stats, null model_mode_info, null code_snippet
1146
+
1147
+ `step(action)`:
1148
+ 1. Validate action (see spec Section 16 error handling matrix)
1149
+ 2. Increment `step_count`
1150
+ 3. Dispatch by `action.action_type`:
1151
+ - **`inspect_gradients`**: Extract real gradient stats, set `gradients_inspected=True`, compute `gradients_were_normal` (all layers `is_exploding==False`)
1152
+ - **`inspect_data_batch`**: Generate data batch stats, set `data_inspected=True`
1153
+ - **`inspect_model_modes`**: Extract model modes, set `model_modes_inspected=True`
1154
+ - **`inspect_model_weights`**: Extract real weight stats, set `model_weights_inspected=True`
1155
+ - **`inspect_code`**: Generate code snippet (if task supports it), set `code_inspected=True`
1156
+ - **`modify_config`**: Validate target/value, apply change, set `fix_action_taken=True`
1157
+ - **`add_callback`**: Apply callback, set `fix_action_taken=True`
1158
+ - **`replace_optimizer`**: Apply, set `fix_action_taken=True`
1159
+ - **`patch_data_loader`**: Apply, set `fix_action_taken=True`
1160
+ - **`fix_model_mode`**: Apply, set `fix_action_taken=True`
1161
+ - **`fix_code`**: Validate fix via `validate_fix()`, set `fix_action_taken=True`
1162
+ - **`restart_run`**: Requires `fix_action_taken`, set `restart_after_fix=True`, check convergence
1163
+ - **`mark_diagnosed`**: Set `diagnosis_submitted=True`, `done=True`
1164
+ - **`rollback_checkpoint`**: Requires `restart_after_fix`
1165
+ 4. Call `compute_reward(action, state, scenario, ...)`
1166
+ 5. Check step limit → set `done=True` if reached
1167
+ 6. Update `available_actions` via `state.compute_available_actions()`
1168
+ 7. Return `MLTrainingObservation` with all updated fields
1169
+
1170
+ **Session isolation**:
1171
+ - Store per-session state in `self._sessions: dict[str, SessionData]`
1172
+ - Session ID comes from the framework (via `episode_id` or WebSocket session)
1173
+ - Clean up on episode completion or disconnect
1174
+
1175
+ ### Error Handling (spec Section 16 — ALL cases):
1176
+
1177
+ | Error | Behavior | Reward |
1178
+ |-------|----------|--------|
1179
+ | Invalid action_type | Return obs unchanged + error note | -0.05 |
1180
+ | Action not in available_actions | Return obs unchanged + error note | -0.05 |
1181
+ | modify_config missing target/value | Return obs unchanged + error note | -0.05 |
1182
+ | modify_config with unknown target | Return obs unchanged + error note | -0.05 |
1183
+ | mark_diagnosed missing diagnosis | Return obs unchanged + error note | -0.05 |
1184
+ | mark_diagnosed with invalid diagnosis | Return obs unchanged + error note | -0.05 |
1185
+ | fix_code missing line/replacement | Return obs unchanged + error note | -0.05 |
1186
+ | Action after done=True | Return final obs, no state change | 0.0 |
1187
+ | Step limit reached | Set done=True, return obs | 0.0 |
1188
+
1189
+ **CRITICAL**: `step()` must NEVER raise an unhandled exception.
1190
+
1191
+ ### Tests to Create FIRST (TDD)
1192
+
1193
+ **`tests/test_episode_lifecycle.py`**:
1194
+ - Full reset→inspect→fix→restart→diagnose flow for Task 1
1195
+ - Full flow for Task 3
1196
+ - Full flow for Task 5
1197
+ - `available_actions` updates correctly at each step
1198
+ - `done=True` after `mark_diagnosed`
1199
+ - Step limit triggers `done=True`
1200
+ - Action after done returns final obs with no state change
1201
+ - Invalid action returns -0.05 penalty
1202
+ - `restart_run` not available before `fix_action_taken`
1203
+ - `fix_code` not available before `code_inspected`
1204
+ - Session isolation: two episodes don't interfere
1205
+
1206
+ ### Acceptance Criteria — Phase 4
1207
+
1208
+ - [ ] `reset(task_id)` for tasks 001/003/005 returns valid `MLTrainingObservation` with correct initial state
1209
+ - [ ] `step()` dispatches all 14 action types correctly
1210
+ - [ ] Task 1: `inspect_gradients` → `is_exploding=True` all layers (real torch.autograd)
1211
+ - [ ] Task 5: `inspect_gradients` → `is_exploding=False` all layers, `gradients_were_normal=True`
1212
+ - [ ] Task 3: `inspect_data_batch` → `class_overlap_score > 0.5`
1213
+ - [ ] Task 5: `inspect_model_modes` → all layers in "eval" mode
1214
+ - [ ] All error conditions from spec Section 16 handled (never raises)
1215
+ - [ ] Progressive information reveal works (gradient_stats empty until inspected)
1216
+ - [ ] All Phase 4 tests pass
1217
+
1218
+ ---
1219
+
1220
+ ## Phase 5: Server (FastAPI + openenv-core) + All Required Endpoints
1221
+
1222
+ ### Goal
1223
+ Wire the real environment into the server. All hackathon-required endpoints return real data.
1224
+
1225
+ ### Files to Edit
1226
+
1227
+ **`server/app.py`** — Full implementation:
1228
+
1229
+ ```python
1230
+ # Store reference to last completed episode for /grader
1231
+ _last_completed: dict[str, dict] = {} # session_id -> {score, task_id, steps}
1232
+ _baseline_running: bool = False
1233
+
1234
+ @app.get("/health")
1235
+ def health_check():
1236
+ return {"status": "ready", "tasks": 3}
1237
+
1238
+ @app.get("/tasks")
1239
+ def get_tasks():
1240
+ schema = MLTrainingAction.model_json_schema()
1241
+ return [
1242
+ {"id": "task_001", "difficulty": "easy", "max_steps": 20, "action_schema": schema},
1243
+ {"id": "task_003", "difficulty": "medium", "max_steps": 25, "action_schema": schema},
1244
+ {"id": "task_005", "difficulty": "hard", "max_steps": 30, "action_schema": schema},
1245
+ ]
1246
+
1247
+ @app.post("/grader")
1248
+ def post_grader(session_id: str | None = None):
1249
+ # Return score for most recently completed episode
1250
+ # Edge cases per spec Section 14
1251
+
1252
+ @app.post("/baseline")
1253
+ async def post_baseline():
1254
+ # Run baseline_heuristic logic internally
1255
+ # Return {"scores": {"task_001": float, ...}}
1256
+ # Return 409 if already running
1257
+ ```
1258
+
1259
+ **Grader endpoint edge cases** (spec Section 14):
1260
+ - No episode completed → `{"score": null, "error": "no_completed_episode"}`
1261
+ - Episode in progress → `{"score": null, "error": "episode_in_progress"}`
1262
+ - Episode completed → `{"score": 0.85, "task_id": "task_003", "steps": 6}`
1263
+ - Always HTTP 200 with JSON body
1264
+
1265
+ ### Tests to Create FIRST (TDD)
1266
+
1267
+ **`tests/test_endpoints.py`**:
1268
+ - `GET /health` returns `{"status": "ready", "tasks": 3}` with 200
1269
+ - `GET /tasks` returns 3 tasks with action schema
1270
+ - `POST /grader` returns `{"score": null, "error": "no_completed_episode"}` initially
1271
+ - `POST /baseline` returns scores for all tasks
1272
+ - `POST /baseline` while running returns 409
1273
+ - Integration: reset→step→grader returns valid score
1274
+
1275
+ ### Acceptance Criteria — Phase 5
1276
+
1277
+ - [ ] `GET /health` returns `{"status": "ready", "tasks": 3}` (200)
1278
+ - [ ] `GET /tasks` returns 3 tasks with IDs, difficulties, action schema
1279
+ - [ ] `POST /grader` handles all edge cases per spec Section 14
1280
+ - [ ] `POST /baseline` runs baseline and returns scores
1281
+ - [ ] Framework auto-provides: `/reset`, `/step`, `/state`, `/ws`, `/schema`, `/docs`
1282
+ - [ ] All Phase 5 tests pass
1283
+
1284
+ ---
1285
+
1286
+ ## Phase 6: Rule-Based Baseline + Reproducibility Guarantees
1287
+
1288
+ ### Goal
1289
+ Deterministic baseline that produces bit-exact identical scores on two runs.
1290
+
1291
+ ### Files to Create
1292
+
1293
+ **`baseline_heuristic.py`** (~150 lines):
1294
+
1295
+ Decision tree from spec Section 17:
1296
+ ```
1297
+ 1. reset(task_id)
1298
+ 2. inspect_gradients
1299
+ 3. IF any layer is_exploding → modify_config(lr=0.001) → restart → diagnose lr_too_high
1300
+ 4. IF any layer is_vanishing → modify_config(lr=0.01) → restart → diagnose vanishing_gradients
1301
+ 5. inspect_data_batch
1302
+ 6. IF class_overlap_score > 0.5 → patch_data_loader → restart → diagnose data_leakage
1303
+ 7. IF val_loss diverging → modify_config(weight_decay=0.01) → restart → diagnose overfitting
1304
+ 8. inspect_model_modes → IF any eval → fix_model_mode → restart → diagnose batchnorm_eval_mode
1305
+ 9. inspect_code → attempt fix → restart → diagnose code_bug
1306
+ 10. FALLBACK: diagnose overfitting
1307
+ ```
1308
+
1309
+ Uses `MLTrainingEnvClient` or `GenericEnvClient` to connect via WebSocket.
1310
+
1311
+ **Reproducibility requirements:**
1312
+ - `torch.manual_seed(seed)` at every `reset()` with deterministic seed per task
1313
+ - No floating-point non-determinism in parametric curves
1314
+ - Heuristic is pure logic with no randomness
1315
+ - Two runs must produce identical JSON output
1316
+
1317
+ ### Tests to Create FIRST (TDD)
1318
+
1319
+ **`tests/test_baseline_reproducibility.py`**:
1320
+ - Run baseline twice → `diff run1.json run2.json` is empty
1321
+ - All scores in [0.0, 1.0]
1322
+ - Expected approximate scores: task_001 ~0.85, task_003 ~0.70, task_005 ~0.45
1323
+
1324
+ ### Acceptance Criteria — Phase 6
1325
+
1326
+ - [ ] `baseline_heuristic.py` runs all 3 MVP tasks without error
1327
+ - [ ] Two consecutive runs produce bit-exact identical JSON output
1328
+ - [ ] No API key required
1329
+ - [ ] All scores in [0.0, 1.0] with meaningful variance
1330
+ - [ ] Decision tree follows spec Section 17 exactly
1331
+
1332
+ ---
1333
+
1334
+ ## Phase 7: Docker, HF Spaces, Logging, Error Handling & Edge Cases
1335
+
1336
+ ### Goal
1337
+ Production-ready container that deploys cleanly.
1338
+
1339
+ ### Files to Edit
1340
+
1341
+ **`Dockerfile`** — Finalize:
1342
+ - Base: `python:3.12-slim`
1343
+ - PyTorch CPU-only: `pip install torch --index-url https://download.pytorch.org/whl/cpu`
1344
+ - Target: <500MB
1345
+ - `EXPOSE 7860`
1346
+ - `CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]`
1347
+
1348
+ **Note on Dockerfile COPY**: Cannot use `COPY ... 2>/dev/null || true` in Dockerfile. Instead, ensure all files exist or use multi-stage approach.
1349
+
1350
+ **Logging** — Add to `server/app.py` and `server/environment.py`:
1351
+ - JSON structured logging to stdout
1352
+ - Log every `reset()`, `step()`, episode completion, errors
1353
+
1354
+ **WebSocket edge cases** (spec Section 16):
1355
+ - Client disconnects mid-episode → retain state 60s
1356
+ - Malformed JSON → return error, keep connection
1357
+ - step() before reset() → return "no_active_episode" error
1358
+ - reset() during active episode → terminate current, start new
1359
+
1360
+ ### Acceptance Criteria — Phase 7
1361
+
1362
+ - [ ] `docker build -t pytorch-debugger .` succeeds
1363
+ - [ ] Docker image <500MB
1364
+ - [ ] `docker run -p 7860:7860 pytorch-debugger` starts and serves in <60s
1365
+ - [ ] `curl http://localhost:7860/health` returns `{"status": "ready", "tasks": 3}`
1366
+ - [ ] All WebSocket edge cases handled per spec Section 16
1367
+ - [ ] Structured JSON logging on all significant events
1368
+
1369
+ ---
1370
+
1371
+ ## Phase 8: Full Testing Suite + Pre-Submission Smoke Tests
1372
+
1373
+ ### Goal
1374
+ >80% test coverage, all edge cases covered.
1375
+
1376
+ ### Files to Create/Extend
1377
+
1378
+ All test files listed above, plus:
1379
+ - Fill coverage gaps identified by `pytest --cov`
1380
+ - Add edge case tests for every error in spec Section 16
1381
+ - Add test for `step()` after `done=True`
1382
+ - Add test for step limit termination
1383
+
1384
+ ### Commands
1385
+
1386
+ ```bash
1387
+ pytest tests/ -v --cov=ml_training_debugger --cov=server --cov-report=term-missing
1388
+ ```
1389
+
1390
+ ### Acceptance Criteria — Phase 8
1391
+
1392
+ - [ ] `pytest --cov` shows >80% coverage on all modules
1393
+ - [ ] Every error condition from spec Section 16 has a test
1394
+ - [ ] Context-gated penalty tests pass (both paths)
1395
+ - [ ] Dynamic available_actions tests pass
1396
+ - [ ] All 3 graders tested with multiple scenarios
1397
+ - [ ] Zero test failures
1398
+
1399
+ ---
1400
+
1401
+ ## Phase 9: Final Polish & Submission Readiness
1402
+
1403
+ ### Goal
1404
+ README complete, all endpoints verified, `openenv validate` passes, deploy to HF Spaces.
1405
+
1406
+ ### Files to Create
1407
+
1408
+ **`README.md`** (~200 lines):
1409
+ - Environment description and motivation
1410
+ - Action/observation space definitions
1411
+ - Task descriptions with difficulty
1412
+ - Setup instructions
1413
+ - Baseline scores table
1414
+
1415
+ **`deploy.sh`**:
1416
+ ```bash
1417
+ #!/bin/bash
1418
+ set -euo pipefail
1419
+
1420
+ echo "=== Building Docker image ==="
1421
+ docker build -t pytorch-debugger .
1422
+
1423
+ echo "=== Starting container ==="
1424
+ docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
1425
+ sleep 10
1426
+
1427
+ echo "=== Health check ==="
1428
+ curl -f http://localhost:7860/health || { echo "FAIL: health"; exit 1; }
1429
+
1430
+ echo "=== Tasks endpoint ==="
1431
+ curl -f http://localhost:7860/tasks | python3 -m json.tool || { echo "FAIL: tasks"; exit 1; }
1432
+
1433
+ echo "=== Baseline reproducibility ==="
1434
+ python3 baseline_heuristic.py > run1.json 2>/dev/null
1435
+ python3 baseline_heuristic.py > run2.json 2>/dev/null
1436
+ diff run1.json run2.json && echo "PASS: reproducible" || { echo "FAIL: non-reproducible"; exit 1; }
1437
+
1438
+ echo "=== Baseline via endpoint ==="
1439
+ curl -f -X POST http://localhost:7860/baseline | python3 -m json.tool || { echo "FAIL: baseline endpoint"; exit 1; }
1440
+
1441
+ echo "=== Grader via endpoint ==="
1442
+ curl -f -X POST http://localhost:7860/grader | python3 -m json.tool || { echo "FAIL: grader endpoint"; exit 1; }
1443
+
1444
+ echo "=== Tests ==="
1445
+ pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
1446
+
1447
+ echo "=== Cleanup ==="
1448
+ docker stop smoke-test && docker rm smoke-test
1449
+ rm -f run1.json run2.json
1450
+
1451
+ echo "=== ALL CHECKS PASSED ==="
1452
+ ```
1453
+
1454
+ ### Acceptance Criteria — Phase 9
1455
+
1456
+ - [ ] `openenv validate` passes
1457
+ - [ ] `deploy.sh` runs end-to-end with zero failures
1458
+ - [ ] README is complete per hackathon requirements
1459
+ - [ ] Docker image <500MB, starts <60s
1460
+ - [ ] Baseline bit-exact reproducible
1461
+ - [ ] 3+ tasks with graders returning [0.0, 1.0] with meaningful variance
1462
+ - [ ] HF Space deployed, tagged `openenv`, responds to `reset()`
1463
+ - [ ] All typed Pydantic models — no `Dict[str, Any]`
1464
+ - [ ] `import torch` in every core module — zero numpy in core
1465
+ - [ ] Context-gated penalty fires correctly and does not fire prematurely
1466
+ - [ ] Test suite passes with >80% coverage
1467
+
1468
+ ---
1469
+
1470
+ ## Technical Risk Mitigations
1471
+
1472
+ | Risk | Impact | Mitigation |
1473
+ |------|--------|------------|
1474
+ | **WebSocket + HTTP composition** | ~~High~~ RESOLVED | `create_app()` returns standard FastAPI. Custom routes add cleanly. Verified in Phase 0. |
1475
+ | **Docker image size** | Medium | `python:3.12-slim` + torch CPU-only (~150MB). Target <500MB. Test early in Phase 7. |
1476
+ | **Task 6 fix validation fragility** | Medium | Multi-strategy pipeline: normalize → tokenize → semantic patterns → AST fallback. Test 5+ whitespace variations. (Post-MVP Phase 2 stretch) |
1477
+ | **Red-herring penalty gating** | HIGH | `gradients_were_normal` set inside `inspect_gradients` handler when ALL layers have `is_exploding=False`. Threshold: `mean_norm > 10.0`. Test BOTH paths explicitly. |
1478
+ | **Session isolation** | Medium | `dict[str, SessionData]` keyed by session ID. Framework provides session management. |
1479
+ | **Baseline reproducibility** | HIGH | `torch.manual_seed(seed)` at every `reset()`. Seed derived deterministically from task_id. Heuristic is pure logic. Test with `diff run1.json run2.json`. |
1480
+ | **Dockerfile build time** | Low | No real training during build. Validation reports pre-computed locally. |
1481
+ | **openenv.yaml format** | Medium | Template uses `spec_version: 1`, `type: space`, `runtime: fastapi`, `app: server.app:app`. Extended fields (tasks, reward, etc.) are additive. Test with `openenv validate` early. |
1482
+ | **Port mismatch** | Low | Spec says 7860 (HF Spaces default). openenv template says 8000. Use 7860 everywhere. |
1483
+
1484
+ ---
1485
+
1486
+ ## Exact openenv.yaml (Final)
1487
+
1488
+ ```yaml
1489
+ spec_version: 1
1490
+ name: pytorch-training-debugger
1491
+ type: space
1492
+ runtime: fastapi
1493
+ app: server.app:app
1494
+ port: 7860
1495
+
1496
+ version: "1.0.0"
1497
+ description: |
1498
+ PyTorch-native fault injection engine for training failure debugging.
1499
+ An AI agent investigates, diagnoses, fixes, and verifies broken
1500
+ training runs using real torch.nn.Module models, torch.autograd
1501
+ gradients, state_dict() weight inspection, and PyTorch code-level
1502
+ debugging. 3 tasks across 3 difficulty tiers with context-gated
1503
+ reward shaping.
1504
+ framework: openenv
1505
+ tags: [ml-debugging, pytorch, reinforcement-learning, root-cause-analysis, fault-injection, openenv]
1506
+
1507
+ observation_space:
1508
+ type: MLTrainingObservation
1509
+ description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
1510
+
1511
+ action_space:
1512
+ type: MLTrainingAction
1513
+ description: "Investigation, fix, and diagnosis actions with dynamic availability"
1514
+
1515
+ tasks:
1516
+ - id: task_001
1517
+ difficulty: easy
1518
+ max_steps: 20
1519
+ - id: task_003
1520
+ difficulty: medium
1521
+ max_steps: 25
1522
+ - id: task_005
1523
+ difficulty: hard
1524
+ max_steps: 30
1525
+
1526
+ reward:
1527
+ range: [-1.0, 1.0]
1528
+ shaped: true
1529
+ step_penalty: -0.01
1530
+ investigation_bonus: 0.05
1531
+ max_investigation_bonus: 0.25
1532
+ correct_diagnosis: 0.50
1533
+ terminal_convergence: 0.40
1534
+
1535
+ endpoints:
1536
+ websocket: "/ws"
1537
+ tasks: "GET /tasks"
1538
+ grader: "POST /grader"
1539
+ baseline: "POST /baseline"
1540
+ health: "GET /health"
1541
+ ```
1542
+
1543
+ ---
1544
+
1545
+ ## Exact Dockerfile (Final)
1546
+
1547
+ ```dockerfile
1548
+ FROM python:3.12-slim
1549
+
1550
+ WORKDIR /app
1551
+
1552
+ # Install PyTorch CPU-only first (largest layer, cached)
1553
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
1554
+
1555
+ # Install remaining dependencies
1556
+ COPY requirements.txt .
1557
+ RUN pip install --no-cache-dir -r requirements.txt
1558
+
1559
+ # Copy application code
1560
+ COPY ml_training_debugger/ ml_training_debugger/
1561
+ COPY server/ server/
1562
+ COPY openenv.yaml .
1563
+ COPY baseline_heuristic.py .
1564
+ COPY README.md .
1565
+
1566
+ EXPOSE 7860
1567
+
1568
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
1569
+ CMD curl -f http://localhost:7860/health || exit 1
1570
+
1571
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
1572
+ ```
1573
+
1574
+ ---
1575
+
1576
+ ## Pre-Submission Smoke Test Sequence
1577
+
1578
+ ```bash
1579
+ # 1. Clean build
1580
+ docker build --no-cache -t pytorch-debugger .
1581
+
1582
+ # 2. Start container
1583
+ docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
1584
+ sleep 10
1585
+
1586
+ # 3. Health check
1587
+ curl -f http://localhost:7860/health
1588
+
1589
+ # 4. Tasks endpoint
1590
+ curl -f http://localhost:7860/tasks | python3 -m json.tool
1591
+
1592
+ # 5. Baseline reproducibility
1593
+ python3 baseline_heuristic.py > run1.json 2>/dev/null
1594
+ python3 baseline_heuristic.py > run2.json 2>/dev/null
1595
+ diff run1.json run2.json && echo "PASS: reproducible" || echo "FAIL"
1596
+
1597
+ # 6. Baseline via endpoint
1598
+ curl -f -X POST http://localhost:7860/baseline | python3 -m json.tool
1599
+
1600
+ # 7. Grader via endpoint
1601
+ curl -f -X POST http://localhost:7860/grader | python3 -m json.tool
1602
+
1603
+ # 8. OpenEnv validation
1604
+ openenv validate
1605
+
1606
+ # 9. Test suite
1607
+ pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
1608
+
1609
+ # 10. Cleanup
1610
+ docker stop smoke-test && docker rm smoke-test
1611
+ rm -f run1.json run2.json
1612
+
1613
+ echo "=== All checks passed ==="
1614
+ ```
1615
+
1616
+ ---
1617
+
1618
+ ## Post-MVP Stretch (Phase 2 from ROADMAP)
1619
+
1620
+ **Only after MVP is 100% deployed and passing all auto-validation:**
1621
+
1622
+ 1. **Task 6** (code debugging) — highest impact differentiator
1623
+ - Create `ml_training_debugger/code_templates.py`
1624
+ - 4 bug variants: eval_mode, detach_loss, zero_grad_missing, inplace_relu
1625
+ - Multi-strategy fix validation: normalize → tokenize → semantic → AST
1626
+ - Diagnosis is ALWAYS `code_bug` regardless of variant
1627
+
1628
+ 2. **Tasks 2 & 4** — fill out to 6 tasks
1629
+ - Task 2: vanishing gradients (easy, mirror of Task 1)
1630
+ - Task 4: overfitting (medium, train-val divergence)
1631
+
1632
+ 3. **Dashboard** — `server/dashboard.html`, Plotly.js via CDN
1633
+
1634
+ 4. **Validation Suite** — `validation/*.py`, R² > 0.85
1635
+
1636
+ 5. **LLM Baseline** — `baseline_inference.py`, GPT-4o
1637
+
1638
+ Update `openenv.yaml`, `/tasks`, `/health` task count as tasks are added.
1639
+
1640
+ ---
1641
+
1642
+ ## SESSION_ID
1643
+
1644
+ - CODEX_SESSION: N/A (codeagent-wrapper not available)
1645
+ - GEMINI_SESSION: N/A (codeagent-wrapper not available)
1646
+
1647
+ Plan generated by Claude Opus 4.6 via deep analysis of all 4 project markdown files + openenv-core framework API inspection.
.coverage ADDED
Binary file (53.2 kB). View file
 
.dockerignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ .git/
4
+ .pytest_cache/
5
+ tests/
6
+ validation/
7
+ *.md
8
+ !README.md
9
+ .claude/
10
+ run*.json
11
+ htmlcov/
12
+ .mypy_cache/
13
+ .ruff_cache/
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .venv/
2
+ __pycache__/
3
+ *.pyc
4
+ *.pyo
5
+ .env
6
+ run*.json
7
+ .pytest_cache/
8
+ htmlcov/
9
+ *.egg-info/
10
+ dist/
11
+ build/
12
+ validation/reports/*.png
13
+ .mypy_cache/
14
+ .ruff_cache/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.12
CLAUDE.md ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md — PyTorch Training Run Debugger
2
+
3
+ OpenEnv RL environment for the Meta PyTorch OpenEnv Hackathon x Scaler School of Technology.
4
+ An AI agent debugs broken PyTorch training runs by investigating gradients, weights, data, model modes, and source code to diagnose and fix real ML failure patterns.
5
+
6
+ **Spec:** `ml-training-debugger-spec.md` is the single source of truth. If this file and the spec conflict, the spec wins.
7
+
8
+ **Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
9
+
10
+ ---
11
+
12
+ ## Non-Negotiable Rules
13
+
14
+ ### MVP-First Execution
15
+ Ship Tasks 1, 3, 5 (easy/medium/hard) + rule-based baseline + Docker + HF deploy **before** touching anything else. A deployed MVP that passes auto-validation beats a half-finished 6-task environment. Priority order after MVP: Task 6 > Tasks 2 & 4 > dashboard > validation suite > LLM baseline.
16
+
17
+ ### Context-Gated Penalty Must Be Exact
18
+ The -0.20 penalty for `add_callback` fires **only when both** `gradients_inspected == True` AND `gradients_were_normal == True`. It must **never** fire before `inspect_gradients` has been called. This is the project's primary innovation. Get the gate conditions wrong and the differentiator is broken. Test both paths:
19
+ - `add_callback` at step 1 (no prior inspection) -> **no penalty**
20
+ - `inspect_gradients` (normal) then `add_callback` -> **-0.20 penalty**
21
+
22
+ ### Task 6 Diagnosis Is Always `code_bug`
23
+ Regardless of the specific bug variant (`eval_mode`, `detach_loss`, `zero_grad_missing`, `inplace_relu`), Task 6's correct diagnosis is **always** `code_bug`. Submitting `batchnorm_eval_mode` on Task 6's `eval_mode` variant is a wrong diagnosis (-0.30). The grader enforces this with a strict equality check.
24
+
25
+ ### PyTorch-Native Only — No NumPy
26
+ Every computation in core modules uses `torch.Tensor`, not `numpy.ndarray`. `import torch` must appear in `models.py`, `simulation.py`, `pytorch_engine.py`, `reward_engine.py`, and `graders.py`. This is a Meta PyTorch hackathon — judges will notice. The only exception is test utilities and the validation suite where `scipy`/`matplotlib` are acceptable.
27
+
28
+ ### Grader != Reward Function
29
+ These are separate modules with separate purposes. The **reward function** (`reward_engine.py`) returns a float per step for RL training signal. The **grader** (`graders.py`) returns a normalized 0.0-1.0 score at episode end for the `/grader` endpoint and auto-validation. The grader evaluates `EpisodeState` holistically — it is **not** a sum of step rewards. Never conflate them.
30
+
31
+ ### Opaque Task IDs
32
+ Task IDs are `task_001` through `task_006`. The agent must never be able to infer the diagnosis from the task ID. Do not use descriptive names anywhere the agent can observe them.
33
+
34
+ ---
35
+
36
+ ## Architecture Constraints
37
+
38
+ ### Framework Integration (Verified)
39
+ ```
40
+ openenv-core v0.2.2 → create_app() → returns standard FastAPI instance
41
+ ```
42
+
43
+ - `MLTrainingAction` extends `Action` from `openenv.core.env_server.types`
44
+ - `MLTrainingObservation` extends `Observation` from `openenv.core.env_server.types` (has built-in `done`, `reward`, `metadata`)
45
+ - `MLTrainingEnvironment` extends `Environment` from `openenv.core.env_server.interfaces` (must implement `reset()`, `step()`, `state` property)
46
+ - `MLTrainingEnvClient` in `client.py` extends `EnvClient` with typed `action_type` and `observation_type` — used by baseline scripts
47
+ - `create_app()` takes the **class** (factory), not an instance
48
+ - Custom routes (`/tasks`, `/grader`, `/baseline`, `/health`) are added directly to the returned FastAPI app via `@app.get()`/`@app.post()` decorators
49
+ - Framework auto-provides: `POST /reset`, `POST /step`, `GET /state`, `WS /ws`, `GET /schema`, `GET /docs`, `/mcp`
50
+
51
+ ### Key Constraints (see spec for full detail)
52
+ - **Real PyTorch models:** `pytorch_engine.py` instantiates `SimpleCNN` (~50K params) at every `reset()`, runs 1-2 real forward+backward passes. Gradient and weight stats come from real `torch.autograd` and `model.state_dict()`.
53
+ - **Typed Pydantic models everywhere:** No `Dict[str, Any]`. `available_actions` is dynamically computed from `EpisodeState`, never hardcoded.
54
+ - **Session isolation:** Each WebSocket client gets its own `EpisodeState` keyed by session ID. `SUPPORTS_CONCURRENT_SESSIONS = True`.
55
+
56
+ ---
57
+
58
+ ## Coding Standards
59
+
60
+ ### Formatting & Linting
61
+ - **black** for formatting (line length 88)
62
+ - **ruff** for linting
63
+ - **isort** for import ordering (profile=black)
64
+ - Run all three before every commit
65
+
66
+ ### Type Hints
67
+ Type annotations on **every** function signature and return type. No `Any` in public APIs. Use `Optional[X]` for nullable fields, `Literal[...]` for closed string unions, `list[X]` (lowercase) for Python 3.12+.
68
+
69
+ ### Testing
70
+ - **pytest** for all tests
71
+ - Every module in `ml_training_debugger/` has a corresponding `tests/test_*.py`
72
+ - Minimum test coverage: 80%
73
+ - Critical tests that must exist:
74
+ - `test_reward_engine.py`: context-gated penalty fires/doesn't fire under correct conditions
75
+ - `test_graders.py`: each grader returns 0.0-1.0, correct diagnosis scores high, wrong diagnosis scores low
76
+ - `test_pytorch_engine.py`: model instantiation, fault injection, gradient/weight extraction produces real tensors
77
+ - `test_code_templates.py`: all 4 bug variants generate valid code, fix validation accepts correct fixes and rejects wrong ones (including whitespace/comment variations)
78
+ - `test_episode_lifecycle.py`: full episode flow reset->inspect->fix->restart->diagnose produces expected state transitions
79
+
80
+ ### File Size Limits
81
+ - 400 lines typical, 800 max per file
82
+ - `models.py` may exceed 400 lines due to many Pydantic models — this is acceptable
83
+ - `pytorch_engine.py` must stay under 300 lines (isolate model definitions if needed)
84
+
85
+ ### Error Handling
86
+ `step()` must **never** raise an unhandled exception. All invalid actions return a valid observation with `-0.05` penalty and an error note. All edge cases (step after done, step before reset, malformed JSON) return structured error responses.
87
+
88
+ ---
89
+
90
+ ## Key Risks to Watch
91
+
92
+ ### Task 6 Code Fix Validation
93
+ LLM agents will submit fixes with trailing spaces, inline comments, or minor reformatting. Use the multi-strategy validation pipeline:
94
+ 1. Normalize whitespace + strip comments
95
+ 2. Token-stream comparison via `tokenize` module
96
+ 3. 2-3 semantic equivalence patterns per bug variant
97
+ 4. `ast.parse()` fallback to verify buggy pattern is absent
98
+
99
+ Test with intentionally messy fixes: `" loss = criterion(output, batch_y) # fixed "` must pass.
100
+
101
+ ### Red-Herring Penalty Gating
102
+ The `gradients_were_normal` flag is set **inside** the `inspect_gradients` handler, based on whether `is_exploding` is False on **all** layers. The threshold for `is_exploding` is `mean_norm > 10.0`. The threshold for `is_vanishing` is `mean_norm < 1e-6`. In Task 5, the FC spike has `is_exploding: False` (it spiked but the mean norm stays below 10.0), so `gradients_were_normal` is set to True. This is the gate that makes the penalty fire when the agent then calls `add_callback`.
103
+
104
+ ### Docker Image Size
105
+ Target: <500MB. PyTorch CPU-only wheel is ~150MB. Use `python:3.12-slim` base. Install torch with `--index-url https://download.pytorch.org/whl/cpu`. Do NOT install CUDA. Pre-compute validation reports locally — do not run real training in Docker build.
106
+
107
+ ### Baseline Reproducibility
108
+ The rule-based baseline must produce **bit-exact identical** scores on two consecutive runs. This requires:
109
+ - `torch.manual_seed(seed)` at every `reset()` with a deterministic seed per task
110
+ - No floating-point non-determinism in the parametric curve generators
111
+ - The heuristic decision tree is pure logic with no randomness
112
+
113
+ ### Auto-Validator Endpoints
114
+ These endpoints are checked programmatically. They must respond correctly or you are disqualified:
115
+ - `GET /health` -> `{"status": "ready", "tasks": N}` (200) — N is the number of active tasks (3 for MVP, 6 for full)
116
+ - `GET /tasks` -> list of tasks with IDs and action schema (200)
117
+ - `POST /grader` -> `{"score": float}` after a completed episode (200)
118
+ - `POST /baseline` -> scores for all tasks (200)
119
+ - `WS /ws` -> responds to `reset` message
120
+
121
+ ---
122
+
123
+ ## Reward Constants (Do Not Change)
124
+
125
+ See spec Section 12 for full rationale. Summary:
126
+
127
+ | Event | Value | Gate |
128
+ |---|---|---|
129
+ | Step penalty | -0.01 | Unconditional, flat (never multiply by step_count) |
130
+ | Investigation bonus | +0.05 | First-time only per inspection type |
131
+ | Context-gated penalty | -0.20 | `gradients_inspected AND gradients_were_normal` |
132
+ | Invalid action | -0.05 | Action not in `available_actions` |
133
+ | Wrong code fix | -0.10 | `fix_code` with wrong line/replacement |
134
+ | Correct diagnosis | +0.50 | `diagnosis == true_root_cause` |
135
+ | Wrong diagnosis | -0.30 | `diagnosis != true_root_cause` |
136
+ | Terminal convergence | +0.40 | `fix_action_taken AND restart_after_fix AND convergence` |
137
+
138
+ ---
139
+
140
+ ## Success Criteria — "Perfect" Submission
141
+
142
+ All of these must be true:
143
+ - [ ] `openenv validate` passes
144
+ - [ ] `docker build && docker run` starts server on port 7860 in <60s
145
+ - [ ] HF Space deploys, responds to `reset()`, tagged with `openenv`
146
+ - [ ] `baseline_heuristic.py` produces identical scores on two runs
147
+ - [ ] 3+ tasks with graders returning scores in [0.0, 1.0] with meaningful variance
148
+ - [ ] Hard task (Task 5 or 6) genuinely challenges frontier models (score < 0.7 for heuristic)
149
+ - [ ] Context-gated penalty fires correctly and does not fire prematurely
150
+ - [ ] All typed Pydantic models, no `Dict[str, Any]`
151
+ - [ ] `import torch` in every core module, zero numpy imports in core
152
+ - [ ] README documents: environment description, action/observation spaces, task descriptions with difficulty, setup instructions, baseline scores
153
+ - [ ] POST `/baseline`, POST `/grader`, GET `/tasks` all respond correctly
154
+ - [ ] Test suite passes with >80% coverage
155
+
156
+ ---
157
+
158
+ ## Commands
159
+
160
+ ```bash
161
+ # Development (from project root: ML Debugger/)
162
+ source .venv/bin/activate
163
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 7860
164
+
165
+ # Tests
166
+ pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
167
+
168
+ # Formatting
169
+ black ml_training_debugger/ server/ tests/
170
+ ruff check ml_training_debugger/ server/ tests/ --fix
171
+ isort ml_training_debugger/ server/ tests/ --profile black
172
+
173
+ # Docker
174
+ docker build -t pytorch-debugger .
175
+ docker run -p 7860:7860 pytorch-debugger
176
+
177
+ # Smoke test
178
+ curl http://localhost:7860/health
179
+ curl http://localhost:7860/tasks
180
+ python baseline_heuristic.py > run1.json
181
+ python baseline_heuristic.py > run2.json
182
+ diff run1.json run2.json # Must be empty
183
+
184
+ # OpenEnv validation
185
+ openenv validate
186
+ ```
Dockerfile ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install PyTorch CPU-only first (largest layer, cached)
6
+ RUN pip install --no-cache-dir torch --index-url https://download.pytorch.org/whl/cpu
7
+
8
+ # Install remaining dependencies
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ # Copy application code
13
+ COPY ml_training_debugger/ ml_training_debugger/
14
+ COPY server/ server/
15
+ COPY openenv.yaml .
16
+ COPY baseline_heuristic.py .
17
+ COPY README.md .
18
+
19
+ EXPOSE 7860
20
+
21
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=10s --retries=3 \
22
+ CMD curl -f http://localhost:7860/health || exit 1
23
+
24
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
PRD.md ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PRD — PyTorch Training Run Debugger
2
+
3
+ **Product:** OpenEnv RL environment for ML training failure diagnosis
4
+ **Hackathon:** Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, Round 1
5
+ **Deadline:** April 8, 2026 (submission window opens March 28)
6
+ **Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
7
+ **Source of truth:** `ml-training-debugger-spec.md` for all implementation detail beyond this PRD
8
+
9
+ ---
10
+
11
+ ## 1. Overview
12
+
13
+ ### 1.1 What We Are Building
14
+
15
+ An OpenEnv-compliant reinforcement learning environment where an AI agent receives a snapshot of a broken PyTorch training run and must investigate, diagnose, fix, and verify the failure through a multi-step interactive process. The environment exposes real PyTorch model internals (gradients from `torch.autograd`, weights from `model.state_dict()`) and covers 6 failure scenarios across 3 difficulty tiers.
16
+
17
+ ### 1.2 Problem Being Solved
18
+
19
+ MLOps teams spend 15-25% of engineer time debugging silent training failures — runs that produce no error, no crash, just bad metrics. Each misdiagnosed restart wastes GPU compute at $2-8/hour/card. The diagnostic process is hard because:
20
+
21
+ - Multiple symptoms can point to multiple causes simultaneously
22
+ - Some bugs produce no error — just mysteriously bad performance
23
+ - Fixing the wrong thing wastes hours of compute and restarts
24
+ - Static analysis catches some bugs but cannot reason through ambiguous runtime signals
25
+
26
+ No existing OpenEnv environment covers this domain. The OpenEnv Hub currently contains a demo echo environment and a code execution environment. This fills a genuine gap.
27
+
28
+ ### 1.3 Why This Domain Wins
29
+
30
+ 1. **Strategic alignment** — PyTorch debugging for a Meta PyTorch hackathon. Judges from Meta and Hugging Face will see their own framework as the core subject matter.
31
+ 2. **Novel reward design** — Context-gated penalties that encode evidence-based reasoning into the reward signal. No existing OpenEnv environment attempts this.
32
+ 3. **Code-level debugging** — Task 6 requires the agent to read and fix actual PyTorch code. Directly addresses Meta's interest: can an AI agent debug PyTorch?
33
+ 4. **Ecosystem gap** — Zero competition in the OpenEnv ecosystem for ML training failure diagnosis.
34
+
35
+ ### 1.4 Key Differentiators
36
+
37
+ | Differentiator | What It Is | Why It Matters |
38
+ |---|---|---|
39
+ | Context-gated reward shaping | Penalty fires only when agent ignores evidence it already gathered; no penalty for reasonable priors | Encodes evidence-based decision making — a capability no other OpenEnv environment has |
40
+ | PyTorch-native internals | Real `torch.nn.Module` models, real `torch.autograd` gradients, real `state_dict()` snapshots | Every model-level observation is grounded in real PyTorch computation, not synthetic data |
41
+ | Code-level debugging (Task 6) | Agent reads PyTorch code, identifies buggy line, submits code fix | Tests code understanding, not just metric interpretation — aligned with Meta's core interest |
42
+
43
+ ---
44
+
45
+ ## 2. Target Users
46
+
47
+ ### 2.1 Primary: Hackathon Judges (Meta + Hugging Face Engineers)
48
+
49
+ **What they evaluate:**
50
+ - Real-world utility (30%) — Is this a genuine task? Would someone use this to train/evaluate agents?
51
+ - Task & grader quality (25%) — Well-defined tasks, accurate graders, meaningful difficulty progression?
52
+ - Environment design (20%) — Clean state management, sensible action/observation spaces, good reward shaping?
53
+ - Code quality & spec compliance (15%) — OpenEnv spec, clean structure, typed models, working Dockerfile?
54
+ - Creativity & novelty (10%) — Novel domain, interesting mechanics, original approach?
55
+
56
+ **What impresses them:**
57
+ - Real `import torch` in core modules (not numpy wrappers)
58
+ - A live dashboard where they can watch an agent investigate in real time
59
+ - Deterministic graders that produce different scores for different agent quality levels
60
+ - The context-gated penalty — nuanced reward design that goes beyond standard practice
61
+
62
+ **What disqualifies:**
63
+ - HF Space doesn't deploy or respond to `reset()`
64
+ - Plagiarized or trivially modified existing environments
65
+ - Graders that always return the same score
66
+ - No baseline inference script
67
+ - Dockerfile doesn't build
68
+
69
+ ### 2.2 Secondary: RL Researchers and Agent Developers
70
+
71
+ **What they need:**
72
+ - A challenging benchmark that differentiates heuristic agents from reasoning-capable ones
73
+ - Clear, typed action/observation schemas for agent integration
74
+ - Reproducible baseline scores for comparison
75
+ - Environments that produce meaningful reward signal across the full trajectory (not just sparse terminal reward)
76
+
77
+ ### 2.3 Tertiary: Auto-Validation System (Phase 1 Gate)
78
+
79
+ A non-human "user" that must pass before any human judge sees the submission:
80
+ - Pings HF Space URL — must return 200 and respond to `reset()`
81
+ - Validates `openenv.yaml`, typed models, `step()`/`reset()`/`state()` endpoints
82
+ - Runs `docker build` on submitted repo
83
+ - Runs baseline script twice — scores must be identical
84
+ - Enumerates tasks, runs each grader — scores must be in [0.0, 1.0]
85
+
86
+ ---
87
+
88
+ ## 3. Success Metrics
89
+
90
+ ### 3.1 Evaluation Criteria Targets
91
+
92
+ | Criterion | Weight | Target Score | How We Hit It |
93
+ |---|---|---|---|
94
+ | Real-world utility | 30% | 26-30 | ML debugging is a $B+ problem; every PyTorch team encounters these failures; fills a genuine OpenEnv gap |
95
+ | Task & grader quality | 25% | 21-25 | 6 tasks (3 MVP), 3 difficulty tiers, deterministic graders, hard tasks challenge frontier models |
96
+ | Environment design | 20% | 17-20 | Progressive reveal, context-gated penalties, dynamic `available_actions`, proper episode boundaries |
97
+ | Code quality & spec compliance | 15% | 13-15 | Full OpenEnv spec, typed Pydantic models, working Dockerfile + HF Space, two baselines |
98
+ | Creativity & novelty | 10% | 9-10 | Context-gated rewards, real PyTorch model internals, code fix task — all new to OpenEnv |
99
+ | **Total** | **100%** | **86-100** | |
100
+
101
+ ### 3.2 Quantitative Success Criteria
102
+
103
+ | Metric | Target | Measurement |
104
+ |---|---|---|
105
+ | Auto-validation | Pass all 5 gates | `openenv validate` + smoke test sequence |
106
+ | Grader score range | Meaningful variance per task | Heuristic baseline ~0.30-0.85 across tasks (not flat) |
107
+ | Heuristic-LLM gap | Measurable difference | LLM scores higher than heuristic on Tasks 5 and 6 |
108
+ | `reset()` latency | <200ms | Model instantiation + 2 forward passes + parametric curves |
109
+ | `step()` latency | <10ms | Action dispatch + reward computation + state update |
110
+ | Baseline reproducibility | Bit-exact across runs | `diff run1.json run2.json` produces no output |
111
+ | Docker image size | <500MB | PyTorch CPU-only + python:3.12-slim |
112
+ | Test coverage | >80% | `pytest --cov` |
113
+
114
+ ### 3.3 Qualitative Success Criteria
115
+
116
+ - A judge can open `/dashboard`, trigger a baseline run, and understand the agent's reasoning at a glance
117
+ - Task 5 (BatchNorm eval mode) visibly differentiates disciplined investigation from red-herring chasing
118
+ - Task 6 (code bug) produces a "wow" moment — an agent reading and fixing PyTorch code in front of Meta judges
119
+ - The context-gated penalty creates a story: "this agent gathered evidence and then ignored it"
120
+
121
+ ---
122
+
123
+ ## 4. Functional Requirements
124
+
125
+ > **Complete typed specifications for all data models, actions, observations, tasks, reward components, and error handling are in `ml-training-debugger-spec.md` Sections 10-16.** This section provides a product-level summary.
126
+
127
+ ### 4.1 Agent Interaction Loop
128
+
129
+ ```
130
+ reset(task_id) → initial observation (loss curves, config, error log — no gradients/weights/data/code)
131
+
132
+ step(action) → updated observation + reward + done flag (progressive reveal)
133
+
134
+ ... repeat ...
135
+
136
+ step(mark_diagnosed) → terminal observation, done=True, episode scored by grader
137
+ ```
138
+
139
+ ### 4.2 Observation Space Summary
140
+
141
+ The `MLTrainingObservation` extends `Observation` from openenv-core. Key design:
142
+ - **Always visible from reset:** loss/accuracy histories, config, error_log, GPU memory, episode state, available actions
143
+ - **Progressively revealed:** gradient stats (real torch.autograd), weight stats (real state_dict), data batch stats, model mode info, code snippets — each populated only after the corresponding `inspect_*` action
144
+ - All fields are typed Pydantic models with explicit types. See spec Section 10 for complete field definitions.
145
+
146
+ ### 4.3 Action Space Summary
147
+
148
+ The `MLTrainingAction` extends `Action` from openenv-core. 14 action types in 3 categories:
149
+ - **Investigation** (5): `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code`
150
+ - **Fix** (7): `modify_config`, `add_callback`, `replace_optimizer`, `patch_data_loader`, `fix_model_mode`, `fix_code`, `rollback_checkpoint`
151
+ - **Terminal** (2): `restart_run`, `mark_diagnosed`
152
+
153
+ Dynamic availability: `restart_run` requires `fix_action_taken`, `fix_code` requires `code_inspected`, `mark_diagnosed` disappears after submission. See spec Section 10 for complete action definitions and required fields.
154
+
155
+ ### 4.4 Diagnosis Enum (RootCauseDiagnosis)
156
+
157
+ Closed set of 6 values. Grader is a single equality check — no fuzzy matching.
158
+
159
+ | Value | Description |
160
+ |---|---|
161
+ | `lr_too_high` | Learning rate too large for the architecture |
162
+ | `vanishing_gradients` | LR too low or architecture too deep, gradients decay to near-zero |
163
+ | `data_leakage` | Validation samples appearing in training batches |
164
+ | `overfitting` | Model memorizing training data, failing to generalize |
165
+ | `batchnorm_eval_mode` | Model left in eval mode, BatchNorm using running statistics |
166
+ | `code_bug` | Bug in the PyTorch training code (Task 6 — always this, regardless of bug variant) |
167
+
168
+ ### 4.5 Reward Function Summary
169
+
170
+ Per-step signal. **Separate from the grader** (see 4.6). Range: [-1.0, 1.0] hard cap.
171
+
172
+ | Event | Reward | Gate Condition |
173
+ |---|---|---|
174
+ | Any step taken | -0.01 | Unconditional, flat constant (never multiplied by step_count) |
175
+ | First-time inspection (per type) | +0.05 | Not previously inspected for that type |
176
+ | `add_callback` after normal gradients | -0.20 | `gradients_inspected == True AND gradients_were_normal == True` |
177
+ | Invalid action | -0.05 | Action not in current `available_actions` |
178
+ | Wrong code fix | -0.10 | `fix_code` with incorrect line or replacement |
179
+ | Correct diagnosis | +0.50 | `diagnosis == true_root_cause` |
180
+ | Wrong diagnosis | -0.30 | `diagnosis != true_root_cause` |
181
+ | Convergence after fix+restart | +0.40 | `fix_action_taken AND restart_after_fix AND convergence_confirmed` |
182
+
183
+ See spec Section 12 for full design rationale.
184
+
185
+ ### 4.6 Grader Function
186
+
187
+ Returns a single normalized 0.0-1.0 score at episode end. Evaluates `EpisodeState` holistically — checks which key actions were taken, whether the correct fix was applied, whether the diagnosis is correct, and efficiency. **Not a sum of step rewards.** One grader function per task. All graders are deterministic.
188
+
189
+ Exposed via `POST /grader`. Returns score for the most recently completed episode.
190
+
191
+ ### 4.7 The Six Tasks
192
+
193
+ | Task | ID | Difficulty | Root Cause | Key Signal | Heuristic Score |
194
+ |---|---|---|---|---|---|
195
+ | Exploding Gradients | `task_001` | Easy | `lr_too_high` | All layers `is_exploding: True`, NaN in error_log | ~0.85 |
196
+ | Vanishing Gradients | `task_002` | Easy | `vanishing_gradients` | Deeper layers `is_vanishing: True`, flat loss | ~0.80 |
197
+ | Silent Data Leakage | `task_003` | Medium | `data_leakage` | High val accuracy from epoch 1, `class_overlap_score` 0.68-0.88 | ~0.70 |
198
+ | Overfitting | `task_004` | Medium | `overfitting` | Train-val divergence, loss→0.01 while val climbs | ~0.65 |
199
+ | BatchNorm Eval Mode | `task_005` | Hard | `batchnorm_eval_mode` | Slow val degradation + compound red herrings | ~0.45 |
200
+ | PyTorch Code Bug | `task_006` | Hard | `code_bug` (always) | Anomalous metrics, root cause only visible in code | ~0.30 |
201
+
202
+ **MVP tasks:** 1, 3, 5 (satisfies the 3-task minimum with easy→medium→hard range).
203
+
204
+ See spec Section 11 for complete task specifications including fault parameters, red herrings, solution paths, and grader breakdowns.
205
+
206
+ ### 4.8 Baseline Agents
207
+
208
+ **Rule-based baseline (submission default, `baseline_heuristic.py`):**
209
+ - Deterministic decision tree: inspect_gradients → check exploding/vanishing → inspect_data → check leakage → check overfitting → inspect_model_modes → inspect_code → fallback
210
+ - No API key required. Bit-exact reproducible.
211
+ - Used for Phase 1 auto-validation reproducibility checks.
212
+
213
+ **LLM baseline (optional, `baseline_inference.py`):**
214
+ - GPT-4o at temperature=0.0, seed=42
215
+ - Requires `OPENAI_API_KEY` environment variable
216
+ - Supplementary demonstration of heuristic vs. reasoning score gap
217
+ - Not used for Phase 1 reproducibility — scores reported only after empirical measurement
218
+
219
+ ### 4.9 Required Endpoints
220
+
221
+ | Endpoint | Method | Required By | Response |
222
+ |---|---|---|---|
223
+ | `/ws` | WebSocket | OpenEnv framework | Handles `reset`, `step`, `state` messages |
224
+ | `/tasks` | GET | Hackathon | Task list with IDs, difficulties, MLTrainingAction JSON schema |
225
+ | `/grader` | POST | Hackathon | `{"score": float, "task_id": str, "steps": int}` for last completed episode |
226
+ | `/baseline` | POST | Hackathon | Triggers baseline run, returns `{"scores": {"task_001": float, ...}}` |
227
+ | `/health` | GET | Hackathon | `{"status": "ready", "tasks": N}` — N is active task count |
228
+ | `/dashboard` | GET | Bonus | Live diagnostic dashboard (HTML/JS, Plotly.js via CDN) |
229
+ | `/validation-report` | GET | Bonus | Pre-computed PyTorch fidelity reports |
230
+
231
+ Framework auto-provides: `POST /reset`, `POST /step`, `GET /state`, `GET /schema`, `GET /docs`, `/mcp`.
232
+
233
+ ### 4.10 Error Handling
234
+
235
+ `step()` must never raise an unhandled exception. All invalid actions return a valid observation with -0.05 penalty and an error note. See spec Section 16 for the complete error handling matrix covering all edge cases (invalid actions, malformed JSON, step before reset, etc.).
236
+
237
+ ---
238
+
239
+ ## 5. Non-Functional Requirements
240
+
241
+ ### 5.1 OpenEnv Spec Compliance
242
+
243
+ | Requirement | Implementation |
244
+ |---|---|
245
+ | `openenv.yaml` present | Name, version, description, framework, tags, observation/action space, tasks with IDs+difficulties+max_steps, reward config, endpoints |
246
+ | Typed Pydantic models | `MLTrainingAction` extends `Action`, `MLTrainingObservation` extends `Observation`, all fields explicitly typed |
247
+ | `step()`/`reset()`/`state()` | Implemented in `MLTrainingEnvironment` extending `Environment` from `openenv.core.env_server.interfaces` |
248
+ | `openenv validate` passes | Tested before every submission |
249
+
250
+ ### 5.2 Framework Integration
251
+
252
+ | Requirement | Implementation |
253
+ |---|---|
254
+ | `openenv-core` v0.2.2 | `create_app()` returns standard FastAPI instance — **verified** |
255
+ | Custom routes compose | `/tasks`, `/grader`, `/baseline`, `/health` added via `@app.get()`/`@app.post()` on the returned FastAPI app |
256
+ | Framework-provided routes | `/reset`, `/step`, `/state`, `/ws`, `/schema`, `/docs`, `/mcp` — do not reimplement |
257
+ | Factory pattern | `create_app(MLTrainingEnvironment, ...)` takes the class, not an instance |
258
+ | Concurrent sessions | `SUPPORTS_CONCURRENT_SESSIONS = True`, session state keyed by session ID |
259
+ | Typed client | `client.py` extends `EnvClient` with typed action/observation — used by baseline scripts |
260
+
261
+ ### 5.3 Docker & Deployment
262
+
263
+ | Requirement | Target |
264
+ |---|---|
265
+ | Base image | `python:3.12-slim` |
266
+ | PyTorch | CPU-only wheel (`--index-url https://download.pytorch.org/whl/cpu`), ~150MB |
267
+ | Total image size | <500MB |
268
+ | Build time | <5 min (no real training during build; validation reports pre-computed) |
269
+ | HF Spaces | Tagged with `openenv`, port 7860 |
270
+ | Health check | `/health` returns `{"status": "ready", "tasks": N}` within 60s of container start |
271
+
272
+ ### 5.4 Reproducibility
273
+
274
+ | Requirement | Implementation |
275
+ |---|---|
276
+ | Deterministic episodes | `torch.manual_seed(seed)` at every `reset()`, seed derived deterministically from task ID |
277
+ | Baseline bit-exact | Rule-based baseline produces identical scores on two consecutive runs |
278
+ | Exploit resistance | Parameters randomized per `reset()` from defined ranges; opaque task IDs |
279
+ | Grader determinism | Same `EpisodeState` always produces same score |
280
+
281
+ ### 5.5 Performance
282
+
283
+ | Requirement | Target |
284
+ |---|---|
285
+ | `reset()` latency | <200ms (model instantiation + 2 forward passes + parametric curves) |
286
+ | `step()` latency | <10ms (action dispatch + reward + state update) |
287
+ | Memory | <512MB RSS (small CNN ~50K params, no GPU, no large datasets) |
288
+
289
+ ### 5.6 Code Quality
290
+
291
+ | Requirement | Standard |
292
+ |---|---|
293
+ | Formatting | black (line length 88) |
294
+ | Linting | ruff |
295
+ | Import ordering | isort (profile=black) |
296
+ | Type hints | Every function signature and return type |
297
+ | Tests | pytest, >80% coverage, every module has corresponding test file |
298
+ | PyTorch-native | All core computation uses `torch.Tensor`, zero numpy in core modules |
299
+
300
+ ---
301
+
302
+ ## 6. Prioritized Scope
303
+
304
+ ### Tier 1: MVP (Must Ship First)
305
+
306
+ **Deadline within deadline:** Deploy to HF Spaces by Day 6 (April 2). Everything after is additive.
307
+
308
+ | Deliverable | Description | DQ Risk if Missing |
309
+ |---|---|---|
310
+ | Task 1 (`task_001`) | Exploding gradients — easy | Yes (need 3+ tasks) |
311
+ | Task 3 (`task_003`) | Silent data leakage — medium | Yes (need 3+ tasks) |
312
+ | Task 5 (`task_005`) | BatchNorm eval mode — hard | Yes (need easy→hard range) |
313
+ | Context-gated penalty | -0.20 for `add_callback` after `gradients_were_normal` | No (but kills differentiation) |
314
+ | Rule-based baseline | `baseline_heuristic.py`, deterministic, no API key | Yes (baseline required) |
315
+ | Reward engine | All 7 reward components implemented exactly | Yes (reward logic required) |
316
+ | Graders (3) | One per MVP task, 0.0-1.0, deterministic | Yes (graders required) |
317
+ | `openenv.yaml` | Full metadata, 3+ tasks listed | Yes (spec compliance) |
318
+ | Required endpoints | `/tasks`, `/grader`, `/baseline`, `/health` | Yes (auto-validator checks) |
319
+ | Dockerfile | Builds and runs, port 7860 | Yes (auto-validator checks) |
320
+ | HF Space | Deployed, tagged `openenv`, responds to `reset()` | Yes (auto-validator pings) |
321
+ | README | Environment description, action/observation spaces, task descriptions, setup instructions, baseline scores | Yes (submission requirement) |
322
+
323
+ ### Tier 2: Strongest Differentiator (Add Immediately After MVP)
324
+
325
+ | Deliverable | Description | Why This Order |
326
+ |---|---|---|
327
+ | Task 6 (`task_006`) | PyTorch code bug — hard, code-level debugging | Single highest-impact feature for Meta judges |
328
+ | Code fix validation | Multi-strategy pipeline (tokenize, AST, semantic patterns) | Required for Task 6 to work with LLM agents |
329
+ | Grader for Task 6 | `code_bug` diagnosis, code fix scoring | Completes Task 6 |
330
+
331
+ ### Tier 3: Full Task Coverage (Time Permitting)
332
+
333
+ | Deliverable | Description |
334
+ |---|---|
335
+ | Task 2 (`task_002`) | Vanishing gradients — easy (similar to Task 1, fast to implement) |
336
+ | Task 4 (`task_004`) | Overfitting — medium (train-val divergence, regularization fix) |
337
+ | Graders for Tasks 2 & 4 | Same pattern as existing graders |
338
+
339
+ ### Tier 4: Polish & Extras (Only After Tiers 1-3 Complete)
340
+
341
+ | Deliverable | Description | Priority Within Tier |
342
+ |---|---|---|
343
+ | Live dashboard | HTML/JS at `/dashboard`, Plotly.js via CDN, 4-panel layout | 1st — transforms judging experience |
344
+ | PyTorch validation suite | 6 scripts proving parametric curves match real training, R² > 0.85 | 2nd — answers "how realistic?" |
345
+ | Validation report endpoint | `GET /validation-report` serving pre-computed fidelity plots | With validation suite |
346
+ | LLM baseline | `baseline_inference.py`, GPT-4o, measures heuristic-LLM gap | 3rd — supplementary demonstration |
347
+
348
+ ### Implementation Timeline (11 days: March 28 - April 8)
349
+
350
+ | Days | Focus | Exit Criteria |
351
+ |---|---|---|
352
+ | 1-2 | Skeleton server + Task 1 end-to-end | `reset()` → `step()` → `grader` works for one task, Docker builds |
353
+ | 3-5 | Tasks 3 & 5 + reward engine + baseline | All 3 MVP tasks pass grader, `baseline_heuristic.py` reproduces |
354
+ | 6 | **Deploy MVP to HF Spaces** | Auto-validation passes. This is the insurance policy. |
355
+ | 7-8 | Task 6 (code debugging) | Code fix validation works for all 4 bug variants |
356
+ | 9-10 | Tasks 2 & 4 + dashboard | Full 6-task environment, dashboard shows agent behavior |
357
+ | 11 | Polish, README, final smoke test | Submission-ready |
358
+
359
+ ### What We Will NOT Build (Explicit Exclusions)
360
+
361
+ - No game or toy environments
362
+ - No numpy in core modules (torch.Tensor only)
363
+ - No free-text diagnosis (closed enum only)
364
+ - No grader that sums step rewards (holistic evaluation only)
365
+ - No cumulative step penalty (flat -0.01 only, never -0.01 * step_count)
366
+ - No accommodation support or non-RL features
367
+ - No multi-GPU or CUDA dependencies (CPU-only PyTorch)
README.md ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch Training Run Debugger
2
+
3
+ **OpenEnv RL Environment** — Meta PyTorch OpenEnv Hackathon x Scaler School of Technology, Round 1
4
+
5
+ An AI agent debugs broken PyTorch training runs by investigating gradients, model weights, data pipelines, and source code to diagnose and fix real ML failure patterns.
6
+
7
+ ## What Is This?
8
+
9
+ This environment recreates the experience of an ML engineer facing a broken PyTorch training job. The agent receives a snapshot of a failing training run and must:
10
+
11
+ 1. **Investigate** — inspect gradients, data batches, model weights, model modes, and code
12
+ 2. **Diagnose** — identify the root cause from a closed set of known ML failures
13
+ 3. **Fix** — apply the correct intervention (reduce LR, patch data, fix model mode, etc.)
14
+ 4. **Verify** — restart training and confirm recovery before submitting diagnosis
15
+
16
+ ### Key Differentiators
17
+
18
+ - **PyTorch-native internals** — Real `torch.nn.Module` models (~50K params), real `torch.autograd` gradients, real `state_dict()` weight snapshots
19
+ - **Context-gated reward shaping** — Penalty fires only when agent ignores evidence it already gathered; no penalty for reasonable priors
20
+ - **Progressive information reveal** — Gradient stats, weight stats, data batch stats only populated after corresponding inspection actions
21
+
22
+ ## Environment Design
23
+
24
+ ### Observation Space (`MLTrainingObservation`)
25
+
26
+ | Field | Type | Visibility |
27
+ |-------|------|-----------|
28
+ | `training_loss_history` | `list[float]` (20 epochs) | Always |
29
+ | `val_accuracy_history` | `list[float]` (20 epochs) | Always |
30
+ | `val_loss_history` | `list[float]` (20 epochs) | Always |
31
+ | `current_config` | `TrainingConfig` | Always |
32
+ | `error_log` | `Optional[str]` | Always |
33
+ | `gradient_stats` | `list[GradientStats]` | After `inspect_gradients` |
34
+ | `model_weight_stats` | `Optional[list[ModelWeightStats]]` | After `inspect_model_weights` |
35
+ | `data_batch_stats` | `Optional[DataBatchStats]` | After `inspect_data_batch` |
36
+ | `model_mode_info` | `Optional[dict[str, str]]` | After `inspect_model_modes` |
37
+ | `code_snippet` | `Optional[CodeSnippet]` | After `inspect_code` |
38
+ | `available_actions` | `list[str]` | Always (dynamic) |
39
+ | `episode_state` | `EpisodeState` | Always |
40
+
41
+ ### Action Space (`MLTrainingAction`)
42
+
43
+ | Category | Actions |
44
+ |----------|---------|
45
+ | **Investigation** | `inspect_gradients`, `inspect_data_batch`, `inspect_model_modes`, `inspect_model_weights`, `inspect_code` |
46
+ | **Fix** | `modify_config`, `add_callback`, `replace_optimizer`, `patch_data_loader`, `fix_model_mode`, `fix_code` |
47
+ | **Terminal** | `restart_run`, `mark_diagnosed` |
48
+
49
+ Dynamic availability: `restart_run` requires a fix first; `fix_code` requires code inspection; `mark_diagnosed` disappears after submission.
50
+
51
+ ### Diagnosis Enum
52
+
53
+ | Value | Description |
54
+ |-------|-------------|
55
+ | `lr_too_high` | Learning rate too large |
56
+ | `vanishing_gradients` | Gradients decay to near-zero |
57
+ | `data_leakage` | Validation samples in training |
58
+ | `overfitting` | Model memorizing, failing to generalize |
59
+ | `batchnorm_eval_mode` | Model in eval mode during training |
60
+ | `code_bug` | Bug in PyTorch training code |
61
+
62
+ ### Reward Function
63
+
64
+ | Event | Reward | Gate |
65
+ |-------|--------|------|
66
+ | Any step | -0.01 | Flat, unconditional |
67
+ | First-time inspection | +0.05 | Per inspection type |
68
+ | `add_callback` after normal gradients | -0.20 | `gradients_inspected AND gradients_were_normal` |
69
+ | Invalid action | -0.05 | Action not in `available_actions` |
70
+ | Correct diagnosis | +0.50 | Equality check |
71
+ | Wrong diagnosis | -0.30 | Inequality check |
72
+ | Convergence after fix+restart | +0.40 | All gates met |
73
+
74
+ ## Tasks
75
+
76
+ | ID | Difficulty | Root Cause | Description |
77
+ |----|-----------|------------|-------------|
78
+ | `task_001` | Easy | `lr_too_high` | Exploding gradients — all layers show `is_exploding: True`, NaN in error log |
79
+ | `task_003` | Medium | `data_leakage` | Silent data leakage — suspiciously high val accuracy, `class_overlap_score > 0.5` |
80
+ | `task_005` | Hard | `batchnorm_eval_mode` | Model in eval mode with compound red herrings (FC gradient spike, GPU 91%, near-vanishing conv1) |
81
+
82
+ ## Baseline Scores
83
+
84
+ Rule-based heuristic baseline (deterministic, no API key):
85
+
86
+ | Task | Score |
87
+ |------|-------|
88
+ | `task_001` | 1.00 |
89
+ | `task_003` | 1.00 |
90
+ | `task_005` | 0.35 |
91
+
92
+ ## Setup
93
+
94
+ ### Local Development
95
+
96
+ ```bash
97
+ # Create virtual environment
98
+ python3 -m venv .venv
99
+ source .venv/bin/activate
100
+
101
+ # Install dependencies
102
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
103
+ pip install openenv-core pydantic fastapi uvicorn
104
+
105
+ # Install dev tools
106
+ pip install pytest pytest-cov black ruff isort
107
+
108
+ # Start server
109
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
110
+
111
+ # Run tests
112
+ pytest tests/ -v --cov=ml_training_debugger
113
+
114
+ # Run baseline
115
+ python baseline_heuristic.py
116
+ ```
117
+
118
+ ### Docker
119
+
120
+ ```bash
121
+ docker build -t pytorch-debugger .
122
+ docker run -p 7860:7860 pytorch-debugger
123
+ curl http://localhost:7860/health
124
+ ```
125
+
126
+ ## API Endpoints
127
+
128
+ | Endpoint | Method | Description |
129
+ |----------|--------|-------------|
130
+ | `/health` | GET | `{"status": "ready", "tasks": 3}` |
131
+ | `/tasks` | GET | Task list with action schema |
132
+ | `/grader` | POST | Grader score for last completed episode |
133
+ | `/baseline` | POST | Run baseline, return scores |
134
+ | `/ws` | WebSocket | Primary agent interface |
135
+ | `/reset` | POST | Reset environment (framework) |
136
+ | `/step` | POST | Execute action (framework) |
137
+ | `/state` | GET | Current state (framework) |
138
+ | `/schema` | GET | Action/observation schemas (framework) |
139
+ | `/docs` | GET | Swagger UI (framework) |
140
+
141
+ ## Architecture
142
+
143
+ - **Python 3.12** · PyTorch CPU-only · openenv-core
144
+ - Real `torch.nn.Module` models with real `torch.autograd` gradients
145
+ - Parametric curve generation for loss/accuracy histories (sub-ms latency)
146
+ - Typed Pydantic models everywhere — no `Dict[str, Any]`
147
+ - `import torch` in every core module — zero numpy in core
148
+ - Session isolation via per-session `EpisodeState`
149
+ - Deterministic reproducibility via `torch.manual_seed()`
ROADMAP.md ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ROADMAP — PyTorch Training Run Debugger
2
+
3
+ **Timeline:** March 28 - April 8, 2026 (11 days)
4
+ **Runtime:** Python 3.12 · PyTorch CPU-only · openenv-core v0.2.2
5
+ **Governing documents:** `ml-training-debugger-spec.md` (source of truth), `PRD.md` (requirements), `CLAUDE.md` (coding rules)
6
+ **Iron rule:** No phase begins until the previous phase's acceptance criteria are met. The single exception: Phase 0 and Phase 1 file creation can overlap on Day 1.
7
+
8
+ ---
9
+
10
+ ## Phase 0: Setup & Validation (Days 1-2)
11
+
12
+ **Goal:** A running skeleton server that proves the toolchain works end-to-end. Zero business logic — just plumbing.
13
+
14
+ ### 0.1 Files to Create
15
+
16
+ | File | Purpose | Lines (est.) |
17
+ |---|---|---|
18
+ | `ML Debugger/` (this directory) | Project root directory (git init here) | — |
19
+ | `pyproject.toml` | Project metadata, dependencies (torch CPU, openenv-core, pydantic>=2.0, fastapi, uvicorn, pytest, black, ruff, isort) | ~40 |
20
+ | `requirements.txt` | Flat dependency list mirroring pyproject.toml (Docker uses this). **Exclude openai** — deferred to Phase 3. | ~10 |
21
+ | `.python-version` | `3.12` | 1 |
22
+ | `openenv.yaml` | Full metadata — start with 3 MVP tasks (task_001, task_003, task_005), expand later | ~50 |
23
+ | `Dockerfile` | `python:3.12-slim`, torch CPU-only, openenv-core, app deps, port 7860 | ~15 |
24
+ | `.dockerignore` | Exclude `.venv/`, `__pycache__/`, `.git/`, `validation/reports/*.png` | ~10 |
25
+ | `.gitignore` | `.venv/`, `__pycache__/`, `*.pyc`, `.env`, `run*.json` | ~15 |
26
+ | `ml_training_debugger/__init__.py` | Package init, version string | ~3 |
27
+ | `ml_training_debugger/models.py` | **Stub only:** `RootCauseDiagnosis` enum, `EpisodeState`, `TrainingConfig`, `GradientStats`, `DataBatchStats`, `ModelWeightStats`, `CodeSnippet`, `MLTrainingObservation` (extends `Observation`), `MLTrainingAction` (extends `Action`). All fields typed, all values defaulted. | ~200 |
28
+ | `ml_training_debugger/client.py` | **Stub:** `MLTrainingEnvClient` extending `EnvClient` with `action_type = MLTrainingAction` and `observation_type = MLTrainingObservation`. Used by baseline scripts. | ~20 |
29
+ | `server/__init__.py` | Empty | 0 |
30
+ | `server/environment.py` | **Stub:** `MLTrainingEnvironment(Environment)` with `reset()` returning a hardcoded observation, `step()` echoing back, `state` property | ~50 |
31
+ | `server/app.py` | `create_app(MLTrainingEnvironment, MLTrainingAction, MLTrainingObservation)` + stub routes for `/tasks`, `/grader`, `/baseline`, `/health` | ~60 |
32
+ | `tests/__init__.py` | Empty | 0 |
33
+ | `tests/test_models.py` | Validate all Pydantic models instantiate, serialize to JSON, and round-trip | ~60 |
34
+ | `tests/conftest.py` | Shared fixtures: sample `EpisodeState`, sample `ScenarioParams`, sample observation | ~40 |
35
+
36
+ ### 0.2 Dependencies to Install
37
+
38
+ ```bash
39
+ # Create venv inside ML Debugger/ project root
40
+ python3 -m venv .venv && source .venv/bin/activate
41
+
42
+ # Core runtime
43
+ pip install torch --index-url https://download.pytorch.org/whl/cpu
44
+ pip install openenv-core pydantic>=2.0 fastapi uvicorn
45
+
46
+ # Dev tools
47
+ pip install pytest pytest-cov pytest-asyncio black ruff isort httpx websockets
48
+
49
+ # NOTE: openai is deferred to Phase 3 (LLM baseline). Do NOT install now.
50
+ ```
51
+
52
+ ### 0.3 Validation Steps (Must All Pass)
53
+
54
+ | # | Command | Expected Result |
55
+ |---|---|---|
56
+ | 1 | `python -c "import torch; print(torch.__version__)"` | Version string, no CUDA |
57
+ | 2 | `python -c "from openenv.core.env_server.http_server import create_app"` | No import error |
58
+ | 3 | `python -c "from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation"` | No import error |
59
+ | 4 | `python -c "from ml_training_debugger.client import MLTrainingEnvClient"` | No import error |
60
+ | 5 | `uvicorn server.app:app --host 0.0.0.0 --port 7860` | Server starts, no crash |
61
+ | 6 | `curl http://localhost:7860/health` | `{"status": "ready", "tasks": 3}` |
62
+ | 7 | `curl http://localhost:7860/tasks` | JSON with task list |
63
+ | 8 | `curl http://localhost:7860/docs` | Swagger UI loads |
64
+ | 9 | `pytest tests/test_models.py -v` | All pass |
65
+ | 10 | `docker build -t pytorch-debugger .` | Builds in <5min, image <500MB |
66
+ | 11 | `docker run -p 7860:7860 pytorch-debugger` then `curl /health` | Returns `{"status": "ready", "tasks": 3}` |
67
+ | 12 | `openenv validate` | Passes (or identify what needs fixing) |
68
+ | 13 | `black --check . && ruff check . && isort --check .` | Clean |
69
+
70
+ ### 0.4 Acceptance Criteria
71
+
72
+ - [ ] Skeleton server starts on port 7860 and responds to `/health`, `/tasks`, `/docs`, `/ws`
73
+ - [ ] `/health` returns `{"status": "ready", "tasks": 3}` (task count matches active tasks)
74
+ - [ ] All Pydantic models instantiate without error and serialize to valid JSON
75
+ - [ ] `client.py` imports without error
76
+ - [ ] Docker image builds under 500MB and container starts cleanly
77
+ - [ ] `openenv validate` passes or all failures are documented with a fix plan
78
+ - [ ] `pytest` runs with zero failures
79
+ - [ ] Git repo initialized, first commit made
80
+
81
+ ---
82
+
83
+ ## Phase 1: MVP — Tasks 1, 3, 5 + Core Engine (Days 2-6)
84
+
85
+ **Goal:** A fully functional 3-task environment that passes all auto-validation gates, deployed to HF Spaces. This is the survival milestone — everything after this is differentiation.
86
+
87
+ ### 1.1 Files to Create
88
+
89
+ | File | Purpose | Lines (est.) | Depends On |
90
+ |---|---|---|---|
91
+ | `ml_training_debugger/scenarios.py` | `ScenarioParams` dataclass, `sample_scenario(task_id, seed)` for tasks 001/003/005. Parameter ranges from spec Section 11. | ~120 | `models.py` |
92
+ | `ml_training_debugger/pytorch_engine.py` | `SimpleCNN(torch.nn.Module)`, `inject_fault(model, scenario)`, `extract_gradient_stats(model)`, `extract_weight_stats(model)`. Real torch.autograd. | ~250 | `scenarios.py` |
93
+ | `ml_training_debugger/simulation.py` | `gen_loss_history(scenario)`, `gen_val_accuracy_history(scenario)`, `gen_val_loss_history(scenario)`. All `torch.Tensor` ops. Parametric curves per spec Section 6. | ~180 | `scenarios.py` |
94
+ | `ml_training_debugger/reward_engine.py` | `compute_reward(action, episode_state, scenario) -> float`. All 7 reward components per spec Section 12. Context-gated penalty logic. | ~100 | `models.py` |
95
+ | `ml_training_debugger/graders.py` | `grade_task_001(state, scenario)`, `grade_task_003(...)`, `grade_task_005(...)`. Each returns float in [0.0, 1.0]. Per spec Section 11 grader breakdowns. | ~150 | `models.py` |
96
+ | `baseline_heuristic.py` | Deterministic decision tree agent using `MLTrainingEnvClient`. Runs all MVP tasks, prints JSON scores. | ~150 | `client.py`, server running |
97
+ | `README.md` | Environment description, action/observation spaces, task descriptions with difficulty, setup instructions, baseline scores table | ~200 | Everything |
98
+
99
+ ### 1.2 Files to Edit
100
+
101
+ | File | Changes | Why |
102
+ |---|---|---|
103
+ | `ml_training_debugger/models.py` | Finalize all field types, add `available_actions` computation logic to `EpisodeState`, add red herring fields (notes, gpu_memory) | Stubs from Phase 0 become real |
104
+ | `ml_training_debugger/client.py` | Wire typed client to connect via WebSocket or HTTP as needed by baseline | Stub becomes functional |
105
+ | `server/environment.py` | Full `reset()` and `step()` implementations. See spec Sections 9, 13 for lifecycle. | Stubs become real |
106
+ | `server/app.py` | Wire `/tasks`, `/grader`, `/baseline`, `/health` to return real data. `/health` returns `{"status": "ready", "tasks": 3}`. | Stubs become real |
107
+ | `openenv.yaml` | Finalize observation_space, action_space, reward section. Verify task IDs and max_steps per spec Section 14. | Was skeletal in Phase 0 |
108
+ | `Dockerfile` | Add `COPY` for all new source files. Verify build still works. | New files added |
109
+
110
+ ### 1.3 Tests to Create
111
+
112
+ | Test File | What It Covers | Critical Assertions |
113
+ |---|---|---|
114
+ | `tests/test_scenarios.py` | `sample_scenario()` for each MVP task | Returns correct root cause enum; params within defined ranges; different seeds produce different params |
115
+ | `tests/test_pytorch_engine.py` | Model instantiation, fault injection, gradient/weight extraction | `SimpleCNN` is a real `torch.nn.Module`; `extract_gradient_stats` returns `GradientStats` with real float norms; exploding fault produces `is_exploding=True`; batchnorm eval fault produces `model.training==False` |
116
+ | `tests/test_simulation.py` | Parametric curve generators | All outputs are `list[float]` of length 20; exploding LR produces diverging loss; leakage produces inflated val_acc; batchnorm produces slow val_acc degradation |
117
+ | `tests/test_reward_engine.py` | All 7 reward components | **Critical:** context-gated penalty fires when `gradients_inspected=True AND gradients_were_normal=True` then `add_callback`; does NOT fire when `add_callback` without prior inspection; step penalty is flat -0.01; investigation bonus is +0.05 first-time only |
118
+ | `tests/test_graders.py` | Graders for tasks 001, 003, 005 | Each returns float in [0.0, 1.0]; correct diagnosis + fix + restart = 1.0; wrong diagnosis < 0.5; partial completion scores between 0 and 1 |
119
+ | `tests/test_episode_lifecycle.py` | Full reset→inspect→fix→restart→diagnose flow | State transitions match spec Section 13; `available_actions` updates correctly; `done=True` after `mark_diagnosed`; step limit triggers `done=True` |
120
+
121
+ ### 1.4 Task-Specific Implementation
122
+
123
+ See spec Section 11 for complete task specifications. Key implementation notes per task:
124
+
125
+ **Task 1 (`task_001`, easy):** Unambiguous signal. LR from spec ranges → real gradients explode → `is_exploding=True` on all layers. Straightforward grader.
126
+
127
+ **Task 3 (`task_003`, medium):** Red herring note about architecture upgrade. Data leakage confirmed via `class_overlap_score`. Normal model (no gradient/weight anomaly). Mild gradient elevation on one layer (`is_exploding=False`).
128
+
129
+ **Task 5 (`task_005`, hard):** The differentiator task. `gradients_were_normal=True` set inside `inspect_gradients` handler because `is_exploding=False` on ALL layers (FC spike mean_norm < 10.0). Context-gated penalty fires when agent then calls `add_callback`. Red herrings: FC spike, GPU 91%, conv1 near-vanishing, error_log warning.
130
+
131
+ ### 1.5 Endpoint Responses
132
+
133
+ **`GET /health`:** `{"status": "ready", "tasks": 3}` (200) — or `{"status": "initializing"}` (503) during startup.
134
+
135
+ **`GET /tasks`:** Task list with IDs, difficulties, max_steps, and MLTrainingAction JSON schema.
136
+
137
+ **`POST /grader`:** `{"score": float, "task_id": str, "steps": int}` (200) — or `{"score": null, "error": "no_completed_episode"}` (200) if no episode. See spec Section 14 for edge cases.
138
+
139
+ **`POST /baseline`:** Runs baseline logic internally, returns `{"scores": {"task_001": float, "task_003": float, "task_005": float}}`. Returns 409 if already running.
140
+
141
+ ### 1.6 Baseline Heuristic Decision Tree
142
+
143
+ See spec Section 17 for the complete decision tree. Summary:
144
+ ```
145
+ 1. reset(task_id)
146
+ 2. inspect_gradients
147
+ 3. IF any layer is_exploding → fix LR → restart → diagnose lr_too_high
148
+ 4. IF any layer is_vanishing → fix LR → restart → diagnose vanishing_gradients
149
+ 5. inspect_data_batch
150
+ 6. IF class_overlap_score > 0.5 → patch_data_loader → restart → diagnose data_leakage
151
+ 7. IF val_loss diverging → modify weight_decay → restart → diagnose overfitting
152
+ 8. inspect_model_modes
153
+ 9. IF any layer in "eval" → fix_model_mode → restart → diagnose batchnorm_eval_mode
154
+ 10. inspect_code → attempt fix → restart → diagnose code_bug
155
+ 11. FALLBACK: diagnose overfitting
156
+ ```
157
+
158
+ ### 1.7 Deploy to HF Spaces
159
+
160
+ | Step | Action | Verification |
161
+ |---|---|---|
162
+ | 1 | Create HF Space (Docker type), tag with `openenv` | Space page shows openenv tag |
163
+ | 2 | Push Dockerfile + source to Space repo | Build triggers automatically |
164
+ | 3 | Wait for build to complete | Build log shows success |
165
+ | 4 | Test health endpoint | `curl https://<space-url>/health` returns `{"status": "ready", "tasks": 3}` |
166
+ | 5 | Test reset via WebSocket | `wscat -c wss://<space-url>/ws` then send `{"type": "reset", "task_id": "task_001"}` |
167
+ | 6 | Run `openenv validate` against deployed space | All checks pass |
168
+
169
+ ### 1.8 Acceptance Criteria
170
+
171
+ - [ ] `reset(task_id)` for tasks 001, 003, 005 returns valid `MLTrainingObservation` with correct initial state
172
+ - [ ] `step()` dispatches all 14 action types correctly (investigation, fix, terminal)
173
+ - [ ] `inspect_gradients` on Task 1 → `is_exploding=True` on all layers (real torch.autograd)
174
+ - [ ] `inspect_gradients` on Task 5 → `is_exploding=False` on all layers, `gradients_were_normal=True`
175
+ - [ ] `inspect_data_batch` on Task 3 → `class_overlap_score > 0.5`
176
+ - [ ] `inspect_model_modes` on Task 5 → all layers in "eval" mode
177
+ - [ ] Context-gated penalty: `inspect_gradients`(normal) then `add_callback` → reward includes -0.20
178
+ - [ ] Context-gated penalty: `add_callback` without prior inspection → NO -0.20 penalty
179
+ - [ ] Grader for Task 1: correct path scores 1.0, wrong diagnosis scores < 0.5
180
+ - [ ] Grader for Task 5: agent that chases red herring scores 0.80-0.85 (penalty applied)
181
+ - [ ] `baseline_heuristic.py` runs twice → `diff run1.json run2.json` is empty
182
+ - [ ] `POST /baseline` returns scores for all 3 tasks, all in [0.0, 1.0]
183
+ - [ ] `POST /grader` returns score after completed episode
184
+ - [ ] `GET /tasks` returns 3 tasks with action schema
185
+ - [ ] `GET /health` returns `{"status": "ready", "tasks": 3}`
186
+ - [ ] Docker builds <500MB, starts <60s, serves on port 7860
187
+ - [ ] HF Space deployed, responds to `reset()`, tagged `openenv`
188
+ - [ ] `openenv validate` passes
189
+ - [ ] `pytest --cov` shows >80% coverage on all Phase 1 modules
190
+ - [ ] `import torch` in every core module; zero `import numpy` in core
191
+ - [ ] README has: description, action/observation spaces, 3 task descriptions, setup instructions, baseline scores
192
+
193
+ ---
194
+
195
+ ## Phase 2: Stretch — Tasks 2, 4, 6 + Code Debugging (Days 7-9)
196
+
197
+ **Goal:** Full 6-task environment with code-level debugging. Task 6 is the single highest-impact differentiator for Meta judges.
198
+
199
+ **Prerequisites:** Phase 1 acceptance criteria ALL met. HF Space deployed and passing auto-validation.
200
+
201
+ ### 2.1 Priority Order (Strict)
202
+
203
+ 1. **Task 6** first — it is the strongest differentiator and the hardest to implement
204
+ 2. **Task 2** second — structurally identical to Task 1 (vanishing vs. exploding), fastest to add
205
+ 3. **Task 4** third — medium difficulty overfitting, similar pattern to existing tasks
206
+
207
+ ### 2.2 Files to Create
208
+
209
+ | File | Purpose | Lines (est.) | Depends On |
210
+ |---|---|---|---|
211
+ | `ml_training_debugger/code_templates.py` | 4 bug variant templates, `generate_code_snippet(bug_type, seed)`, `validate_fix(bug_type, line, replacement)` with multi-strategy pipeline per spec Section 22 | ~250 | `models.py` |
212
+ | `tests/test_code_templates.py` | All 4 variants generate valid code; fix validation accepts correct fixes; rejects wrong fixes; handles whitespace/comment variations | ~150 | `code_templates.py` |
213
+
214
+ ### 2.3 Files to Edit
215
+
216
+ | File | Changes | Complexity |
217
+ |---|---|---|
218
+ | `ml_training_debugger/scenarios.py` | Add `sample_scenario` cases for task_002, task_004, task_006. Task 006 includes `bug_type` field. | Low |
219
+ | `ml_training_debugger/pytorch_engine.py` | Add fault injection for vanishing gradients, overfitting, code bug variants. | Medium |
220
+ | `ml_training_debugger/simulation.py` | Add curve generators for vanishing (flat loss), overfitting (train-val divergence), code bug variants. | Medium |
221
+ | `ml_training_debugger/reward_engine.py` | Add wrong code fix penalty (-0.10). No other changes. | Low |
222
+ | `ml_training_debugger/graders.py` | Add `grade_task_002`, `grade_task_004`, `grade_task_006`. Task 006: diagnosis must be `code_bug` always. | Medium |
223
+ | `server/environment.py` | `step()` handlers for `inspect_code` and `fix_code`. Update `available_actions`. | Medium |
224
+ | `server/app.py` | Update `/tasks` to return 6 tasks. Update `/health` to return `"tasks": 6`. | Low |
225
+ | `openenv.yaml` | Add task_002, task_004, task_006. | Low |
226
+ | `baseline_heuristic.py` | Extend decision tree for vanishing, overfitting, code bug. | Medium |
227
+ | `README.md` | Add descriptions for Tasks 2, 4, 6. Update baseline scores. | Low |
228
+
229
+ ### 2.4 Task 6 Code Fix Validation
230
+
231
+ The `validate_fix()` pipeline is defined in spec Section 22 (Known Risks). Key layers:
232
+
233
+ 1. **Normalize:** strip whitespace + inline comments → compare against known correct strings
234
+ 2. **Tokenize:** Python `tokenize` module, filter noise tokens, compare streams
235
+ 3. **Semantic patterns:** 2-3 per variant (e.g. `"criterion("` present AND `".detach()"` absent)
236
+ 4. **AST fallback:** `ast.parse()` full code with replacement, verify buggy pattern absent
237
+
238
+ Test cases that MUST pass: correct fix, trailing whitespace, inline comments, different indentation.
239
+ Test cases that MUST fail: bug still present, `pass`, wrong line number.
240
+
241
+ ### 2.5 Tests to Create/Extend
242
+
243
+ | Test File | New Coverage |
244
+ |---|---|
245
+ | `tests/test_code_templates.py` | **New file.** All 4 variants, validate_fix accepts/rejects correctly, 5+ whitespace/comment variations per variant |
246
+ | `tests/test_scenarios.py` | Extend: sample_scenario for task_002, 004, 006 |
247
+ | `tests/test_simulation.py` | Extend: vanishing flat loss, overfitting divergence, code bug symptoms |
248
+ | `tests/test_graders.py` | Extend: graders 002, 004, 006. Task 006: `code_bug` required; `batchnorm_eval_mode` on eval_mode variant = wrong |
249
+ | `tests/test_reward_engine.py` | Extend: wrong code fix penalty (-0.10) |
250
+ | `tests/test_episode_lifecycle.py` | Extend: `inspect_code` → `fix_code` available; `fix_code` before `inspect_code` → invalid |
251
+
252
+ ### 2.6 Acceptance Criteria
253
+
254
+ - [ ] All 6 tasks return valid observations from `reset()` and process all action types in `step()`
255
+ - [ ] Task 6: `inspect_code` returns `CodeSnippet` with real PyTorch code containing the sampled bug
256
+ - [ ] Task 6: `fix_code` correct → `fix_action_taken=True`, no penalty
257
+ - [ ] Task 6: `fix_code` wrong → -0.10 penalty
258
+ - [ ] Task 6: `mark_diagnosed(code_bug)` → correct (+0.50)
259
+ - [ ] Task 6: `mark_diagnosed(batchnorm_eval_mode)` on eval_mode variant → wrong (-0.30)
260
+ - [ ] `validate_fix` accepts 5+ whitespace/comment variations per variant
261
+ - [ ] `validate_fix` rejects all invalid fixes
262
+ - [ ] Graders for all 6 tasks return [0.0, 1.0] with meaningful variance
263
+ - [ ] `baseline_heuristic.py` handles all 6 tasks, still bit-exact reproducible
264
+ - [ ] `POST /baseline` returns scores for all 6 tasks
265
+ - [ ] `GET /tasks` returns 6 tasks
266
+ - [ ] `GET /health` returns `{"status": "ready", "tasks": 6}`
267
+ - [ ] All new tests pass; overall coverage >80%
268
+ - [ ] Updated openenv.yaml lists all 6 tasks
269
+ - [ ] HF Space redeployed with 6 tasks, auto-validation still passes
270
+
271
+ ---
272
+
273
+ ## Phase 3: Polish — Dashboard, Validation Suite, LLM Baseline (Days 10-11)
274
+
275
+ **Goal:** Transform a technically correct submission into a visually impressive, deeply validated, winning submission.
276
+
277
+ **Prerequisites:** Phase 2 acceptance criteria ALL met. 6-task environment deployed.
278
+
279
+ ### 3.1 Priority Order Within Phase 3
280
+
281
+ 1. **Dashboard** — transforms judging experience (highest ROI for judges)
282
+ 2. **Full test suite + README polish** — ensures no auto-validation failure
283
+ 3. **Validation suite** — answers "how realistic are your curves?"
284
+ 4. **LLM baseline** — demonstrates heuristic-reasoning gap (lowest priority)
285
+
286
+ ### 3.2 Files to Create
287
+
288
+ | File | Purpose | Lines (est.) | Priority |
289
+ |---|---|---|---|
290
+ | `server/dashboard.html` | Single-file SPA. 4 panels per spec Section 19. Plotly.js via CDN. | ~400 | 1st |
291
+ | `validation/requirements.txt` | `torch`, `matplotlib`, `scipy` | ~3 | 3rd |
292
+ | `validation/conftest.py` | Shared fixtures: CIFAR-10 subset loader, model definitions | ~50 | 3rd |
293
+ | `validation/validate_exploding_gradients.py` | Real training, compare to parametric curve, R² > 0.85 | ~80 | 3rd |
294
+ | `validation/validate_data_leakage.py` | Real training with leakage, compare | ~80 | 3rd |
295
+ | `validation/validate_batchnorm_eval.py` | Real training with `model.eval()`, compare | ~80 | 3rd |
296
+ | `validation/validate_vanishing_gradients.py` | Real gradient decay, compare | ~80 | 3rd |
297
+ | `validation/validate_overfitting.py` | Real train-val divergence, compare | ~80 | 3rd |
298
+ | `validation/validate_code_bugs.py` | Run 4 bug variants, confirm symptoms | ~80 | 3rd |
299
+ | `validation/reports/` | Pre-computed fidelity scores + comparison plots | — | 3rd |
300
+ | `baseline_inference.py` | LLM agent (GPT-4o, temp=0.0, seed=42). Runs all 6 tasks. **Now install openai.** | ~200 | 4th |
301
+
302
+ ### 3.3 Files to Edit
303
+
304
+ | File | Changes | Priority |
305
+ |---|---|---|
306
+ | `server/app.py` | Add `GET /dashboard` and `GET /validation-report` routes | 1st/3rd |
307
+ | `requirements.txt` | Add `openai` (only now, for LLM baseline) | 4th |
308
+ | `Dockerfile` | `COPY validation/reports/` and `COPY server/dashboard.html` | 1st |
309
+ | `README.md` | Final polish: dashboard description, validation suite, measured baseline scores | 2nd |
310
+ | `openenv.yaml` | Add dashboard and validation-report to endpoints | 1st |
311
+
312
+ ### 3.4 Dashboard Panels
313
+
314
+ See spec Section 19 for full specification. Summary:
315
+ 1. **Training Metrics** — Plotly.js line charts for loss/accuracy with restart markers
316
+ 2. **Gradient & Weight Heatmap** — color-coded per-layer grid (green/yellow/red/blue)
317
+ 3. **Action Timeline** — horizontal bars per step, color-coded by type, reward bars
318
+ 4. **Episode Summary** — task ID, state flags, available actions, grader score
319
+
320
+ Tech: single HTML file, Plotly.js CDN, native WebSocket, CSS Grid. Zero Docker bloat.
321
+
322
+ ### 3.5 Validation Suite
323
+
324
+ Run locally (NOT in Docker build). Each script: real training → capture metrics → compare to parametric → assert R² > 0.85 → save plots. Pre-computed reports committed to git and served via `/validation-report`. See spec Section 18.
325
+
326
+ ### 3.6 Tests to Create/Extend
327
+
328
+ | Test File | Coverage |
329
+ |---|---|
330
+ | `tests/test_dashboard.py` | `GET /dashboard` returns 200 with HTML containing "Plotly" and "WebSocket" |
331
+ | `tests/test_endpoints.py` | Integration: full episode via HTTP (reset→step→grader), verify response schemas |
332
+ | `tests/test_baseline_reproducibility.py` | Run baseline twice, assert identical JSON |
333
+ | Existing test files | Fill coverage gaps to >80% on every module |
334
+
335
+ ### 3.7 Acceptance Criteria
336
+
337
+ - [ ] `GET /dashboard` serves HTML that renders in a browser with 4 panels
338
+ - [ ] Dashboard connects to WebSocket and updates in real time during a baseline run
339
+ - [ ] Validation suite passes all scripts with R² > 0.85 (run locally)
340
+ - [ ] Pre-computed validation reports exist in `validation/reports/`
341
+ - [ ] `GET /validation-report` serves fidelity data
342
+ - [ ] LLM baseline runs, scores higher than heuristic on Tasks 5 and 6 (if implemented)
343
+ - [ ] README is complete: all 6 tasks, both baselines, dashboard description, setup instructions
344
+ - [ ] `pytest --cov` shows >80% coverage across all modules
345
+ - [ ] Final `openenv validate` passes
346
+ - [ ] Final Docker build <500MB, starts <60s
347
+ - [ ] HF Space redeployed with dashboard + all features
348
+
349
+ ---
350
+
351
+ ## Pre-Submission Gate Checklist
352
+
353
+ **Every item must be checked before submitting. Failure on any starred (*) item = disqualification.**
354
+
355
+ ### Auto-Validation Gates (*)
356
+
357
+ - [ ] * **HF Space deploys** — `curl https://<space-url>/health` returns `{"status": "ready", "tasks": N}` with HTTP 200
358
+ - [ ] * **HF Space responds to reset** — WebSocket connection to `/ws`, send reset message, receive valid observation
359
+ - [ ] * **OpenEnv spec compliance** — `openenv validate` passes (openenv.yaml present, typed models, step/reset/state work)
360
+ - [ ] * **Dockerfile builds** — `docker build -t pytorch-debugger .` succeeds
361
+ - [ ] * **Docker runs** — `docker run -p 7860:7860 pytorch-debugger` starts and serves on port 7860
362
+ - [ ] * **Baseline reproduces** — `python baseline_heuristic.py > run1.json && python baseline_heuristic.py > run2.json && diff run1.json run2.json` produces no output
363
+ - [ ] * **3+ tasks with graders** — `GET /tasks` returns ≥3 tasks; `POST /grader` returns score in [0.0, 1.0] after each task completes
364
+ - [ ] * **Graders produce varying scores** — different agent behaviors produce different scores (not always same value)
365
+
366
+ ### Required Endpoint Gates (*)
367
+
368
+ - [ ] * **`GET /tasks`** — returns JSON with task IDs, difficulties, action schema
369
+ - [ ] * **`POST /grader`** — returns `{"score": float}` after a completed episode
370
+ - [ ] * **`POST /baseline`** — triggers baseline, returns scores for all tasks
371
+ - [ ] * **`GET /health`** — returns `{"status": "ready", "tasks": N}`
372
+
373
+ ### Submission Artifacts (*)
374
+
375
+ - [ ] * **Public GitHub repo** — contains all code, README, requirements, openenv.yaml
376
+ - [ ] * **HF Spaces demo link** — deployed, tagged `openenv`, accessible
377
+ - [ ] * **README complete** — environment description, action/observation space definitions, task descriptions with difficulty, setup instructions, baseline scores
378
+
379
+ ### Quality Gates (Not DQ, but impact scoring)
380
+
381
+ - [ ] All typed Pydantic models — no `Dict[str, Any]`
382
+ - [ ] `import torch` in every core module — zero `import numpy` in core
383
+ - [ ] Context-gated penalty fires correctly (manually tested both paths)
384
+ - [ ] Task 5 red herrings present: FC spike, GPU 91%, conv1 near-vanishing, error_log warning
385
+ - [ ] Task 6 code fix validation handles whitespace and comment variations
386
+ - [ ] Task 6 diagnosis is always `code_bug` regardless of bug variant
387
+ - [ ] Grader and reward function are separate modules
388
+ - [ ] Step penalty is flat -0.01 (not multiplied by step_count)
389
+ - [ ] Episode state is isolated per WebSocket session
390
+ - [ ] Test suite passes with >80% coverage
391
+ - [ ] Code formatted with black, linted with ruff, imports sorted with isort
392
+
393
+ ### Final Smoke Test Sequence
394
+
395
+ Run this entire sequence the night before submission:
396
+
397
+ ```bash
398
+ # 1. Clean build
399
+ docker build --no-cache -t pytorch-debugger .
400
+ docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
401
+
402
+ # 2. Wait for startup
403
+ sleep 10
404
+ curl -f http://localhost:7860/health || echo "FAIL: health"
405
+
406
+ # 3. Tasks endpoint
407
+ curl -f http://localhost:7860/tasks | python -m json.tool || echo "FAIL: tasks"
408
+
409
+ # 4. Baseline reproducibility
410
+ python baseline_heuristic.py > run1.json 2>/dev/null
411
+ python baseline_heuristic.py > run2.json 2>/dev/null
412
+ diff run1.json run2.json && echo "PASS: reproducible" || echo "FAIL: non-reproducible"
413
+
414
+ # 5. Baseline via endpoint
415
+ curl -f -X POST http://localhost:7860/baseline | python -m json.tool || echo "FAIL: baseline endpoint"
416
+
417
+ # 6. Grader via endpoint (after baseline has completed episodes)
418
+ curl -f -X POST http://localhost:7860/grader | python -m json.tool || echo "FAIL: grader endpoint"
419
+
420
+ # 7. OpenEnv validation
421
+ openenv validate || echo "FAIL: openenv validate"
422
+
423
+ # 8. Test suite
424
+ pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
425
+
426
+ # 9. Cleanup
427
+ docker stop smoke-test && docker rm smoke-test
428
+
429
+ echo "=== Smoke test complete ==="
430
+ ```
431
+
432
+ ### If Something Fails at Submission Time
433
+
434
+ | Failure | Triage |
435
+ |---|---|
436
+ | HF Space won't deploy | Check Dockerfile CMD, port 7860, build logs. Redeploy. |
437
+ | Baseline non-reproducible | Check `torch.manual_seed()` in `reset()`. Check for `random` module usage. |
438
+ | Grader returns same score | Check that `sample_scenario` uses different seeds. Check grader logic has branching. |
439
+ | `openenv validate` fails | Read error message. Usually missing field in openenv.yaml or wrong model base class. |
440
+ | Docker image >500MB | Check `docker images` size. Remove unused deps. Ensure torch is CPU-only. |
441
+ | Test coverage <80% | Run `pytest --cov` with `--cov-report=html`. Find uncovered branches. Add targeted tests. |
baseline_heuristic.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Rule-based heuristic baseline agent.
3
+
4
+ Deterministic decision tree — no API key required. Bit-exact reproducible.
5
+ Spec reference: Section 17.
6
+
7
+ Usage:
8
+ python baseline_heuristic.py [--url http://localhost:7860]
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import argparse
14
+ import json
15
+ import sys
16
+
17
+ from ml_training_debugger.graders import grade_episode
18
+ from ml_training_debugger.models import EpisodeState, MLTrainingAction, MLTrainingObservation
19
+ from ml_training_debugger.scenarios import sample_scenario
20
+ from server.environment import MLTrainingEnvironment
21
+
22
+ MVP_TASKS = ["task_001", "task_003", "task_005"]
23
+
24
+
25
+ def run_heuristic_episode(task_id: str, seed: int = 42) -> float:
26
+ """Run one heuristic baseline episode. Returns grader score."""
27
+ env = MLTrainingEnvironment()
28
+ obs = env.reset(seed=seed, episode_id=f"baseline_{task_id}", task_id=task_id)
29
+
30
+ # Step 1: inspect_gradients
31
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
32
+
33
+ if obs.gradient_stats:
34
+ # Check exploding
35
+ if any(g.is_exploding for g in obs.gradient_stats):
36
+ obs = env.step(
37
+ MLTrainingAction(
38
+ action_type="modify_config",
39
+ target="learning_rate",
40
+ value=0.001,
41
+ )
42
+ )
43
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
44
+ obs = env.step(
45
+ MLTrainingAction(
46
+ action_type="mark_diagnosed",
47
+ diagnosis="lr_too_high",
48
+ )
49
+ )
50
+ session = env._get_session()
51
+ return session.last_score if session and session.last_score is not None else 0.0
52
+
53
+ # Check vanishing
54
+ if any(g.is_vanishing for g in obs.gradient_stats):
55
+ obs = env.step(
56
+ MLTrainingAction(
57
+ action_type="modify_config",
58
+ target="learning_rate",
59
+ value=0.01,
60
+ )
61
+ )
62
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
63
+ obs = env.step(
64
+ MLTrainingAction(
65
+ action_type="mark_diagnosed",
66
+ diagnosis="vanishing_gradients",
67
+ )
68
+ )
69
+ session = env._get_session()
70
+ return session.last_score if session and session.last_score is not None else 0.0
71
+
72
+ # Step 2: inspect_data_batch
73
+ obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
74
+ if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
75
+ obs = env.step(MLTrainingAction(action_type="patch_data_loader"))
76
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
77
+ obs = env.step(
78
+ MLTrainingAction(
79
+ action_type="mark_diagnosed",
80
+ diagnosis="data_leakage",
81
+ )
82
+ )
83
+ session = env._get_session()
84
+ return session.last_score if session and session.last_score is not None else 0.0
85
+
86
+ # Check overfitting (val_loss diverging)
87
+ if obs.val_loss_history and len(obs.val_loss_history) >= 10:
88
+ early = sum(obs.val_loss_history[:5]) / 5
89
+ late = sum(obs.val_loss_history[-5:]) / 5
90
+ if (
91
+ late > early * 1.2
92
+ and obs.data_batch_stats
93
+ and obs.data_batch_stats.class_overlap_score < 0.1
94
+ ):
95
+ obs = env.step(
96
+ MLTrainingAction(
97
+ action_type="modify_config",
98
+ target="weight_decay",
99
+ value=0.01,
100
+ )
101
+ )
102
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
103
+ obs = env.step(
104
+ MLTrainingAction(
105
+ action_type="mark_diagnosed",
106
+ diagnosis="overfitting",
107
+ )
108
+ )
109
+ session = env._get_session()
110
+ return session.last_score if session and session.last_score is not None else 0.0
111
+
112
+ # Step 3: inspect_model_modes
113
+ obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
114
+ if obs.model_mode_info:
115
+ has_eval = any(v == "eval" for v in obs.model_mode_info.values())
116
+ if has_eval:
117
+ obs = env.step(MLTrainingAction(action_type="fix_model_mode"))
118
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
119
+ obs = env.step(
120
+ MLTrainingAction(
121
+ action_type="mark_diagnosed",
122
+ diagnosis="batchnorm_eval_mode",
123
+ )
124
+ )
125
+ session = env._get_session()
126
+ return session.last_score if session and session.last_score is not None else 0.0
127
+
128
+ # Step 4: inspect_code
129
+ obs = env.step(MLTrainingAction(action_type="inspect_code"))
130
+ if obs.code_snippet:
131
+ code = obs.code_snippet.code
132
+ if "model.eval()" in code and "model.train()" not in code:
133
+ obs = env.step(
134
+ MLTrainingAction(
135
+ action_type="fix_code",
136
+ line=5,
137
+ replacement="model.train()",
138
+ )
139
+ )
140
+ elif ".detach()" in code:
141
+ obs = env.step(
142
+ MLTrainingAction(
143
+ action_type="fix_code",
144
+ line=14,
145
+ replacement=" loss = criterion(output, batch_y)",
146
+ )
147
+ )
148
+
149
+ if obs.episode_state.fix_action_taken:
150
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
151
+
152
+ obs = env.step(
153
+ MLTrainingAction(
154
+ action_type="mark_diagnosed",
155
+ diagnosis="code_bug",
156
+ )
157
+ )
158
+ session = env._get_session()
159
+ return session.last_score if session and session.last_score is not None else 0.0
160
+
161
+ # Fallback
162
+ obs = env.step(
163
+ MLTrainingAction(
164
+ action_type="mark_diagnosed",
165
+ diagnosis="overfitting",
166
+ )
167
+ )
168
+ session = env._get_session()
169
+ return session.last_score if session and session.last_score is not None else 0.0
170
+
171
+
172
+ def main() -> None:
173
+ parser = argparse.ArgumentParser(description="Rule-based baseline agent")
174
+ parser.add_argument("--url", default="http://localhost:7860")
175
+ args = parser.parse_args()
176
+
177
+ scores: dict[str, float] = {}
178
+ for task_id in MVP_TASKS:
179
+ score = run_heuristic_episode(task_id)
180
+ scores[task_id] = round(score, 4)
181
+
182
+ print(json.dumps(scores, indent=2))
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
deploy.sh ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -euo pipefail
3
+
4
+ echo "=== PyTorch Training Run Debugger — Pre-Submission Smoke Test ==="
5
+ echo ""
6
+
7
+ # 1. Run tests
8
+ echo "=== 1. Running test suite ==="
9
+ source .venv/bin/activate
10
+ pytest tests/ -v --cov=ml_training_debugger --cov-report=term-missing
11
+ echo ""
12
+
13
+ # 2. Code formatting check
14
+ echo "=== 2. Code formatting ==="
15
+ black --check ml_training_debugger/ server/ tests/ || { echo "Run: black ml_training_debugger/ server/ tests/"; exit 1; }
16
+ ruff check ml_training_debugger/ server/ tests/ || { echo "Run: ruff check --fix"; exit 1; }
17
+ isort --check ml_training_debugger/ server/ tests/ --profile black || { echo "Run: isort --profile black"; exit 1; }
18
+ echo "PASS: formatting OK"
19
+ echo ""
20
+
21
+ # 3. Baseline reproducibility
22
+ echo "=== 3. Baseline reproducibility ==="
23
+ python baseline_heuristic.py > /tmp/run1.json 2>/dev/null
24
+ python baseline_heuristic.py > /tmp/run2.json 2>/dev/null
25
+ diff /tmp/run1.json /tmp/run2.json && echo "PASS: bit-exact reproducible" || { echo "FAIL: non-reproducible"; exit 1; }
26
+ echo ""
27
+
28
+ # 4. Docker build
29
+ echo "=== 4. Docker build ==="
30
+ docker build -t pytorch-debugger .
31
+ IMAGE_SIZE=$(docker images pytorch-debugger --format "{{.Size}}")
32
+ echo "Image size: $IMAGE_SIZE"
33
+ echo ""
34
+
35
+ # 5. Docker run + health check
36
+ echo "=== 5. Docker run + endpoint checks ==="
37
+ docker run -d -p 7860:7860 --name smoke-test pytorch-debugger
38
+ sleep 10
39
+
40
+ curl -f http://localhost:7860/health || { echo "FAIL: health"; docker stop smoke-test; docker rm smoke-test; exit 1; }
41
+ echo ""
42
+ curl -f http://localhost:7860/tasks || { echo "FAIL: tasks"; docker stop smoke-test; docker rm smoke-test; exit 1; }
43
+ echo ""
44
+ curl -f -X POST http://localhost:7860/grader || { echo "FAIL: grader"; docker stop smoke-test; docker rm smoke-test; exit 1; }
45
+ echo ""
46
+
47
+ # 6. Cleanup
48
+ docker stop smoke-test && docker rm smoke-test
49
+ rm -f /tmp/run1.json /tmp/run2.json
50
+
51
+ echo ""
52
+ echo "=== ALL CHECKS PASSED ==="
ml-training-debugger-spec.md ADDED
The diff for this file is too large to render. See raw diff
 
ml_training_debugger/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """PyTorch Training Run Debugger — OpenEnv Environment."""
2
+
3
+ __version__ = "1.0.0"
ml_training_debugger/client.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Typed EnvClient for baseline scripts.
2
+
3
+ Extends GenericEnvClient since we can't easily subclass the
4
+ abstract EnvClient without implementing all transport methods.
5
+ Used by baseline_heuristic.py.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from openenv.core.generic_client import GenericEnvClient
11
+
12
+
13
+ class MLTrainingEnvClient(GenericEnvClient):
14
+ """Typed client for the PyTorch Training Debugger environment.
15
+
16
+ Wraps GenericEnvClient for convenient use in baselines.
17
+ Actions are sent as dicts matching MLTrainingAction schema.
18
+ Observations are received as dicts matching MLTrainingObservation schema.
19
+ """
20
+
21
+ pass
ml_training_debugger/code_templates.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch code snippet templates for Task 6 code-level debugging.
2
+
3
+ Each template is a real, syntactically valid Python/PyTorch training script
4
+ with one injected bug. Spec reference: Section 11 (Task 6), Section 22.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import ast
10
+ import io
11
+ import tokenize
12
+ from typing import Optional
13
+
14
+ import torch # noqa: F401 — PyTorch-native project
15
+
16
+ # Bug variant templates: (buggy_code, correct_line_num, correct_replacement)
17
+ _TEMPLATES: dict[str, tuple[str, int, str]] = {
18
+ "eval_mode": (
19
+ """\
20
+ import torch
21
+ import torch.nn as nn
22
+
23
+ model = SimpleCNN()
24
+ model.eval()
25
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
26
+ criterion = nn.CrossEntropyLoss()
27
+
28
+ for epoch in range(100):
29
+ for batch_x, batch_y in train_loader:
30
+ optimizer.zero_grad()
31
+ output = model(batch_x)
32
+ loss = criterion(output, batch_y)
33
+ loss.backward()
34
+ optimizer.step()""",
35
+ 5,
36
+ "model.train()",
37
+ ),
38
+ "detach_loss": (
39
+ """\
40
+ import torch
41
+ import torch.nn as nn
42
+
43
+ model = SimpleCNN()
44
+ model.train()
45
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
46
+ criterion = nn.CrossEntropyLoss()
47
+
48
+ for epoch in range(100):
49
+ for batch_x, batch_y in train_loader:
50
+ optimizer.zero_grad()
51
+ output = model(batch_x)
52
+ loss = criterion(output, batch_y).detach()
53
+ loss.backward()
54
+ optimizer.step()""",
55
+ 14,
56
+ " loss = criterion(output, batch_y)",
57
+ ),
58
+ "zero_grad_missing": (
59
+ """\
60
+ import torch
61
+ import torch.nn as nn
62
+
63
+ model = SimpleCNN()
64
+ model.train()
65
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
66
+ criterion = nn.CrossEntropyLoss()
67
+
68
+ for epoch in range(100):
69
+ for batch_x, batch_y in train_loader:
70
+ output = model(batch_x)
71
+ loss = criterion(output, batch_y)
72
+ loss.backward()
73
+ optimizer.step()""",
74
+ 11,
75
+ " optimizer.zero_grad()",
76
+ ),
77
+ "inplace_relu": (
78
+ """\
79
+ import torch
80
+ import torch.nn as nn
81
+ import torch.nn.functional as F
82
+
83
+ model = SimpleCNN()
84
+ model.train()
85
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
86
+ criterion = nn.CrossEntropyLoss()
87
+
88
+ for epoch in range(100):
89
+ for batch_x, batch_y in train_loader:
90
+ optimizer.zero_grad()
91
+ output = model(batch_x)
92
+ output = F.relu(output, inplace=True)
93
+ loss = criterion(output, batch_y)
94
+ loss.backward()
95
+ optimizer.step()""",
96
+ 15,
97
+ " output = F.relu(output)",
98
+ ),
99
+ }
100
+
101
+ # Semantic equivalence patterns per bug variant
102
+ _SEMANTIC_PATTERNS: dict[str, list[tuple[str, str]]] = {
103
+ "eval_mode": [
104
+ # (must_contain, must_not_contain)
105
+ ("model.train()", "model.eval()"),
106
+ ],
107
+ "detach_loss": [
108
+ ("criterion(", ".detach()"),
109
+ ],
110
+ "zero_grad_missing": [
111
+ ("zero_grad()", ""), # just needs zero_grad present
112
+ ],
113
+ "inplace_relu": [
114
+ ("F.relu(", "inplace=True"),
115
+ ],
116
+ }
117
+
118
+
119
+ def generate_code_snippet(bug_type: str, seed: int = 42) -> dict:
120
+ """Generate a code snippet with the specified bug.
121
+
122
+ Returns dict with keys: code, filename, line_count, imports, hint.
123
+ """
124
+ if bug_type not in _TEMPLATES:
125
+ raise ValueError(f"Unknown bug_type: {bug_type}")
126
+
127
+ code, _line, _replacement = _TEMPLATES[bug_type]
128
+ lines = code.strip().split("\n")
129
+ imports = [
130
+ line for line in lines if line.startswith("import ") or line.startswith("from ")
131
+ ]
132
+
133
+ hint: Optional[str] = None
134
+ if bug_type == "eval_mode":
135
+ hint = "Check the model mode before the training loop."
136
+ elif bug_type == "detach_loss":
137
+ hint = "Examine how the loss is computed and used."
138
+
139
+ return {
140
+ "code": code,
141
+ "filename": "train.py",
142
+ "line_count": len(lines),
143
+ "imports": imports,
144
+ "hint": hint,
145
+ }
146
+
147
+
148
+ def _normalize_code(s: str) -> str:
149
+ """Strip whitespace and inline comments for comparison."""
150
+ s = s.strip()
151
+ # Remove inline comments
152
+ result_lines: list[str] = []
153
+ for line in s.split("\n"):
154
+ # Remove trailing comment but preserve strings
155
+ stripped = line.rstrip()
156
+ result_lines.append(stripped)
157
+ return "\n".join(result_lines)
158
+
159
+
160
+ def _tokenize_compare(original: str, replacement: str) -> bool:
161
+ """Compare token streams ignoring whitespace and comments."""
162
+
163
+ def get_tokens(code: str) -> list[tuple[int, str]]:
164
+ try:
165
+ tokens = list(tokenize.generate_tokens(io.StringIO(code).readline))
166
+ # Filter out COMMENT, NL, NEWLINE, INDENT, DEDENT, ENCODING, ENDMARKER
167
+ skip = {
168
+ tokenize.COMMENT,
169
+ tokenize.NL,
170
+ tokenize.NEWLINE,
171
+ tokenize.INDENT,
172
+ tokenize.DEDENT,
173
+ tokenize.ENCODING,
174
+ tokenize.ENDMARKER,
175
+ }
176
+ return [(t.type, t.string) for t in tokens if t.type not in skip]
177
+ except tokenize.TokenError:
178
+ return []
179
+
180
+ return get_tokens(original) == get_tokens(replacement)
181
+
182
+
183
+ def validate_fix(bug_type: str, line: int, replacement: str) -> bool:
184
+ """Validate a code fix submission.
185
+
186
+ Multi-strategy pipeline per spec Section 22:
187
+ 1. Normalize whitespace + strip comments
188
+ 2. Token-stream comparison
189
+ 3. Semantic equivalence patterns
190
+ 4. AST fallback
191
+ """
192
+ if bug_type not in _TEMPLATES:
193
+ return False
194
+
195
+ code, correct_line, correct_replacement = _TEMPLATES[bug_type]
196
+ lines = code.strip().split("\n")
197
+
198
+ # Check line number is valid
199
+ if line < 1 or line > len(lines):
200
+ return False
201
+
202
+ # For zero_grad_missing, the fix is inserting a line, not replacing
203
+ if bug_type == "zero_grad_missing":
204
+ # Accept if the replacement contains zero_grad
205
+ normalized = _normalize_code(replacement)
206
+ if "zero_grad" in normalized:
207
+ return True
208
+ return False
209
+
210
+ # Strategy 1: Normalize and compare
211
+ norm_replacement = _normalize_code(replacement)
212
+ norm_correct = _normalize_code(correct_replacement)
213
+ if norm_replacement == norm_correct:
214
+ return True
215
+
216
+ # Strategy 2: Token-stream comparison
217
+ if _tokenize_compare(correct_replacement, replacement):
218
+ return True
219
+
220
+ # Strategy 3: Semantic equivalence patterns
221
+ patterns = _SEMANTIC_PATTERNS.get(bug_type, [])
222
+ for must_contain, must_not_contain in patterns:
223
+ if must_contain and must_contain in norm_replacement:
224
+ if not must_not_contain or must_not_contain not in norm_replacement:
225
+ return True
226
+
227
+ # Strategy 4: AST fallback — verify buggy pattern absent
228
+ try:
229
+ # Replace the line in the full code and parse
230
+ new_lines = lines.copy()
231
+ new_lines[line - 1] = replacement.rstrip()
232
+ new_code = "\n".join(new_lines)
233
+ tree = ast.parse(new_code)
234
+
235
+ # Check that the buggy pattern is absent
236
+ ast.dump(tree) # Validates AST is well-formed
237
+ if bug_type == "eval_mode" and "eval" not in replacement.lower():
238
+ if "train" in replacement.lower():
239
+ return True
240
+ if bug_type == "detach_loss" and "detach" not in replacement.lower():
241
+ return True
242
+ if bug_type == "inplace_relu" and "inplace" not in replacement.lower():
243
+ if "relu" in replacement.lower():
244
+ return True
245
+ except SyntaxError:
246
+ pass
247
+
248
+ return False
ml_training_debugger/graders.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Per-task grader functions — returns normalized 0.0-1.0 score at episode end.
2
+
3
+ Separate from reward_engine.py. Evaluates EpisodeState holistically.
4
+ NOT a sum of step rewards. Spec reference: Section 11 grader breakdowns.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch # noqa: F401 — PyTorch-native project
10
+
11
+ from ml_training_debugger.models import EpisodeState
12
+ from ml_training_debugger.scenarios import ScenarioParams
13
+
14
+ FIX_ACTIONS = frozenset(
15
+ {
16
+ "modify_config",
17
+ "add_callback",
18
+ "replace_optimizer",
19
+ "patch_data_loader",
20
+ "fix_model_mode",
21
+ "fix_code",
22
+ }
23
+ )
24
+
25
+
26
+ def _has_action(state: EpisodeState, action_type: str) -> bool:
27
+ return action_type in state.actions_taken
28
+
29
+
30
+ def _correct_diagnosis(state: EpisodeState, scenario: ScenarioParams) -> bool:
31
+ if not state.diagnosis_submitted:
32
+ return False
33
+ # Find the diagnosis from actions_taken metadata
34
+ # We store "mark_diagnosed:<diagnosis>" in actions_taken
35
+ for action_str in reversed(state.actions_taken):
36
+ if action_str.startswith("mark_diagnosed:"):
37
+ submitted = action_str.split(":", 1)[1]
38
+ return submitted == scenario.root_cause.value
39
+ return False
40
+
41
+
42
+ def _submitted_diagnosis(state: EpisodeState) -> str | None:
43
+ for action_str in reversed(state.actions_taken):
44
+ if action_str.startswith("mark_diagnosed:"):
45
+ return action_str.split(":", 1)[1]
46
+ return None
47
+
48
+
49
+ def grade_task_001(state: EpisodeState, scenario: ScenarioParams) -> float:
50
+ """Grade Task 1 — Exploding Gradients (easy). Spec Section 11."""
51
+ score = 0.0
52
+
53
+ # +0.05 for inspect_gradients
54
+ if state.gradients_inspected:
55
+ score += 0.05
56
+
57
+ # +0.20 for correct fix (modify_config with LR reduction)
58
+ if _has_action(state, "modify_config"):
59
+ score += 0.20
60
+
61
+ # +0.35 for restart with convergence
62
+ if state.restart_after_fix:
63
+ score += 0.35
64
+
65
+ # +0.40 for correct diagnosis
66
+ if _correct_diagnosis(state, scenario):
67
+ score += 0.40
68
+
69
+ return min(1.0, max(0.0, score))
70
+
71
+
72
+ def grade_task_002(state: EpisodeState, scenario: ScenarioParams) -> float:
73
+ """Grade Task 2 — Vanishing Gradients (easy). Spec Section 11."""
74
+ score = 0.0
75
+
76
+ if state.gradients_inspected:
77
+ score += 0.05
78
+ if _has_action(state, "modify_config"):
79
+ score += 0.20
80
+ if state.restart_after_fix:
81
+ score += 0.35
82
+ if _correct_diagnosis(state, scenario):
83
+ score += 0.40
84
+
85
+ return min(1.0, max(0.0, score))
86
+
87
+
88
+ def grade_task_003(state: EpisodeState, scenario: ScenarioParams) -> float:
89
+ """Grade Task 3 — Silent Data Leakage (medium). Spec Section 11."""
90
+ score = 0.0
91
+
92
+ # +0.05 for inspect_data_batch
93
+ if state.data_inspected:
94
+ score += 0.05
95
+
96
+ # +0.30 for patch_data_loader
97
+ if _has_action(state, "patch_data_loader"):
98
+ score += 0.30
99
+
100
+ # +0.30 for restart with convergence (val accuracy normalizes)
101
+ if state.restart_after_fix:
102
+ score += 0.30
103
+
104
+ # +0.35 for correct diagnosis
105
+ if _correct_diagnosis(state, scenario):
106
+ score += 0.35
107
+
108
+ return min(1.0, max(0.0, score))
109
+
110
+
111
+ def grade_task_004(state: EpisodeState, scenario: ScenarioParams) -> float:
112
+ """Grade Task 4 — Overfitting (medium). Spec Section 11."""
113
+ score = 0.0
114
+
115
+ if state.data_inspected:
116
+ score += 0.05
117
+ if _has_action(state, "modify_config") or _has_action(state, "add_callback"):
118
+ score += 0.25
119
+ if state.restart_after_fix:
120
+ score += 0.30
121
+ if _correct_diagnosis(state, scenario):
122
+ score += 0.40
123
+
124
+ return min(1.0, max(0.0, score))
125
+
126
+
127
+ def grade_task_005(state: EpisodeState, scenario: ScenarioParams) -> float:
128
+ """Grade Task 5 — BatchNorm Eval Mode (hard). Spec Section 11.
129
+
130
+ Context-gated penalty: -0.20 if add_callback after gradients_were_normal.
131
+ """
132
+ score = 0.0
133
+
134
+ # +0.05 for inspect_gradients
135
+ if state.gradients_inspected:
136
+ score += 0.05
137
+
138
+ # +0.05 for inspect_model_modes — the revealing action
139
+ if state.model_modes_inspected:
140
+ score += 0.05
141
+
142
+ # -0.20 for add_callback after gradients_were_normal
143
+ if (
144
+ _has_action(state, "add_callback")
145
+ and state.gradients_inspected
146
+ and state.gradients_were_normal
147
+ ):
148
+ score -= 0.20
149
+
150
+ # +0.25 for fix_model_mode
151
+ if _has_action(state, "fix_model_mode"):
152
+ score += 0.25
153
+
154
+ # +0.30 for restart with convergence
155
+ if state.restart_after_fix:
156
+ score += 0.30
157
+
158
+ # +0.40 for correct diagnosis
159
+ if _correct_diagnosis(state, scenario):
160
+ score += 0.40
161
+
162
+ return min(1.0, max(0.0, score))
163
+
164
+
165
+ def grade_task_006(state: EpisodeState, scenario: ScenarioParams) -> float:
166
+ """Grade Task 6 — PyTorch Code Bug (hard). Spec Section 11.
167
+
168
+ Diagnosis must ALWAYS be 'code_bug' regardless of bug variant.
169
+ """
170
+ score = 0.0
171
+
172
+ # +0.05 for inspect_code
173
+ if state.code_inspected:
174
+ score += 0.05
175
+
176
+ # +0.30 for correct code fix
177
+ if _has_action(state, "fix_code") and state.fix_action_taken:
178
+ score += 0.30
179
+
180
+ # +0.25 for restart with convergence
181
+ if state.restart_after_fix:
182
+ score += 0.25
183
+
184
+ # +0.40 for correct diagnosis (must be code_bug)
185
+ if _correct_diagnosis(state, scenario):
186
+ score += 0.40
187
+
188
+ return min(1.0, max(0.0, score))
189
+
190
+
191
+ # Registry mapping task IDs to grader functions
192
+ GRADERS = {
193
+ "task_001": grade_task_001,
194
+ "task_002": grade_task_002,
195
+ "task_003": grade_task_003,
196
+ "task_004": grade_task_004,
197
+ "task_005": grade_task_005,
198
+ "task_006": grade_task_006,
199
+ }
200
+
201
+
202
+ def grade_episode(task_id: str, state: EpisodeState, scenario: ScenarioParams) -> float:
203
+ """Grade a completed episode. Returns 0.0-1.0."""
204
+ grader = GRADERS.get(task_id)
205
+ if grader is None:
206
+ return 0.0
207
+ return grader(state, scenario)
ml_training_debugger/models.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """All Pydantic models, enums, and typed data structures.
2
+
3
+ No business logic. Pure data definitions.
4
+ Spec reference: Section 10 — Data Models.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import enum
10
+ from typing import Optional, Union
11
+
12
+ import torch # noqa: F401 — PyTorch-native project, required import
13
+ from openenv.core.env_server.types import Action, Observation
14
+ from pydantic import BaseModel, Field
15
+
16
+
17
+ class RootCauseDiagnosis(str, enum.Enum):
18
+ """Closed enumeration of ML failure root causes. Spec Section 10."""
19
+
20
+ LR_TOO_HIGH = "lr_too_high"
21
+ VANISHING_GRADIENTS = "vanishing_gradients"
22
+ DATA_LEAKAGE = "data_leakage"
23
+ OVERFITTING = "overfitting"
24
+ BATCHNORM_EVAL_MODE = "batchnorm_eval_mode"
25
+ CODE_BUG = "code_bug"
26
+
27
+
28
+ VALID_DIAGNOSES: set[str] = {d.value for d in RootCauseDiagnosis}
29
+
30
+
31
+ class TrainingConfig(BaseModel):
32
+ """Typed hyperparameter configuration. Spec Section 10."""
33
+
34
+ learning_rate: float = 0.001
35
+ weight_decay: float = 0.0001
36
+ batch_size: int = 64
37
+ hidden_dim: int = 64
38
+ num_layers: int = 3
39
+ optimizer: str = "adam"
40
+ dropout_rate: float = 0.0
41
+ gradient_clip_norm: Optional[float] = None
42
+
43
+
44
+ VALID_CONFIG_KEYS: set[str] = set(TrainingConfig.model_fields.keys())
45
+
46
+
47
+ class GradientStats(BaseModel):
48
+ """Per-layer gradient information from real torch.autograd. Spec Section 10."""
49
+
50
+ layer_name: str
51
+ norm_history: list[float]
52
+ mean_norm: float
53
+ max_norm: float
54
+ is_exploding: bool # True when mean_norm > 10.0
55
+ is_vanishing: bool # True when mean_norm < 1e-6
56
+
57
+
58
+ class ModelWeightStats(BaseModel):
59
+ """Per-layer weight statistics from real state_dict(). Spec Section 10."""
60
+
61
+ layer_name: str
62
+ weight_norm: float
63
+ weight_mean: float
64
+ weight_std: float
65
+ weight_min: float
66
+ weight_max: float
67
+ dead_neuron_pct: float = 0.0
68
+ has_nan: bool = False
69
+ has_inf: bool = False
70
+
71
+
72
+ class DataBatchStats(BaseModel):
73
+ """Data batch inspection results. Spec Section 10."""
74
+
75
+ label_distribution: dict[int, float]
76
+ feature_mean: float
77
+ feature_std: float
78
+ null_count: int = 0
79
+ class_overlap_score: float
80
+ batch_size: int
81
+ duplicate_ratio: float = 0.0
82
+
83
+
84
+ class CodeSnippet(BaseModel):
85
+ """PyTorch code for Task 6 inspection. Spec Section 10."""
86
+
87
+ code: str
88
+ filename: str = "train.py"
89
+ line_count: int
90
+ imports: list[str]
91
+ hint: Optional[str] = None
92
+
93
+
94
+ class EpisodeState(BaseModel):
95
+ """Tracks agent history within an episode. Spec Section 10."""
96
+
97
+ step_count: int = 0
98
+ gradients_inspected: bool = False
99
+ gradients_were_normal: bool = False
100
+ data_inspected: bool = False
101
+ model_modes_inspected: bool = False
102
+ model_weights_inspected: bool = False
103
+ code_inspected: bool = False
104
+ fix_action_taken: bool = False
105
+ restart_after_fix: bool = False
106
+ diagnosis_submitted: bool = False
107
+ actions_taken: list[str] = Field(default_factory=list)
108
+
109
+ def compute_available_actions(self) -> list[str]:
110
+ """Dynamically compute available actions based on current state.
111
+
112
+ Rules from spec Section 10 — Dynamic available_actions:
113
+ - restart_run: only after fix_action_taken
114
+ - rollback_checkpoint: only after restart_after_fix
115
+ - fix_code: only after code_inspected
116
+ - mark_diagnosed: disappears after diagnosis_submitted
117
+ """
118
+ actions: list[str] = [
119
+ "inspect_gradients",
120
+ "inspect_data_batch",
121
+ "inspect_model_modes",
122
+ "inspect_model_weights",
123
+ "inspect_code",
124
+ "modify_config",
125
+ "add_callback",
126
+ "replace_optimizer",
127
+ "patch_data_loader",
128
+ "fix_model_mode",
129
+ ]
130
+ if self.code_inspected:
131
+ actions.append("fix_code")
132
+ if self.fix_action_taken:
133
+ actions.append("restart_run")
134
+ if self.restart_after_fix:
135
+ actions.append("rollback_checkpoint")
136
+ if not self.diagnosis_submitted:
137
+ actions.append("mark_diagnosed")
138
+ return actions
139
+
140
+
141
+ ALL_ACTION_TYPES: set[str] = {
142
+ "inspect_gradients",
143
+ "inspect_data_batch",
144
+ "inspect_model_modes",
145
+ "inspect_model_weights",
146
+ "inspect_code",
147
+ "modify_config",
148
+ "add_callback",
149
+ "replace_optimizer",
150
+ "patch_data_loader",
151
+ "fix_model_mode",
152
+ "fix_code",
153
+ "restart_run",
154
+ "mark_diagnosed",
155
+ "rollback_checkpoint",
156
+ }
157
+
158
+
159
+ class MLTrainingAction(Action):
160
+ """What the agent can do — extends openenv Action. Spec Section 10."""
161
+
162
+ action_type: str
163
+ target: Optional[str] = None
164
+ value: Optional[Union[float, int, str]] = None
165
+ diagnosis: Optional[str] = None
166
+ line: Optional[int] = None
167
+ replacement: Optional[str] = None
168
+
169
+
170
+ class MLTrainingObservation(Observation):
171
+ """Full observation — extends openenv Observation.
172
+
173
+ Observation base has built-in: done (bool), reward (float|None), metadata (dict).
174
+ Spec Section 10.
175
+ """
176
+
177
+ run_id: str = ""
178
+ framework: str = "pytorch"
179
+ epoch: int = 20
180
+ training_loss_history: list[float] = Field(default_factory=list)
181
+ val_loss_history: list[float] = Field(default_factory=list)
182
+ val_accuracy_history: list[float] = Field(default_factory=list)
183
+ gradient_stats: list[GradientStats] = Field(default_factory=list)
184
+ model_weight_stats: Optional[list[ModelWeightStats]] = None
185
+ gpu_memory_used_gb: float = 6.2
186
+ gpu_memory_total_gb: float = 16.0
187
+ learning_rate: float = 0.001
188
+ current_config: TrainingConfig = Field(default_factory=TrainingConfig)
189
+ error_log: Optional[str] = None
190
+ data_batch_stats: Optional[DataBatchStats] = None
191
+ model_mode_info: Optional[dict[str, str]] = None
192
+ code_snippet: Optional[CodeSnippet] = None
193
+ available_actions: list[str] = Field(default_factory=list)
194
+ episode_state: EpisodeState = Field(default_factory=EpisodeState)
195
+ notes: Optional[str] = None
ml_training_debugger/pytorch_engine.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch-native fault injection engine.
2
+
3
+ Real torch.nn.Module models, real torch.autograd gradients,
4
+ real state_dict() weight snapshots. Zero numpy.
5
+ Spec reference: Sections 6, 9.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from typing import Optional
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ from ml_training_debugger.models import GradientStats, ModelWeightStats
16
+ from ml_training_debugger.scenarios import ScenarioParams
17
+
18
+
19
+ class SimpleCNN(nn.Module):
20
+ """3-layer CNN for CIFAR-10 style classification. ~50K params.
21
+
22
+ Spec Section 9 — PyTorch Model Pool.
23
+ """
24
+
25
+ def __init__(self, num_layers: int = 3, hidden_dim: int = 64) -> None:
26
+ super().__init__()
27
+ self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
28
+ self.bn1 = nn.BatchNorm2d(32)
29
+ self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
30
+ self.bn2 = nn.BatchNorm2d(64)
31
+ self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
32
+ self.bn3 = nn.BatchNorm2d(64)
33
+ self.fc = nn.Linear(64 * 4 * 4, 10)
34
+ self.pool = nn.MaxPool2d(2, 2)
35
+ self.relu = nn.ReLU()
36
+
37
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
38
+ x = self.pool(self.relu(self.bn1(self.conv1(x))))
39
+ x = self.pool(self.relu(self.bn2(self.conv2(x))))
40
+ x = self.pool(self.relu(self.bn3(self.conv3(x))))
41
+ x = x.view(x.size(0), -1)
42
+ x = self.fc(x)
43
+ return x
44
+
45
+
46
+ def create_model_and_inject_fault(
47
+ scenario: ScenarioParams,
48
+ ) -> tuple[nn.Module, dict]:
49
+ """Instantiate a real PyTorch model and inject the specified fault.
50
+
51
+ Returns:
52
+ (model, info_dict) where info_dict contains computed artifacts.
53
+ """
54
+ torch.manual_seed(scenario.seed)
55
+
56
+ model = SimpleCNN()
57
+ criterion = nn.CrossEntropyLoss()
58
+ info: dict = {}
59
+
60
+ # Generate random batch (CIFAR-10 style: 3x32x32)
61
+ batch_x = torch.randn(8, 3, 32, 32)
62
+ batch_y = torch.randint(0, 10, (8,))
63
+
64
+ if scenario.root_cause.value == "lr_too_high":
65
+ # Exploding gradients: high LR with SGD → gradients explode on all layers
66
+ model.train()
67
+ optimizer = torch.optim.SGD(
68
+ model.parameters(), lr=scenario.learning_rate * 10.0
69
+ )
70
+ for _ in range(3):
71
+ optimizer.zero_grad()
72
+ output = model(batch_x)
73
+ loss = criterion(output, batch_y)
74
+ loss.backward()
75
+ optimizer.step()
76
+ # Run one final backward to capture extreme gradients
77
+ optimizer.zero_grad()
78
+ output = model(batch_x)
79
+ loss = criterion(output, batch_y)
80
+ loss.backward()
81
+
82
+ elif scenario.root_cause.value == "vanishing_gradients":
83
+ # Tiny LR → gradients are extremely small
84
+ model.train()
85
+ optimizer = torch.optim.SGD(model.parameters(), lr=scenario.learning_rate)
86
+ for _ in range(2):
87
+ optimizer.zero_grad()
88
+ output = model(batch_x)
89
+ loss = criterion(output, batch_y)
90
+ loss.backward()
91
+ optimizer.step()
92
+
93
+ elif scenario.root_cause.value == "data_leakage":
94
+ # Normal model — no gradient anomaly
95
+ model.train()
96
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
97
+ optimizer.zero_grad()
98
+ output = model(batch_x)
99
+ loss = criterion(output, batch_y)
100
+ loss.backward()
101
+ optimizer.step()
102
+
103
+ elif scenario.root_cause.value == "overfitting":
104
+ # Normal model with zero weight decay
105
+ model.train()
106
+ optimizer = torch.optim.Adam(
107
+ model.parameters(),
108
+ lr=0.001,
109
+ weight_decay=scenario.weight_decay,
110
+ )
111
+ optimizer.zero_grad()
112
+ output = model(batch_x)
113
+ loss = criterion(output, batch_y)
114
+ loss.backward()
115
+ optimizer.step()
116
+
117
+ elif scenario.root_cause.value == "batchnorm_eval_mode":
118
+ # model.eval() before training — the real bug
119
+ model.eval()
120
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
121
+ # Still run forward/backward to get gradient data
122
+ output = model(batch_x)
123
+ loss = criterion(output, batch_y)
124
+ loss.backward()
125
+ optimizer.step()
126
+
127
+ elif scenario.root_cause.value == "code_bug":
128
+ # Normal training with the model bug injected in code only
129
+ model.train()
130
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
131
+ optimizer.zero_grad()
132
+ output = model(batch_x)
133
+ loss = criterion(output, batch_y)
134
+ loss.backward()
135
+ optimizer.step()
136
+
137
+ return model, info
138
+
139
+
140
+ def extract_gradient_stats(
141
+ model: nn.Module,
142
+ scenario: Optional[ScenarioParams] = None,
143
+ ) -> list[GradientStats]:
144
+ """Extract gradient statistics from real param.grad tensors.
145
+
146
+ For Task 5 (batchnorm_eval_mode), injects red-herring spike on
147
+ the configured layer.
148
+ """
149
+ stats: list[GradientStats] = []
150
+ named_layers = [
151
+ ("conv1", model.conv1),
152
+ ("conv2", model.conv2),
153
+ ("conv3", model.conv3),
154
+ ("fc", model.fc),
155
+ ]
156
+
157
+ for layer_name, layer in named_layers:
158
+ norms: list[float] = []
159
+ for param in layer.parameters():
160
+ if param.grad is not None:
161
+ norm_val = torch.norm(param.grad).item()
162
+ norms.append(norm_val)
163
+
164
+ if not norms:
165
+ norms = [0.0]
166
+
167
+ mean_norm = sum(norms) / len(norms)
168
+ max_norm = max(norms)
169
+
170
+ # Build norm_history (simulated last 5 values, based on current)
171
+ norm_history = [mean_norm * (0.9 + 0.2 * i / 4) for i in range(5)]
172
+
173
+ # Task 5 red herring: spike on configured layer
174
+ if scenario and scenario.root_cause.value == "batchnorm_eval_mode":
175
+ if layer_name == scenario.red_herring_spike_layer:
176
+ spike = scenario.red_herring_intensity
177
+ norm_history = [
178
+ mean_norm,
179
+ mean_norm,
180
+ mean_norm * spike,
181
+ mean_norm * spike * 1.2,
182
+ mean_norm,
183
+ ]
184
+ mean_norm = sum(norm_history) / len(norm_history)
185
+ max_norm = max(norm_history)
186
+
187
+ # Conv1 near-vanishing red herring
188
+ if layer_name == "conv1" and scenario.red_herring_spike_layer != "conv1":
189
+ near_vanish = 0.0003
190
+ norm_history = [near_vanish * (0.95 + 0.1 * i / 4) for i in range(5)]
191
+ mean_norm = near_vanish
192
+ max_norm = max(norm_history)
193
+
194
+ is_exploding = mean_norm > 10.0
195
+ is_vanishing = mean_norm < 1e-6
196
+
197
+ stats.append(
198
+ GradientStats(
199
+ layer_name=layer_name,
200
+ norm_history=norm_history,
201
+ mean_norm=mean_norm,
202
+ max_norm=max_norm,
203
+ is_exploding=is_exploding,
204
+ is_vanishing=is_vanishing,
205
+ )
206
+ )
207
+
208
+ return stats
209
+
210
+
211
+ def extract_weight_stats(model: nn.Module) -> list[ModelWeightStats]:
212
+ """Extract weight statistics from real model.state_dict()."""
213
+ stats: list[ModelWeightStats] = []
214
+ for name, param in model.named_parameters():
215
+ if "weight" not in name:
216
+ continue
217
+ stats.append(
218
+ ModelWeightStats(
219
+ layer_name=name,
220
+ weight_norm=torch.norm(param).item(),
221
+ weight_mean=param.mean().item(),
222
+ weight_std=param.std().item(),
223
+ weight_min=param.min().item(),
224
+ weight_max=param.max().item(),
225
+ dead_neuron_pct=0.0,
226
+ has_nan=bool(torch.isnan(param).any().item()),
227
+ has_inf=bool(torch.isinf(param).any().item()),
228
+ )
229
+ )
230
+ return stats
231
+
232
+
233
+ def extract_model_modes(model: nn.Module) -> dict[str, str]:
234
+ """Extract training/eval mode for each named module."""
235
+ modes: dict[str, str] = {}
236
+ for name, module in model.named_modules():
237
+ if name == "":
238
+ continue
239
+ modes[name] = "train" if module.training else "eval"
240
+ return modes
ml_training_debugger/reward_engine.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Reward function — all 7 components per spec Section 12.
2
+
3
+ Separate from graders.py. Returns a float per step for RL training signal.
4
+ Hard cap at [-1.0, 1.0].
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch # noqa: F401 — PyTorch-native project
10
+
11
+ from ml_training_debugger.models import EpisodeState, MLTrainingAction
12
+ from ml_training_debugger.scenarios import ScenarioParams
13
+
14
+ # Reward constants — do not change (CLAUDE.md)
15
+ STEP_PENALTY = -0.01
16
+ INVESTIGATION_BONUS = 0.05
17
+ CONTEXT_GATED_PENALTY = -0.20
18
+ INVALID_ACTION_PENALTY = -0.05
19
+ WRONG_CODE_FIX_PENALTY = -0.10
20
+ CORRECT_DIAGNOSIS_REWARD = 0.50
21
+ WRONG_DIAGNOSIS_PENALTY = -0.30
22
+ TERMINAL_CONVERGENCE_REWARD = 0.40
23
+
24
+ INVESTIGATION_ACTIONS = frozenset(
25
+ {
26
+ "inspect_gradients",
27
+ "inspect_data_batch",
28
+ "inspect_model_modes",
29
+ "inspect_model_weights",
30
+ "inspect_code",
31
+ }
32
+ )
33
+
34
+ _INSPECTION_STATE_MAP = {
35
+ "inspect_gradients": "gradients_inspected",
36
+ "inspect_data_batch": "data_inspected",
37
+ "inspect_model_modes": "model_modes_inspected",
38
+ "inspect_model_weights": "model_weights_inspected",
39
+ "inspect_code": "code_inspected",
40
+ }
41
+
42
+
43
+ def compute_reward(
44
+ action: MLTrainingAction,
45
+ state: EpisodeState,
46
+ scenario: ScenarioParams,
47
+ is_valid_action: bool = True,
48
+ is_correct_fix: bool | None = None,
49
+ convergence_confirmed: bool = False,
50
+ ) -> float:
51
+ """Compute reward for a single step.
52
+
53
+ Args:
54
+ action: The action taken.
55
+ state: Episode state BEFORE the action is applied.
56
+ scenario: Current scenario params.
57
+ is_valid_action: Whether the action is in available_actions.
58
+ is_correct_fix: For fix_code — True/False/None.
59
+ convergence_confirmed: Whether restart showed convergence.
60
+
61
+ Returns:
62
+ Reward float, capped at [-1.0, 1.0].
63
+ """
64
+ reward = 0.0
65
+
66
+ # Component 1: Flat step penalty (unconditional)
67
+ reward += STEP_PENALTY
68
+
69
+ # Component 4: Invalid action penalty
70
+ if not is_valid_action:
71
+ reward += INVALID_ACTION_PENALTY
72
+ return max(-1.0, min(1.0, reward))
73
+
74
+ action_type = action.action_type
75
+
76
+ # Component 2: Investigation bonus (first-time only)
77
+ if action_type in INVESTIGATION_ACTIONS:
78
+ state_field = _INSPECTION_STATE_MAP.get(action_type)
79
+ if state_field and not getattr(state, state_field):
80
+ reward += INVESTIGATION_BONUS
81
+
82
+ # Component 3: Context-gated red herring penalty
83
+ # Fires ONLY when gradients_inspected=True AND gradients_were_normal=True
84
+ if action_type == "add_callback":
85
+ if state.gradients_inspected and state.gradients_were_normal:
86
+ reward += CONTEXT_GATED_PENALTY
87
+
88
+ # Component 7: Wrong code fix penalty
89
+ if action_type == "fix_code" and is_correct_fix is False:
90
+ reward += WRONG_CODE_FIX_PENALTY
91
+
92
+ # Component 5: Diagnosis outcome
93
+ if action_type == "mark_diagnosed":
94
+ if action.diagnosis == scenario.root_cause.value:
95
+ reward += CORRECT_DIAGNOSIS_REWARD
96
+ else:
97
+ reward += WRONG_DIAGNOSIS_PENALTY
98
+
99
+ # Component 6: Terminal convergence reward
100
+ if action_type == "restart_run":
101
+ if state.fix_action_taken and convergence_confirmed:
102
+ reward += TERMINAL_CONVERGENCE_REWARD
103
+
104
+ return max(-1.0, min(1.0, reward))
ml_training_debugger/scenarios.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ScenarioParams and scenario sampling.
2
+
3
+ Internal scenario configuration — not exposed to the agent.
4
+ Spec reference: Sections 6, 10, 11.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import dataclasses
10
+ from typing import Optional
11
+
12
+ import torch
13
+
14
+ from ml_training_debugger.models import RootCauseDiagnosis
15
+
16
+
17
+ @dataclasses.dataclass(frozen=True)
18
+ class ScenarioParams:
19
+ """Internal scenario parameters created at reset() time."""
20
+
21
+ task_id: str
22
+ root_cause: RootCauseDiagnosis
23
+ seed: int
24
+ learning_rate: float = 0.001
25
+ weight_decay: float = 0.0001
26
+ leakage_pct: float = 0.0
27
+ depth_multiplier: float = 1.0
28
+ divergence_epoch: int = 5
29
+ red_herring_intensity: float = 1.0
30
+ red_herring_spike_layer: str = "fc"
31
+ bug_type: Optional[str] = None
32
+ notes: Optional[str] = None
33
+ error_log: Optional[str] = None
34
+ gpu_memory_used_gb: float = 6.2
35
+ max_steps: int = 20
36
+
37
+
38
+ def _task_seed(task_id: str, seed: int) -> int:
39
+ """Derive a deterministic seed from task_id and provided seed."""
40
+ task_num = int(task_id.split("_")[1])
41
+ return seed * 1000 + task_num
42
+
43
+
44
+ def _choose(options: list, rng: torch.Generator) -> object:
45
+ """Choose a random element from a list using torch RNG."""
46
+ idx = int(torch.randint(0, len(options), (1,), generator=rng).item())
47
+ return options[idx]
48
+
49
+
50
+ def sample_scenario(task_id: str, seed: int = 42) -> ScenarioParams:
51
+ """Sample a ScenarioParams for the given task.
52
+
53
+ Args:
54
+ task_id: One of task_001 through task_006.
55
+ seed: Base seed for reproducibility.
56
+
57
+ Returns:
58
+ ScenarioParams with randomized fault parameters.
59
+
60
+ Raises:
61
+ ValueError: If task_id is unknown.
62
+ """
63
+ effective_seed = _task_seed(task_id, seed)
64
+ rng = torch.Generator()
65
+ rng.manual_seed(effective_seed)
66
+
67
+ if task_id == "task_001":
68
+ lr = _choose([0.05, 0.08, 0.10, 0.15, 0.30], rng)
69
+ return ScenarioParams(
70
+ task_id=task_id,
71
+ root_cause=RootCauseDiagnosis.LR_TOO_HIGH,
72
+ seed=effective_seed,
73
+ learning_rate=float(lr),
74
+ error_log=f"RuntimeError: Loss is NaN at epoch 12 (lr={lr})",
75
+ max_steps=20,
76
+ )
77
+
78
+ if task_id == "task_002":
79
+ lr = _choose([1e-6, 5e-6, 1e-5], rng)
80
+ depth_mult = _choose([1.0, 1.5, 2.0], rng)
81
+ return ScenarioParams(
82
+ task_id=task_id,
83
+ root_cause=RootCauseDiagnosis.VANISHING_GRADIENTS,
84
+ seed=effective_seed,
85
+ learning_rate=float(lr),
86
+ depth_multiplier=float(depth_mult),
87
+ notes=(
88
+ "Training resumed from a checkpoint saved at epoch 0 — "
89
+ "early learning rate warmup may still be in effect."
90
+ ),
91
+ max_steps=20,
92
+ )
93
+
94
+ if task_id == "task_003":
95
+ leakage = _choose([0.12, 0.18, 0.22, 0.28], rng)
96
+ return ScenarioParams(
97
+ task_id=task_id,
98
+ root_cause=RootCauseDiagnosis.DATA_LEAKAGE,
99
+ seed=effective_seed,
100
+ leakage_pct=float(leakage),
101
+ notes=(
102
+ "Model architecture upgraded from 2-layer to 4-layer CNN "
103
+ "at epoch 2. Performance improvement may reflect increased "
104
+ "model capacity."
105
+ ),
106
+ max_steps=25,
107
+ )
108
+
109
+ if task_id == "task_004":
110
+ wd = _choose([0.0, 0.0001, 0.001], rng)
111
+ div_epoch = _choose([5, 8, 12], rng)
112
+ return ScenarioParams(
113
+ task_id=task_id,
114
+ root_cause=RootCauseDiagnosis.OVERFITTING,
115
+ seed=effective_seed,
116
+ weight_decay=float(wd),
117
+ divergence_epoch=int(div_epoch),
118
+ notes=(
119
+ "Dataset augmentation was disabled for this run to speed "
120
+ "up training. Re-enabling may improve generalization."
121
+ ),
122
+ max_steps=25,
123
+ )
124
+
125
+ if task_id == "task_005":
126
+ intensity = torch.empty(1).uniform_(0.8, 2.5, generator=rng).item()
127
+ spike_layer = _choose(["fc", "conv1"], rng)
128
+ return ScenarioParams(
129
+ task_id=task_id,
130
+ root_cause=RootCauseDiagnosis.BATCHNORM_EVAL_MODE,
131
+ seed=effective_seed,
132
+ red_herring_intensity=float(intensity),
133
+ red_herring_spike_layer=str(spike_layer),
134
+ gpu_memory_used_gb=14.56, # 91% of 16GB — red herring
135
+ error_log=(
136
+ "Warning: GPU memory pressure detected, consider reducing "
137
+ "batch size or enabling gradient checkpointing"
138
+ ),
139
+ max_steps=30,
140
+ )
141
+
142
+ if task_id == "task_006":
143
+ bug = _choose(
144
+ ["eval_mode", "detach_loss", "zero_grad_missing", "inplace_relu"], rng
145
+ )
146
+ return ScenarioParams(
147
+ task_id=task_id,
148
+ root_cause=RootCauseDiagnosis.CODE_BUG,
149
+ seed=effective_seed,
150
+ bug_type=str(bug),
151
+ notes="Try adjusting the learning rate schedule.",
152
+ max_steps=30,
153
+ )
154
+
155
+ raise ValueError(f"Unknown task_id: {task_id}")
ml_training_debugger/simulation.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parametric curve generation using torch.Tensor operations.
2
+
3
+ All loss/accuracy histories are generated via parametric equations.
4
+ Zero numpy. Spec reference: Section 6.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+
11
+ from ml_training_debugger.scenarios import ScenarioParams
12
+
13
+ EPOCHS = 20
14
+
15
+
16
+ def gen_loss_history(scenario: ScenarioParams) -> list[float]:
17
+ """Generate training loss history (20 epochs) using torch ops."""
18
+ torch.manual_seed(scenario.seed)
19
+ t = torch.arange(EPOCHS, dtype=torch.float32)
20
+
21
+ root = scenario.root_cause.value
22
+
23
+ if root == "lr_too_high":
24
+ # Exponentially growing loss
25
+ lr_tensor = torch.tensor(scenario.learning_rate, dtype=torch.float32)
26
+ base = torch.exp(lr_tensor * t * 0.5)
27
+ loss = 2.3 * base
28
+ # Add NaN marker after epoch 12
29
+ loss_list = loss.tolist()
30
+ for i in range(12, EPOCHS):
31
+ loss_list[i] = float("inf")
32
+ return loss_list
33
+
34
+ if root == "vanishing_gradients":
35
+ # Flat loss — barely decreases
36
+ noise = torch.randn(EPOCHS) * 0.02
37
+ loss = 2.3 - t * 0.002 + noise
38
+ return loss.clamp(min=0.01).tolist()
39
+
40
+ if root == "data_leakage":
41
+ # Normal-looking training loss
42
+ loss = 2.3 * torch.exp(-0.15 * t) + 0.05
43
+ noise = torch.randn(EPOCHS) * 0.02
44
+ return (loss + noise).clamp(min=0.01).tolist()
45
+
46
+ if root == "overfitting":
47
+ # Steadily decreasing to near-zero
48
+ loss = 2.3 * torch.exp(-0.25 * t) + 0.01
49
+ noise = torch.randn(EPOCHS) * 0.01
50
+ return (loss + noise).clamp(min=0.001).tolist()
51
+
52
+ if root == "batchnorm_eval_mode":
53
+ # Roughly normal with higher variance
54
+ base = 2.3 * torch.exp(-0.1 * t) + 0.3
55
+ noise = torch.randn(EPOCHS) * 0.15
56
+ return (base + noise).clamp(min=0.1).tolist()
57
+
58
+ if root == "code_bug":
59
+ # Varies by bug variant — generic anomalous
60
+ loss = 2.3 * torch.exp(-0.05 * t) + 0.5
61
+ noise = torch.randn(EPOCHS) * 0.1
62
+ return (loss + noise).clamp(min=0.1).tolist()
63
+
64
+ # Fallback
65
+ return (2.3 * torch.exp(-0.1 * t)).tolist()
66
+
67
+
68
+ def gen_val_accuracy_history(scenario: ScenarioParams) -> list[float]:
69
+ """Generate validation accuracy history (20 epochs) using torch ops."""
70
+ torch.manual_seed(scenario.seed + 1)
71
+ t = torch.arange(EPOCHS, dtype=torch.float32)
72
+
73
+ root = scenario.root_cause.value
74
+
75
+ if root == "lr_too_high":
76
+ # Collapses along with training loss
77
+ acc = torch.sigmoid(torch.linspace(0, -3, EPOCHS)) * 0.5
78
+ return acc.clamp(0.0, 1.0).tolist()
79
+
80
+ if root == "vanishing_gradients":
81
+ # Near random chance
82
+ noise = torch.randn(EPOCHS) * 0.02
83
+ acc = 0.10 + t * 0.001 + noise
84
+ return acc.clamp(0.0, 1.0).tolist()
85
+
86
+ if root == "data_leakage":
87
+ # Suspiciously high from epoch 1
88
+ leakage = torch.tensor(scenario.leakage_pct, dtype=torch.float32)
89
+ base = torch.sigmoid(torch.linspace(-3, 3, EPOCHS))
90
+ acc = base * (1.0 - leakage) + leakage * 0.95
91
+ # Inflate early epochs
92
+ acc = acc.clamp(0.0, 1.0)
93
+ # Ensure suspiciously high from epoch 1
94
+ acc_list = acc.tolist()
95
+ for i in range(EPOCHS):
96
+ acc_list[i] = max(acc_list[i], 0.82 * (1.0 + scenario.leakage_pct))
97
+ return [min(v, 0.99) for v in acc_list]
98
+
99
+ if root == "overfitting":
100
+ # Rises then falls — classic divergence
101
+ div = scenario.divergence_epoch
102
+ acc_list: list[float] = []
103
+ for i in range(EPOCHS):
104
+ if i < div:
105
+ val = 0.10 + (0.75 - 0.10) * (i / max(div, 1))
106
+ else:
107
+ decline = (i - div) * 0.02
108
+ val = 0.75 - decline
109
+ acc_list.append(max(0.0, min(1.0, val)))
110
+ return acc_list
111
+
112
+ if root == "batchnorm_eval_mode":
113
+ # Slow degradation ~1-2% per epoch
114
+ start = 0.76
115
+ noise = torch.randn(EPOCHS) * 0.01
116
+ acc = torch.tensor(
117
+ [start - 0.015 * i for i in range(EPOCHS)], dtype=torch.float32
118
+ )
119
+ acc = acc + noise
120
+ return acc.clamp(0.0, 1.0).tolist()
121
+
122
+ if root == "code_bug":
123
+ # Anomalous — depends on variant but generally poor
124
+ noise = torch.randn(EPOCHS) * 0.03
125
+ acc = 0.10 + t * 0.005 + noise
126
+ return acc.clamp(0.0, 1.0).tolist()
127
+
128
+ # Fallback
129
+ return (torch.sigmoid(torch.linspace(-3, 3, EPOCHS)) * 0.9).tolist()
130
+
131
+
132
+ def gen_val_loss_history(scenario: ScenarioParams) -> list[float]:
133
+ """Generate validation loss history (20 epochs) using torch ops."""
134
+ torch.manual_seed(scenario.seed + 2)
135
+ t = torch.arange(EPOCHS, dtype=torch.float32)
136
+
137
+ root = scenario.root_cause.value
138
+
139
+ if root == "lr_too_high":
140
+ # Mirrors training loss divergence
141
+ lr_tensor = torch.tensor(scenario.learning_rate, dtype=torch.float32)
142
+ loss = 2.3 * torch.exp(lr_tensor * t * 0.5)
143
+ loss_list = loss.tolist()
144
+ for i in range(12, EPOCHS):
145
+ loss_list[i] = float("inf")
146
+ return loss_list
147
+
148
+ if root == "vanishing_gradients":
149
+ noise = torch.randn(EPOCHS) * 0.02
150
+ loss = 2.3 - t * 0.001 + noise
151
+ return loss.clamp(min=0.01).tolist()
152
+
153
+ if root == "data_leakage":
154
+ # Low val loss (because leaking train data into val)
155
+ base = 2.3 * torch.exp(-0.2 * t) + 0.03
156
+ noise = torch.randn(EPOCHS) * 0.02
157
+ return (base + noise).clamp(min=0.01).tolist()
158
+
159
+ if root == "overfitting":
160
+ # Initially decreases, then diverges upward
161
+ div = scenario.divergence_epoch
162
+ loss_list: list[float] = []
163
+ for i in range(EPOCHS):
164
+ if i < div:
165
+ val = 2.3 * (1.0 - 0.8 * i / max(div, 1))
166
+ else:
167
+ val = 0.46 + 0.1 * (i - div)
168
+ loss_list.append(max(0.01, val))
169
+ return loss_list
170
+
171
+ if root == "batchnorm_eval_mode":
172
+ # Slightly increasing
173
+ base = 1.5 + t * 0.03
174
+ noise = torch.randn(EPOCHS) * 0.1
175
+ return (base + noise).clamp(min=0.1).tolist()
176
+
177
+ if root == "code_bug":
178
+ loss = 2.3 * torch.exp(-0.03 * t) + 0.8
179
+ noise = torch.randn(EPOCHS) * 0.1
180
+ return (loss + noise).clamp(min=0.1).tolist()
181
+
182
+ # Fallback
183
+ return (2.3 * torch.exp(-0.1 * t) + 0.1).tolist()
184
+
185
+
186
+ def gen_data_batch_stats(scenario: ScenarioParams) -> dict:
187
+ """Generate data batch statistics for the scenario."""
188
+ torch.manual_seed(scenario.seed + 3)
189
+
190
+ root = scenario.root_cause.value
191
+
192
+ if root == "data_leakage":
193
+ overlap = 0.5 + scenario.leakage_pct * 1.5 # 0.68-0.88 range
194
+ overlap = min(overlap, 0.92)
195
+ return {
196
+ "label_distribution": {i: 0.1 for i in range(10)},
197
+ "feature_mean": 0.45 + torch.randn(1).item() * 0.05,
198
+ "feature_std": 0.22 + torch.randn(1).item() * 0.02,
199
+ "null_count": 0,
200
+ "class_overlap_score": overlap,
201
+ "batch_size": 64,
202
+ "duplicate_ratio": scenario.leakage_pct,
203
+ }
204
+
205
+ if root == "overfitting":
206
+ return {
207
+ "label_distribution": {i: 0.1 for i in range(10)},
208
+ "feature_mean": 0.48 + torch.randn(1).item() * 0.03,
209
+ "feature_std": 0.25 + torch.randn(1).item() * 0.02,
210
+ "null_count": 0,
211
+ "class_overlap_score": 0.0,
212
+ "batch_size": 64,
213
+ "duplicate_ratio": 0.0,
214
+ }
215
+
216
+ # Default: normal data
217
+ return {
218
+ "label_distribution": {i: 0.1 for i in range(10)},
219
+ "feature_mean": 0.47 + torch.randn(1).item() * 0.03,
220
+ "feature_std": 0.24 + torch.randn(1).item() * 0.02,
221
+ "null_count": 0,
222
+ "class_overlap_score": 0.0 + torch.randn(1).abs().item() * 0.05,
223
+ "batch_size": 64,
224
+ "duplicate_ratio": 0.0,
225
+ }
openenv.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: pytorch-training-debugger
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 7860
7
+
8
+ version: "1.0.0"
9
+ description: |
10
+ PyTorch-native fault injection engine for training failure debugging.
11
+ An AI agent investigates, diagnoses, fixes, and verifies broken
12
+ training runs using real torch.nn.Module models, torch.autograd
13
+ gradients, state_dict() weight inspection, and PyTorch code-level
14
+ debugging. 3 tasks across 3 difficulty tiers with context-gated
15
+ reward shaping.
16
+ framework: openenv
17
+ tags:
18
+ - ml-debugging
19
+ - pytorch
20
+ - reinforcement-learning
21
+ - root-cause-analysis
22
+ - fault-injection
23
+ - openenv
24
+
25
+ observation_space:
26
+ type: MLTrainingObservation
27
+ description: "Training run snapshot with progressive reveal — gradients, weights, data stats, model modes revealed on inspection"
28
+
29
+ action_space:
30
+ type: MLTrainingAction
31
+ description: "Investigation, fix, and diagnosis actions with dynamic availability"
32
+
33
+ tasks:
34
+ - id: task_001
35
+ difficulty: easy
36
+ max_steps: 20
37
+ - id: task_003
38
+ difficulty: medium
39
+ max_steps: 25
40
+ - id: task_005
41
+ difficulty: hard
42
+ max_steps: 30
43
+
44
+ reward:
45
+ range: [-1.0, 1.0]
46
+ shaped: true
47
+ step_penalty: -0.01
48
+ investigation_bonus: 0.05
49
+ max_investigation_bonus: 0.25
50
+ correct_diagnosis: 0.50
51
+ terminal_convergence: 0.40
52
+
53
+ endpoints:
54
+ websocket: "/ws"
55
+ tasks: "GET /tasks"
56
+ grader: "POST /grader"
57
+ baseline: "POST /baseline"
58
+ health: "GET /health"
pyproject.toml ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "pytorch-training-debugger"
3
+ version = "1.0.0"
4
+ description = "OpenEnv RL environment for PyTorch training failure debugging"
5
+ requires-python = ">=3.12"
6
+ dependencies = [
7
+ "torch",
8
+ "openenv-core",
9
+ "pydantic>=2.0",
10
+ "fastapi",
11
+ "uvicorn",
12
+ ]
13
+
14
+ [project.optional-dependencies]
15
+ dev = [
16
+ "pytest",
17
+ "pytest-cov",
18
+ "pytest-asyncio",
19
+ "black",
20
+ "ruff",
21
+ "isort",
22
+ "httpx",
23
+ "websockets",
24
+ ]
25
+ llm = [
26
+ "openai",
27
+ ]
28
+
29
+ [tool.black]
30
+ line-length = 88
31
+
32
+ [tool.isort]
33
+ profile = "black"
34
+
35
+ [tool.ruff]
36
+ line-length = 88
37
+ target-version = "py312"
38
+
39
+ [tool.pytest.ini_options]
40
+ testpaths = ["tests"]
41
+ asyncio_mode = "auto"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ openenv-core
3
+ pydantic>=2.0
4
+ fastapi
5
+ uvicorn
6
+ openai
server/__init__.py ADDED
File without changes
server/_baseline_results.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared state for grader results across endpoints."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Optional
6
+
7
+ # Store last completed episode results
8
+ _last_results: dict[str, dict] = {}
9
+
10
+
11
+ def store_grader_result(
12
+ session_id: str, score: float, task_id: str, steps: int
13
+ ) -> None:
14
+ """Store a grader result for retrieval."""
15
+ _last_results[session_id] = {
16
+ "score": round(score, 4),
17
+ "task_id": task_id,
18
+ "steps": steps,
19
+ }
20
+ _last_results["_latest"] = _last_results[session_id]
21
+
22
+
23
+ def get_last_grader_result(session_id: Optional[str] = None) -> dict | None:
24
+ """Get grader result for a session, or the most recent one."""
25
+ if session_id:
26
+ return _last_results.get(session_id)
27
+ return _last_results.get("_latest")
server/app.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI app — openenv create_app() + custom hackathon routes.
2
+
3
+ Spec reference: Sections 9, 14.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import asyncio
9
+ import logging
10
+ from typing import Optional
11
+
12
+ from fastapi import FastAPI
13
+ from fastapi.responses import JSONResponse
14
+ from openenv.core.env_server.http_server import create_app
15
+
16
+ from ml_training_debugger.models import MLTrainingAction, MLTrainingObservation
17
+ from server.environment import MLTrainingEnvironment
18
+
19
+ logging.basicConfig(
20
+ level=logging.INFO,
21
+ format='{"time":"%(asctime)s","level":"%(levelname)s","msg":"%(message)s"}',
22
+ )
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # MVP task list
26
+ MVP_TASKS = [
27
+ {"id": "task_001", "difficulty": "easy", "max_steps": 20},
28
+ {"id": "task_003", "difficulty": "medium", "max_steps": 25},
29
+ {"id": "task_005", "difficulty": "hard", "max_steps": 30},
30
+ ]
31
+
32
+ # create_app takes the class (factory), not an instance
33
+ app: FastAPI = create_app(
34
+ MLTrainingEnvironment,
35
+ MLTrainingAction,
36
+ MLTrainingObservation,
37
+ env_name="pytorch_training_debugger",
38
+ max_concurrent_envs=5,
39
+ )
40
+
41
+ # Override framework's /health route with our custom version
42
+ # Remove the framework's health route first
43
+ app.routes[:] = [
44
+ r for r in app.routes if not (hasattr(r, "path") and r.path == "/health")
45
+ ]
46
+
47
+ # Track baseline state
48
+ _baseline_lock = asyncio.Lock()
49
+ _baseline_running = False
50
+
51
+
52
+ @app.get("/health")
53
+ def health_check() -> dict:
54
+ """Health check — required by hackathon auto-validator."""
55
+ return {"status": "ready", "tasks": len(MVP_TASKS)}
56
+
57
+
58
+ @app.get("/tasks")
59
+ def get_tasks() -> list[dict]:
60
+ """Return task list with IDs, difficulties, and action schema."""
61
+ schema = MLTrainingAction.model_json_schema()
62
+ return [{**task, "action_schema": schema} for task in MVP_TASKS]
63
+
64
+
65
+ @app.post("/grader")
66
+ def post_grader(session_id: Optional[str] = None) -> dict:
67
+ """Return grader score for most recently completed episode.
68
+
69
+ Edge cases per spec Section 14:
70
+ - No episode completed → {"score": null, "error": "no_completed_episode"}
71
+ - Episode in progress → {"score": null, "error": "episode_in_progress"}
72
+ - Episode completed → {"score": float, "task_id": str, "steps": int}
73
+ """
74
+ # Try to find the environment instance
75
+ # The framework manages environment instances internally,
76
+ # so we use the internal baseline results for the /grader endpoint
77
+ from server._baseline_results import get_last_grader_result
78
+
79
+ result = get_last_grader_result(session_id)
80
+ if result is None:
81
+ return {"score": None, "error": "no_completed_episode"}
82
+ return result
83
+
84
+
85
+ @app.post("/baseline", response_model=None)
86
+ async def post_baseline():
87
+ """Trigger baseline run, return scores for all tasks.
88
+
89
+ Returns 409 if already running.
90
+ """
91
+ global _baseline_running
92
+
93
+ if _baseline_running:
94
+ return JSONResponse(
95
+ status_code=409,
96
+ content={"error": "baseline_in_progress"},
97
+ )
98
+
99
+ _baseline_running = True
100
+ try:
101
+ scores = await _run_baseline()
102
+ return {"scores": scores}
103
+ finally:
104
+ _baseline_running = False
105
+
106
+
107
+ async def _run_baseline() -> dict[str, float]:
108
+ """Run the rule-based baseline internally."""
109
+
110
+ scores: dict[str, float] = {}
111
+
112
+ for task_info in MVP_TASKS:
113
+ task_id = task_info["id"]
114
+ env = MLTrainingEnvironment()
115
+ obs = env.reset(seed=42, episode_id=f"baseline_{task_id}", task_id=task_id)
116
+
117
+ # Run heuristic decision tree
118
+ score = _run_heuristic_episode(env, obs, task_id)
119
+ scores[task_id] = round(score, 4)
120
+
121
+ return scores
122
+
123
+
124
+ def _run_heuristic_episode(
125
+ env: MLTrainingEnvironment,
126
+ obs: MLTrainingObservation,
127
+ task_id: str,
128
+ ) -> float:
129
+ """Run one heuristic baseline episode. Returns grader score."""
130
+ # Step 1: inspect_gradients
131
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
132
+
133
+ # Check for exploding gradients
134
+ if obs.gradient_stats:
135
+ if any(g.is_exploding for g in obs.gradient_stats):
136
+ obs = env.step(
137
+ MLTrainingAction(
138
+ action_type="modify_config",
139
+ target="learning_rate",
140
+ value=0.001,
141
+ )
142
+ )
143
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
144
+ obs = env.step(
145
+ MLTrainingAction(
146
+ action_type="mark_diagnosed",
147
+ diagnosis="lr_too_high",
148
+ )
149
+ )
150
+ session = env._get_session()
151
+ if session and session.last_score is not None:
152
+ return session.last_score
153
+ return 0.0
154
+
155
+ # Check for vanishing gradients
156
+ if any(g.is_vanishing for g in obs.gradient_stats):
157
+ obs = env.step(
158
+ MLTrainingAction(
159
+ action_type="modify_config",
160
+ target="learning_rate",
161
+ value=0.01,
162
+ )
163
+ )
164
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
165
+ obs = env.step(
166
+ MLTrainingAction(
167
+ action_type="mark_diagnosed",
168
+ diagnosis="vanishing_gradients",
169
+ )
170
+ )
171
+ session = env._get_session()
172
+ if session and session.last_score is not None:
173
+ return session.last_score
174
+ return 0.0
175
+
176
+ # Step 2: inspect_data_batch
177
+ obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
178
+ if obs.data_batch_stats and obs.data_batch_stats.class_overlap_score > 0.5:
179
+ obs = env.step(MLTrainingAction(action_type="patch_data_loader"))
180
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
181
+ obs = env.step(
182
+ MLTrainingAction(
183
+ action_type="mark_diagnosed",
184
+ diagnosis="data_leakage",
185
+ )
186
+ )
187
+ session = env._get_session()
188
+ if session and session.last_score is not None:
189
+ return session.last_score
190
+ return 0.0
191
+
192
+ # Check for overfitting (val_loss diverging)
193
+ if obs.val_loss_history and len(obs.val_loss_history) >= 10:
194
+ early = sum(obs.val_loss_history[:5]) / 5
195
+ late = sum(obs.val_loss_history[-5:]) / 5
196
+ if (
197
+ late > early * 1.2
198
+ and obs.data_batch_stats
199
+ and obs.data_batch_stats.class_overlap_score < 0.1
200
+ ):
201
+ obs = env.step(
202
+ MLTrainingAction(
203
+ action_type="modify_config",
204
+ target="weight_decay",
205
+ value=0.01,
206
+ )
207
+ )
208
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
209
+ obs = env.step(
210
+ MLTrainingAction(
211
+ action_type="mark_diagnosed",
212
+ diagnosis="overfitting",
213
+ )
214
+ )
215
+ session = env._get_session()
216
+ if session and session.last_score is not None:
217
+ return session.last_score
218
+ return 0.0
219
+
220
+ # Step 3: inspect_model_modes
221
+ obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
222
+ if obs.model_mode_info:
223
+ has_eval = any(v == "eval" for v in obs.model_mode_info.values())
224
+ if has_eval:
225
+ obs = env.step(MLTrainingAction(action_type="fix_model_mode"))
226
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
227
+ obs = env.step(
228
+ MLTrainingAction(
229
+ action_type="mark_diagnosed",
230
+ diagnosis="batchnorm_eval_mode",
231
+ )
232
+ )
233
+ session = env._get_session()
234
+ if session and session.last_score is not None:
235
+ return session.last_score
236
+ return 0.0
237
+
238
+ # Step 4: inspect_code (for Task 6)
239
+ obs = env.step(MLTrainingAction(action_type="inspect_code"))
240
+ if obs.code_snippet:
241
+ # Simple pattern matching for known bugs
242
+ code = obs.code_snippet.code
243
+ if "model.eval()" in code and "model.train()" not in code:
244
+ obs = env.step(
245
+ MLTrainingAction(
246
+ action_type="fix_code",
247
+ line=5,
248
+ replacement="model.train()",
249
+ )
250
+ )
251
+ elif ".detach()" in code:
252
+ obs = env.step(
253
+ MLTrainingAction(
254
+ action_type="fix_code",
255
+ line=14,
256
+ replacement=" loss = criterion(output, batch_y)",
257
+ )
258
+ )
259
+ else:
260
+ # Can't reliably fix — just diagnose
261
+ pass
262
+
263
+ if obs.episode_state.fix_action_taken:
264
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
265
+
266
+ obs = env.step(
267
+ MLTrainingAction(
268
+ action_type="mark_diagnosed",
269
+ diagnosis="code_bug",
270
+ )
271
+ )
272
+ session = env._get_session()
273
+ if session and session.last_score is not None:
274
+ return session.last_score
275
+ return 0.0
276
+
277
+ # Fallback
278
+ obs = env.step(
279
+ MLTrainingAction(
280
+ action_type="mark_diagnosed",
281
+ diagnosis="overfitting",
282
+ )
283
+ )
284
+ session = env._get_session()
285
+ if session and session.last_score is not None:
286
+ return session.last_score
287
+ return 0.0
server/environment.py ADDED
@@ -0,0 +1,516 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MLTrainingEnvironment — extends openenv Environment.
2
+
3
+ Full implementation of reset() and step() with session isolation,
4
+ progressive information reveal, and comprehensive error handling.
5
+ step() NEVER raises an unhandled exception.
6
+ Spec reference: Sections 9, 13, 16.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import dataclasses
12
+ import logging
13
+ import uuid
14
+ from typing import Any, Optional
15
+
16
+ import torch
17
+ from openenv.core.env_server.interfaces import Environment
18
+
19
+ from ml_training_debugger.code_templates import (
20
+ generate_code_snippet,
21
+ validate_fix,
22
+ )
23
+ from ml_training_debugger.graders import grade_episode
24
+ from ml_training_debugger.models import (
25
+ ALL_ACTION_TYPES,
26
+ VALID_CONFIG_KEYS,
27
+ VALID_DIAGNOSES,
28
+ CodeSnippet,
29
+ DataBatchStats,
30
+ EpisodeState,
31
+ MLTrainingAction,
32
+ MLTrainingObservation,
33
+ TrainingConfig,
34
+ )
35
+ from ml_training_debugger.pytorch_engine import (
36
+ create_model_and_inject_fault,
37
+ extract_gradient_stats,
38
+ extract_model_modes,
39
+ extract_weight_stats,
40
+ )
41
+ from ml_training_debugger.reward_engine import compute_reward
42
+ from ml_training_debugger.scenarios import ScenarioParams, sample_scenario
43
+ from ml_training_debugger.simulation import (
44
+ gen_data_batch_stats,
45
+ gen_loss_history,
46
+ gen_val_accuracy_history,
47
+ gen_val_loss_history,
48
+ )
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ @dataclasses.dataclass
54
+ class SessionData:
55
+ """Per-session episode data."""
56
+
57
+ scenario: ScenarioParams
58
+ model: torch.nn.Module
59
+ state: EpisodeState
60
+ config: TrainingConfig
61
+ gradient_stats: list[Any]
62
+ weight_stats: list[Any] | None
63
+ model_modes: dict[str, str] | None
64
+ data_batch_stats_raw: dict | None
65
+ code_snippet_raw: dict | None
66
+ loss_history: list[float]
67
+ val_acc_history: list[float]
68
+ val_loss_history: list[float]
69
+ done: bool
70
+ last_score: float | None
71
+ convergence_after_fix: bool
72
+
73
+
74
+ class MLTrainingEnvironment(Environment[MLTrainingAction, MLTrainingObservation, dict]):
75
+ """OpenEnv environment for PyTorch training run debugging.
76
+
77
+ Spec Section 9 — Architecture.
78
+ """
79
+
80
+ SUPPORTS_CONCURRENT_SESSIONS = True
81
+
82
+ def __init__(self, **kwargs: Any) -> None:
83
+ super().__init__(**kwargs)
84
+ self._sessions: dict[str, SessionData] = {}
85
+ self._last_completed: dict[str, dict] = {}
86
+ self._current_session_id: str = ""
87
+
88
+ def _get_session(self, episode_id: str | None = None) -> SessionData | None:
89
+ sid = episode_id or self._current_session_id
90
+ return self._sessions.get(sid)
91
+
92
+ def _build_observation(
93
+ self, session: SessionData, reward: float = 0.0
94
+ ) -> MLTrainingObservation:
95
+ """Build observation from session data."""
96
+ state = session.state
97
+
98
+ gradient_stats_models = []
99
+ if state.gradients_inspected and session.gradient_stats:
100
+ gradient_stats_models = session.gradient_stats
101
+
102
+ weight_stats_models = None
103
+ if state.model_weights_inspected and session.weight_stats is not None:
104
+ weight_stats_models = session.weight_stats
105
+
106
+ data_batch = None
107
+ if state.data_inspected and session.data_batch_stats_raw is not None:
108
+ data_batch = DataBatchStats(**session.data_batch_stats_raw)
109
+
110
+ model_modes = None
111
+ if state.model_modes_inspected and session.model_modes is not None:
112
+ model_modes = session.model_modes
113
+
114
+ code_snippet = None
115
+ if state.code_inspected and session.code_snippet_raw is not None:
116
+ code_snippet = CodeSnippet(**session.code_snippet_raw)
117
+
118
+ return MLTrainingObservation(
119
+ run_id=self._current_session_id,
120
+ framework="pytorch",
121
+ epoch=20,
122
+ training_loss_history=session.loss_history,
123
+ val_loss_history=session.val_loss_history,
124
+ val_accuracy_history=session.val_acc_history,
125
+ gradient_stats=gradient_stats_models,
126
+ model_weight_stats=weight_stats_models,
127
+ gpu_memory_used_gb=session.scenario.gpu_memory_used_gb,
128
+ gpu_memory_total_gb=16.0,
129
+ learning_rate=session.config.learning_rate,
130
+ current_config=session.config,
131
+ error_log=session.scenario.error_log,
132
+ data_batch_stats=data_batch,
133
+ model_mode_info=model_modes,
134
+ code_snippet=code_snippet,
135
+ available_actions=state.compute_available_actions(),
136
+ episode_state=state,
137
+ notes=session.scenario.notes,
138
+ done=session.done,
139
+ reward=reward,
140
+ )
141
+
142
+ def reset(
143
+ self,
144
+ seed: Optional[int] = None,
145
+ episode_id: Optional[str] = None,
146
+ **kwargs: Any,
147
+ ) -> MLTrainingObservation:
148
+ """Reset environment for a new episode. Spec Section 13."""
149
+ # Determine task_id — passed via kwargs or defaults to task_001
150
+ task_id = kwargs.get("task_id", "task_001")
151
+
152
+ # If called with episode_id that has an active session, terminate it
153
+ session_id = episode_id or str(uuid.uuid4())
154
+ if session_id in self._sessions:
155
+ old = self._sessions[session_id]
156
+ if not old.done:
157
+ score = grade_episode(old.scenario.task_id, old.state, old.scenario)
158
+ self._last_completed[session_id] = {
159
+ "score": score,
160
+ "task_id": old.scenario.task_id,
161
+ "steps": old.state.step_count,
162
+ }
163
+
164
+ self._current_session_id = session_id
165
+
166
+ # Derive deterministic seed
167
+ base_seed = seed if seed is not None else 42
168
+ scenario = sample_scenario(task_id, base_seed)
169
+
170
+ # Set torch seed for reproducibility
171
+ torch.manual_seed(scenario.seed)
172
+
173
+ # Create real PyTorch model with fault injection
174
+ model, info = create_model_and_inject_fault(scenario)
175
+
176
+ # Generate parametric curves
177
+ loss_history = gen_loss_history(scenario)
178
+ val_acc_history = gen_val_accuracy_history(scenario)
179
+ val_loss_history = gen_val_loss_history(scenario)
180
+
181
+ # Pre-generate data batch stats
182
+ data_batch_raw = gen_data_batch_stats(scenario)
183
+
184
+ # Pre-generate code snippet (for Task 6)
185
+ code_snippet_raw = None
186
+ if scenario.bug_type is not None:
187
+ code_snippet_raw = generate_code_snippet(scenario.bug_type, scenario.seed)
188
+
189
+ # Build initial config from scenario
190
+ config = TrainingConfig(
191
+ learning_rate=scenario.learning_rate,
192
+ weight_decay=scenario.weight_decay,
193
+ )
194
+
195
+ # Create fresh episode state
196
+ state = EpisodeState()
197
+
198
+ session = SessionData(
199
+ scenario=scenario,
200
+ model=model,
201
+ state=state,
202
+ config=config,
203
+ gradient_stats=[],
204
+ weight_stats=None,
205
+ model_modes=None,
206
+ data_batch_stats_raw=data_batch_raw,
207
+ code_snippet_raw=code_snippet_raw,
208
+ loss_history=loss_history,
209
+ val_acc_history=val_acc_history,
210
+ val_loss_history=val_loss_history,
211
+ done=False,
212
+ last_score=None,
213
+ convergence_after_fix=False,
214
+ )
215
+
216
+ self._sessions[session_id] = session
217
+
218
+ logger.info(
219
+ "reset",
220
+ extra={
221
+ "session_id": session_id,
222
+ "task_id": task_id,
223
+ "scenario_seed": scenario.seed,
224
+ },
225
+ )
226
+
227
+ return self._build_observation(session)
228
+
229
+ def step(
230
+ self,
231
+ action: MLTrainingAction,
232
+ timeout_s: Optional[float] = None,
233
+ **kwargs: Any,
234
+ ) -> MLTrainingObservation:
235
+ """Process one agent action. NEVER raises. Spec Sections 13, 16."""
236
+ session = self._get_session()
237
+
238
+ # No active episode
239
+ if session is None:
240
+ return MLTrainingObservation(
241
+ done=True,
242
+ reward=0.0,
243
+ error_log="Error: no active episode. Call reset(task_id) first.",
244
+ )
245
+
246
+ # Episode already done
247
+ if session.done:
248
+ return self._build_observation(session, reward=0.0)
249
+
250
+ state = session.state
251
+ scenario = session.scenario
252
+ action_type = action.action_type
253
+
254
+ # Increment step count
255
+ state.step_count += 1
256
+
257
+ # Validate action_type is a known type
258
+ if action_type not in ALL_ACTION_TYPES:
259
+ reward = compute_reward(action, state, scenario, is_valid_action=False)
260
+ state.actions_taken.append(f"invalid:{action_type}")
261
+ obs = self._build_observation(session, reward=reward)
262
+ obs.error_log = (
263
+ f"Invalid action_type: {action_type}. "
264
+ f"Valid types: {sorted(ALL_ACTION_TYPES)}"
265
+ )
266
+ return obs
267
+
268
+ # Check if action is in available_actions
269
+ available = state.compute_available_actions()
270
+ if action_type not in available:
271
+ reward = compute_reward(action, state, scenario, is_valid_action=False)
272
+ state.actions_taken.append(f"unavailable:{action_type}")
273
+ obs = self._build_observation(session, reward=reward)
274
+ obs.error_log = (
275
+ f"Action '{action_type}' not available. " f"Available: {available}"
276
+ )
277
+ return obs
278
+
279
+ # Validate required fields for specific actions
280
+ error = self._validate_action_fields(action)
281
+ if error is not None:
282
+ reward = compute_reward(action, state, scenario, is_valid_action=False)
283
+ state.actions_taken.append(f"malformed:{action_type}")
284
+ obs = self._build_observation(session, reward=reward)
285
+ obs.error_log = error
286
+ return obs
287
+
288
+ # Dispatch action
289
+ is_correct_fix: bool | None = None
290
+ convergence = False
291
+
292
+ try:
293
+ is_correct_fix, convergence = self._dispatch_action(action, session)
294
+ except Exception as exc:
295
+ logger.error(
296
+ "step_error",
297
+ extra={
298
+ "session_id": self._current_session_id,
299
+ "action": action_type,
300
+ "error": str(exc),
301
+ },
302
+ exc_info=True,
303
+ )
304
+ reward = compute_reward(action, state, scenario, is_valid_action=False)
305
+ obs = self._build_observation(session, reward=reward)
306
+ obs.error_log = f"Internal error processing {action_type}: {exc}"
307
+ return obs
308
+
309
+ # Record action
310
+ if action_type == "mark_diagnosed" and action.diagnosis:
311
+ state.actions_taken.append(f"mark_diagnosed:{action.diagnosis}")
312
+ else:
313
+ state.actions_taken.append(action_type)
314
+
315
+ # Compute reward
316
+ reward = compute_reward(
317
+ action,
318
+ state,
319
+ scenario,
320
+ is_valid_action=True,
321
+ is_correct_fix=is_correct_fix,
322
+ convergence_confirmed=convergence,
323
+ )
324
+
325
+ # Check step limit
326
+ if state.step_count >= scenario.max_steps and not session.done:
327
+ session.done = True
328
+
329
+ # Check done
330
+ if session.done:
331
+ score = grade_episode(scenario.task_id, state, scenario)
332
+ session.last_score = score
333
+ self._last_completed[self._current_session_id] = {
334
+ "score": score,
335
+ "task_id": scenario.task_id,
336
+ "steps": state.step_count,
337
+ }
338
+ logger.info(
339
+ "episode_completed",
340
+ extra={
341
+ "session_id": self._current_session_id,
342
+ "task_id": scenario.task_id,
343
+ "steps": state.step_count,
344
+ "score": score,
345
+ },
346
+ )
347
+
348
+ logger.info(
349
+ "step",
350
+ extra={
351
+ "session_id": self._current_session_id,
352
+ "step_count": state.step_count,
353
+ "action_type": action_type,
354
+ "reward": reward,
355
+ },
356
+ )
357
+
358
+ return self._build_observation(session, reward=reward)
359
+
360
+ def _validate_action_fields(self, action: MLTrainingAction) -> str | None:
361
+ """Validate required fields for specific actions. Return error or None."""
362
+ if action.action_type == "modify_config":
363
+ if action.target is None or action.value is None:
364
+ return "modify_config requires 'target' and 'value' fields"
365
+ if action.target not in VALID_CONFIG_KEYS:
366
+ return f"Unknown config key: {action.target}. Valid: {sorted(VALID_CONFIG_KEYS)}"
367
+
368
+ if action.action_type == "mark_diagnosed":
369
+ if action.diagnosis is None:
370
+ return "mark_diagnosed requires 'diagnosis' field"
371
+ if action.diagnosis not in VALID_DIAGNOSES:
372
+ return (
373
+ f"Invalid diagnosis: {action.diagnosis}. "
374
+ f"Valid: {sorted(VALID_DIAGNOSES)}"
375
+ )
376
+
377
+ if action.action_type == "fix_code":
378
+ if action.line is None or action.replacement is None:
379
+ return "fix_code requires 'line' and 'replacement' fields"
380
+
381
+ return None
382
+
383
+ def _dispatch_action(
384
+ self, action: MLTrainingAction, session: SessionData
385
+ ) -> tuple[bool | None, bool]:
386
+ """Dispatch action to handler. Returns (is_correct_fix, convergence)."""
387
+ state = session.state
388
+ scenario = session.scenario
389
+ is_correct_fix: bool | None = None
390
+ convergence = False
391
+
392
+ at = action.action_type
393
+
394
+ if at == "inspect_gradients":
395
+ if not state.gradients_inspected:
396
+ stats = extract_gradient_stats(session.model, scenario)
397
+ session.gradient_stats = stats
398
+ state.gradients_inspected = True
399
+ # Set gradients_were_normal: True if ALL layers is_exploding=False
400
+ state.gradients_were_normal = all(not s.is_exploding for s in stats)
401
+
402
+ elif at == "inspect_data_batch":
403
+ state.data_inspected = True
404
+
405
+ elif at == "inspect_model_modes":
406
+ if not state.model_modes_inspected:
407
+ modes = extract_model_modes(session.model)
408
+ session.model_modes = modes
409
+ state.model_modes_inspected = True
410
+
411
+ elif at == "inspect_model_weights":
412
+ if not state.model_weights_inspected:
413
+ stats = extract_weight_stats(session.model)
414
+ session.weight_stats = stats
415
+ state.model_weights_inspected = True
416
+
417
+ elif at == "inspect_code":
418
+ state.code_inspected = True
419
+
420
+ elif at == "modify_config":
421
+ if action.target and action.value is not None:
422
+ setattr(session.config, action.target, action.value)
423
+ state.fix_action_taken = True
424
+
425
+ elif at == "add_callback":
426
+ state.fix_action_taken = True
427
+
428
+ elif at == "replace_optimizer":
429
+ state.fix_action_taken = True
430
+
431
+ elif at == "patch_data_loader":
432
+ state.fix_action_taken = True
433
+
434
+ elif at == "fix_model_mode":
435
+ state.fix_action_taken = True
436
+
437
+ elif at == "fix_code":
438
+ state.fix_action_taken = True
439
+ if scenario.bug_type and action.line and action.replacement:
440
+ is_correct_fix = validate_fix(
441
+ scenario.bug_type, action.line, action.replacement
442
+ )
443
+ else:
444
+ is_correct_fix = False
445
+
446
+ elif at == "restart_run":
447
+ state.restart_after_fix = True
448
+ # Check convergence — did the fix address the root cause?
449
+ convergence = self._check_convergence(session)
450
+ session.convergence_after_fix = convergence
451
+
452
+ elif at == "mark_diagnosed":
453
+ state.diagnosis_submitted = True
454
+ session.done = True
455
+
456
+ elif at == "rollback_checkpoint":
457
+ pass # No-op for now
458
+
459
+ return is_correct_fix, convergence
460
+
461
+ def _check_convergence(self, session: SessionData) -> bool:
462
+ """Check if the applied fix would resolve the root cause."""
463
+ scenario = session.scenario
464
+ state = session.state
465
+ root = scenario.root_cause.value
466
+
467
+ if root == "lr_too_high":
468
+ return (
469
+ "modify_config" in state.actions_taken
470
+ and session.config.learning_rate <= 0.001
471
+ )
472
+
473
+ if root == "vanishing_gradients":
474
+ return (
475
+ "modify_config" in state.actions_taken
476
+ and session.config.learning_rate >= 0.001
477
+ )
478
+
479
+ if root == "data_leakage":
480
+ return "patch_data_loader" in state.actions_taken
481
+
482
+ if root == "overfitting":
483
+ return (
484
+ "modify_config" in state.actions_taken
485
+ or "add_callback" in state.actions_taken
486
+ )
487
+
488
+ if root == "batchnorm_eval_mode":
489
+ return "fix_model_mode" in state.actions_taken
490
+
491
+ if root == "code_bug":
492
+ return "fix_code" in state.actions_taken and state.fix_action_taken
493
+
494
+ return False
495
+
496
+ @property
497
+ def state(self) -> dict:
498
+ """Return current environment state."""
499
+ session = self._get_session()
500
+ if session is None:
501
+ return {"status": "no_active_episode"}
502
+ return {
503
+ "status": "active",
504
+ "task_id": session.scenario.task_id,
505
+ "step_count": session.state.step_count,
506
+ "done": session.done,
507
+ }
508
+
509
+ def get_last_completed(self, session_id: str | None = None) -> dict | None:
510
+ """Get last completed episode data for grader endpoint."""
511
+ if session_id:
512
+ return self._last_completed.get(session_id)
513
+ # Return most recent
514
+ if self._last_completed:
515
+ return list(self._last_completed.values())[-1]
516
+ return None
tests/__init__.py ADDED
File without changes
tests/conftest.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared test fixtures."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.models import (
8
+ EpisodeState,
9
+ TrainingConfig,
10
+ )
11
+ from ml_training_debugger.scenarios import ScenarioParams, sample_scenario
12
+
13
+
14
+ @pytest.fixture
15
+ def fresh_state() -> EpisodeState:
16
+ return EpisodeState()
17
+
18
+
19
+ @pytest.fixture
20
+ def sample_config() -> TrainingConfig:
21
+ return TrainingConfig(learning_rate=0.001)
22
+
23
+
24
+ @pytest.fixture
25
+ def task_001_scenario() -> ScenarioParams:
26
+ return sample_scenario("task_001", seed=42)
27
+
28
+
29
+ @pytest.fixture
30
+ def task_003_scenario() -> ScenarioParams:
31
+ return sample_scenario("task_003", seed=42)
32
+
33
+
34
+ @pytest.fixture
35
+ def task_005_scenario() -> ScenarioParams:
36
+ return sample_scenario("task_005", seed=42)
tests/test_code_templates.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test code bug generation and fix validation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.code_templates import generate_code_snippet, validate_fix
8
+
9
+
10
+ class TestGenerateCodeSnippet:
11
+ def test_eval_mode(self):
12
+ snippet = generate_code_snippet("eval_mode")
13
+ assert "model.eval()" in snippet["code"]
14
+ assert snippet["filename"] == "train.py"
15
+ assert snippet["line_count"] > 0
16
+ assert len(snippet["imports"]) > 0
17
+
18
+ def test_detach_loss(self):
19
+ snippet = generate_code_snippet("detach_loss")
20
+ assert ".detach()" in snippet["code"]
21
+
22
+ def test_zero_grad_missing(self):
23
+ snippet = generate_code_snippet("zero_grad_missing")
24
+ assert "zero_grad" not in snippet["code"]
25
+
26
+ def test_inplace_relu(self):
27
+ snippet = generate_code_snippet("inplace_relu")
28
+ assert "inplace=True" in snippet["code"]
29
+
30
+ def test_unknown_bug_raises(self):
31
+ with pytest.raises(ValueError):
32
+ generate_code_snippet("nonexistent_bug")
33
+
34
+
35
+ class TestValidateFix:
36
+ def test_eval_mode_correct_fix(self):
37
+ assert validate_fix("eval_mode", 5, "model.train()")
38
+
39
+ def test_eval_mode_with_whitespace(self):
40
+ assert validate_fix("eval_mode", 5, " model.train() ")
41
+
42
+ def test_eval_mode_wrong_fix(self):
43
+ assert not validate_fix("eval_mode", 5, "pass")
44
+
45
+ def test_detach_loss_correct_fix(self):
46
+ assert validate_fix(
47
+ "detach_loss", 14, " loss = criterion(output, batch_y)"
48
+ )
49
+
50
+ def test_detach_loss_with_trailing_spaces(self):
51
+ assert validate_fix(
52
+ "detach_loss", 14, " loss = criterion(output, batch_y) "
53
+ )
54
+
55
+ def test_zero_grad_correct_fix(self):
56
+ assert validate_fix("zero_grad_missing", 11, " optimizer.zero_grad()")
57
+
58
+ def test_inplace_relu_correct_fix(self):
59
+ assert validate_fix("inplace_relu", 15, " output = F.relu(output)")
60
+
61
+ def test_wrong_line_number(self):
62
+ assert not validate_fix("eval_mode", 999, "model.train()")
63
+
64
+ def test_unknown_bug_type(self):
65
+ assert not validate_fix("nonexistent", 1, "pass")
tests/test_episode_lifecycle.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test full episode lifecycle — reset, step, state transitions."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.models import MLTrainingAction
8
+ from server.environment import MLTrainingEnvironment
9
+
10
+
11
+ @pytest.fixture
12
+ def env():
13
+ return MLTrainingEnvironment()
14
+
15
+
16
+ class TestReset:
17
+ def test_reset_returns_valid_observation(self, env):
18
+ obs = env.reset(seed=42, episode_id="test", task_id="task_001")
19
+ assert obs.run_id == "test"
20
+ assert obs.framework == "pytorch"
21
+ assert len(obs.training_loss_history) == 20
22
+ assert len(obs.val_accuracy_history) == 20
23
+ assert obs.done is False
24
+
25
+ def test_reset_initial_state(self, env):
26
+ obs = env.reset(seed=42, episode_id="test", task_id="task_001")
27
+ assert obs.episode_state.step_count == 0
28
+ assert not obs.episode_state.gradients_inspected
29
+ assert not obs.episode_state.diagnosis_submitted
30
+
31
+ def test_reset_progressive_reveal(self, env):
32
+ obs = env.reset(seed=42, episode_id="test", task_id="task_001")
33
+ assert obs.gradient_stats == []
34
+ assert obs.model_weight_stats is None
35
+ assert obs.data_batch_stats is None
36
+ assert obs.model_mode_info is None
37
+ assert obs.code_snippet is None
38
+
39
+ def test_reset_available_actions(self, env):
40
+ obs = env.reset(seed=42, episode_id="test", task_id="task_001")
41
+ assert "inspect_gradients" in obs.available_actions
42
+ assert "mark_diagnosed" in obs.available_actions
43
+ assert "fix_code" not in obs.available_actions
44
+ assert "restart_run" not in obs.available_actions
45
+
46
+
47
+ class TestStepInspections:
48
+ def test_inspect_gradients_populates_stats(self, env):
49
+ env.reset(seed=42, episode_id="test", task_id="task_001")
50
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
51
+ assert len(obs.gradient_stats) > 0
52
+ assert obs.episode_state.gradients_inspected
53
+
54
+ def test_inspect_data_batch(self, env):
55
+ env.reset(seed=42, episode_id="test", task_id="task_003")
56
+ obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
57
+ assert obs.data_batch_stats is not None
58
+ assert obs.episode_state.data_inspected
59
+
60
+ def test_inspect_model_modes(self, env):
61
+ env.reset(seed=42, episode_id="test", task_id="task_005")
62
+ obs = env.step(MLTrainingAction(action_type="inspect_model_modes"))
63
+ assert obs.model_mode_info is not None
64
+ assert obs.episode_state.model_modes_inspected
65
+
66
+ def test_inspect_model_weights(self, env):
67
+ env.reset(seed=42, episode_id="test", task_id="task_001")
68
+ obs = env.step(MLTrainingAction(action_type="inspect_model_weights"))
69
+ assert obs.model_weight_stats is not None
70
+ assert obs.episode_state.model_weights_inspected
71
+
72
+
73
+ class TestStepFixActions:
74
+ def test_modify_config(self, env):
75
+ env.reset(seed=42, episode_id="test", task_id="task_001")
76
+ obs = env.step(
77
+ MLTrainingAction(
78
+ action_type="modify_config",
79
+ target="learning_rate",
80
+ value=0.001,
81
+ )
82
+ )
83
+ assert obs.episode_state.fix_action_taken
84
+ assert "restart_run" in obs.available_actions
85
+
86
+ def test_restart_run_after_fix(self, env):
87
+ env.reset(seed=42, episode_id="test", task_id="task_001")
88
+ env.step(
89
+ MLTrainingAction(
90
+ action_type="modify_config",
91
+ target="learning_rate",
92
+ value=0.001,
93
+ )
94
+ )
95
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
96
+ assert obs.episode_state.restart_after_fix
97
+
98
+
99
+ class TestStepDiagnosis:
100
+ def test_mark_diagnosed_ends_episode(self, env):
101
+ env.reset(seed=42, episode_id="test", task_id="task_001")
102
+ obs = env.step(
103
+ MLTrainingAction(
104
+ action_type="mark_diagnosed",
105
+ diagnosis="lr_too_high",
106
+ )
107
+ )
108
+ assert obs.done is True
109
+ assert obs.episode_state.diagnosis_submitted
110
+
111
+ def test_step_after_done(self, env):
112
+ env.reset(seed=42, episode_id="test", task_id="task_001")
113
+ env.step(
114
+ MLTrainingAction(
115
+ action_type="mark_diagnosed",
116
+ diagnosis="lr_too_high",
117
+ )
118
+ )
119
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
120
+ assert obs.done is True
121
+ assert obs.reward == 0.0
122
+
123
+
124
+ class TestErrorHandling:
125
+ def test_invalid_action_type(self, env):
126
+ env.reset(seed=42, episode_id="test", task_id="task_001")
127
+ obs = env.step(MLTrainingAction(action_type="nonexistent_action"))
128
+ assert obs.reward == pytest.approx(-0.01 + -0.05)
129
+ assert obs.error_log is not None
130
+
131
+ def test_action_not_in_available(self, env):
132
+ env.reset(seed=42, episode_id="test", task_id="task_001")
133
+ # fix_code requires code_inspected=True
134
+ obs = env.step(
135
+ MLTrainingAction(
136
+ action_type="fix_code",
137
+ line=1,
138
+ replacement="pass",
139
+ )
140
+ )
141
+ assert obs.reward < 0
142
+
143
+ def test_modify_config_missing_target(self, env):
144
+ env.reset(seed=42, episode_id="test", task_id="task_001")
145
+ obs = env.step(MLTrainingAction(action_type="modify_config"))
146
+ assert "target" in obs.error_log.lower() or "value" in obs.error_log.lower()
147
+
148
+ def test_mark_diagnosed_missing_diagnosis(self, env):
149
+ env.reset(seed=42, episode_id="test", task_id="task_001")
150
+ obs = env.step(MLTrainingAction(action_type="mark_diagnosed"))
151
+ assert "diagnosis" in obs.error_log.lower()
152
+
153
+ def test_mark_diagnosed_invalid_diagnosis(self, env):
154
+ env.reset(seed=42, episode_id="test", task_id="task_001")
155
+ obs = env.step(
156
+ MLTrainingAction(
157
+ action_type="mark_diagnosed",
158
+ diagnosis="not_a_real_diagnosis",
159
+ )
160
+ )
161
+ assert "invalid" in obs.error_log.lower()
162
+
163
+ def test_step_before_reset(self, env):
164
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
165
+ assert obs.done is True
166
+
167
+
168
+ class TestFullEpisodeFlow:
169
+ def test_task_001_full_flow(self, env):
170
+ """Full optimal flow for Task 1."""
171
+ obs = env.reset(seed=42, episode_id="test", task_id="task_001")
172
+ assert not obs.done
173
+
174
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
175
+ assert obs.episode_state.gradients_inspected
176
+ assert any(g.is_exploding for g in obs.gradient_stats)
177
+
178
+ obs = env.step(
179
+ MLTrainingAction(
180
+ action_type="modify_config",
181
+ target="learning_rate",
182
+ value=0.001,
183
+ )
184
+ )
185
+ assert obs.episode_state.fix_action_taken
186
+
187
+ obs = env.step(MLTrainingAction(action_type="restart_run"))
188
+ assert obs.episode_state.restart_after_fix
189
+
190
+ obs = env.step(
191
+ MLTrainingAction(
192
+ action_type="mark_diagnosed",
193
+ diagnosis="lr_too_high",
194
+ )
195
+ )
196
+ assert obs.done
197
+ assert obs.reward > 0
198
+
199
+ def test_task_005_context_gated_penalty(self, env):
200
+ """Task 5: inspect gradients (normal) → add_callback → penalty fires."""
201
+ obs = env.reset(seed=42, episode_id="test", task_id="task_005")
202
+
203
+ obs = env.step(MLTrainingAction(action_type="inspect_gradients"))
204
+ assert obs.episode_state.gradients_inspected
205
+ assert obs.episode_state.gradients_were_normal
206
+ # All layers is_exploding=False
207
+ for g in obs.gradient_stats:
208
+ assert not g.is_exploding
209
+
210
+ # Now add_callback should trigger context-gated penalty
211
+ obs = env.step(MLTrainingAction(action_type="add_callback"))
212
+ assert obs.reward == pytest.approx(-0.01 + -0.20)
213
+
214
+ def test_task_003_data_leakage(self, env):
215
+ """Task 3: data inspection reveals leakage."""
216
+ obs = env.reset(seed=42, episode_id="test", task_id="task_003")
217
+
218
+ obs = env.step(MLTrainingAction(action_type="inspect_data_batch"))
219
+ assert obs.data_batch_stats is not None
220
+ assert obs.data_batch_stats.class_overlap_score > 0.5
tests/test_graders.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test grader functions — each returns 0.0-1.0."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.graders import (
8
+ grade_episode,
9
+ grade_task_001,
10
+ grade_task_003,
11
+ grade_task_005,
12
+ )
13
+ from ml_training_debugger.models import EpisodeState
14
+ from ml_training_debugger.scenarios import sample_scenario
15
+
16
+
17
+ @pytest.fixture
18
+ def scenario_001():
19
+ return sample_scenario("task_001", seed=42)
20
+
21
+
22
+ @pytest.fixture
23
+ def scenario_003():
24
+ return sample_scenario("task_003", seed=42)
25
+
26
+
27
+ @pytest.fixture
28
+ def scenario_005():
29
+ return sample_scenario("task_005", seed=42)
30
+
31
+
32
+ class TestGradeTask001:
33
+ def test_perfect_score(self, scenario_001):
34
+ state = EpisodeState(
35
+ gradients_inspected=True,
36
+ fix_action_taken=True,
37
+ restart_after_fix=True,
38
+ diagnosis_submitted=True,
39
+ actions_taken=[
40
+ "inspect_gradients",
41
+ "modify_config",
42
+ "restart_run",
43
+ "mark_diagnosed:lr_too_high",
44
+ ],
45
+ )
46
+ score = grade_task_001(state, scenario_001)
47
+ assert score == 1.0
48
+
49
+ def test_wrong_diagnosis(self, scenario_001):
50
+ state = EpisodeState(
51
+ gradients_inspected=True,
52
+ fix_action_taken=True,
53
+ restart_after_fix=True,
54
+ diagnosis_submitted=True,
55
+ actions_taken=[
56
+ "inspect_gradients",
57
+ "modify_config",
58
+ "restart_run",
59
+ "mark_diagnosed:data_leakage",
60
+ ],
61
+ )
62
+ score = grade_task_001(state, scenario_001)
63
+ assert score < 0.7 # Missing diagnosis credit
64
+
65
+ def test_no_investigation(self, scenario_001):
66
+ state = EpisodeState(
67
+ diagnosis_submitted=True,
68
+ actions_taken=["mark_diagnosed:lr_too_high"],
69
+ )
70
+ score = grade_task_001(state, scenario_001)
71
+ assert 0.0 < score < 1.0
72
+
73
+ def test_score_in_range(self, scenario_001):
74
+ state = EpisodeState()
75
+ score = grade_task_001(state, scenario_001)
76
+ assert 0.0 <= score <= 1.0
77
+
78
+
79
+ class TestGradeTask003:
80
+ def test_perfect_score(self, scenario_003):
81
+ state = EpisodeState(
82
+ data_inspected=True,
83
+ fix_action_taken=True,
84
+ restart_after_fix=True,
85
+ diagnosis_submitted=True,
86
+ actions_taken=[
87
+ "inspect_data_batch",
88
+ "patch_data_loader",
89
+ "restart_run",
90
+ "mark_diagnosed:data_leakage",
91
+ ],
92
+ )
93
+ score = grade_task_003(state, scenario_003)
94
+ assert score == pytest.approx(1.0)
95
+
96
+ def test_wrong_diagnosis(self, scenario_003):
97
+ state = EpisodeState(
98
+ data_inspected=True,
99
+ diagnosis_submitted=True,
100
+ actions_taken=[
101
+ "inspect_data_batch",
102
+ "mark_diagnosed:overfitting",
103
+ ],
104
+ )
105
+ score = grade_task_003(state, scenario_003)
106
+ assert score < 0.5
107
+
108
+
109
+ class TestGradeTask005:
110
+ def test_perfect_score(self, scenario_005):
111
+ state = EpisodeState(
112
+ gradients_inspected=True,
113
+ gradients_were_normal=True,
114
+ model_modes_inspected=True,
115
+ fix_action_taken=True,
116
+ restart_after_fix=True,
117
+ diagnosis_submitted=True,
118
+ actions_taken=[
119
+ "inspect_gradients",
120
+ "inspect_model_modes",
121
+ "fix_model_mode",
122
+ "restart_run",
123
+ "mark_diagnosed:batchnorm_eval_mode",
124
+ ],
125
+ )
126
+ score = grade_task_005(state, scenario_005)
127
+ assert score == 1.0
128
+
129
+ def test_red_herring_chaser(self, scenario_005):
130
+ """Agent that chases gradient red herring scores 0.80-0.85."""
131
+ state = EpisodeState(
132
+ gradients_inspected=True,
133
+ gradients_were_normal=True,
134
+ model_modes_inspected=True,
135
+ fix_action_taken=True,
136
+ restart_after_fix=True,
137
+ diagnosis_submitted=True,
138
+ actions_taken=[
139
+ "inspect_gradients",
140
+ "add_callback", # Wrong: chases red herring
141
+ "inspect_model_modes",
142
+ "fix_model_mode",
143
+ "restart_run",
144
+ "mark_diagnosed:batchnorm_eval_mode",
145
+ ],
146
+ )
147
+ score = grade_task_005(state, scenario_005)
148
+ # -0.20 penalty for add_callback after normal gradients
149
+ assert 0.7 <= score <= 0.90
150
+
151
+
152
+ class TestGradeEpisode:
153
+ def test_dispatch_to_correct_grader(self, scenario_001):
154
+ state = EpisodeState(
155
+ gradients_inspected=True,
156
+ diagnosis_submitted=True,
157
+ actions_taken=[
158
+ "inspect_gradients",
159
+ "mark_diagnosed:lr_too_high",
160
+ ],
161
+ )
162
+ score = grade_episode("task_001", state, scenario_001)
163
+ assert 0.0 <= score <= 1.0
164
+
165
+ def test_unknown_task_returns_zero(self, scenario_001):
166
+ state = EpisodeState()
167
+ score = grade_episode("task_999", state, scenario_001)
168
+ assert score == 0.0
tests/test_models.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test all Pydantic models instantiate and serialize correctly."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+
7
+ from openenv.core.env_server.types import Action, Observation
8
+
9
+ from ml_training_debugger.models import (
10
+ EpisodeState,
11
+ GradientStats,
12
+ MLTrainingAction,
13
+ MLTrainingObservation,
14
+ RootCauseDiagnosis,
15
+ TrainingConfig,
16
+ )
17
+
18
+
19
+ class TestRootCauseDiagnosis:
20
+ def test_all_six_values_exist(self):
21
+ assert len(RootCauseDiagnosis) == 6
22
+
23
+ def test_values_are_strings(self):
24
+ for d in RootCauseDiagnosis:
25
+ assert isinstance(d.value, str)
26
+
27
+ def test_specific_values(self):
28
+ assert RootCauseDiagnosis.LR_TOO_HIGH.value == "lr_too_high"
29
+ assert RootCauseDiagnosis.CODE_BUG.value == "code_bug"
30
+
31
+
32
+ class TestTrainingConfig:
33
+ def test_default_instantiation(self):
34
+ config = TrainingConfig()
35
+ assert config.learning_rate == 0.001
36
+ assert config.gradient_clip_norm is None
37
+
38
+ def test_json_roundtrip(self):
39
+ config = TrainingConfig(learning_rate=0.01, weight_decay=0.1)
40
+ data = json.loads(config.model_dump_json())
41
+ restored = TrainingConfig.model_validate(data)
42
+ assert restored.learning_rate == 0.01
43
+ assert restored.weight_decay == 0.1
44
+
45
+
46
+ class TestGradientStats:
47
+ def test_exploding(self):
48
+ stats = GradientStats(
49
+ layer_name="fc",
50
+ norm_history=[15.0],
51
+ mean_norm=15.0,
52
+ max_norm=15.0,
53
+ is_exploding=True,
54
+ is_vanishing=False,
55
+ )
56
+ assert stats.is_exploding
57
+
58
+ def test_vanishing(self):
59
+ stats = GradientStats(
60
+ layer_name="conv1",
61
+ norm_history=[1e-7],
62
+ mean_norm=1e-7,
63
+ max_norm=1e-7,
64
+ is_exploding=False,
65
+ is_vanishing=True,
66
+ )
67
+ assert stats.is_vanishing
68
+
69
+ def test_normal(self):
70
+ stats = GradientStats(
71
+ layer_name="conv1",
72
+ norm_history=[0.5],
73
+ mean_norm=0.5,
74
+ max_norm=0.5,
75
+ is_exploding=False,
76
+ is_vanishing=False,
77
+ )
78
+ assert not stats.is_exploding
79
+ assert not stats.is_vanishing
80
+
81
+
82
+ class TestEpisodeState:
83
+ def test_fresh_state(self):
84
+ state = EpisodeState()
85
+ assert state.step_count == 0
86
+ assert not state.gradients_inspected
87
+ assert not state.diagnosis_submitted
88
+
89
+ def test_available_actions_initial(self):
90
+ state = EpisodeState()
91
+ actions = state.compute_available_actions()
92
+ assert "inspect_gradients" in actions
93
+ assert "mark_diagnosed" in actions
94
+ assert "fix_code" not in actions
95
+ assert "restart_run" not in actions
96
+ assert "rollback_checkpoint" not in actions
97
+
98
+ def test_fix_code_available_after_code_inspected(self):
99
+ state = EpisodeState(code_inspected=True)
100
+ actions = state.compute_available_actions()
101
+ assert "fix_code" in actions
102
+
103
+ def test_restart_run_available_after_fix(self):
104
+ state = EpisodeState(fix_action_taken=True)
105
+ actions = state.compute_available_actions()
106
+ assert "restart_run" in actions
107
+
108
+ def test_rollback_available_after_restart(self):
109
+ state = EpisodeState(restart_after_fix=True)
110
+ actions = state.compute_available_actions()
111
+ assert "rollback_checkpoint" in actions
112
+
113
+ def test_mark_diagnosed_disappears_after_submission(self):
114
+ state = EpisodeState(diagnosis_submitted=True)
115
+ actions = state.compute_available_actions()
116
+ assert "mark_diagnosed" not in actions
117
+
118
+
119
+ class TestMLTrainingObservation:
120
+ def test_extends_observation(self):
121
+ assert issubclass(MLTrainingObservation, Observation)
122
+
123
+ def test_has_done_and_reward(self):
124
+ obs = MLTrainingObservation(done=True, reward=0.5)
125
+ assert obs.done is True
126
+ assert obs.reward == 0.5
127
+
128
+ def test_json_serialization(self):
129
+ obs = MLTrainingObservation(
130
+ run_id="test",
131
+ training_loss_history=[1.0, 2.0],
132
+ val_accuracy_history=[0.5],
133
+ )
134
+ data = json.loads(obs.model_dump_json())
135
+ assert data["run_id"] == "test"
136
+ assert data["framework"] == "pytorch"
137
+
138
+
139
+ class TestMLTrainingAction:
140
+ def test_extends_action(self):
141
+ assert issubclass(MLTrainingAction, Action)
142
+
143
+ def test_basic_action(self):
144
+ action = MLTrainingAction(action_type="inspect_gradients")
145
+ assert action.action_type == "inspect_gradients"
146
+
147
+ def test_modify_config_action(self):
148
+ action = MLTrainingAction(
149
+ action_type="modify_config",
150
+ target="learning_rate",
151
+ value=0.001,
152
+ )
153
+ assert action.target == "learning_rate"
154
+
155
+ def test_mark_diagnosed_action(self):
156
+ action = MLTrainingAction(
157
+ action_type="mark_diagnosed",
158
+ diagnosis="lr_too_high",
159
+ )
160
+ assert action.diagnosis == "lr_too_high"
161
+
162
+ def test_fix_code_action(self):
163
+ action = MLTrainingAction(
164
+ action_type="fix_code",
165
+ line=13,
166
+ replacement="loss = criterion(output, batch_y)",
167
+ )
168
+ assert action.line == 13
tests/test_pytorch_engine.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test real PyTorch model instantiation and fault injection."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from ml_training_debugger.pytorch_engine import (
9
+ SimpleCNN,
10
+ create_model_and_inject_fault,
11
+ extract_gradient_stats,
12
+ extract_model_modes,
13
+ extract_weight_stats,
14
+ )
15
+ from ml_training_debugger.scenarios import sample_scenario
16
+
17
+
18
+ class TestSimpleCNN:
19
+ def test_is_nn_module(self):
20
+ model = SimpleCNN()
21
+ assert isinstance(model, nn.Module)
22
+
23
+ def test_param_count(self):
24
+ model = SimpleCNN()
25
+ count = sum(p.numel() for p in model.parameters())
26
+ assert 30_000 < count < 100_000 # ~50K params
27
+
28
+ def test_forward_pass(self):
29
+ model = SimpleCNN()
30
+ x = torch.randn(2, 3, 32, 32)
31
+ out = model(x)
32
+ assert out.shape == (2, 10)
33
+
34
+
35
+ class TestFaultInjection:
36
+ def test_task_001_exploding_gradients(self):
37
+ scenario = sample_scenario("task_001", seed=42)
38
+ model, info = create_model_and_inject_fault(scenario)
39
+ stats = extract_gradient_stats(model, scenario)
40
+ assert len(stats) > 0
41
+ # At least some layers should have elevated gradients
42
+ any_high = any(s.mean_norm > 1.0 for s in stats)
43
+ assert any_high
44
+
45
+ def test_task_005_eval_mode(self):
46
+ scenario = sample_scenario("task_005", seed=42)
47
+ model, info = create_model_and_inject_fault(scenario)
48
+ assert not model.training # model.eval() was called
49
+
50
+ def test_task_005_gradients_not_exploding(self):
51
+ scenario = sample_scenario("task_005", seed=42)
52
+ model, info = create_model_and_inject_fault(scenario)
53
+ stats = extract_gradient_stats(model, scenario)
54
+ # ALL layers must have is_exploding=False
55
+ for s in stats:
56
+ assert not s.is_exploding, f"Layer {s.layer_name} should not be exploding"
57
+
58
+
59
+ class TestExtractGradientStats:
60
+ def test_returns_gradient_stats(self):
61
+ scenario = sample_scenario("task_001", seed=42)
62
+ model, _ = create_model_and_inject_fault(scenario)
63
+ stats = extract_gradient_stats(model, scenario)
64
+ assert len(stats) == 4 # conv1, conv2, conv3, fc
65
+ for s in stats:
66
+ assert isinstance(s.mean_norm, float)
67
+ assert isinstance(s.norm_history, list)
68
+ assert len(s.norm_history) == 5
69
+
70
+
71
+ class TestExtractWeightStats:
72
+ def test_returns_weight_stats(self):
73
+ scenario = sample_scenario("task_001", seed=42)
74
+ model, _ = create_model_and_inject_fault(scenario)
75
+ stats = extract_weight_stats(model)
76
+ assert len(stats) > 0
77
+ for s in stats:
78
+ assert isinstance(s.weight_norm, float)
79
+ assert isinstance(s.has_nan, bool)
80
+
81
+
82
+ class TestExtractModelModes:
83
+ def test_train_mode(self):
84
+ model = SimpleCNN()
85
+ model.train()
86
+ modes = extract_model_modes(model)
87
+ assert all(v == "train" for v in modes.values())
88
+
89
+ def test_eval_mode(self):
90
+ model = SimpleCNN()
91
+ model.eval()
92
+ modes = extract_model_modes(model)
93
+ assert all(v == "eval" for v in modes.values())
tests/test_reward_engine.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test reward engine — all 7 components. THE MOST CRITICAL TEST FILE."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.models import EpisodeState, MLTrainingAction
8
+ from ml_training_debugger.reward_engine import (
9
+ CONTEXT_GATED_PENALTY,
10
+ CORRECT_DIAGNOSIS_REWARD,
11
+ INVALID_ACTION_PENALTY,
12
+ INVESTIGATION_BONUS,
13
+ STEP_PENALTY,
14
+ TERMINAL_CONVERGENCE_REWARD,
15
+ WRONG_CODE_FIX_PENALTY,
16
+ WRONG_DIAGNOSIS_PENALTY,
17
+ compute_reward,
18
+ )
19
+ from ml_training_debugger.scenarios import sample_scenario
20
+
21
+
22
+ @pytest.fixture
23
+ def scenario():
24
+ return sample_scenario("task_001", seed=42)
25
+
26
+
27
+ @pytest.fixture
28
+ def scenario_005():
29
+ return sample_scenario("task_005", seed=42)
30
+
31
+
32
+ class TestStepPenalty:
33
+ def test_flat_step_penalty(self, scenario):
34
+ state = EpisodeState()
35
+ action = MLTrainingAction(action_type="add_callback")
36
+ reward = compute_reward(action, state, scenario)
37
+ assert reward == pytest.approx(STEP_PENALTY)
38
+
39
+ def test_step_penalty_not_multiplied_by_step_count(self, scenario):
40
+ state = EpisodeState(step_count=30)
41
+ action = MLTrainingAction(action_type="add_callback")
42
+ reward = compute_reward(action, state, scenario)
43
+ # Must be flat -0.01, NOT -0.01 * 30
44
+ assert reward == pytest.approx(-0.01)
45
+
46
+
47
+ class TestInvestigationBonus:
48
+ def test_first_time_bonus(self, scenario):
49
+ state = EpisodeState(gradients_inspected=False)
50
+ action = MLTrainingAction(action_type="inspect_gradients")
51
+ reward = compute_reward(action, state, scenario)
52
+ assert reward == pytest.approx(STEP_PENALTY + INVESTIGATION_BONUS)
53
+
54
+ def test_no_bonus_on_repeat(self, scenario):
55
+ state = EpisodeState(gradients_inspected=True)
56
+ action = MLTrainingAction(action_type="inspect_gradients")
57
+ reward = compute_reward(action, state, scenario)
58
+ assert reward == pytest.approx(STEP_PENALTY)
59
+
60
+ def test_each_inspection_type_gives_bonus(self, scenario):
61
+ for action_type, field in [
62
+ ("inspect_gradients", "gradients_inspected"),
63
+ ("inspect_data_batch", "data_inspected"),
64
+ ("inspect_model_modes", "model_modes_inspected"),
65
+ ("inspect_model_weights", "model_weights_inspected"),
66
+ ("inspect_code", "code_inspected"),
67
+ ]:
68
+ state = EpisodeState(**{field: False})
69
+ action = MLTrainingAction(action_type=action_type)
70
+ reward = compute_reward(action, state, scenario)
71
+ assert reward == pytest.approx(
72
+ STEP_PENALTY + INVESTIGATION_BONUS
73
+ ), f"Failed for {action_type}"
74
+
75
+
76
+ class TestContextGatedPenalty:
77
+ """The project's primary innovation — must be exact."""
78
+
79
+ def test_no_penalty_before_inspection(self, scenario_005):
80
+ """add_callback at step 1 (no prior inspection) -> NO penalty."""
81
+ state = EpisodeState()
82
+ action = MLTrainingAction(action_type="add_callback")
83
+ reward = compute_reward(action, state, scenario_005)
84
+ assert reward == pytest.approx(STEP_PENALTY)
85
+
86
+ def test_penalty_after_normal_gradients(self, scenario_005):
87
+ """inspect_gradients (normal) then add_callback -> -0.20 penalty."""
88
+ state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
89
+ action = MLTrainingAction(action_type="add_callback")
90
+ reward = compute_reward(action, state, scenario_005)
91
+ assert reward == pytest.approx(STEP_PENALTY + CONTEXT_GATED_PENALTY)
92
+
93
+ def test_no_penalty_after_abnormal_gradients(self, scenario):
94
+ """inspect_gradients (exploding) then add_callback -> no context penalty."""
95
+ state = EpisodeState(gradients_inspected=True, gradients_were_normal=False)
96
+ action = MLTrainingAction(action_type="add_callback")
97
+ reward = compute_reward(action, state, scenario)
98
+ assert reward == pytest.approx(STEP_PENALTY)
99
+
100
+ def test_penalty_only_for_add_callback(self, scenario_005):
101
+ """Other fix actions don't trigger context-gated penalty."""
102
+ state = EpisodeState(gradients_inspected=True, gradients_were_normal=True)
103
+ for action_type in ["modify_config", "fix_model_mode", "patch_data_loader"]:
104
+ action = MLTrainingAction(
105
+ action_type=action_type, target="learning_rate", value=0.001
106
+ )
107
+ reward = compute_reward(action, state, scenario_005)
108
+ assert reward == pytest.approx(
109
+ STEP_PENALTY
110
+ ), f"Unexpected penalty for {action_type}"
111
+
112
+
113
+ class TestDiagnosisReward:
114
+ def test_correct_diagnosis(self, scenario):
115
+ state = EpisodeState()
116
+ action = MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high")
117
+ reward = compute_reward(action, state, scenario)
118
+ assert reward == pytest.approx(STEP_PENALTY + CORRECT_DIAGNOSIS_REWARD)
119
+
120
+ def test_wrong_diagnosis(self, scenario):
121
+ state = EpisodeState()
122
+ action = MLTrainingAction(
123
+ action_type="mark_diagnosed", diagnosis="data_leakage"
124
+ )
125
+ reward = compute_reward(action, state, scenario)
126
+ assert reward == pytest.approx(STEP_PENALTY + WRONG_DIAGNOSIS_PENALTY)
127
+
128
+
129
+ class TestTerminalConvergence:
130
+ def test_convergence_after_fix_and_restart(self, scenario):
131
+ state = EpisodeState(fix_action_taken=True)
132
+ action = MLTrainingAction(action_type="restart_run")
133
+ reward = compute_reward(action, state, scenario, convergence_confirmed=True)
134
+ assert reward == pytest.approx(STEP_PENALTY + TERMINAL_CONVERGENCE_REWARD)
135
+
136
+ def test_no_convergence_without_fix(self, scenario):
137
+ state = EpisodeState(fix_action_taken=False)
138
+ action = MLTrainingAction(action_type="restart_run")
139
+ reward = compute_reward(action, state, scenario, convergence_confirmed=True)
140
+ # fix_action_taken is False, so no convergence reward
141
+ assert reward == pytest.approx(STEP_PENALTY)
142
+
143
+
144
+ class TestInvalidAction:
145
+ def test_invalid_action_penalty(self, scenario):
146
+ state = EpisodeState()
147
+ action = MLTrainingAction(action_type="restart_run")
148
+ reward = compute_reward(action, state, scenario, is_valid_action=False)
149
+ assert reward == pytest.approx(STEP_PENALTY + INVALID_ACTION_PENALTY)
150
+
151
+
152
+ class TestWrongCodeFix:
153
+ def test_wrong_code_fix_penalty(self, scenario):
154
+ state = EpisodeState(code_inspected=True)
155
+ action = MLTrainingAction(action_type="fix_code", line=1, replacement="pass")
156
+ reward = compute_reward(action, state, scenario, is_correct_fix=False)
157
+ assert reward == pytest.approx(STEP_PENALTY + WRONG_CODE_FIX_PENALTY)
158
+
159
+
160
+ class TestRewardCap:
161
+ def test_reward_capped_at_one(self, scenario):
162
+ # Theoretical max would exceed 1.0 in some scenarios
163
+ reward = compute_reward(
164
+ MLTrainingAction(action_type="mark_diagnosed", diagnosis="lr_too_high"),
165
+ EpisodeState(),
166
+ scenario,
167
+ )
168
+ assert reward <= 1.0
169
+
170
+ def test_reward_capped_at_negative_one(self, scenario):
171
+ reward = compute_reward(
172
+ MLTrainingAction(action_type="mark_diagnosed", diagnosis="wrong"),
173
+ EpisodeState(),
174
+ scenario,
175
+ )
176
+ assert reward >= -1.0
tests/test_scenarios.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test scenario sampling."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from ml_training_debugger.models import RootCauseDiagnosis
8
+ from ml_training_debugger.scenarios import sample_scenario
9
+
10
+
11
+ class TestSampleScenario:
12
+ def test_task_001_root_cause(self):
13
+ s = sample_scenario("task_001", seed=42)
14
+ assert s.root_cause == RootCauseDiagnosis.LR_TOO_HIGH
15
+ assert s.learning_rate >= 0.05
16
+
17
+ def test_task_003_root_cause(self):
18
+ s = sample_scenario("task_003", seed=42)
19
+ assert s.root_cause == RootCauseDiagnosis.DATA_LEAKAGE
20
+ assert 0.10 <= s.leakage_pct <= 0.30
21
+
22
+ def test_task_005_root_cause(self):
23
+ s = sample_scenario("task_005", seed=42)
24
+ assert s.root_cause == RootCauseDiagnosis.BATCHNORM_EVAL_MODE
25
+ assert 0.8 <= s.red_herring_intensity <= 2.5
26
+
27
+ def test_different_seeds_produce_different_params(self):
28
+ s1 = sample_scenario("task_001", seed=42)
29
+ s2 = sample_scenario("task_001", seed=99)
30
+ # Same root cause, but may have different LR
31
+ assert s1.root_cause == s2.root_cause
32
+
33
+ def test_same_seed_same_params(self):
34
+ s1 = sample_scenario("task_001", seed=42)
35
+ s2 = sample_scenario("task_001", seed=42)
36
+ assert s1.learning_rate == s2.learning_rate
37
+ assert s1.seed == s2.seed
38
+
39
+ def test_unknown_task_raises(self):
40
+ with pytest.raises(ValueError, match="Unknown task_id"):
41
+ sample_scenario("task_999", seed=42)
42
+
43
+ def test_task_005_has_error_log(self):
44
+ s = sample_scenario("task_005", seed=42)
45
+ assert s.error_log is not None
46
+ assert "GPU memory" in s.error_log
47
+
48
+ def test_task_003_has_notes(self):
49
+ s = sample_scenario("task_003", seed=42)
50
+ assert s.notes is not None
51
+ assert "architecture" in s.notes.lower()
tests/test_simulation.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Test parametric curve generators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ml_training_debugger.scenarios import sample_scenario
6
+ from ml_training_debugger.simulation import (
7
+ gen_data_batch_stats,
8
+ gen_loss_history,
9
+ gen_val_accuracy_history,
10
+ gen_val_loss_history,
11
+ )
12
+
13
+
14
+ class TestGenLossHistory:
15
+ def test_returns_20_floats(self):
16
+ s = sample_scenario("task_001", seed=42)
17
+ hist = gen_loss_history(s)
18
+ assert len(hist) == 20
19
+ assert all(isinstance(v, float) for v in hist)
20
+
21
+ def test_task_001_diverges(self):
22
+ s = sample_scenario("task_001", seed=42)
23
+ hist = gen_loss_history(s)
24
+ assert hist[-1] == float("inf") # NaN/inf after epoch 12
25
+
26
+ def test_task_003_normal(self):
27
+ s = sample_scenario("task_003", seed=42)
28
+ hist = gen_loss_history(s)
29
+ assert hist[0] > hist[-1] # Loss decreases
30
+
31
+ def test_task_005_higher_variance(self):
32
+ s = sample_scenario("task_005", seed=42)
33
+ hist = gen_loss_history(s)
34
+ assert len(hist) == 20
35
+
36
+
37
+ class TestGenValAccuracy:
38
+ def test_returns_20_floats(self):
39
+ s = sample_scenario("task_001", seed=42)
40
+ hist = gen_val_accuracy_history(s)
41
+ assert len(hist) == 20
42
+ assert all(isinstance(v, float) for v in hist)
43
+
44
+ def test_task_003_suspiciously_high(self):
45
+ s = sample_scenario("task_003", seed=42)
46
+ hist = gen_val_accuracy_history(s)
47
+ assert hist[1] > 0.80 # Suspiciously high from early epochs
48
+
49
+ def test_task_005_degrades(self):
50
+ s = sample_scenario("task_005", seed=42)
51
+ hist = gen_val_accuracy_history(s)
52
+ assert hist[0] > hist[-1] # Degrades over time
53
+
54
+
55
+ class TestGenValLoss:
56
+ def test_returns_20_floats(self):
57
+ s = sample_scenario("task_001", seed=42)
58
+ hist = gen_val_loss_history(s)
59
+ assert len(hist) == 20
60
+
61
+
62
+ class TestGenDataBatchStats:
63
+ def test_leakage_high_overlap(self):
64
+ s = sample_scenario("task_003", seed=42)
65
+ stats = gen_data_batch_stats(s)
66
+ assert stats["class_overlap_score"] > 0.5
67
+ assert stats["duplicate_ratio"] > 0.0
68
+
69
+ def test_normal_low_overlap(self):
70
+ s = sample_scenario("task_001", seed=42)
71
+ stats = gen_data_batch_stats(s)
72
+ assert stats["class_overlap_score"] < 0.3
tests/test_simulation_extended.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Extended simulation tests for coverage gaps."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from ml_training_debugger.scenarios import sample_scenario
6
+ from ml_training_debugger.simulation import (
7
+ gen_data_batch_stats,
8
+ gen_loss_history,
9
+ gen_val_accuracy_history,
10
+ gen_val_loss_history,
11
+ )
12
+
13
+
14
+ class TestVanishingGradients:
15
+ def test_loss_barely_decreases(self):
16
+ s = sample_scenario("task_002", seed=42)
17
+ hist = gen_loss_history(s)
18
+ assert len(hist) == 20
19
+ assert abs(hist[0] - hist[-1]) < 0.5
20
+
21
+ def test_val_acc_near_random(self):
22
+ s = sample_scenario("task_002", seed=42)
23
+ hist = gen_val_accuracy_history(s)
24
+ assert all(v < 0.3 for v in hist)
25
+
26
+ def test_val_loss_flat(self):
27
+ s = sample_scenario("task_002", seed=42)
28
+ hist = gen_val_loss_history(s)
29
+ assert len(hist) == 20
30
+
31
+
32
+ class TestOverfitting:
33
+ def test_loss_decreases_to_near_zero(self):
34
+ s = sample_scenario("task_004", seed=42)
35
+ hist = gen_loss_history(s)
36
+ assert hist[-1] < 0.5
37
+
38
+ def test_val_acc_diverges(self):
39
+ s = sample_scenario("task_004", seed=42)
40
+ hist = gen_val_accuracy_history(s)
41
+ # Should rise then fall
42
+ mid = hist[len(hist) // 2]
43
+ assert mid > hist[-1] or mid > 0.3
44
+
45
+ def test_val_loss_diverges(self):
46
+ s = sample_scenario("task_004", seed=42)
47
+ hist = gen_val_loss_history(s)
48
+ assert len(hist) == 20
49
+ # Overfitting: val loss should increase in the latter half
50
+ mid_val = hist[s.divergence_epoch] if s.divergence_epoch < 20 else hist[10]
51
+ assert mid_val > 0 # Val loss is positive
52
+
53
+ def test_data_batch_stats_clean(self):
54
+ s = sample_scenario("task_004", seed=42)
55
+ stats = gen_data_batch_stats(s)
56
+ assert stats["class_overlap_score"] == 0.0
57
+ assert stats["duplicate_ratio"] == 0.0
58
+
59
+
60
+ class TestCodeBug:
61
+ def test_loss_history(self):
62
+ s = sample_scenario("task_006", seed=42)
63
+ hist = gen_loss_history(s)
64
+ assert len(hist) == 20
65
+
66
+ def test_val_acc_poor(self):
67
+ s = sample_scenario("task_006", seed=42)
68
+ hist = gen_val_accuracy_history(s)
69
+ assert len(hist) == 20
70
+
71
+ def test_val_loss(self):
72
+ s = sample_scenario("task_006", seed=42)
73
+ hist = gen_val_loss_history(s)
74
+ assert len(hist) == 20
75
+
76
+
77
+ class TestBatchNormEval:
78
+ def test_val_loss_increases(self):
79
+ s = sample_scenario("task_005", seed=42)
80
+ hist = gen_val_loss_history(s)
81
+ assert len(hist) == 20