Rithwik Ravi commited on
Commit
8b003d5
·
1 Parent(s): 7d84930

Formatting Update

Browse files
Files changed (1) hide show
  1. inference.py +53 -46
inference.py CHANGED
@@ -5,9 +5,10 @@ from server.models import Action, BrowserGymAction # using our local Action mode
5
  from server.app import env_instance as env
6
 
7
  # Environment Configuration
8
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
9
- MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o")
10
- API_KEY = os.environ.get("OPENAI_API_KEY") or os.environ.get("HF_TOKEN", "")
 
11
 
12
  MAX_STEPS = 15
13
  TEMPERATURE = 0.2
@@ -47,7 +48,6 @@ def build_user_prompt(step: int, observation, history: list) -> str:
47
  return prompt
48
 
49
  def parse_model_action(response_text: str) -> str:
50
- # Try to parse JSON
51
  text = response_text.strip()
52
  if text.startswith("```json"): text = text[7:]
53
  if text.startswith("```"): text = text[3:]
@@ -58,35 +58,41 @@ def parse_model_action(response_text: str) -> str:
58
  data = json.loads(text)
59
  return data.get("action_str", "SELECT 1;")
60
  except json.JSONDecodeError:
61
- # Fallback if model doesn't follow json format correctly
62
  return text
63
 
64
- def run_task(task_id: int):
65
- print(f"\n{'='*50}\nStarting Task {task_id}\n{'='*50}")
66
-
67
  client = OpenAI(
68
  base_url=API_BASE_URL,
69
- api_key=API_KEY
70
  )
71
 
72
  history = []
 
73
 
74
- # Using the local env object wrapper
75
- result = env.reset(task_id=task_id)
76
- observation = result.observation
77
- print(f"Episode goal: {observation.goal}\n")
 
 
 
 
 
 
 
 
 
78
 
79
  for step in range(1, MAX_STEPS + 1):
80
- # We handle done from the step result, but for initial step we check just in case
81
  user_prompt = build_user_prompt(step, observation, history)
82
 
83
- # print("PROMPT:", user_prompt)
84
-
85
  messages = [
86
  {"role": "system", "content": SYSTEM_PROMPT},
87
  {"role": "user", "content": user_prompt},
88
  ]
89
 
 
90
  try:
91
  completion = client.chat.completions.create(
92
  model=MODEL_NAME,
@@ -94,51 +100,52 @@ def run_task(task_id: int):
94
  temperature=TEMPERATURE,
95
  max_tokens=MAX_TOKENS,
96
  stream=False,
97
- response_format={"type": "json_object"} # enforce json output
98
  )
99
  response_text = completion.choices[0].message.content or ""
100
  action_str = parse_model_action(response_text)
101
  except Exception as exc:
102
- failure_msg = f"Model request failed ({exc}). Using fallback action."
103
- print(failure_msg)
104
  action_str = "SELECT 1;"
105
 
106
- print(f"Step {step}: model suggested -> {action_str[:100]}...")
107
-
108
- # Step the environment
109
- step_result = env.step(BrowserGymAction(action_str=action_str))
110
- observation = step_result.observation
111
- reward = step_result.reward
112
- done = step_result.done
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
- error_flag = " ERROR" if observation.last_action_error else ""
115
- history_line = f"Step {step}: {action_str[:50]}... -> reward {reward:+.2f}{error_flag}"
116
  history.append(history_line)
117
-
118
- print(f" Reward: {reward:+.2f} | Done: {done} | Last action error: {observation.last_action_error}")
119
 
120
  if done:
121
- final_score = step_result.info.get("current_score", 0.0)
122
- print(f"\nEpisode complete! Final Score: {final_score}/1.0")
123
  break
124
- else:
125
- final_score = env.state().get("current_score", 0.0)
126
- print(f"\nReached max steps ({MAX_STEPS}). Final Score: {final_score}/1.0")
127
-
128
- return final_score
129
 
130
- def main():
131
- print("Testing OpenEnv Data Engineer Inference Baseline")
 
132
 
133
- if not API_KEY:
134
- print("Warning: API_KEY/HF_TOKEN not set. Will likely fail unless local LLM.")
135
 
136
- scores = {}
137
  for task_id in [1, 2, 3]:
138
- score = run_task(task_id)
139
- scores[f"Task_{task_id}"] = score
140
-
141
- print(f"\n{'*'*50}\nEVALUATION COMPLETE\n{scores}\n{'*'*50}")
142
 
143
  if __name__ == "__main__":
144
  main()
 
5
  from server.app import env_instance as env
6
 
7
  # Environment Configuration
8
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
9
+ MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o")
10
+ HF_TOKEN = os.getenv("HF_TOKEN")
11
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
12
 
13
  MAX_STEPS = 15
14
  TEMPERATURE = 0.2
 
48
  return prompt
49
 
50
  def parse_model_action(response_text: str) -> str:
 
51
  text = response_text.strip()
52
  if text.startswith("```json"): text = text[7:]
53
  if text.startswith("```"): text = text[3:]
 
58
  data = json.loads(text)
59
  return data.get("action_str", "SELECT 1;")
60
  except json.JSONDecodeError:
 
61
  return text
62
 
63
+ def run_task(task_id: int):
 
 
64
  client = OpenAI(
65
  base_url=API_BASE_URL,
66
+ api_key=HF_TOKEN
67
  )
68
 
69
  history = []
70
+ rewards = []
71
 
72
+ try:
73
+ result = env.reset(task_id=task_id)
74
+ observation = result.observation
75
+ final_score = result.info.get("initial_score", 0.0)
76
+ except Exception as e:
77
+ print(f"[START] task={task_id} env=sql-data-engineer-env model={MODEL_NAME}")
78
+ print(f"[END] success=false steps=0 score=0.00 rewards=")
79
+ return 0.0
80
+
81
+ print(f"[START] task={task_id} env=sql-data-engineer-env model={MODEL_NAME}")
82
+
83
+ done = False
84
+ step_count = 0
85
 
86
  for step in range(1, MAX_STEPS + 1):
87
+ step_count = step
88
  user_prompt = build_user_prompt(step, observation, history)
89
 
 
 
90
  messages = [
91
  {"role": "system", "content": SYSTEM_PROMPT},
92
  {"role": "user", "content": user_prompt},
93
  ]
94
 
95
+ action_str = ""
96
  try:
97
  completion = client.chat.completions.create(
98
  model=MODEL_NAME,
 
100
  temperature=TEMPERATURE,
101
  max_tokens=MAX_TOKENS,
102
  stream=False,
103
+ response_format={"type": "json_object"}
104
  )
105
  response_text = completion.choices[0].message.content or ""
106
  action_str = parse_model_action(response_text)
107
  except Exception as exc:
 
 
108
  action_str = "SELECT 1;"
109
 
110
+ try:
111
+ step_result = env.step(BrowserGymAction(action_str=action_str))
112
+ observation = step_result.observation
113
+ reward = step_result.reward
114
+ done = step_result.done
115
+ final_score = step_result.info.get("current_score", 0.0)
116
+
117
+ if observation.last_action_error:
118
+ error_msg = observation.result.replace('\n', ' ')
119
+ else:
120
+ error_msg = "null"
121
+ except Exception as e:
122
+ reward = 0.0
123
+ done = True
124
+ error_msg = str(e).replace('\n', ' ')
125
+
126
+ rewards.append(f"{reward:.2f}")
127
+
128
+ done_str = "true" if done else "false"
129
+ safe_action = action_str.replace('\n', ' ')
130
+ err_out = f'"{error_msg}"' if error_msg != "null" else "null"
131
+
132
+ print(f"[STEP] step={step} action=\"{safe_action}\" reward={reward:.2f} done={done_str} error={err_out}")
133
 
134
+ history_line = f"Step {step}: {safe_action[:50]}... -> reward {reward:+.2f}"
 
135
  history.append(history_line)
 
 
136
 
137
  if done:
 
 
138
  break
 
 
 
 
 
139
 
140
+ success_str = "true" if final_score >= 1.0 else "false"
141
+ rewards_str = ",".join(rewards)
142
+ print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}")
143
 
144
+ return final_score
 
145
 
146
+ def main():
147
  for task_id in [1, 2, 3]:
148
+ run_task(task_id)
 
 
 
149
 
150
  if __name__ == "__main__":
151
  main()