| """ |
| server.py -- AuditRepairEnv++ OpenEnv Server |
| ============================================= |
| FastAPI server: /reset, /step, /state, /health |
| OpenEnv-compliant, HuggingFace-ready, port 7860. |
| """ |
|
|
| import os |
| import time |
| import uuid |
| from typing import Any, Dict, List, Optional |
|
|
| from fastapi import FastAPI, HTTPException, Request |
| from fastapi.middleware.cors import CORSMiddleware |
| from pydantic import BaseModel, Field |
|
|
| from tasks import TASK_CONFIGS, TASK_IDS, LedgerEnvironment, AuditObservation |
|
|
|
|
| |
| |
| |
|
|
| class ResetRequest(BaseModel): |
| task_id: Optional[str] = Field(default=None, description="easy | medium | hard") |
|
|
| class StepAction(BaseModel): |
| message: str = Field(..., description="Agent action text, e.g. 'FIX_ENTRY 1'") |
|
|
| class StepResponse(BaseModel): |
| observation: AuditObservation |
| reward: float |
| done: bool |
| info: Dict[str, Any] = Field(default_factory=dict) |
| last_action_error: Optional[str] = None |
|
|
| class StateResponse(BaseModel): |
| episode_id: str |
| task_id: str |
| step: int |
| max_steps: int |
| total_reward: float |
| done: bool |
| remaining_budget: int |
| initial_budget: int |
| errors_count: int |
| history: List[Dict[str, Any]] |
| started_at: float |
|
|
|
|
| |
| |
| |
|
|
| class EpisodeState: |
| def __init__(self, env: LedgerEnvironment): |
| self.episode_id = str(uuid.uuid4()) |
| self.env = env |
| self.total_reward = 0.0 |
| self.history: List[Dict[str, Any]] = [] |
| self.started_at = time.time() |
|
|
|
|
| _current_episode: Optional[EpisodeState] = None |
|
|
|
|
| |
| |
| |
|
|
| app = FastAPI(title="AuditRepairEnv++", version="1.0.0") |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.get("/", include_in_schema=False) |
| async def root(): |
| return {"name": "AuditRepairEnv++", "status": "running", "docs": "/docs", "message": "API is live."} |
|
|
|
|
| |
| |
| |
|
|
| async def _do_reset(task_id: Optional[str] = None): |
| global _current_episode |
|
|
| tid = task_id or "easy" |
| if tid not in TASK_CONFIGS: |
| raise HTTPException(400, f"Unknown task '{tid}'. Available: {TASK_IDS}") |
|
|
| config = TASK_CONFIGS[tid] |
| env = config.create_env() |
| _current_episode = EpisodeState(env) |
|
|
| obs = env.get_observation(echoed_message=f"Environment reset. Task: {config.name}") |
| return obs.model_dump() |
|
|
|
|
| @app.post("/reset") |
| async def reset_post(request: ResetRequest = ResetRequest()): |
| return await _do_reset(request.task_id) |
|
|
|
|
| @app.get("/reset") |
| async def reset_get(task_id: Optional[str] = None): |
| return await _do_reset(task_id) |
|
|
|
|
| @app.post("/step") |
| async def step(action: StepAction): |
| global _current_episode |
|
|
| if _current_episode is None: |
| raise HTTPException(400, "No active episode. Call /reset first.") |
| if _current_episode.env.done: |
| raise HTTPException(400, "Episode finished. Call /reset to start a new one.") |
|
|
| ep = _current_episode |
| result = ep.env.step_with_message(action.message) |
|
|
| reward = float(result.get("reward", 0)) |
| done = bool(result.get("done", False)) |
| error = result.get("error") |
|
|
| |
| current_score = ep.env.compute_final_score() |
| ep.total_reward = current_score |
|
|
| ep.history.append({ |
| "step": ep.env.step, |
| "action": action.message[:200], |
| "reward": reward, |
| "step_score": current_score, |
| "done": done, |
| "info": result.get("result", ""), |
| }) |
|
|
| final_score = current_score if done else None |
|
|
| return StepResponse( |
| observation=result["observation"], |
| reward=current_score, |
| done=done, |
| info={ |
| "total_reward": ep.total_reward, |
| "episode_id": ep.episode_id, |
| "result": result.get("result", ""), |
| "final_score": final_score, |
| }, |
| last_action_error=error, |
| ).model_dump() |
|
|
|
|
| @app.get("/state") |
| async def state(): |
| if _current_episode is None: |
| raise HTTPException(400, "No active episode. Call /reset first.") |
| ep = _current_episode |
| return StateResponse( |
| episode_id=ep.episode_id, |
| task_id=ep.env.task_id, |
| step=ep.env.step, |
| max_steps=ep.env.max_steps, |
| total_reward=ep.total_reward, |
| done=ep.env.done, |
| remaining_budget=ep.env.remaining_budget, |
| initial_budget=ep.env.initial_budget, |
| errors_count=len(ep.env.get_errors()), |
| history=ep.history, |
| started_at=ep.started_at, |
| ).model_dump() |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return { |
| "status": "ok", |
| "environment": "AuditRepairEnv++", |
| "tasks": TASK_IDS, |
| } |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |