arjeet commited on
Commit
2f02c40
·
1 Parent(s): cebc6e3

inference update v3

Browse files
Files changed (1) hide show
  1. inference.py +21 -14
inference.py CHANGED
@@ -5,21 +5,20 @@ from openai import OpenAI
5
  from server.cust_env_environment import DocSweeperEnvironment
6
  from models import DocAction
7
 
8
-
9
  IMAGE_NAME = os.getenv("IMAGE_NAME")
10
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
11
 
12
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
13
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
 
14
 
15
  def run_inference(task_name: str):
16
-
17
  api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
18
  model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
19
  hf_token = os.environ.get("HF_TOKEN") or API_KEY
20
 
21
  if not api_base_url:
22
- raise ValueError("Missinh api base url")
23
  if not model_name:
24
  raise ValueError("Missing model name")
25
  if not hf_token:
@@ -32,12 +31,13 @@ def run_inference(task_name: str):
32
 
33
  env = DocSweeperEnvironment(task=task_name)
34
  obs = env.reset()
 
35
  done = False
36
  total_reward = 0.0
37
  step_count = 0
38
-
39
 
40
- print(f"[START] task={task_name} model={model_name}")
41
 
42
  system_prompt = f"""
43
  You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
@@ -102,22 +102,29 @@ def run_inference(task_name: str):
102
 
103
  action = DocAction(**safe_kwargs)
104
  obs = env.step(action)
 
105
  total_reward += obs.reward
 
106
  done = obs.done
107
 
108
-
109
- print(f"[STEP] step={step_count} action={action.tool_name} reward={obs.reward:.2f} done={done} thought=\"{thought[:100]}...\"")
 
110
 
111
  except Exception as e:
112
- obs.terminal_feedback = f"SYSTEM ERROR: {str(e)}. Review the schema rules."
113
- print(f"[STEP] step={step_count} action=error reward=0.0 done={done} error=\"{str(e)}\"")
114
-
115
- runtime = time.time() - start_time
116
-
 
117
 
118
  final_score = max(0.0, min(1.0, total_reward))
119
-
120
- print(f"[END] task={task_name} score={final_score:.2f} total_steps={step_count} runtime_seconds={runtime:.1f}")
 
 
 
121
 
122
 
123
  if __name__ == "__main__":
 
5
  from server.cust_env_environment import DocSweeperEnvironment
6
  from models import DocAction
7
 
 
8
  IMAGE_NAME = os.getenv("IMAGE_NAME")
9
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
 
11
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
12
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
13
+ BENCHMARK_NAME = "doc_sweeper"
14
 
15
  def run_inference(task_name: str):
 
16
  api_base_url = os.environ.get("API_BASE_URL") or API_BASE_URL
17
  model_name = os.environ.get("MODEL_NAME") or MODEL_NAME
18
  hf_token = os.environ.get("HF_TOKEN") or API_KEY
19
 
20
  if not api_base_url:
21
+ raise ValueError("Missing API base url")
22
  if not model_name:
23
  raise ValueError("Missing model name")
24
  if not hf_token:
 
31
 
32
  env = DocSweeperEnvironment(task=task_name)
33
  obs = env.reset()
34
+
35
  done = False
36
  total_reward = 0.0
37
  step_count = 0
38
+ rewards_history = []
39
 
40
+ print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
41
 
42
  system_prompt = f"""
43
  You are an elite, systematic documentation engineer. You interact with a virtual file system via JSON tool calls.
 
102
 
103
  action = DocAction(**safe_kwargs)
104
  obs = env.step(action)
105
+
106
  total_reward += obs.reward
107
+ rewards_history.append(obs.reward)
108
  done = obs.done
109
 
110
+ action_str = f"{action.tool_name}"
111
+ done_str = str(done).lower()
112
+ print(f"[STEP] step={step_count} action={action_str} reward={obs.reward:.2f} done={done_str} error=null", flush=True)
113
 
114
  except Exception as e:
115
+ error_msg = str(e).replace('\n', ' ')
116
+ obs.terminal_feedback = f"SYSTEM ERROR: {error_msg}. Review the schema rules."
117
+ rewards_history.append(0.0)
118
+
119
+ done_str = str(done).lower()
120
+ print(f"[STEP] step={step_count} action=error reward=0.00 done={done_str} error=\"{error_msg}\"", flush=True)
121
 
122
  final_score = max(0.0, min(1.0, total_reward))
123
+ success = final_score > 0.0 # Define what success means for your environment
124
+ success_str = str(success).lower()
125
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards_history)
126
+
127
+ print(f"[END] success={success_str} steps={step_count} score={final_score:.2f} rewards={rewards_str}", flush=True)
128
 
129
 
130
  if __name__ == "__main__":