arrow072 commited on
Commit
d41e4fb
·
verified ·
1 Parent(s): 22fc9d3

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +124 -126
inference.py CHANGED
@@ -1,130 +1,128 @@
1
- from fastapi import FastAPI
2
- from fastapi.responses import HTMLResponse
3
- from pydantic import BaseModel
4
- from env import TrafficEnv
5
- from tasks import get_config
6
- from baseline_agent import RuleBasedAgent
7
  import os
8
- import openai
9
-
10
- class LLMAgent:
11
- def __init__(self):
12
- try:
13
- self.client = openai.OpenAI(
14
- base_url=os.environ["API_BASE_URL"],
15
- api_key=os.environ["API_KEY"]
16
- )
17
- except Exception:
18
- self.client = None
19
- self.fallback = RuleBasedAgent()
20
-
21
- def select_action(self, state):
22
- prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
23
- try:
24
- response = self.client.chat.completions.create(
25
- model="gpt-3.5-turbo",
26
- messages=[
27
- {"role": "system", "content": "You are a traffic signal controller."},
28
- {"role": "user", "content": prompt}
29
- ],
30
- max_tokens=5,
31
- temperature=0.0
32
- )
33
- content = response.choices[0].message.content.strip()
34
- # Still call fallback to maintain its internal step counter
35
- self.fallback.select_action(state)
36
-
37
- if "1" in content:
38
- return 1
39
- else:
40
- return 0
41
- except Exception as e:
42
- return self.fallback.select_action(state)
43
-
44
- def reset(self):
45
- self.fallback.reset()
46
-
47
- app = FastAPI()
48
- env = TrafficEnv(get_config("medium"))
49
- agent = LLMAgent()
50
-
51
- class Action(BaseModel):
52
- action: int
53
-
54
- @app.get("/", response_class=HTMLResponse)
55
- def root():
56
- with open("index.html", "r", encoding="utf-8") as f:
57
- return f.read()
58
-
59
- @app.post("/reset")
60
- def reset():
61
- state = env.reset()
62
- try:
63
- state = state.tolist()
64
- except:
65
- pass
66
- agent.reset()
67
- return {"state":state}
68
-
69
- @app.post("/step")
70
- def step(data:Action):
71
- state,reward,done,info = env.step(data.action)
72
- try:
73
- state = state.tolist()
74
- except:
75
- pass
76
- return {
77
- "state":state,
78
- "reward":reward,
79
- "done":done,
80
- "info":info
81
- }
82
-
83
- @app.post("/auto_step")
84
- def auto_step():
85
- state_dict = env.get_state()
86
- action = agent.select_action(state_dict)
87
- state,reward,done,info = env.step(action)
88
  try:
89
- state = state.tolist()
90
- except:
91
- pass
92
- return {
93
- "state":state,
94
- "reward":reward,
95
- "done":done,
96
- "info":info,
97
- "action_taken": action
98
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  if __name__ == "__main__":
101
- import sys
102
- tasks_to_run = ["easy", "medium", "hard"]
103
- if len(sys.argv) > 1:
104
- # e.g., if validator optionally passes a task name as argument
105
- task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
106
- if task_arg in tasks_to_run:
107
- tasks_to_run = [task_arg]
108
-
109
- for task_name in tasks_to_run:
110
- config = get_config(task_name)
111
- eval_env = TrafficEnv(config)
112
- eval_agent = LLMAgent()
113
-
114
- state = eval_env.reset()
115
- eval_agent.reset()
116
-
117
- print("[START]", flush=True)
118
-
119
- done = False
120
- step_idx = 0
121
- total_reward = 0.0
122
-
123
- while not done:
124
- action = eval_agent.select_action(state)
125
- state, reward, done, info = eval_env.step(action)
126
- print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
127
- step_idx += 1
128
- total_reward += reward
129
-
130
- print("[END]", flush=True)
 
 
 
 
 
 
 
1
  import os
2
+ from openai import OpenAI
3
+ from env import TrafficEnv
4
+
5
+ EASY_CONFIG = {
6
+ "max_steps": 20,
7
+ "max_queue": 20,
8
+ "arrival_rate": (0, 2),
9
+ "discharge_rate": (3, 5),
10
+ "emergency_prob": 0.01,
11
+ "switch_penalty": 0.2,
12
+ "starvation_threshold": 10,
13
+ "burst_prob": 0.0,
14
+ "burst_multiplier": 1.0,
15
+ }
16
+
17
+ MEDIUM_CONFIG = {
18
+ "max_steps": 20,
19
+ "max_queue": 20,
20
+ "arrival_rate": (1, 3),
21
+ "discharge_rate": (3, 5),
22
+ "emergency_prob": 0.03,
23
+ "switch_penalty": 0.2,
24
+ "starvation_threshold": 10,
25
+ "burst_prob": 0.2,
26
+ "burst_multiplier": 1.5,
27
+ }
28
+
29
+ HARD_CONFIG = {
30
+ "max_steps": 20,
31
+ "max_queue": 20,
32
+ "arrival_rate": (2, 4),
33
+ "discharge_rate": (3, 5),
34
+ "emergency_prob": 0.05,
35
+ "switch_penalty": 0.2,
36
+ "starvation_threshold": 8,
37
+ "burst_prob": 0.35,
38
+ "burst_multiplier": 2.0,
39
+ }
40
+
41
+ def strict_score(x: float) -> float:
42
+ x = (float(x) + 1.0) / 2.0
43
+ return max(0.001, min(0.999, x))
44
+
45
+ def build_client():
46
+ api_base_url = os.environ.get("API_BASE_URL")
47
+ api_key = os.environ.get("API_KEY")
48
+ model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
49
+
50
+ if api_base_url and api_key:
51
+ client = OpenAI(base_url=api_base_url, api_key=api_key)
52
+ return client, model_name, True
53
+
54
+ return None, model_name, False
55
+
56
+ def choose_action(client, model_name, state):
57
+ prompt = f"""
58
+ You are controlling a traffic signal at a 4-way intersection.
59
+
60
+ Current state:
61
+ {state}
62
+
63
+ Available actions:
64
+ 0 = keep current signal phase
65
+ 1 = switch signal phase
66
+
67
+ Reply with only one number: 0 or 1
68
+ """.strip()
69
+
70
+ response = client.chat.completions.create(
71
+ model=model_name,
72
+ messages=[
73
+ {"role": "system", "content": "Reply with only 0 or 1."},
74
+ {"role": "user", "content": prompt},
75
+ ],
76
+ temperature=0,
77
+ )
78
+
79
+ content = response.choices[0].message.content.strip()
80
+
 
81
  try:
82
+ action = int(content)
83
+ if action not in (0, 1):
84
+ action = 0
85
+ except Exception:
86
+ action = 0
87
+
88
+ return action
89
+
90
+ def run_task(task_name, config, client, model_name, use_llm):
91
+ env = TrafficEnv(config)
92
+ state = env.reset()
93
+
94
+ print("[START]", flush=True)
95
+
96
+ done = False
97
+ step_idx = 0
98
+ total_reward = 0.0
99
+
100
+ while not done:
101
+ action = choose_action(client, model_name, state) if use_llm else 0
102
+ state, reward, done, info = env.step(action)
103
+
104
+ step_score = strict_score(reward)
105
+ print(
106
+ f"[STEP] task={task_name}, step={step_idx}, score={step_score:.3f}, done={done}",
107
+ flush=True,
108
+ )
109
+
110
+ total_reward += reward
111
+ step_idx += 1
112
+
113
+ avg_reward = total_reward / max(1, step_idx)
114
+ final_score = strict_score(avg_reward)
115
+
116
+ print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
117
 
118
  if __name__ == "__main__":
119
+ client, model_name, use_llm = build_client()
120
+
121
+ tasks = [
122
+ ("easy", EASY_CONFIG),
123
+ ("medium", MEDIUM_CONFIG),
124
+ ("hard", HARD_CONFIG),
125
+ ]
126
+
127
+ for task_name, config in tasks:
128
+ run_task(task_name, config, client, model_name, use_llm)