Priyansh Saxena commited on
Commit
8097081
·
0 Parent(s):

feat: complete files

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ .pytest_cache/
6
+ dist/
7
+ *.egg-info/
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PORT=7860
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ EXPOSE 7860
15
+ CMD ["uvicorn", "src.pytorch_debug_env.server:app", "--host", "0.0.0.0", "--port", "7860"]
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import asyncio
3
+ import json
4
+ import os
5
+ from typing import List
6
+
7
+ from openai import OpenAI
8
+ import httpx
9
+
10
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
11
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
12
+ API_KEY = os.environ.get("OPENAI_API_KEY", "dummy")
13
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
14
+ TASK_NAME = os.environ.get("TASK_NAME", "easy")
15
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "5"))
16
+ SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7"))
17
+ MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
18
+
19
+
20
+ def log_start(task, env, model):
21
+ print(json.dumps({
22
+ "type": "START",
23
+ "task": task,
24
+ "env": env,
25
+ "model": model,
26
+ }), flush=True)
27
+
28
+
29
+ def log_step(step, action, reward, done, error):
30
+ print(json.dumps({
31
+ "type": "STEP",
32
+ "step": step,
33
+ "action": action,
34
+ "reward": float(reward),
35
+ "done": bool(done),
36
+ "error": error,
37
+ }), flush=True)
38
+
39
+
40
+ def log_end(success, steps, score, rewards):
41
+ print(json.dumps({
42
+ "type": "END",
43
+ "success": bool(success),
44
+ "steps": steps,
45
+ "score": float(score),
46
+ "rewards": [float(r) for r in rewards],
47
+ }), flush=True)
48
+
49
+
50
+ def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str:
51
+ prompt = f"""
52
+ You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema:
53
+ {{
54
+ "current_hypothesis": {{"bug_type": "<string>", "affected_file": "<string>", "confidence": <0.0-1.0>}},
55
+ "investigation_action": {{"action": "reveal_file", "target": "<filename>"}},
56
+ "commit_diagnosis": false,
57
+ "final_diagnosis": null
58
+ }}
59
+
60
+ Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic
61
+ Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion
62
+
63
+ Observation:
64
+ {json.dumps(observation)[:8000]}
65
+ History: {history}
66
+ """
67
+ completion = client.chat.completions.create(
68
+ model=MODEL_NAME,
69
+ messages=[{"role": "user", "content": prompt}],
70
+ temperature=0,
71
+ max_tokens=500,
72
+ )
73
+ return (completion.choices[0].message.content or "").strip()
74
+
75
+
76
+ async def main():
77
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
78
+ rewards = []
79
+ history = []
80
+ steps_taken = 0
81
+ score = 0.0
82
+ success = False
83
+
84
+ log_start(task=TASK_NAME, env="pytorch-debug-env", model=MODEL_NAME)
85
+
86
+ async with httpx.AsyncClient(timeout=60.0) as session:
87
+ reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": TASK_NAME})
88
+ reset_resp.raise_for_status()
89
+ result = reset_resp.json()
90
+ session_id = result.get("session_id")
91
+ observation = result["observation"]
92
+
93
+ for step in range(1, MAX_STEPS + 1):
94
+ if result.get("done"):
95
+ break
96
+
97
+ action_text = get_model_message(client, observation, history)
98
+ try:
99
+ action_json = json.loads(action_text)
100
+ step_resp = await session.post(f"{ENV_URL}/step", params={"session_id": session_id}, json=action_json)
101
+ step_resp.raise_for_status()
102
+ result = step_resp.json()
103
+ reward = result.get("reward", 0.0)
104
+ done = result.get("done", False)
105
+ error = None
106
+ observation = result["observation"]
107
+ except Exception as exc:
108
+ reward = 0.0
109
+ done = True
110
+ error = str(exc)
111
+
112
+ rewards.append(reward)
113
+ steps_taken = step
114
+ log_step(step=step, action=action_text, reward=reward, done=done, error=error)
115
+ history.append(f"step={step} reward={reward:.3f}")
116
+
117
+ if done:
118
+ break
119
+
120
+ score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
121
+ success = score >= SUCCESS_SCORE_THRESHOLD
122
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ asyncio.run(main())
openenv.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pytorch-debug-env
2
+ version: 1.0.0
3
+ description: Multi-step OpenEnv environment for diagnosing broken PyTorch training jobs.
4
+ author: Priyansh Saxena
5
+
6
+ client:
7
+ class_name: PyTorchDebugEnv
8
+ module: src.pytorch_debug_env.environment
9
+
10
+ action:
11
+ class_name: PyTorchDebugAction
12
+ module: src.pytorch_debug_env.models
13
+
14
+ observation:
15
+ class_name: PyTorchDebugObservation
16
+ module: src.pytorch_debug_env.models
17
+
18
+ default_image: pytorch-debug-env:latest
19
+ spec_version: 1
20
+
21
+ tags:
22
+ - openenv
23
+ - pytorch
24
+ - debugging
25
+ - reinforcement-learning
26
+
27
+ tasks:
28
+ - id: easy
29
+ name: Single-file bug detection
30
+ difficulty: easy
31
+ - id: medium
32
+ name: Multi-file root cause analysis
33
+ difficulty: medium
34
+ - id: hard
35
+ name: Silent failure diagnosis
36
+ difficulty: hard
37
+
38
+ runtime:
39
+ framework: fastapi
40
+ container_port: 7860
pytest.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.9.2
4
+ numpy==2.1.1
5
+ openai==1.51.0
6
+ httpx==0.27.2
7
+ pytest==8.3.3
8
+ pytest-asyncio==0.24.0
9
+ openenv>=0.1.0
scenarios/seeds.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
src/pytorch_debug_env/__init__.py ADDED
File without changes
src/pytorch_debug_env/bug_library.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/bug_library.py
2
+ from dataclasses import dataclass, field
3
+ from typing import Callable, Dict, List, Optional
4
+ import numpy as np
5
+
6
+
7
+ @dataclass
8
+ class BugTemplate:
9
+ bug_type: str
10
+ category: str
11
+ difficulty: str
12
+ primary_bug_file: str
13
+ related_files: List[str]
14
+ red_herring_file: Optional[str]
15
+ fix_strategy: str
16
+ line_range: List[int]
17
+ description: str
18
+ artifact_generator: Callable
19
+ repo_mutator: Callable
20
+ metadata: Dict = field(default_factory=dict)
21
+
22
+
23
+ BUG_CATEGORIES = {
24
+ "shape_mismatch": "model",
25
+ "missing_zero_grad": "optimization",
26
+ "wrong_loss_function": "optimization",
27
+ "learning_rate_too_high": "optimization",
28
+ "gradient_explosion": "optimization",
29
+ "memory_leak": "resource",
30
+ "data_leakage": "data",
31
+ "incorrect_normalization": "data",
32
+ "distributed_sync_error": "distributed",
33
+ "amp_overflow": "numerics",
34
+ }
35
+
36
+ # Realistic artifact generator
37
+ def dummy_artifact_generator(artifact_type: str, rng):
38
+ if artifact_type == "loss_curve":
39
+ t = np.arange(100)
40
+ base = 2.3 * np.exp(-0.01 * t) + 0.15
41
+ oscillation = 0.22 * np.sin(0.25 * t) * np.exp(-0.002 * t)
42
+ return [
43
+ {"step": int(i), "train_loss": float(base[i] + oscillation[i])}
44
+ for i in range(100)
45
+ ]
46
+ elif artifact_type == "gpu_profile":
47
+ t = np.arange(100)
48
+ allocated = 2048 + 2.4 * t
49
+ return [
50
+ {"step": int(i), "allocated_mb": float(allocated[i])}
51
+ for i in range(100)
52
+ ]
53
+ elif artifact_type == "training_log":
54
+ return "Epoch 1, Step 0: loss 2.45\nEpoch 1, Step 1: loss 2.43\n"
55
+ return []
56
+
57
+ def mutate_missing_zero_grad(repo_files, rng):
58
+ repo_files["train.py"] = """import torch
59
+ from model.architecture import Net
60
+
61
+ model = Net()
62
+ optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
63
+ criterion = torch.nn.CrossEntropyLoss()
64
+
65
+ for epoch in range(10):
66
+ for x, y in dataloader:
67
+ # optimizer.zero_grad() # BUG: commented out
68
+ output = model(x)
69
+ loss = criterion(output, y)
70
+ loss.backward()
71
+ optimizer.step()
72
+ """
73
+ return repo_files
74
+
75
+ def mutate_data_leakage(repo_files, rng):
76
+ repo_files["data/dataset.py"] = """from torch.utils.data import Dataset
77
+
78
+ class ImageDataset(Dataset):
79
+ def __init__(self, data, split="train"):
80
+ # BUG: We use the entire data instead of just the split
81
+ self.data = data
82
+ self.split = split
83
+ """
84
+ return repo_files
85
+
86
+ def mutate_memory_leak(repo_files, rng):
87
+ repo_files["data/dataset.py"] = """from torch.utils.data import Dataset
88
+
89
+ class ImageDataset(Dataset):
90
+ def __init__(self):
91
+ # BUG: Storing huge tensors in a class-level variable leading to memory accumulation
92
+ self.cache = []
93
+
94
+ def load(self, x):
95
+ self.cache.append(x)
96
+ return x
97
+ """
98
+ return repo_files
99
+
100
+ BUG_TEMPLATES = [
101
+ BugTemplate(
102
+ bug_type="missing_zero_grad",
103
+ category="optimization",
104
+ difficulty="easy",
105
+ primary_bug_file="train.py",
106
+ related_files=[],
107
+ red_herring_file="model/architecture.py",
108
+ fix_strategy="Call optimizer.zero_grad() before loss.backward()",
109
+ line_range=[9, 14],
110
+ description="Missing zero grad",
111
+ artifact_generator=dummy_artifact_generator,
112
+ repo_mutator=mutate_missing_zero_grad,
113
+ ),
114
+ BugTemplate(
115
+ bug_type="data_leakage",
116
+ category="data",
117
+ difficulty="medium",
118
+ primary_bug_file="data/dataset.py",
119
+ related_files=["data/preprocessing.py"],
120
+ red_herring_file="train.py",
121
+ fix_strategy="Ensure validation split is strictly separate from training",
122
+ line_range=[4, 6],
123
+ description="Data leakage",
124
+ artifact_generator=dummy_artifact_generator,
125
+ repo_mutator=mutate_data_leakage,
126
+ ),
127
+ BugTemplate(
128
+ bug_type="memory_leak",
129
+ category="resource",
130
+ difficulty="hard",
131
+ primary_bug_file="data/dataset.py",
132
+ related_files=["train.py"],
133
+ red_herring_file="model/attention.py",
134
+ fix_strategy="Avoid holding reference to tensors in class cache",
135
+ line_range=[5, 9],
136
+ description="Memory leak",
137
+ artifact_generator=dummy_artifact_generator,
138
+ repo_mutator=mutate_memory_leak,
139
+ )
140
+ ]
src/pytorch_debug_env/environment.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/environment.py
2
+ from __future__ import annotations
3
+
4
+ from dataclasses import dataclass, field
5
+ from typing import Dict, List
6
+
7
+ from .models import (
8
+ HypothesisRecord,
9
+ PyTorchDebugAction,
10
+ PyTorchDebugObservation,
11
+ PyTorchDebugState,
12
+ )
13
+ from .reward import compute_step_reward
14
+ from .scenario_generator import ScenarioGenerator
15
+ from .graders import grade_easy, grade_medium, grade_hard
16
+
17
+ GRADER_MAP = {"easy": grade_easy, "medium": grade_medium, "hard": grade_hard}
18
+
19
+
20
+ @dataclass
21
+ class RuntimeState:
22
+ scenario: object | None = None
23
+ max_steps: int = 5
24
+ current_step: int = 0
25
+ revealed_files: List[str] = field(default_factory=list)
26
+ hypothesis_history: List[HypothesisRecord] = field(default_factory=list)
27
+ done: bool = False
28
+ final_score: float = 0.0
29
+
30
+
31
+ class PyTorchDebugEnv:
32
+ def __init__(self, generator: ScenarioGenerator, max_steps: int = 5):
33
+ self.generator = generator
34
+ self.runtime = RuntimeState(max_steps=max_steps)
35
+
36
+ async def reset(self, task_id: str = "easy"):
37
+ scenario = self.generator.generate(task_id)
38
+ self.runtime = RuntimeState(
39
+ scenario=scenario,
40
+ max_steps=5 if task_id == "easy" else 6,
41
+ current_step=0,
42
+ revealed_files=["train.py", "config/training_config.yaml"],
43
+ hypothesis_history=[],
44
+ done=False,
45
+ final_score=0.0,
46
+ )
47
+ return self._build_observation(last_feedback="Episode reset.")
48
+
49
+ async def step(self, action: PyTorchDebugAction):
50
+ if self.runtime.scenario is None:
51
+ raise RuntimeError("Call /reset before /step")
52
+
53
+ if self.runtime.done:
54
+ raise RuntimeError("Episode already completed")
55
+
56
+ self.runtime.current_step += 1
57
+ scenario = self.runtime.scenario
58
+ previous_quality = self.runtime.hypothesis_history[-1].quality if self.runtime.hypothesis_history else 0.0
59
+
60
+ investigation_target = None
61
+ if action.investigation_action and action.investigation_action.action == "reveal_file":
62
+ investigation_target = action.investigation_action.target
63
+ if investigation_target in scenario.repo_files and investigation_target not in self.runtime.revealed_files:
64
+ self.runtime.revealed_files.append(investigation_target)
65
+
66
+ committed = action.final_diagnosis.model_dump() if action.commit_diagnosis and action.final_diagnosis else None
67
+ reward, components = compute_step_reward(
68
+ previous_quality=previous_quality,
69
+ current_hypothesis=action.current_hypothesis.model_dump(),
70
+ ground_truth=scenario.ground_truth,
71
+ investigation_target=investigation_target,
72
+ committed_diagnosis=None, # Temporarily don't compute diagnosis reward here to use grader
73
+ step_num=self.runtime.current_step,
74
+ max_steps=self.runtime.max_steps,
75
+ )
76
+
77
+ if committed:
78
+ grader = GRADER_MAP.get(scenario.task_id, grade_easy)
79
+ diagnosis_reward = grader(committed, scenario.ground_truth)
80
+
81
+ # Combine the diagnosis reward logic from `compute_step_reward` that applies on top
82
+ if diagnosis_reward > 0.7:
83
+ diagnosis_reward += max(0.0, 0.08 * (self.runtime.max_steps - self.runtime.current_step))
84
+
85
+ # Update the total reward incorporating diagnosis
86
+ components["diagnosis_reward"] = round(diagnosis_reward, 4)
87
+ delta = components["hypothesis_delta"]
88
+ inv_reward = components["investigation_reward"]
89
+ conf_bonus = components["confirmation_bonus"]
90
+
91
+ total = 0.60 * delta + 0.20 * inv_reward + 0.20 * diagnosis_reward + conf_bonus
92
+ reward = round(min(max(total, 0.0), 1.0), 4)
93
+
94
+ self.runtime.hypothesis_history.append(
95
+ HypothesisRecord(
96
+ step=self.runtime.current_step,
97
+ hypothesis=action.current_hypothesis,
98
+ quality=components["hypothesis_quality"],
99
+ )
100
+ )
101
+
102
+ if action.commit_diagnosis or self.runtime.current_step >= self.runtime.max_steps:
103
+ self.runtime.done = True
104
+ self.runtime.final_score = reward
105
+
106
+ observation = self._build_observation(
107
+ last_feedback=self._feedback(action, scenario.ground_truth)
108
+ )
109
+ return {
110
+ "observation": observation,
111
+ "reward": reward,
112
+ "done": self.runtime.done,
113
+ "info": components,
114
+ }
115
+
116
+ async def state(self):
117
+ scenario = self.runtime.scenario
118
+ if not scenario:
119
+ return None
120
+ return PyTorchDebugState(
121
+ scenario_id=scenario.scenario_id,
122
+ task_id=scenario.task_id,
123
+ max_steps=self.runtime.max_steps,
124
+ current_step=self.runtime.current_step,
125
+ revealed_files=self.runtime.revealed_files,
126
+ remaining_files=[
127
+ f for f in scenario.repo_files.keys() if f not in self.runtime.revealed_files
128
+ ],
129
+ done=self.runtime.done,
130
+ final_score=self.runtime.final_score,
131
+ )
132
+
133
+ def _build_observation(self, last_feedback: str) -> PyTorchDebugObservation:
134
+ scenario = self.runtime.scenario
135
+ revealed = {k: v for k, v in scenario.repo_files.items() if k in self.runtime.revealed_files}
136
+ available = [k for k in scenario.repo_files.keys() if k not in self.runtime.revealed_files]
137
+
138
+ loss_window_size = min(len(scenario.loss_curve), 100 * (self.runtime.current_step + 1))
139
+ gpu_window_size = min(len(scenario.gpu_profile), 100 * (self.runtime.current_step + 1))
140
+ log_lines = scenario.training_log.splitlines()
141
+ visible_log = "\n".join(log_lines[-min(len(log_lines), 10 * (self.runtime.current_step + 1)):])
142
+
143
+ return PyTorchDebugObservation(
144
+ scenario_id=scenario.scenario_id,
145
+ task_id=scenario.task_id,
146
+ revealed_files=revealed,
147
+ available_files=available,
148
+ loss_curve_window=scenario.loss_curve[:loss_window_size],
149
+ gpu_profile_window=scenario.gpu_profile[:gpu_window_size],
150
+ training_log_tail=visible_log,
151
+ step_num=self.runtime.current_step,
152
+ steps_remaining=max(0, self.runtime.max_steps - self.runtime.current_step),
153
+ investigation_budget=max(0, self.runtime.max_steps - self.runtime.current_step),
154
+ hypothesis_history=self.runtime.hypothesis_history,
155
+ last_feedback=last_feedback,
156
+ )
157
+
158
+ def _feedback(self, action: PyTorchDebugAction, gt: Dict) -> str:
159
+ suspected_file = action.current_hypothesis.affected_file
160
+ suspected_bug = action.current_hypothesis.bug_type
161
+
162
+ if suspected_file == gt.get("red_herring_file"):
163
+ return "That file contains a plausible symptom, but not the root cause. Investigate upstream causes."
164
+ if suspected_file == gt["primary_bug_file"] and suspected_bug != gt["bug_type"]:
165
+ return "Correct region, wrong failure mode. Re-check the training artifacts more carefully."
166
+ if suspected_bug == gt["bug_type"] and suspected_file != gt["primary_bug_file"]:
167
+ return "The bug class looks right, but the faulty implementation is in another file."
168
+ return "Continue refining the hypothesis using newly revealed evidence."
src/pytorch_debug_env/graders.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/graders.py
2
+ from __future__ import annotations
3
+
4
+ from .reward import final_diagnosis_score
5
+
6
+
7
+ def grade_easy(action: dict, gt: dict) -> float:
8
+ return final_diagnosis_score(action, gt)
9
+
10
+
11
+ def grade_medium(action: dict, gt: dict) -> float:
12
+ score = final_diagnosis_score(action, gt)
13
+ if action.get("affected_file") in gt.get("related_files", []):
14
+ score = min(1.0, score + 0.05)
15
+ return round(score, 4)
16
+
17
+
18
+ def grade_hard(action: dict, gt: dict) -> float:
19
+ score = final_diagnosis_score(action, gt)
20
+
21
+ # partial credit if model gets the right category on subtle bugs
22
+ if score < 0.2 and action.get("bug_type"):
23
+ if gt.get("category"):
24
+ from .bug_library import BUG_CATEGORIES
25
+ if BUG_CATEGORIES.get(action["bug_type"]) == gt["category"]:
26
+ score = max(score, 0.18)
27
+
28
+ if action.get("affected_file") == gt.get("red_herring_file"):
29
+ score = max(0.0, score - 0.1)
30
+
31
+ return round(min(score, 1.0), 4)
src/pytorch_debug_env/models.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/models.py
2
+ from __future__ import annotations
3
+
4
+ from typing import Dict, List, Literal, Optional
5
+ from pydantic import BaseModel, Field
6
+
7
+
8
+ class Hypothesis(BaseModel):
9
+ bug_type: str = Field(..., description="Current suspected bug type")
10
+ affected_file: str = Field(..., description="Current suspected file")
11
+ confidence: float = Field(..., ge=0.0, le=1.0)
12
+
13
+
14
+ class InvestigationAction(BaseModel):
15
+ action: Literal[
16
+ "reveal_file",
17
+ "extend_loss_curve",
18
+ "extend_gpu_profile",
19
+ "reveal_log_chunk",
20
+ "run_diagnostic",
21
+ ]
22
+ target: Optional[str] = None
23
+
24
+
25
+ class FinalDiagnosis(BaseModel):
26
+ bug_type: str
27
+ affected_file: str
28
+ line_range: List[int]
29
+ fix_strategy: str
30
+ confidence: float = Field(..., ge=0.0, le=1.0)
31
+
32
+
33
+ class PyTorchDebugAction(BaseModel):
34
+ current_hypothesis: Hypothesis
35
+ investigation_action: Optional[InvestigationAction] = None
36
+ commit_diagnosis: bool = False
37
+ final_diagnosis: Optional[FinalDiagnosis] = None
38
+
39
+
40
+ class HypothesisRecord(BaseModel):
41
+ step: int
42
+ hypothesis: Hypothesis
43
+ quality: float
44
+
45
+
46
+ class PyTorchDebugObservation(BaseModel):
47
+ scenario_id: str
48
+ task_id: str
49
+ revealed_files: Dict[str, str]
50
+ available_files: List[str]
51
+ loss_curve_window: List[Dict]
52
+ gpu_profile_window: List[Dict]
53
+ training_log_tail: str
54
+ step_num: int
55
+ steps_remaining: int
56
+ investigation_budget: int
57
+ hypothesis_history: List[HypothesisRecord]
58
+ last_feedback: str
59
+
60
+
61
+ class PyTorchDebugState(BaseModel):
62
+ scenario_id: str
63
+ task_id: str
64
+ max_steps: int
65
+ current_step: int
66
+ revealed_files: List[str]
67
+ remaining_files: List[str]
68
+ done: bool
69
+ final_score: float = 0.0
70
+
71
+
72
+ class PyTorchDebugReward(BaseModel):
73
+ value: float = Field(..., ge=0.0, le=1.0)
74
+ components: Dict[str, float]
src/pytorch_debug_env/reward.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/reward.py
2
+ from __future__ import annotations
3
+
4
+ from .bug_library import BUG_CATEGORIES
5
+
6
+
7
+ def hypothesis_quality(hypothesis: dict, ground_truth: dict) -> float:
8
+ q = 0.0
9
+
10
+ if hypothesis.get("affected_file") == ground_truth["primary_bug_file"]:
11
+ q += 0.45
12
+ elif hypothesis.get("affected_file") in ground_truth.get("related_files", []):
13
+ q += 0.15
14
+
15
+ if hypothesis.get("bug_type") == ground_truth["bug_type"]:
16
+ q += 0.40
17
+ elif BUG_CATEGORIES.get(hypothesis.get("bug_type")) == BUG_CATEGORIES.get(ground_truth["bug_type"]):
18
+ q += 0.13
19
+
20
+ calibration = 1.0 - abs(hypothesis.get("confidence", 0.5) - min(q, 1.0))
21
+ q += 0.15 * calibration
22
+ return round(min(q, 1.0), 4)
23
+
24
+
25
+ def final_diagnosis_score(diagnosis: dict, ground_truth: dict) -> float:
26
+ score = 0.0
27
+
28
+ if diagnosis.get("bug_type") == ground_truth["bug_type"]:
29
+ score += 0.40
30
+ if diagnosis.get("affected_file") == ground_truth["primary_bug_file"]:
31
+ score += 0.25
32
+
33
+ predicted = diagnosis.get("line_range", [0, 0])
34
+ actual = ground_truth.get("line_range", [0, 0])
35
+ overlap = line_overlap(predicted, actual)
36
+ score += 0.20 * overlap
37
+
38
+ if diagnosis.get("fix_strategy") == ground_truth["fix_strategy"]:
39
+ score += 0.15
40
+
41
+ return round(min(score, 1.0), 4)
42
+
43
+
44
+ def line_overlap(pred: list[int], actual: list[int]) -> float:
45
+ p1, p2 = pred
46
+ a1, a2 = actual
47
+ inter = max(0, min(p2, a2) - max(p1, a1) + 1)
48
+ union = max(p2, a2) - min(p1, a1) + 1
49
+ return inter / union if union else 0.0
50
+
51
+
52
+ def compute_step_reward(
53
+ previous_quality: float,
54
+ current_hypothesis: dict,
55
+ ground_truth: dict,
56
+ investigation_target: str | None = None,
57
+ committed_diagnosis: dict | None = None,
58
+ step_num: int = 1,
59
+ max_steps: int = 5,
60
+ ) -> tuple[float, dict]:
61
+ current_quality = hypothesis_quality(current_hypothesis, ground_truth)
62
+ delta = current_quality - previous_quality
63
+
64
+ confirmation_bonus = 0.03 * current_quality if abs(delta) < 0.01 else 0.0
65
+
66
+ investigation_reward = 0.0
67
+ if investigation_target:
68
+ if investigation_target == ground_truth["primary_bug_file"]:
69
+ investigation_reward = 0.07
70
+ elif investigation_target in ground_truth.get("related_files", []):
71
+ investigation_reward = 0.025
72
+ elif investigation_target == ground_truth.get("red_herring_file"):
73
+ investigation_reward = -0.04
74
+ else:
75
+ investigation_reward = -0.01
76
+
77
+ diagnosis_reward = 0.0
78
+ if committed_diagnosis:
79
+ diagnosis_reward = final_diagnosis_score(committed_diagnosis, ground_truth)
80
+ if diagnosis_reward > 0.7:
81
+ diagnosis_reward += max(0.0, 0.08 * (max_steps - step_num))
82
+
83
+ total = 0.60 * delta + 0.20 * investigation_reward + 0.20 * diagnosis_reward + confirmation_bonus
84
+ total = round(min(max(total, -0.2), 1.0), 4)
85
+
86
+ return total, {
87
+ "hypothesis_quality": current_quality,
88
+ "hypothesis_delta": round(delta, 4),
89
+ "investigation_reward": round(investigation_reward, 4),
90
+ "diagnosis_reward": round(diagnosis_reward, 4),
91
+ "confirmation_bonus": round(confirmation_bonus, 4),
92
+ }
src/pytorch_debug_env/scenario_generator.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/scenario_generator.py
2
+ from __future__ import annotations
3
+
4
+ import random
5
+ import uuid
6
+ from dataclasses import dataclass
7
+ from typing import Dict, List
8
+
9
+ import numpy as np
10
+
11
+ from .bug_library import BugTemplate
12
+
13
+
14
+ @dataclass
15
+ class Scenario:
16
+ scenario_id: str
17
+ task_id: str
18
+ repo_files: Dict[str, str]
19
+ loss_curve: List[Dict]
20
+ gpu_profile: List[Dict]
21
+ training_log: str
22
+ ground_truth: Dict
23
+
24
+
25
+ class ScenarioGenerator:
26
+ def __init__(self, bug_templates: List[BugTemplate]):
27
+ self.bug_templates = bug_templates
28
+
29
+ def generate(self, difficulty: str, seed: int | None = None) -> Scenario:
30
+ rng = random.Random(seed)
31
+ template = rng.choice([b for b in self.bug_templates if b.difficulty == difficulty])
32
+
33
+ repo_files = self._base_repo(rng)
34
+ repo_files = template.repo_mutator(repo_files, rng)
35
+
36
+ loss_curve = template.artifact_generator("loss_curve", rng)
37
+ gpu_profile = template.artifact_generator("gpu_profile", rng)
38
+ training_log = template.artifact_generator("training_log", rng)
39
+
40
+ ground_truth = {
41
+ "bug_type": template.bug_type,
42
+ "category": template.category,
43
+ "primary_bug_file": template.primary_bug_file,
44
+ "related_files": template.related_files,
45
+ "red_herring_file": template.red_herring_file,
46
+ "fix_strategy": template.fix_strategy,
47
+ "line_range": template.line_range,
48
+ }
49
+
50
+ return Scenario(
51
+ scenario_id=str(uuid.uuid4())[:8],
52
+ task_id=difficulty,
53
+ repo_files=repo_files,
54
+ loss_curve=loss_curve,
55
+ gpu_profile=gpu_profile,
56
+ training_log=training_log,
57
+ ground_truth=ground_truth,
58
+ )
59
+
60
+ def _base_repo(self, rng: random.Random) -> Dict[str, str]:
61
+ return {
62
+ "train.py": self._train_py(),
63
+ "model/architecture.py": self._model_py(),
64
+ "model/attention.py": self._attention_py(),
65
+ "data/dataset.py": self._dataset_py(),
66
+ "data/preprocessing.py": self._preprocess_py(),
67
+ "config/training_config.yaml": self._config_yaml(),
68
+ }
69
+
70
+ def _train_py(self) -> str:
71
+ return """import torch\nfrom model.architecture import Net\n\n# training loop placeholder\n"""
72
+
73
+ def _model_py(self) -> str:
74
+ return """import torch.nn as nn\n\nclass Net(nn.Module):\n def __init__(self):\n super().__init__()\n"""
75
+
76
+ def _attention_py(self) -> str:
77
+ return """# custom attention layer\n"""
78
+
79
+ def _dataset_py(self) -> str:
80
+ return """from torch.utils.data import Dataset\n\nclass ImageDataset(Dataset):\n pass\n"""
81
+
82
+ def _preprocess_py(self) -> str:
83
+ return """def normalize(x):\n return x\n"""
84
+
85
+ def _config_yaml(self) -> str:
86
+ return "lr: 0.001\nbatch_size: 32\n"
src/pytorch_debug_env/server.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # src/pytorch_debug_env/server.py
2
+ from fastapi import FastAPI, Query
3
+ from uuid import uuid4
4
+
5
+ from .environment import PyTorchDebugEnv
6
+ from .models import PyTorchDebugAction
7
+ from .scenario_generator import ScenarioGenerator
8
+ from .bug_library import BUG_TEMPLATES
9
+
10
+ app = FastAPI(title="PyTorch Debug Env")
11
+
12
+ sessions = {}
13
+ latest_session_id = None
14
+
15
+ @app.get("/")
16
+ async def root():
17
+ return {
18
+ "name": "pytorch-debug-env",
19
+ "version": "1.0.0",
20
+ "endpoints": ["/reset", "/step", "/state", "/health"],
21
+ "tasks": ["easy", "medium", "hard"]
22
+ }
23
+
24
+ @app.get("/health")
25
+ async def health():
26
+ return {"status": "ok"}
27
+
28
+
29
+ @app.post("/reset")
30
+ async def reset(task_id: str = "easy"):
31
+ global latest_session_id
32
+ session_id = str(uuid4())
33
+ env = PyTorchDebugEnv(generator=ScenarioGenerator(BUG_TEMPLATES))
34
+ sessions[session_id] = env
35
+ latest_session_id = session_id
36
+ obs = await env.reset(task_id=task_id)
37
+ return {"session_id": session_id, "observation": obs, "done": False}
38
+
39
+
40
+ @app.post("/step")
41
+ async def step(action: PyTorchDebugAction, session_id: str = Query(None)):
42
+ sid = session_id or latest_session_id
43
+ env = sessions.get(sid)
44
+ if not env:
45
+ return {"error": "Invalid session_id"}
46
+ return await env.step(action)
47
+
48
+
49
+ @app.get("/state")
50
+ async def state(session_id: str = Query(None)):
51
+ sid = session_id or latest_session_id
52
+ env = sessions.get(sid)
53
+ if not env:
54
+ return {"error": "Invalid session_id"}
55
+ return await env.state()
tests/conftest.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import pytest
2
+
3
+ @pytest.fixture
4
+ def anyio_backend():
5
+ return "asyncio"
tests/test_environment.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_environment.py
2
+ import pytest
3
+ from src.pytorch_debug_env.environment import PyTorchDebugEnv
4
+ from src.pytorch_debug_env.scenario_generator import ScenarioGenerator
5
+ from src.pytorch_debug_env.bug_library import BUG_TEMPLATES
6
+
7
+ @pytest.mark.asyncio
8
+ async def test_env_reset():
9
+ generator = ScenarioGenerator(BUG_TEMPLATES)
10
+ env = PyTorchDebugEnv(generator=generator)
11
+ obs = await env.reset("easy")
12
+ assert obs.task_id == "easy"
tests/test_graders.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_graders.py
2
+ from src.pytorch_debug_env.graders import grade_easy
3
+
4
+ def test_grade_easy():
5
+ gt = {
6
+ "bug_type": "missing_zero_grad",
7
+ "primary_bug_file": "train.py",
8
+ "related_files": [],
9
+ "line_range": [10, 15],
10
+ "fix_strategy": "Call optimizer.zero_grad() before loss.backward()",
11
+ }
12
+ action = {
13
+ "bug_type": "missing_zero_grad",
14
+ "affected_file": "train.py",
15
+ "line_range": [10, 15],
16
+ "fix_strategy": "Call optimizer.zero_grad() before loss.backward()",
17
+ "confidence": 0.8
18
+ }
19
+ assert grade_easy(action, gt) > 0.8
tests/test_reward.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # tests/test_reward.py
2
+ from src.pytorch_debug_env.reward import hypothesis_quality
3
+
4
+
5
+ def test_hypothesis_quality_exact_match():
6
+ gt = {
7
+ "bug_type": "missing_zero_grad",
8
+ "primary_bug_file": "train.py",
9
+ "related_files": [],
10
+ }
11
+ hyp = {
12
+ "bug_type": "missing_zero_grad",
13
+ "affected_file": "train.py",
14
+ "confidence": 0.8,
15
+ }
16
+ assert hypothesis_quality(hyp, gt) > 0.8