arrow072 commited on
Commit
0a01302
·
verified ·
1 Parent(s): d41e4fb

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +126 -124
inference.py CHANGED
@@ -1,128 +1,130 @@
1
- import os
2
- from openai import OpenAI
 
3
  from env import TrafficEnv
4
-
5
- EASY_CONFIG = {
6
- "max_steps": 20,
7
- "max_queue": 20,
8
- "arrival_rate": (0, 2),
9
- "discharge_rate": (3, 5),
10
- "emergency_prob": 0.01,
11
- "switch_penalty": 0.2,
12
- "starvation_threshold": 10,
13
- "burst_prob": 0.0,
14
- "burst_multiplier": 1.0,
15
- }
16
-
17
- MEDIUM_CONFIG = {
18
- "max_steps": 20,
19
- "max_queue": 20,
20
- "arrival_rate": (1, 3),
21
- "discharge_rate": (3, 5),
22
- "emergency_prob": 0.03,
23
- "switch_penalty": 0.2,
24
- "starvation_threshold": 10,
25
- "burst_prob": 0.2,
26
- "burst_multiplier": 1.5,
27
- }
28
-
29
- HARD_CONFIG = {
30
- "max_steps": 20,
31
- "max_queue": 20,
32
- "arrival_rate": (2, 4),
33
- "discharge_rate": (3, 5),
34
- "emergency_prob": 0.05,
35
- "switch_penalty": 0.2,
36
- "starvation_threshold": 8,
37
- "burst_prob": 0.35,
38
- "burst_multiplier": 2.0,
39
- }
40
-
41
- def strict_score(x: float) -> float:
42
- x = (float(x) + 1.0) / 2.0
43
- return max(0.001, min(0.999, x))
44
-
45
- def build_client():
46
- api_base_url = os.environ.get("API_BASE_URL")
47
- api_key = os.environ.get("API_KEY")
48
- model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
49
-
50
- if api_base_url and api_key:
51
- client = OpenAI(base_url=api_base_url, api_key=api_key)
52
- return client, model_name, True
53
-
54
- return None, model_name, False
55
-
56
- def choose_action(client, model_name, state):
57
- prompt = f"""
58
- You are controlling a traffic signal at a 4-way intersection.
59
-
60
- Current state:
61
- {state}
62
-
63
- Available actions:
64
- 0 = keep current signal phase
65
- 1 = switch signal phase
66
-
67
- Reply with only one number: 0 or 1
68
- """.strip()
69
-
70
- response = client.chat.completions.create(
71
- model=model_name,
72
- messages=[
73
- {"role": "system", "content": "Reply with only 0 or 1."},
74
- {"role": "user", "content": prompt},
75
- ],
76
- temperature=0,
77
- )
78
-
79
- content = response.choices[0].message.content.strip()
80
-
81
- try:
82
- action = int(content)
83
- if action not in (0, 1):
84
- action = 0
85
- except Exception:
86
- action = 0
87
-
88
- return action
89
-
90
- def run_task(task_name, config, client, model_name, use_llm):
91
- env = TrafficEnv(config)
92
  state = env.reset()
93
-
94
- print("[START]", flush=True)
95
-
96
- done = False
97
- step_idx = 0
98
- total_reward = 0.0
99
-
100
- while not done:
101
- action = choose_action(client, model_name, state) if use_llm else 0
102
- state, reward, done, info = env.step(action)
103
-
104
- step_score = strict_score(reward)
105
- print(
106
- f"[STEP] task={task_name}, step={step_idx}, score={step_score:.3f}, done={done}",
107
- flush=True,
108
- )
109
-
110
- total_reward += reward
111
- step_idx += 1
112
-
113
- avg_reward = total_reward / max(1, step_idx)
114
- final_score = strict_score(avg_reward)
115
-
116
- print(f"[END] task={task_name}, score={final_score:.3f}", flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
  if __name__ == "__main__":
119
- client, model_name, use_llm = build_client()
120
-
121
- tasks = [
122
- ("easy", EASY_CONFIG),
123
- ("medium", MEDIUM_CONFIG),
124
- ("hard", HARD_CONFIG),
125
- ]
126
-
127
- for task_name, config in tasks:
128
- run_task(task_name, config, client, model_name, use_llm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.responses import HTMLResponse
3
+ from pydantic import BaseModel
4
  from env import TrafficEnv
5
+ from tasks import get_config
6
+ from baseline_agent import RuleBasedAgent
7
+ import os
8
+ import openai
9
+
10
+ class LLMAgent:
11
+ def __init__(self):
12
+ try:
13
+ self.client = openai.OpenAI(
14
+ base_url=os.environ["API_BASE_URL"],
15
+ api_key=os.environ["API_KEY"]
16
+ )
17
+ except Exception:
18
+ self.client = None
19
+ self.fallback = RuleBasedAgent()
20
+
21
+ def select_action(self, state):
22
+ prompt = f"Traffic state: {state}. Reply with 1 to switch phase, 0 to keep phase. Output only 1 or 0."
23
+ try:
24
+ response = self.client.chat.completions.create(
25
+ model="gpt-3.5-turbo",
26
+ messages=[
27
+ {"role": "system", "content": "You are a traffic signal controller."},
28
+ {"role": "user", "content": prompt}
29
+ ],
30
+ max_tokens=5,
31
+ temperature=0.0
32
+ )
33
+ content = response.choices[0].message.content.strip()
34
+ # Still call fallback to maintain its internal step counter
35
+ self.fallback.select_action(state)
36
+
37
+ if "1" in content:
38
+ return 1
39
+ else:
40
+ return 0
41
+ except Exception as e:
42
+ return self.fallback.select_action(state)
43
+
44
+ def reset(self):
45
+ self.fallback.reset()
46
+
47
+ app = FastAPI()
48
+ env = TrafficEnv(get_config("medium"))
49
+ agent = LLMAgent()
50
+
51
+ class Action(BaseModel):
52
+ action: int
53
+
54
+ @app.get("/", response_class=HTMLResponse)
55
+ def root():
56
+ with open("index.html", "r", encoding="utf-8") as f:
57
+ return f.read()
58
+
59
+ @app.post("/reset")
60
+ def reset():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  state = env.reset()
62
+ try:
63
+ state = state.tolist()
64
+ except:
65
+ pass
66
+ agent.reset()
67
+ return {"state":state}
68
+
69
+ @app.post("/step")
70
+ def step(data:Action):
71
+ state,reward,done,info = env.step(data.action)
72
+ try:
73
+ state = state.tolist()
74
+ except:
75
+ pass
76
+ return {
77
+ "state":state,
78
+ "reward":reward,
79
+ "done":done,
80
+ "info":info
81
+ }
82
+
83
+ @app.post("/auto_step")
84
+ def auto_step():
85
+ state_dict = env.get_state()
86
+ action = agent.select_action(state_dict)
87
+ state,reward,done,info = env.step(action)
88
+ try:
89
+ state = state.tolist()
90
+ except:
91
+ pass
92
+ return {
93
+ "state":state,
94
+ "reward":reward,
95
+ "done":done,
96
+ "info":info,
97
+ "action_taken": action
98
+ }
99
 
100
  if __name__ == "__main__":
101
+ import sys
102
+ tasks_to_run = ["easy", "medium", "hard"]
103
+ if len(sys.argv) > 1:
104
+ # e.g., if validator optionally passes a task name as argument
105
+ task_arg = sys.argv[1].replace("--task=", "").replace("--task", "")
106
+ if task_arg in tasks_to_run:
107
+ tasks_to_run = [task_arg]
108
+
109
+ for task_name in tasks_to_run:
110
+ config = get_config(task_name)
111
+ eval_env = TrafficEnv(config)
112
+ eval_agent = LLMAgent()
113
+
114
+ state = eval_env.reset()
115
+ eval_agent.reset()
116
+
117
+ print("[START]", flush=True)
118
+
119
+ done = False
120
+ step_idx = 0
121
+ total_reward = 0.0
122
+
123
+ while not done:
124
+ action = eval_agent.select_action(state)
125
+ state, reward, done, info = eval_env.step(action)
126
+ print(f"[STEP] step={step_idx}, reward={reward}, done={done}", flush=True)
127
+ step_idx += 1
128
+ total_reward += reward
129
+
130
+ print("[END]", flush=True)