hiitsesh's picture
fix: improve grading logic to handle NaN values and enforce score boundaries
ffda6a0
from fastapi import FastAPI, HTTPException, Body
from src.models import Action, TaskConfig, ResetRequest
from src.env import DesalEnv
from src.tasks import TASKS
import subprocess
from typing import Optional
app = FastAPI(title="Advanced Municipal Desalination Plant Env")
env = DesalEnv()
@app.get("/")
def health_check():
return {"status": "ok", "message": "Advanced DesalEnv is running", "features": ["weather", "salinity", "mechanics"]}
@app.post("/reset")
def reset_env(task_id: str = "easy_spring", req: Optional[ResetRequest] = None):
# Support both GET query params and POST JSON body for task_id
if req and req.task_id != "easy_spring":
task_id = req.task_id
if task_id not in TASKS:
raise HTTPException(status_code=404, detail="Task not found")
obs = env.reset(TASKS[task_id])
return {"observation": obs.dict()}
@app.post("/step")
def step_env(action: Action):
try:
result = env.step(action)
return result.dict()
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/state")
def get_state():
if env.state is None:
raise HTTPException(status_code=400, detail="Environment not initialized")
return {"observation": env.state.dict()}
@app.get("/tasks")
def list_tasks():
return {"tasks": list(TASKS.keys()), "action_schema": Action.schema()}
@app.get("/grader")
def grader():
if env.state is None:
return {"score": 0.001}
# Grade relative to typical maximum and minimum returns to generate a 0.0-1.0 range
baseline_offset = env.config.max_steps * 1000.0 # Compensate for penalties
scale_factor = env.config.max_steps * 1500.0
try:
raw_score = float(env.total_reward + baseline_offset) / scale_factor
import math
if math.isnan(raw_score):
score = 0.001
else:
score = float(max(0.001, min(0.999, raw_score)))
except:
score = 0.001
if score >= 1.0:
score = 0.999
elif score <= 0.0:
score = 0.001
return {"score": score}
@app.post("/baseline")
def run_baseline():
result = subprocess.run(["python", "src/baseline.py"], capture_output=True, text=True)
return {"output": result.stdout}
def main():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()