arjeet commited on
Commit
d246804
·
1 Parent(s): a8dd45c

inference update v4

Browse files
Files changed (1) hide show
  1. inference.py +3 -5
inference.py CHANGED
@@ -8,7 +8,6 @@ from models import DocAction
8
  IMAGE_NAME = os.getenv("IMAGE_NAME")
9
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
 
11
- # Swapped back to OpenAI defaults
12
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
13
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
14
  BENCHMARK_NAME = "doc_sweeper"
@@ -25,12 +24,11 @@ def run_inference(task_name: str):
25
  if not hf_token:
26
  raise ValueError("Missing hf_token")
27
 
28
- # Replaced Groq with OpenAI, keeping the timeout fixes!
29
  client = OpenAI(
30
  api_key=hf_token,
31
  base_url=api_base_url,
32
- timeout=15.0, # Max 15 seconds per request
33
- max_retries=1 # Do not get stuck in infinite backoff loops
34
  )
35
 
36
  env = DocSweeperEnvironment(task=task_name)
@@ -40,7 +38,7 @@ def run_inference(task_name: str):
40
  total_reward = 0.0
41
  step_count = 0
42
  rewards_history = []
43
- MAX_STEPS = 20 # Hard step limit failsafe
44
 
45
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
46
 
 
8
  IMAGE_NAME = os.getenv("IMAGE_NAME")
9
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
10
 
 
11
  API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
12
  MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
13
  BENCHMARK_NAME = "doc_sweeper"
 
24
  if not hf_token:
25
  raise ValueError("Missing hf_token")
26
 
 
27
  client = OpenAI(
28
  api_key=hf_token,
29
  base_url=api_base_url,
30
+ timeout=15.0,
31
+ max_retries=1
32
  )
33
 
34
  env = DocSweeperEnvironment(task=task_name)
 
38
  total_reward = 0.0
39
  step_count = 0
40
  rewards_history = []
41
+ MAX_STEPS = 20
42
 
43
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
44