Update inference.py
Browse files- inference.py +68 -12
inference.py
CHANGED
|
@@ -18,9 +18,12 @@ stdout log format (parsed by the OpenEnv validator)
|
|
| 18 |
HTTP endpoints (OpenEnv spec: reset / step / state)
|
| 19 |
----------------------------------------------------
|
| 20 |
GET / — UI
|
| 21 |
-
GET /health — liveness probe
|
| 22 |
-
GET /
|
| 23 |
-
GET /
|
|
|
|
|
|
|
|
|
|
| 24 |
POST /reset — start new episode
|
| 25 |
POST /step — advance one step
|
| 26 |
POST /auto_step — agent picks + steps
|
|
@@ -119,10 +122,60 @@ def root() -> str:
|
|
| 119 |
return fh.read()
|
| 120 |
|
| 121 |
|
|
|
|
| 122 |
@app.get("/health")
|
| 123 |
def health() -> dict:
|
| 124 |
-
"""Liveness probe —
|
| 125 |
-
return {"status": "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
@app.get("/tasks")
|
|
@@ -171,7 +224,7 @@ class Action(BaseModel):
|
|
| 171 |
@app.post("/step")
|
| 172 |
def step_env(data: Action) -> dict:
|
| 173 |
state, reward, done, info = _env.step(data.action)
|
| 174 |
-
score = round(max(
|
| 175 |
return {"state": state, "reward": reward, "score": score, "done": done, "info": info}
|
| 176 |
|
| 177 |
|
|
@@ -191,7 +244,7 @@ def auto_step() -> dict:
|
|
| 191 |
state_dict = _env.get_state()
|
| 192 |
action = _agent.select_action(state_dict)
|
| 193 |
state, reward, done, info = _env.step(action)
|
| 194 |
-
score = round(max(
|
| 195 |
return {"state": state, "reward": reward, "score": score,
|
| 196 |
"done": done, "info": info, "action_taken": action}
|
| 197 |
|
|
@@ -201,7 +254,6 @@ def grader() -> dict:
|
|
| 201 |
"""
|
| 202 |
Run the rule-based baseline on all 3 tasks and return per-task scores
|
| 203 |
normalised to open interval (0, 1) as required by the validator.
|
| 204 |
-
Returns flat structure: {"easy": score, "medium": score, "hard": score}
|
| 205 |
"""
|
| 206 |
results: dict = {}
|
| 207 |
for task_id in ("easy", "medium", "hard"):
|
|
@@ -222,9 +274,13 @@ def grader() -> dict:
|
|
| 222 |
steps += 1
|
| 223 |
|
| 224 |
mean_reward = total_reward / max(1, steps)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return results
|
| 229 |
|
| 230 |
|
|
@@ -260,7 +316,7 @@ if __name__ == "__main__":
|
|
| 260 |
total_reward += reward
|
| 261 |
|
| 262 |
# score: reward normalised to open interval (0, 1)
|
| 263 |
-
score = round(max(
|
| 264 |
|
| 265 |
print(
|
| 266 |
f"[STEP] step={step_idx}, score={score}, "
|
|
|
|
| 18 |
HTTP endpoints (OpenEnv spec: reset / step / state)
|
| 19 |
----------------------------------------------------
|
| 20 |
GET / — UI
|
| 21 |
+
GET /health — liveness probe ← returns {"status": "healthy"}
|
| 22 |
+
GET /metadata — env name/description ← required by validator
|
| 23 |
+
GET /schema — action/obs/state ← required by validator
|
| 24 |
+
POST /mcp — JSON-RPC 2.0 stub ← required by validator
|
| 25 |
+
GET /state — current env state (required by OpenEnv spec)
|
| 26 |
+
GET /tasks — enumerate tasks (required by validator)
|
| 27 |
POST /reset — start new episode
|
| 28 |
POST /step — advance one step
|
| 29 |
POST /auto_step — agent picks + steps
|
|
|
|
| 122 |
return fh.read()
|
| 123 |
|
| 124 |
|
| 125 |
+
# ── FIX 1: /health must return "healthy", not "ok" ──────────────────────────
|
| 126 |
@app.get("/health")
|
| 127 |
def health() -> dict:
|
| 128 |
+
"""Liveness probe — validator strictly checks status == 'healthy'."""
|
| 129 |
+
return {"status": "healthy"}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# ── FIX 2: /metadata endpoint (required by openenv-core validator) ───────────
|
| 133 |
+
@app.get("/metadata")
|
| 134 |
+
def metadata() -> dict:
|
| 135 |
+
"""Environment metadata — validator checks for 'name' and 'description' fields."""
|
| 136 |
+
return {
|
| 137 |
+
"name": "TrafficSignalOptimization-v1",
|
| 138 |
+
"description": (
|
| 139 |
+
"AI-driven Traffic Signal Optimization for a 4-way urban intersection. "
|
| 140 |
+
"An RL environment that minimises congestion, reduces average waiting time, "
|
| 141 |
+
"responds to emergency vehicles, and maintains signal stability across "
|
| 142 |
+
"three difficulty tiers: easy, medium, and hard."
|
| 143 |
+
),
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# ── FIX 3: /schema endpoint (required by openenv-core validator) ─────────────
|
| 148 |
+
@app.get("/schema")
|
| 149 |
+
def schema() -> dict:
|
| 150 |
+
"""Action / observation / state schemas — all three keys required by validator."""
|
| 151 |
+
return {
|
| 152 |
+
"action": {
|
| 153 |
+
"type": "Discrete",
|
| 154 |
+
"n": 2,
|
| 155 |
+
"description": "0 = keep current phase, 1 = switch phase",
|
| 156 |
+
},
|
| 157 |
+
"observation": {
|
| 158 |
+
"type": "Dict",
|
| 159 |
+
"keys": [
|
| 160 |
+
"north_cars", "south_cars", "east_cars", "west_cars",
|
| 161 |
+
"waiting_times", "phase", "emergency_flags", "step_count",
|
| 162 |
+
],
|
| 163 |
+
},
|
| 164 |
+
"state": {
|
| 165 |
+
"type": "Dict",
|
| 166 |
+
"keys": [
|
| 167 |
+
"north_cars", "south_cars", "east_cars", "west_cars",
|
| 168 |
+
"waiting_times", "phase", "emergency_flags", "step_count",
|
| 169 |
+
],
|
| 170 |
+
},
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ── FIX 4: /mcp endpoint (required by openenv-core validator) ────────────────
|
| 175 |
+
@app.post("/mcp")
|
| 176 |
+
def mcp(request: dict = {}) -> dict:
|
| 177 |
+
"""JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'."""
|
| 178 |
+
return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}}
|
| 179 |
|
| 180 |
|
| 181 |
@app.get("/tasks")
|
|
|
|
| 224 |
@app.post("/step")
|
| 225 |
def step_env(data: Action) -> dict:
|
| 226 |
state, reward, done, info = _env.step(data.action)
|
| 227 |
+
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
|
| 228 |
return {"state": state, "reward": reward, "score": score, "done": done, "info": info}
|
| 229 |
|
| 230 |
|
|
|
|
| 244 |
state_dict = _env.get_state()
|
| 245 |
action = _agent.select_action(state_dict)
|
| 246 |
state, reward, done, info = _env.step(action)
|
| 247 |
+
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
|
| 248 |
return {"state": state, "reward": reward, "score": score,
|
| 249 |
"done": done, "info": info, "action_taken": action}
|
| 250 |
|
|
|
|
| 254 |
"""
|
| 255 |
Run the rule-based baseline on all 3 tasks and return per-task scores
|
| 256 |
normalised to open interval (0, 1) as required by the validator.
|
|
|
|
| 257 |
"""
|
| 258 |
results: dict = {}
|
| 259 |
for task_id in ("easy", "medium", "hard"):
|
|
|
|
| 274 |
steps += 1
|
| 275 |
|
| 276 |
mean_reward = total_reward / max(1, steps)
|
| 277 |
+
score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6)
|
| 278 |
+
results[task_id] = {
|
| 279 |
+
"score": score,
|
| 280 |
+
"steps": steps,
|
| 281 |
+
"total_reward": round(total_reward, 4),
|
| 282 |
+
"info": info,
|
| 283 |
+
}
|
| 284 |
return results
|
| 285 |
|
| 286 |
|
|
|
|
| 316 |
total_reward += reward
|
| 317 |
|
| 318 |
# score: reward normalised to open interval (0, 1)
|
| 319 |
+
score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6)
|
| 320 |
|
| 321 |
print(
|
| 322 |
f"[STEP] step={step_idx}, score={score}, "
|