hiitsesh commited on
Commit
0287ccf
·
1 Parent(s): 36ac8be

fix: refactor OpenAI client initialization and update API request handling

Browse files
Files changed (1) hide show
  1. inference.py +14 -21
inference.py CHANGED
@@ -2,26 +2,13 @@ import os
2
  import json
3
  import re
4
  import requests
5
- from openai import OpenAI
6
 
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
8
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
9
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
10
 
11
  ENV_BASE_URL = "http://localhost:7860"
12
 
13
- # Initialize OpenAI client
14
- client = None
15
- if API_BASE_URL and API_KEY:
16
- client = OpenAI(
17
- base_url=API_BASE_URL,
18
- api_key=API_KEY
19
- )
20
- elif API_KEY:
21
- client = OpenAI(api_key=API_KEY)
22
- else:
23
- client = OpenAI(api_key="dummy_key")
24
-
25
  SYSTEM_PROMPT = """You are an elite AI agent controlling an industrial reverse-osmosis desalination plant.
26
  Your objective: Manage the trade-offs of fresh water production against energy costs and membrane degradation, while ensuring water_salinity NEVER exceeds 450 PPM and reservoir NEVER dries out.
27
  IMPORTANT: You MUST respond ONLY with valid JSON holding exactly two keys: "production_rate" (float 0.0 to 50.0) and "run_cleaning" (boolean).
@@ -134,16 +121,22 @@ def evaluate_baseline(task_id):
134
 
135
  error_msg = "null"
136
  try:
137
- response = client.chat.completions.create(
138
- model=MODEL_NAME,
139
- messages=[
 
 
 
 
140
  {"role": "system", "content": SYSTEM_PROMPT},
141
  {"role": "user", "content": prompt}
142
  ],
143
- temperature=0.0,
144
- max_tokens=150
145
- )
146
- llm_content = response.choices[0].message.content
 
 
147
  action = parse_action(llm_content)
148
  except Exception as e:
149
  error_msg = f"'{str(e)}'"
 
2
  import json
3
  import re
4
  import requests
 
5
 
6
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
7
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or "dummy_key"
8
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
9
 
10
  ENV_BASE_URL = "http://localhost:7860"
11
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  SYSTEM_PROMPT = """You are an elite AI agent controlling an industrial reverse-osmosis desalination plant.
13
  Your objective: Manage the trade-offs of fresh water production against energy costs and membrane degradation, while ensuring water_salinity NEVER exceeds 450 PPM and reservoir NEVER dries out.
14
  IMPORTANT: You MUST respond ONLY with valid JSON holding exactly two keys: "production_rate" (float 0.0 to 50.0) and "run_cleaning" (boolean).
 
121
 
122
  error_msg = "null"
123
  try:
124
+ headers = {
125
+ "Authorization": f"Bearer {API_KEY}",
126
+ "Content-Type": "application/json"
127
+ }
128
+ payload = {
129
+ "model": MODEL_NAME,
130
+ "messages": [
131
  {"role": "system", "content": SYSTEM_PROMPT},
132
  {"role": "user", "content": prompt}
133
  ],
134
+ "temperature": 0.0,
135
+ "max_tokens": 150
136
+ }
137
+ response = requests.post(f"{API_BASE_URL.rstrip('/')}/chat/completions", headers=headers, json=payload, timeout=30)
138
+ response.raise_for_status()
139
+ llm_content = response.json()["choices"][0]["message"]["content"]
140
  action = parse_action(llm_content)
141
  except Exception as e:
142
  error_msg = f"'{str(e)}'"