| """ |
| inference.py — Traffic Signal Optimization · OpenEnv Hackathon Submission |
| ============================================================================ |
| |
| Env variables expected by the evaluator |
| ---------------------------------------- |
| API_BASE_URL Base URL of the LLM endpoint (e.g. https://router.huggingface.co/v1) |
| MODEL_NAME Model identifier (e.g. meta-llama/Llama-3.2-3B-Instruct) |
| HF_TOKEN HuggingFace / API key |
| |
| stdout log format (parsed by the OpenEnv validator) |
| ----------------------------------------------------- |
| [START] |
| [STEP] step=0, score=0.512300, reward=0.024600, done=False |
| ... |
| [END] |
| |
| HTTP endpoints (OpenEnv spec: reset / step / state) |
| ---------------------------------------------------- |
| GET / — UI |
| GET /health — liveness probe ← returns {"status": "healthy"} |
| GET /metadata — env name/description ← required by validator |
| GET /schema — action/obs/state ← required by validator |
| POST /mcp — JSON-RPC 2.0 stub ← required by validator |
| GET /state — current env state (required by OpenEnv spec) |
| GET /tasks — enumerate tasks (required by validator) |
| POST /reset — start new episode |
| POST /step — advance one step |
| POST /auto_step — agent picks + steps |
| POST /grader — run baseline on all tasks, return scores |
| """ |
|
|
| import os |
| import sys |
|
|
| from fastapi import FastAPI |
| from fastapi.responses import HTMLResponse |
| from pydantic import BaseModel |
| from env import TrafficEnv |
| from tasks import get_config |
| from baseline_agent import RuleBasedAgent |
| import openai |
|
|
|
|
| |
| |
| |
|
|
| class LLMAgent: |
| """ |
| OpenAI-compatible LLM agent with a rule-based fallback. |
| Reads API_BASE_URL / MODEL_NAME / HF_TOKEN from the environment. |
| """ |
|
|
| def __init__(self) -> None: |
| api_base = os.environ.get("API_BASE_URL", "").strip() |
| api_key = os.environ.get("HF_TOKEN", "not-needed") |
| self.model = os.environ.get("MODEL_NAME", "gpt-3.5-turbo") |
|
|
| self.client = None |
| if api_base: |
| try: |
| self.client = openai.OpenAI(base_url=api_base, api_key=api_key) |
| except Exception: |
| self.client = None |
|
|
| self.fallback = RuleBasedAgent() |
|
|
| def select_action(self, state: dict) -> int: |
| if self.client is not None: |
| prompt = ( |
| f"Traffic intersection state:\n{state}\n\n" |
| "You control the traffic signal. Reply with ONLY 0 or 1.\n" |
| "0 = keep current green phase\n" |
| "1 = switch to the other phase" |
| ) |
| try: |
| resp = self.client.chat.completions.create( |
| model=self.model, |
| messages=[ |
| {"role": "system", "content": "You are a traffic signal controller. Output only 0 or 1."}, |
| {"role": "user", "content": prompt}, |
| ], |
| max_tokens=5, |
| temperature=0.0, |
| ) |
| content = resp.choices[0].message.content.strip() |
| self.fallback.select_action(state) |
| return 1 if "1" in content else 0 |
| except Exception: |
| pass |
| return self.fallback.select_action(state) |
|
|
| def reset(self) -> None: |
| self.fallback.reset() |
|
|
|
|
| |
| |
| |
|
|
| _env = TrafficEnv(get_config("medium")) |
| _agent = LLMAgent() |
|
|
|
|
| |
| |
| |
|
|
| app = FastAPI( |
| title="Traffic Signal Optimization — OpenEnv", |
| description="4-way intersection RL environment · Meta × PyTorch OpenEnv Hackathon", |
| version="1.0.0", |
| ) |
|
|
|
|
| |
|
|
| @app.get("/", response_class=HTMLResponse) |
| def root() -> str: |
| with open("index.html", "r", encoding="utf-8") as fh: |
| return fh.read() |
|
|
|
|
| |
| @app.get("/health") |
| def health() -> dict: |
| """Liveness probe — validator strictly checks status == 'healthy'.""" |
| return {"status": "healthy"} |
|
|
|
|
| |
| @app.get("/metadata") |
| def metadata() -> dict: |
| """Environment metadata — validator checks for 'name' and 'description' fields.""" |
| return { |
| "name": "TrafficSignalOptimization-v1", |
| "description": ( |
| "AI-driven Traffic Signal Optimization for a 4-way urban intersection. " |
| "An RL environment that minimises congestion, reduces average waiting time, " |
| "responds to emergency vehicles, and maintains signal stability across " |
| "three difficulty tiers: easy, medium, and hard." |
| ), |
| } |
|
|
|
|
| |
| @app.get("/schema") |
| def schema() -> dict: |
| """Action / observation / state schemas — all three keys required by validator.""" |
| return { |
| "action": { |
| "type": "Discrete", |
| "n": 2, |
| "description": "0 = keep current phase, 1 = switch phase", |
| }, |
| "observation": { |
| "type": "Dict", |
| "keys": [ |
| "north_cars", "south_cars", "east_cars", "west_cars", |
| "waiting_times", "phase", "emergency_flags", "step_count", |
| ], |
| }, |
| "state": { |
| "type": "Dict", |
| "keys": [ |
| "north_cars", "south_cars", "east_cars", "west_cars", |
| "waiting_times", "phase", "emergency_flags", "step_count", |
| ], |
| }, |
| } |
|
|
|
|
| |
| @app.post("/mcp") |
| def mcp(request: dict = {}) -> dict: |
| """JSON-RPC 2.0 stub — validator checks jsonrpc == '2.0'.""" |
| return {"jsonrpc": "2.0", "id": None, "result": {"status": "ok"}} |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks() -> dict: |
| """Enumerate the 3 difficulty tasks for the validator.""" |
| return { |
| "tasks": [ |
| { |
| "id": "easy", |
| "description": "Stable low-volume traffic, rare emergencies (1%)", |
| "max_steps": 50, |
| "arrival_rate": [0, 1], |
| "emergency_prob": 0.01, |
| }, |
| { |
| "id": "medium", |
| "description": "Moderate traffic with 10% burst events, 5% emergency", |
| "max_steps": 100, |
| "arrival_rate": [1, 3], |
| "emergency_prob": 0.05, |
| }, |
| { |
| "id": "hard", |
| "description": "High-intensity traffic, 20% bursts, 15% emergency, strict fairness", |
| "max_steps": 200, |
| "arrival_rate": [2, 5], |
| "emergency_prob": 0.15, |
| }, |
| ] |
| } |
|
|
|
|
| |
|
|
| @app.post("/reset") |
| def reset_env() -> dict: |
| state = _env.reset() |
| _agent.reset() |
| return {"state": state} |
|
|
|
|
| class Action(BaseModel): |
| action: int |
|
|
|
|
| @app.post("/step") |
| def step_env(data: Action) -> dict: |
| state, reward, done, info = _env.step(data.action) |
| score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) |
| return {"state": state, "reward": reward, "score": score, "done": done, "info": info} |
|
|
|
|
| @app.get("/state") |
| def get_state() -> dict: |
| """ |
| Return current environment state. |
| Required by OpenEnv spec (the reset / step / state triple). |
| """ |
| return {"state": _env.get_state()} |
|
|
|
|
| |
|
|
| @app.post("/auto_step") |
| def auto_step() -> dict: |
| state_dict = _env.get_state() |
| action = _agent.select_action(state_dict) |
| state, reward, done, info = _env.step(action) |
| score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) |
| return {"state": state, "reward": reward, "score": score, |
| "done": done, "info": info, "action_taken": action} |
|
|
|
|
| @app.post("/grader") |
| def grader() -> dict: |
| """ |
| Run the rule-based baseline on all 3 tasks and return per-task scores |
| normalised to open interval (0, 1) as required by the validator. |
| """ |
| results: dict = {} |
| for task_id in ("easy", "medium", "hard"): |
| cfg = get_config(task_id) |
| eval_env = TrafficEnv(cfg) |
| agent = RuleBasedAgent() |
| state = eval_env.reset() |
| agent.reset() |
|
|
| total_reward = 0.0 |
| steps = 0 |
| done = False |
|
|
| while not done: |
| action = agent.select_action(state) |
| state, reward, done, info = eval_env.step(action) |
| total_reward += reward |
| steps += 1 |
|
|
| mean_reward = total_reward / max(1, steps) |
| score = round(max(0.001, min(0.999, (mean_reward + 1.0) / 2.0)), 6) |
| results[task_id] = { |
| "score": score, |
| "steps": steps, |
| "total_reward": round(total_reward, 4), |
| "info": info, |
| } |
| return results |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| tasks_to_run = ["easy", "medium", "hard"] |
|
|
| if len(sys.argv) > 1: |
| raw = sys.argv[1].replace("--task=", "").replace("--task", "").strip() |
| if raw in tasks_to_run: |
| tasks_to_run = [raw] |
|
|
| for task_name in tasks_to_run: |
| config = get_config(task_name) |
| eval_env = TrafficEnv(config) |
| eval_agent = LLMAgent() |
|
|
| state = eval_env.reset() |
| eval_agent.reset() |
|
|
| print("[START]", flush=True) |
|
|
| done = False |
| step_idx = 0 |
| total_reward = 0.0 |
|
|
| while not done: |
| action = eval_agent.select_action(state) |
| state, reward, done, info = eval_env.step(action) |
| total_reward += reward |
|
|
| |
| score = round(max(0.001, min(0.999, (reward + 1.0) / 2.0)), 6) |
|
|
| print( |
| f"[STEP] step={step_idx}, score={score}, " |
| f"reward={round(reward, 6)}, done={done}", |
| flush=True, |
| ) |
| step_idx += 1 |
|
|
| print("[END]", flush=True) |
|
|