eeshwar143 commited on
Commit
deed111
·
1 Parent(s): 522cbe5

Support both validator proxy key env vars

Browse files
Files changed (2) hide show
  1. .env.example +1 -0
  2. inference.py +39 -5
.env.example CHANGED
@@ -1,5 +1,6 @@
1
  API_BASE_URL=https://api.openai.com/v1
2
  MODEL_NAME=gpt-4o-mini
 
3
  HF_TOKEN=
4
  LOCAL_IMAGE_NAME=
5
  ENV_BASE_URL=http://127.0.0.1:8000
 
1
  API_BASE_URL=https://api.openai.com/v1
2
  MODEL_NAME=gpt-4o-mini
3
+ API_KEY=
4
  HF_TOKEN=
5
  LOCAL_IMAGE_NAME=
6
  ENV_BASE_URL=http://127.0.0.1:8000
inference.py CHANGED
@@ -41,6 +41,7 @@ API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
41
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
42
  API_KEY = os.getenv("API_KEY")
43
  HF_TOKEN = os.getenv("HF_TOKEN")
 
44
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
45
  ENV_BASE_URL = os.getenv("ENV_BASE_URL")
46
 
@@ -69,19 +70,51 @@ def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> No
69
 
70
 
71
  def create_openai_client() -> Any:
72
- # The validator checks that model traffic goes through its injected proxy key.
73
- # Keep HF_TOKEN defined for environment compatibility, but do not use it here.
74
- if not API_KEY:
75
  return None
76
 
77
  if OpenAI is not None:
78
- return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
79
 
80
  openai_module.api_base = API_BASE_URL
81
- openai_module.api_key = API_KEY
82
  return openai_module
83
 
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  def get_model_message(
86
  client: Any,
87
  step: int,
@@ -328,6 +361,7 @@ async def main() -> None:
328
  env: SupportQueueEnv | None = None
329
 
330
  try:
 
331
  env = await build_env()
332
  for task in tasks:
333
  results.append(await run_task(client, env, task))
 
41
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
42
  API_KEY = os.getenv("API_KEY")
43
  HF_TOKEN = os.getenv("HF_TOKEN")
44
+ PROXY_API_KEY = API_KEY or HF_TOKEN
45
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
46
  ENV_BASE_URL = os.getenv("ENV_BASE_URL")
47
 
 
70
 
71
 
72
  def create_openai_client() -> Any:
73
+ # Support both the newer API_KEY contract and the earlier HF_TOKEN contract.
74
+ # In either case, all traffic still goes through API_BASE_URL.
75
+ if not PROXY_API_KEY:
76
  return None
77
 
78
  if OpenAI is not None:
79
+ return OpenAI(base_url=API_BASE_URL, api_key=PROXY_API_KEY)
80
 
81
  openai_module.api_base = API_BASE_URL
82
+ openai_module.api_key = PROXY_API_KEY
83
  return openai_module
84
 
85
 
86
+ def warmup_model_client(client: Any) -> None:
87
+ if client is None:
88
+ print("[DEBUG] No API_KEY/HF_TOKEN found; skipping model warmup.", flush=True)
89
+ return
90
+
91
+ try:
92
+ if hasattr(client, "chat") and hasattr(client.chat, "completions"):
93
+ client.chat.completions.create(
94
+ model=MODEL_NAME,
95
+ messages=[
96
+ {"role": "system", "content": "Reply with ok."},
97
+ {"role": "user", "content": "ok"},
98
+ ],
99
+ temperature=0.0,
100
+ max_tokens=2,
101
+ stream=False,
102
+ )
103
+ else:
104
+ client.ChatCompletion.create(
105
+ model=MODEL_NAME,
106
+ messages=[
107
+ {"role": "system", "content": "Reply with ok."},
108
+ {"role": "user", "content": "ok"},
109
+ ],
110
+ temperature=0.0,
111
+ max_tokens=2,
112
+ stream=False,
113
+ )
114
+ except Exception as exc:
115
+ print(f"[DEBUG] Model warmup failed: {exc}", flush=True)
116
+
117
+
118
  def get_model_message(
119
  client: Any,
120
  step: int,
 
361
  env: SupportQueueEnv | None = None
362
 
363
  try:
364
+ warmup_model_client(client)
365
  env = await build_env()
366
  for task in tasks:
367
  results.append(await run_task(client, env, task))