Priyansh Saxena commited on
Commit ·
0feb6c8
unverified ·
0
Parent(s):
Add files via upload
Browse files- .gitignore +7 -0
- Dockerfile +15 -0
- README.md +81 -0
- inference.py +126 -0
- openenv.yaml +40 -0
- pytest.ini +2 -0
- requirements.txt +9 -0
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
.pytest_cache/
|
| 6 |
+
dist/
|
| 7 |
+
*.egg-info/
|
Dockerfile
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 4 |
+
PYTHONUNBUFFERED=1 \
|
| 5 |
+
PORT=7860
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
EXPOSE 7860
|
| 15 |
+
CMD ["uvicorn", "src.pytorch_debug_env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: PyTorch Debug Env
|
| 3 |
+
emoji: 🔥
|
| 4 |
+
colorFrom: orange
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
short_description: Multi-step RL environment for diagnosing broken PyTorch training jobs
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- pytorch
|
| 12 |
+
- reinforcement-learning
|
| 13 |
+
- debugging
|
| 14 |
+
- ml-training
|
| 15 |
+
- agent
|
| 16 |
+
pinned: true
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
# PyTorch Debug Env 🔥
|
| 20 |
+
|
| 21 |
+
A complete [OpenEnv](https://meta-pytorch.org/OpenEnv/) environment for the **Meta PyTorch Hackathon** where an AI agent investigates and diagnoses broken PyTorch training jobs.
|
| 22 |
+
|
| 23 |
+
## Quick Start
|
| 24 |
+
|
| 25 |
+
```python
|
| 26 |
+
from openenv import AutoEnv, AutoAction
|
| 27 |
+
|
| 28 |
+
env = AutoEnv.from_env("ArchCoder/pytorch-debug-env")
|
| 29 |
+
Action = AutoAction.from_env("ArchCoder/pytorch-debug-env")
|
| 30 |
+
|
| 31 |
+
with env.sync() as client:
|
| 32 |
+
result = client.reset(task_id="easy")
|
| 33 |
+
action = Action(
|
| 34 |
+
current_hypothesis={
|
| 35 |
+
"bug_type": "missing_zero_grad",
|
| 36 |
+
"affected_file": "train.py",
|
| 37 |
+
"confidence": 0.7
|
| 38 |
+
},
|
| 39 |
+
commit_diagnosis=False
|
| 40 |
+
)
|
| 41 |
+
step_result = client.step(action)
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## API Endpoints
|
| 45 |
+
|
| 46 |
+
| Endpoint | Method | Description |
|
| 47 |
+
|----------|--------|-------------|
|
| 48 |
+
| `/` | GET | Environment info |
|
| 49 |
+
| `/health` | GET | Health check |
|
| 50 |
+
| `/reset?task_id=easy` | POST | Start new episode |
|
| 51 |
+
| `/step` | POST | Submit hypothesis + action |
|
| 52 |
+
| `/state` | GET | Current episode state |
|
| 53 |
+
|
| 54 |
+
## Tasks
|
| 55 |
+
|
| 56 |
+
| Task | Difficulty | Description |
|
| 57 |
+
|------|-----------|-------------|
|
| 58 |
+
| `easy` | ⭐ | Single-file bug — missing `zero_grad`, wrong loss |
|
| 59 |
+
| `medium` | ⭐⭐ | Multi-file root cause — data leakage, scheduler mismatch |
|
| 60 |
+
| `hard` | ⭐⭐⭐ | Silent failure — memory leak, AMP overflow, red herrings |
|
| 61 |
+
|
| 62 |
+
## Reward Structure
|
| 63 |
+
|
| 64 |
+
- **Hypothesis delta** (60%) — reward for improving your bug hypothesis each step
|
| 65 |
+
- **Investigation** (20%) — reward for inspecting the right files
|
| 66 |
+
- **Final diagnosis** (20%) — accuracy of committed diagnosis vs ground truth
|
| 67 |
+
|
| 68 |
+
Scores range from `0.0` to `1.0`. Partial credit for correct bug category on hard tasks.
|
| 69 |
+
|
| 70 |
+
## Environment State
|
| 71 |
+
|
| 72 |
+
Each episode provides a synthetic PyTorch repo with:
|
| 73 |
+
- Source files (`train.py`, `model/`, `data/`, `config/`)
|
| 74 |
+
- Loss curves and GPU memory profiles
|
| 75 |
+
- Training logs with realistic noise and red herrings
|
| 76 |
+
|
| 77 |
+
The agent reveals files progressively across up to 5–6 steps, refining its hypothesis before committing a final diagnosis.
|
| 78 |
+
|
| 79 |
+
## Author
|
| 80 |
+
|
| 81 |
+
**Priyansh Saxena** — IIIT Gwalior
|
inference.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference.py
|
| 2 |
+
import asyncio
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
from typing import List
|
| 6 |
+
|
| 7 |
+
from openai import OpenAI
|
| 8 |
+
import httpx
|
| 9 |
+
|
| 10 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 11 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-3.5-turbo")
|
| 12 |
+
API_KEY = os.environ.get("OPENAI_API_KEY", "dummy")
|
| 13 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
|
| 14 |
+
TASK_NAME = os.environ.get("TASK_NAME", "easy")
|
| 15 |
+
MAX_STEPS = int(os.environ.get("MAX_STEPS", "5"))
|
| 16 |
+
SUCCESS_SCORE_THRESHOLD = float(os.environ.get("SUCCESS_SCORE_THRESHOLD", "0.7"))
|
| 17 |
+
MAX_TOTAL_REWARD = float(os.environ.get("MAX_TOTAL_REWARD", "1.0"))
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def log_start(task, env, model):
|
| 21 |
+
print(json.dumps({
|
| 22 |
+
"type": "START",
|
| 23 |
+
"task": task,
|
| 24 |
+
"env": env,
|
| 25 |
+
"model": model,
|
| 26 |
+
}), flush=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def log_step(step, action, reward, done, error):
|
| 30 |
+
print(json.dumps({
|
| 31 |
+
"type": "STEP",
|
| 32 |
+
"step": step,
|
| 33 |
+
"action": action,
|
| 34 |
+
"reward": float(reward),
|
| 35 |
+
"done": bool(done),
|
| 36 |
+
"error": error,
|
| 37 |
+
}), flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log_end(success, steps, score, rewards):
|
| 41 |
+
print(json.dumps({
|
| 42 |
+
"type": "END",
|
| 43 |
+
"success": bool(success),
|
| 44 |
+
"steps": steps,
|
| 45 |
+
"score": float(score),
|
| 46 |
+
"rewards": [float(r) for r in rewards],
|
| 47 |
+
}), flush=True)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_model_message(client: OpenAI, observation: dict, history: List[str]) -> str:
|
| 51 |
+
prompt = f"""
|
| 52 |
+
You are debugging a PyTorch training job. Respond ONLY with valid JSON matching this exact schema:
|
| 53 |
+
{{
|
| 54 |
+
"current_hypothesis": {{"bug_type": "<string>", "affected_file": "<string>", "confidence": <0.0-1.0>}},
|
| 55 |
+
"investigation_action": {{"action": "reveal_file", "target": "<filename>"}},
|
| 56 |
+
"commit_diagnosis": false,
|
| 57 |
+
"final_diagnosis": null
|
| 58 |
+
}}
|
| 59 |
+
|
| 60 |
+
Valid action types: reveal_file, extend_loss_curve, extend_gpu_profile, reveal_log_chunk, run_diagnostic
|
| 61 |
+
Valid bug types: missing_zero_grad, data_leakage, memory_leak, learning_rate_too_high, gradient_explosion
|
| 62 |
+
|
| 63 |
+
Observation:
|
| 64 |
+
{json.dumps(observation)[:8000]}
|
| 65 |
+
History: {history}
|
| 66 |
+
"""
|
| 67 |
+
completion = client.chat.completions.create(
|
| 68 |
+
model=MODEL_NAME,
|
| 69 |
+
messages=[{"role": "user", "content": prompt}],
|
| 70 |
+
temperature=0,
|
| 71 |
+
max_tokens=500,
|
| 72 |
+
)
|
| 73 |
+
return (completion.choices[0].message.content or "").strip()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
async def main():
|
| 77 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 78 |
+
rewards = []
|
| 79 |
+
history = []
|
| 80 |
+
steps_taken = 0
|
| 81 |
+
score = 0.0
|
| 82 |
+
success = False
|
| 83 |
+
|
| 84 |
+
log_start(task=TASK_NAME, env="pytorch-debug-env", model=MODEL_NAME)
|
| 85 |
+
|
| 86 |
+
async with httpx.AsyncClient(timeout=60.0) as session:
|
| 87 |
+
reset_resp = await session.post(f"{ENV_URL}/reset", params={"task_id": TASK_NAME})
|
| 88 |
+
reset_resp.raise_for_status()
|
| 89 |
+
result = reset_resp.json()
|
| 90 |
+
session_id = result.get("session_id")
|
| 91 |
+
observation = result["observation"]
|
| 92 |
+
|
| 93 |
+
for step in range(1, MAX_STEPS + 1):
|
| 94 |
+
if result.get("done"):
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
action_text = get_model_message(client, observation, history)
|
| 98 |
+
try:
|
| 99 |
+
action_json = json.loads(action_text)
|
| 100 |
+
step_resp = await session.post(f"{ENV_URL}/step", params={"session_id": session_id}, json=action_json)
|
| 101 |
+
step_resp.raise_for_status()
|
| 102 |
+
result = step_resp.json()
|
| 103 |
+
reward = result.get("reward", 0.0)
|
| 104 |
+
done = result.get("done", False)
|
| 105 |
+
error = None
|
| 106 |
+
observation = result["observation"]
|
| 107 |
+
except Exception as exc:
|
| 108 |
+
reward = 0.0
|
| 109 |
+
done = True
|
| 110 |
+
error = str(exc)
|
| 111 |
+
|
| 112 |
+
rewards.append(reward)
|
| 113 |
+
steps_taken = step
|
| 114 |
+
log_step(step=step, action=action_text, reward=reward, done=done, error=error)
|
| 115 |
+
history.append(f"step={step} reward={reward:.3f}")
|
| 116 |
+
|
| 117 |
+
if done:
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
score = min(max(rewards[-1] if rewards else 0.0, 0.0), 1.0)
|
| 121 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 122 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
asyncio.run(main())
|
openenv.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: pytorch-debug-env
|
| 2 |
+
version: 1.0.0
|
| 3 |
+
description: Multi-step OpenEnv environment for diagnosing broken PyTorch training jobs.
|
| 4 |
+
author: Priyansh Saxena
|
| 5 |
+
|
| 6 |
+
client:
|
| 7 |
+
class_name: PyTorchDebugEnv
|
| 8 |
+
module: src.pytorch_debug_env.environment
|
| 9 |
+
|
| 10 |
+
action:
|
| 11 |
+
class_name: PyTorchDebugAction
|
| 12 |
+
module: src.pytorch_debug_env.models
|
| 13 |
+
|
| 14 |
+
observation:
|
| 15 |
+
class_name: PyTorchDebugObservation
|
| 16 |
+
module: src.pytorch_debug_env.models
|
| 17 |
+
|
| 18 |
+
default_image: pytorch-debug-env:latest
|
| 19 |
+
spec_version: 1
|
| 20 |
+
|
| 21 |
+
tags:
|
| 22 |
+
- openenv
|
| 23 |
+
- pytorch
|
| 24 |
+
- debugging
|
| 25 |
+
- reinforcement-learning
|
| 26 |
+
|
| 27 |
+
tasks:
|
| 28 |
+
- id: easy
|
| 29 |
+
name: Single-file bug detection
|
| 30 |
+
difficulty: easy
|
| 31 |
+
- id: medium
|
| 32 |
+
name: Multi-file root cause analysis
|
| 33 |
+
difficulty: medium
|
| 34 |
+
- id: hard
|
| 35 |
+
name: Silent failure diagnosis
|
| 36 |
+
difficulty: hard
|
| 37 |
+
|
| 38 |
+
runtime:
|
| 39 |
+
framework: fastapi
|
| 40 |
+
container_port: 7860
|
pytest.ini
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[pytest]
|
| 2 |
+
asyncio_mode = auto
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.0
|
| 2 |
+
uvicorn[standard]==0.30.6
|
| 3 |
+
pydantic==2.9.2
|
| 4 |
+
numpy==2.1.1
|
| 5 |
+
openai==1.51.0
|
| 6 |
+
httpx==0.27.2
|
| 7 |
+
pytest==8.3.3
|
| 8 |
+
pytest-asyncio==0.24.0
|
| 9 |
+
openenv>=0.1.0
|