Spaces:
Sleeping
Sleeping
arjeet commited on
Commit ·
d246804
1
Parent(s): a8dd45c
inference update v4
Browse files- inference.py +3 -5
inference.py
CHANGED
|
@@ -8,7 +8,6 @@ from models import DocAction
|
|
| 8 |
IMAGE_NAME = os.getenv("IMAGE_NAME")
|
| 9 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 10 |
|
| 11 |
-
# Swapped back to OpenAI defaults
|
| 12 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
|
| 13 |
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
|
| 14 |
BENCHMARK_NAME = "doc_sweeper"
|
|
@@ -25,12 +24,11 @@ def run_inference(task_name: str):
|
|
| 25 |
if not hf_token:
|
| 26 |
raise ValueError("Missing hf_token")
|
| 27 |
|
| 28 |
-
# Replaced Groq with OpenAI, keeping the timeout fixes!
|
| 29 |
client = OpenAI(
|
| 30 |
api_key=hf_token,
|
| 31 |
base_url=api_base_url,
|
| 32 |
-
timeout=15.0,
|
| 33 |
-
max_retries=1
|
| 34 |
)
|
| 35 |
|
| 36 |
env = DocSweeperEnvironment(task=task_name)
|
|
@@ -40,7 +38,7 @@ def run_inference(task_name: str):
|
|
| 40 |
total_reward = 0.0
|
| 41 |
step_count = 0
|
| 42 |
rewards_history = []
|
| 43 |
-
MAX_STEPS = 20
|
| 44 |
|
| 45 |
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
|
| 46 |
|
|
|
|
| 8 |
IMAGE_NAME = os.getenv("IMAGE_NAME")
|
| 9 |
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 10 |
|
|
|
|
| 11 |
API_BASE_URL = os.getenv("API_BASE_URL") or "https://api.openai.com/v1"
|
| 12 |
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
|
| 13 |
BENCHMARK_NAME = "doc_sweeper"
|
|
|
|
| 24 |
if not hf_token:
|
| 25 |
raise ValueError("Missing hf_token")
|
| 26 |
|
|
|
|
| 27 |
client = OpenAI(
|
| 28 |
api_key=hf_token,
|
| 29 |
base_url=api_base_url,
|
| 30 |
+
timeout=15.0,
|
| 31 |
+
max_retries=1
|
| 32 |
)
|
| 33 |
|
| 34 |
env = DocSweeperEnvironment(task=task_name)
|
|
|
|
| 38 |
total_reward = 0.0
|
| 39 |
step_count = 0
|
| 40 |
rewards_history = []
|
| 41 |
+
MAX_STEPS = 20
|
| 42 |
|
| 43 |
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={model_name}", flush=True)
|
| 44 |
|