Priyansh Saxena commited on
Commit
0feb6c8
·
unverified ·
0 Parent(s):

Add files via upload

Browse files
Files changed (7) hide show
  1. .gitignore +7 -0
  2. Dockerfile +15 -0
  3. README.md +81 -0
  4. inference.py +126 -0
  5. openenv.yaml +40 -0
  6. pytest.ini +2 -0
  7. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ .pytest_cache/
6
+ dist/
7
+ *.egg-info/
Dockerfile ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PORT=7860
6
+
7
+ WORKDIR /app
8
+
9
+ COPY requirements.txt .
10
+ RUN pip install --no-cache-dir -r requirements.txt
11
+
12
+ COPY . .
13
+
14
+ EXPOSE 7860
15
+ CMD ["uvicorn", "src.pytorch_debug_env.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: PyTorch Debug Env
3
+ emoji: 🔥
4
+ colorFrom: orange
5
+ colorTo: red
6
+ sdk: docker
7
+ app_port: 7860
8
+ short_description: Multi-step RL environment for diagnosing broken PyTorch training jobs
9
+ tags:
10
+ - openenv
11
+ - pytorch
12
+ - reinforcement-learning
13
+ - debugging
14
+ - ml-training
15
+ - agent
16
+ pinned: true
17
+ ---
18
+
19
+ # PyTorch Debug Env 🔥
20
+
21
+ A complete [OpenEnv](https://meta-pytorch.org/OpenEnv/) environment for the **Meta PyTorch Hackathon** where an AI agent investigates and diagnoses broken PyTorch training jobs.
22
+
23
+ ## Quick Start
24
+
25
+ ```python
26
+ from openenv import AutoEnv, AutoAction
27
+
28
+ env = AutoEnv.from_env("ArchCoder/pytorch-debug-env")
29
+ Action = AutoAction.from_env("ArchCoder/pytorch-debug-env")
30
+
31
+ with env.sync() as client:
32
+ result = client.reset(task_id="easy")
33
+ action = Action(
34
+ current_hypothesis={
35
+ "bug_type": "missing_zero_grad",
36
+ "affected_file": "train.py",
37
+ "confidence": 0.7
38
+ },
39
+ commit_diagnosis=False
40
+ )
41
+ step_result = client.step(action)
42
+ ```
43
+
44
+ ## API Endpoints
45
+
46
+ | Endpoint | Method | Description |
47
+ |----------|--------|-------------|
48
+ | `/` | GET | Environment info |
49
+ | `/health` | GET | Health check |
50
+ | `/reset?task_id=easy` | POST | Start new episode |
51
+ | `/step` | POST | Submit hypothesis + action |
52
+ | `/state` | GET | Current episode state |
53
+
54
+ ## Tasks
55
+
56
+ | Task | Difficulty | Description |
57
+ |------|-----------|-------------|
58
+ | `easy` | ⭐ | Single-file bug — missing `zero_grad`, wrong loss |
59
+ | `medium` | ⭐⭐ | Multi-file root cause — data leakage, scheduler mismatch |
60
+ | `hard` | ⭐⭐⭐ | Silent failure — memory leak, AMP overflow, red herrings |
61
+
62
+ ## Reward Structure
63
+
64
+ - **Hypothesis delta** (60%) — reward for improving your bug hypothesis each step
65
+ - **Investigation** (20%) — reward for inspecting the right files
66
+ - **Final diagnosis** (20%) — accuracy of committed diagnosis vs ground truth
67
+
68
+ Scores range from `0.0` to `1.0`. Partial credit for correct bug category on hard tasks.
69
+
70
+ ## Environment State
71
+
72
+ Each episode provides a synthetic PyTorch repo with:
73
+ - Source files (`train.py`, `model/`, `data/`, `config/`)
74
+ - Loss curves and GPU memory profiles
75
+ - Training logs with realistic noise and red herrings
76
+
77
+ The agent reveals files progressively across up to 5–6 steps, refining its hypothesis before committing a final diagnosis.
78
+
79
+ ## Author
80
+
81
+ **Priyansh Saxena** — IIIT Gwalior
inference.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+ import asyncio
3
+ import json
4
+ import os
5
+ from typing import List
6
+
7
+ from openai import OpenAI
8
+ import httpx
9
+
10
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
11
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
12
+ API_KEY = os.environ.get("OPENAI_API_KEY", "dummy")
13
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
14
+ TASK_NAME = os.environ.get("TASK_NAME", "easy")
15
+ MAX_STEPS = int(os.environ.get("MAX_STEPS", "5"))
16
+ SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7"))
17
+ MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
18
+
19
+
20
+ def log_start(task, env, model):
21
+ print(json.dumps({
22
+ "type": "START",
23
+ "task": task,
24
+ "env": env,
25
+ "model": model,
26
+ }), flush=True)
27
+
28
+
29
+ def log_step(step, action, reward, done, error):
30
+ print(json.dumps({
31
+ "type": "STEP",
32
+ "step": step,
33
+ "action": action,
34
+ "reward": float(reward),
35
+ "done": bool(done),
36
+ "error": error,
37
+ }), flush=True)
38
+
39
+
40
+ def log_end(success, steps, score, rewards):
41
+ print(json.dumps({
42
+ "type": "END",
43
+ "success": bool(success),
44
+ "steps": steps,
45
+ "score": float(score),
46
+ "rewards": [float(r) for r in rewards],
47
+ }), flush=True)
48
+
49
+
50
+ def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str:
51
+ prompt = f"""
52
+ You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema:
53
+ {{
54
+ "current_hypothesis": {{"bug_type": "<string>", "affected_file": "<string>", "confidence": <0.0-1.0>}},
55
+ "investigation_action": {{"action": "reveal_file", "target": "<filename>"}},
56
+ "commit_diagnosis": false,
57
+ "final_diagnosis": null
58
+ }}
59
+
60
+ Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic
61
+ Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion
62
+
63
+ Observation:
64
+ {json.dumps(observation)[:8000]}
65
+ History: {history}
66
+ """
67
+ completion = client.chat.completions.create(
68
+ model=MODEL_NAME,
69
+ messages=[{"role": "user", "content": prompt}],
70
+ temperature=0,
71
+ max_tokens=500,
72
+ )
73
+ return (completion.choices[0].message.content or "").strip()
74
+
75
+
76
+ async def main():
77
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
78
+ rewards = []
79
+ history = []
80
+ steps_taken = 0
81
+ score = 0.0
82
+ success = False
83
+
84
+ log_start(task=TASK_NAME, env="pytorch-debug-env", model=MODEL_NAME)
85
+
86
+ async with httpx.AsyncClient(timeout=60.0) as session:
87
+ reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": TASK_NAME})
88
+ reset_resp.raise_for_status()
89
+ result = reset_resp.json()
90
+ session_id = result.get("session_id")
91
+ observation = result["observation"]
92
+
93
+ for step in range(1, MAX_STEPS + 1):
94
+ if result.get("done"):
95
+ break
96
+
97
+ action_text = get_model_message(client, observation, history)
98
+ try:
99
+ action_json = json.loads(action_text)
100
+ step_resp = await session.post(f"{ENV_URL}/step", params={"session_id": session_id}, json=action_json)
101
+ step_resp.raise_for_status()
102
+ result = step_resp.json()
103
+ reward = result.get("reward", 0.0)
104
+ done = result.get("done", False)
105
+ error = None
106
+ observation = result["observation"]
107
+ except Exception as exc:
108
+ reward = 0.0
109
+ done = True
110
+ error = str(exc)
111
+
112
+ rewards.append(reward)
113
+ steps_taken = step
114
+ log_step(step=step, action=action_text, reward=reward, done=done, error=error)
115
+ history.append(f"step={step} reward={reward:.3f}")
116
+
117
+ if done:
118
+ break
119
+
120
+ score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
121
+ success = score >= SUCCESS_SCORE_THRESHOLD
122
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ asyncio.run(main())
openenv.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pytorch-debug-env
2
+ version: 1.0.0
3
+ description: Multi-step OpenEnv environment for diagnosing broken PyTorch training jobs.
4
+ author: Priyansh Saxena
5
+
6
+ client:
7
+ class_name: PyTorchDebugEnv
8
+ module: src.pytorch_debug_env.environment
9
+
10
+ action:
11
+ class_name: PyTorchDebugAction
12
+ module: src.pytorch_debug_env.models
13
+
14
+ observation:
15
+ class_name: PyTorchDebugObservation
16
+ module: src.pytorch_debug_env.models
17
+
18
+ default_image: pytorch-debug-env:latest
19
+ spec_version: 1
20
+
21
+ tags:
22
+ - openenv
23
+ - pytorch
24
+ - debugging
25
+ - reinforcement-learning
26
+
27
+ tasks:
28
+ - id: easy
29
+ name: Single-file bug detection
30
+ difficulty: easy
31
+ - id: medium
32
+ name: Multi-file root cause analysis
33
+ difficulty: medium
34
+ - id: hard
35
+ name: Silent failure diagnosis
36
+ difficulty: hard
37
+
38
+ runtime:
39
+ framework: fastapi
40
+ container_port: 7860
pytest.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [pytest]
2
+ asyncio_mode = auto
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ pydantic==2.9.2
4
+ numpy==2.1.1
5
+ openai==1.51.0
6
+ httpx==0.27.2
7
+ pytest==8.3.3
8
+ pytest-asyncio==0.24.0
9
+ openenv>=0.1.0