Omkar1806 commited on
Commit
b37fd8a
Β·
verified Β·
1 Parent(s): c5fdb67

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +20 -15
inference.py CHANGED
@@ -7,17 +7,19 @@ from env import EmailTriageEnv
7
  from app import smart_agent_logic
8
 
9
 
10
- # βœ… MUST use these EXACT env vars
11
  API_BASE_URL = os.environ.get("API_BASE_URL")
12
  API_KEY = os.environ.get("API_KEY")
13
  MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
 
15
- TASK_NAME = os.getenv("MY_ENV_V4_TASK", "easy")
16
  BENCHMARK = "email_triage_env"
17
 
18
  MAX_STEPS = 20
19
  SUCCESS_SCORE_THRESHOLD = 0.5
20
 
 
 
 
21
 
22
  def log_start(task: str, env: str, model: str) -> None:
23
  print(f"[START] task={task} env={env} model={model}", flush=True)
@@ -42,14 +44,7 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
42
  )
43
 
44
 
45
- def main():
46
- # βœ… REQUIRED: Initialize OpenAI client with provided proxy
47
- try:
48
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
49
- except Exception as e:
50
- print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
51
- client = None
52
-
53
  env = EmailTriageEnv(task=TASK_NAME)
54
 
55
  rewards: List[float] = []
@@ -69,9 +64,9 @@ def main():
69
  try:
70
  desc = state["description"]
71
 
72
- # βœ… πŸ”₯ LLM CALL (MANDATORY)
73
  action_list = None
74
 
 
75
  if client:
76
  try:
77
  response = client.chat.completions.create(
@@ -91,17 +86,15 @@ def main():
91
  )
92
 
93
  text = response.choices[0].message.content.strip()
94
-
95
- # Parse response
96
  action_list = [int(x) for x in text.replace(",", " ").split()[:3]]
97
 
98
  if len(action_list) != 3:
99
- raise ValueError("Invalid LLM output")
100
 
101
  except Exception as llm_error:
102
  print(f"[DEBUG] LLM failed: {llm_error}", flush=True)
103
 
104
- # βœ… fallback if LLM fails
105
  if not action_list:
106
  action_list = smart_agent_logic(desc)
107
 
@@ -129,5 +122,17 @@ def main():
129
  log_end(success, steps_taken, score, rewards)
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if __name__ == "__main__":
133
  main()
 
7
  from app import smart_agent_logic
8
 
9
 
10
+ # βœ… REQUIRED env vars
11
  API_BASE_URL = os.environ.get("API_BASE_URL")
12
  API_KEY = os.environ.get("API_KEY")
13
  MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
 
 
15
  BENCHMARK = "email_triage_env"
16
 
17
  MAX_STEPS = 20
18
  SUCCESS_SCORE_THRESHOLD = 0.5
19
 
20
+ # βœ… RUN ALL TASKS
21
+ TASKS = ["easy", "medium", "hard"]
22
+
23
 
24
  def log_start(task: str, env: str, model: str) -> None:
25
  print(f"[START] task={task} env={env} model={model}", flush=True)
 
44
  )
45
 
46
 
47
+ def run_task(client, TASK_NAME):
 
 
 
 
 
 
 
48
  env = EmailTriageEnv(task=TASK_NAME)
49
 
50
  rewards: List[float] = []
 
64
  try:
65
  desc = state["description"]
66
 
 
67
  action_list = None
68
 
69
+ # βœ… LLM CALL
70
  if client:
71
  try:
72
  response = client.chat.completions.create(
 
86
  )
87
 
88
  text = response.choices[0].message.content.strip()
 
 
89
  action_list = [int(x) for x in text.replace(",", " ").split()[:3]]
90
 
91
  if len(action_list) != 3:
92
+ raise ValueError()
93
 
94
  except Exception as llm_error:
95
  print(f"[DEBUG] LLM failed: {llm_error}", flush=True)
96
 
97
+ # fallback
98
  if not action_list:
99
  action_list = smart_agent_logic(desc)
100
 
 
122
  log_end(success, steps_taken, score, rewards)
123
 
124
 
125
+ def main():
126
+ try:
127
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
128
+ except Exception as e:
129
+ print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
130
+ client = None
131
+
132
+ # βœ… RUN ALL TASKS
133
+ for task in TASKS:
134
+ run_task(client, task)
135
+
136
+
137
  if __name__ == "__main__":
138
  main()