| |
| from fastapi import FastAPI, Query |
| from uuid import uuid4 |
|
|
| from .environment import PyTorchDebugEnv |
| from .models import PyTorchDebugAction |
| from .scenario_generator import ScenarioGenerator |
| from .bug_library import BUG_TEMPLATES |
|
|
| app = FastAPI(title="PyTorch Debug Env") |
|
|
| sessions = {} |
| latest_session_id = None |
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "name": "pytorch-debug-env", |
| "version": "1.0.0", |
| "endpoints": ["/reset", "/step", "/state", "/health"], |
| "tasks": ["easy", "medium", "hard"] |
| } |
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.post("/reset") |
| async def reset(task_id: str = "easy", seed: int | None = None): |
| global latest_session_id |
| session_id = str(uuid4()) |
| env = PyTorchDebugEnv(generator=ScenarioGenerator(BUG_TEMPLATES)) |
| sessions[session_id] = env |
| latest_session_id = session_id |
| obs = await env.reset(task_id=task_id, seed=seed) |
| return {"session_id": session_id, "observation": obs, "done": False} |
|
|
|
|
| @app.post("/step") |
| async def step(action: PyTorchDebugAction, session_id: str = Query(None)): |
| sid = session_id or latest_session_id |
| env = sessions.get(sid) |
| if not env: |
| return {"error": "Invalid session_id"} |
| return await env.step(action) |
|
|
|
|
| @app.get("/state") |
| async def state(session_id: str = Query(None)): |
| sid = session_id or latest_session_id |
| env = sessions.get(sid) |
| if not env: |
| return {"error": "Invalid session_id"} |
| return await env.state() |
|
|