Spaces:
Running
Running
Added Salesforce endpoint
Browse files- gen_api_answer.py +51 -11
gen_api_answer.py
CHANGED
|
@@ -23,7 +23,7 @@ together_client = Together()
|
|
| 23 |
hf_api_key = os.getenv("HF_API_KEY")
|
| 24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
| 25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
| 26 |
-
|
| 27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
| 28 |
"""Get response from OpenAI API"""
|
| 29 |
try:
|
|
@@ -195,6 +195,36 @@ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, m
|
|
| 195 |
except Exception as e:
|
| 196 |
return f"Error with Cohere model {model_name}: {str(e)}"
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
def get_model_response(
|
| 199 |
model_name,
|
| 200 |
model_info,
|
|
@@ -210,24 +240,25 @@ def get_model_response(
|
|
| 210 |
api_model = model_info["api_model"]
|
| 211 |
organization = model_info["organization"]
|
| 212 |
|
| 213 |
-
# Determine if model is Prometheus
|
| 214 |
is_prometheus = (organization == "Prometheus")
|
| 215 |
is_atla = (organization == "Atla")
|
| 216 |
is_flow_judge = (organization == "Flow AI")
|
| 217 |
-
|
| 218 |
-
|
|
|
|
|
|
|
| 219 |
|
| 220 |
# Select the appropriate base prompt
|
| 221 |
-
|
| 222 |
-
if is_atla:
|
| 223 |
base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
|
| 224 |
elif is_flow_judge:
|
| 225 |
base_prompt = FLOW_JUDGE_PROMPT
|
| 226 |
else:
|
| 227 |
base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
|
| 228 |
|
| 229 |
-
# For non-Prometheus/non-Atla models, replace the
|
| 230 |
-
if not (is_prometheus or is_atla or is_flow_judge):
|
| 231 |
base_prompt = base_prompt.replace(
|
| 232 |
'3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
|
| 233 |
'3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
|
|
@@ -247,7 +278,6 @@ def get_model_response(
|
|
| 247 |
score4_desc=prompt_data['score4_desc'],
|
| 248 |
score5_desc=prompt_data['score5_desc']
|
| 249 |
)
|
| 250 |
-
|
| 251 |
else:
|
| 252 |
human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
|
| 253 |
ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
|
|
@@ -300,8 +330,13 @@ def get_model_response(
|
|
| 300 |
)
|
| 301 |
elif organization == "Flow AI":
|
| 302 |
return get_flow_judge_response(
|
| 303 |
-
api_model, final_prompt
|
| 304 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
else:
|
| 306 |
# All other organizations use Together API
|
| 307 |
return get_together_response(
|
|
@@ -324,7 +359,12 @@ def parse_model_response(response):
|
|
| 324 |
data = json.loads(response)
|
| 325 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
| 326 |
except json.JSONDecodeError:
|
| 327 |
-
# If that fails
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
| 329 |
if json_match:
|
| 330 |
data = json.loads(json_match.group(0))
|
|
|
|
| 23 |
hf_api_key = os.getenv("HF_API_KEY")
|
| 24 |
flow_judge_api_key = os.getenv("FLOW_JUDGE_API_KEY")
|
| 25 |
cohere_client = cohere.ClientV2(os.getenv("CO_API_KEY"))
|
| 26 |
+
salesforce_api_key = os.getenv("SALESFORCE_API_KEY")
|
| 27 |
def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
|
| 28 |
"""Get response from OpenAI API"""
|
| 29 |
try:
|
|
|
|
| 195 |
except Exception as e:
|
| 196 |
return f"Error with Cohere model {model_name}: {str(e)}"
|
| 197 |
|
| 198 |
+
def get_salesforce_response(model_name, prompt, system_prompt=None, max_tokens=2048, temperature=0):
|
| 199 |
+
"""Get response from Salesforce Research API"""
|
| 200 |
+
try:
|
| 201 |
+
headers = {
|
| 202 |
+
'accept': 'application/json',
|
| 203 |
+
"content-type": "application/json",
|
| 204 |
+
"X-Api-Key": salesforce_api_key,
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
# Create messages list
|
| 208 |
+
messages = []
|
| 209 |
+
messages.append({"role": "user", "content": prompt})
|
| 210 |
+
|
| 211 |
+
json_data = {
|
| 212 |
+
"prompts": messages,
|
| 213 |
+
"temperature": temperature,
|
| 214 |
+
"top_p": 1,
|
| 215 |
+
"max_tokens": max_tokens,
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
response = requests.post(
|
| 219 |
+
'https://gateway.salesforceresearch.ai/sfr-judge/process',
|
| 220 |
+
headers=headers,
|
| 221 |
+
json=json_data
|
| 222 |
+
)
|
| 223 |
+
response.raise_for_status()
|
| 224 |
+
return response.json()['result'][0]
|
| 225 |
+
except Exception as e:
|
| 226 |
+
return f"Error with Salesforce model {model_name}: {str(e)}"
|
| 227 |
+
|
| 228 |
def get_model_response(
|
| 229 |
model_name,
|
| 230 |
model_info,
|
|
|
|
| 240 |
api_model = model_info["api_model"]
|
| 241 |
organization = model_info["organization"]
|
| 242 |
|
| 243 |
+
# Determine if model is Prometheus, Atla, Flow Judge, or Salesforce
|
| 244 |
is_prometheus = (organization == "Prometheus")
|
| 245 |
is_atla = (organization == "Atla")
|
| 246 |
is_flow_judge = (organization == "Flow AI")
|
| 247 |
+
is_salesforce = (organization == "Salesforce")
|
| 248 |
+
|
| 249 |
+
# For non-Prometheus/Atla/Flow Judge/Salesforce models, use the Judge system prompt
|
| 250 |
+
system_prompt = None if (is_prometheus or is_atla or is_flow_judge or is_salesforce) else JUDGE_SYSTEM_PROMPT
|
| 251 |
|
| 252 |
# Select the appropriate base prompt
|
| 253 |
+
if is_atla or is_salesforce: # Use same prompt for Atla and Salesforce
|
|
|
|
| 254 |
base_prompt = ATLA_PROMPT_WITH_REFERENCE if use_reference else ATLA_PROMPT
|
| 255 |
elif is_flow_judge:
|
| 256 |
base_prompt = FLOW_JUDGE_PROMPT
|
| 257 |
else:
|
| 258 |
base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE if use_reference else PROMETHEUS_PROMPT
|
| 259 |
|
| 260 |
+
# For non-Prometheus/non-Atla/non-Salesforce models, use Prometheus but replace the output format with JSON
|
| 261 |
+
if not (is_prometheus or is_atla or is_flow_judge or is_salesforce):
|
| 262 |
base_prompt = base_prompt.replace(
|
| 263 |
'3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
|
| 264 |
'3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
|
|
|
|
| 278 |
score4_desc=prompt_data['score4_desc'],
|
| 279 |
score5_desc=prompt_data['score5_desc']
|
| 280 |
)
|
|
|
|
| 281 |
else:
|
| 282 |
human_input = f"<user_input>\n{prompt_data['human_input']}\n</user_input>"
|
| 283 |
ai_response = f"<response>\n{prompt_data['ai_response']}\n</response>"
|
|
|
|
| 330 |
)
|
| 331 |
elif organization == "Flow AI":
|
| 332 |
return get_flow_judge_response(
|
| 333 |
+
api_model, final_prompt
|
| 334 |
)
|
| 335 |
+
elif organization == "Salesforce":
|
| 336 |
+
response = get_salesforce_response(
|
| 337 |
+
api_model, final_prompt, system_prompt, max_tokens, temperature
|
| 338 |
+
)
|
| 339 |
+
return response
|
| 340 |
else:
|
| 341 |
# All other organizations use Together API
|
| 342 |
return get_together_response(
|
|
|
|
| 359 |
data = json.loads(response)
|
| 360 |
return str(data.get("result", "N/A")), data.get("feedback", "N/A")
|
| 361 |
except json.JSONDecodeError:
|
| 362 |
+
# If that fails, check if this is a Salesforce response (which uses ATLA format)
|
| 363 |
+
if "**Reasoning:**" in response or "**Result:**" in response:
|
| 364 |
+
# Use ATLA parser for Salesforce responses
|
| 365 |
+
return atla_parse_model_response(response)
|
| 366 |
+
|
| 367 |
+
# Otherwise try to find JSON within the response
|
| 368 |
json_match = re.search(r"{.*}", response, re.DOTALL)
|
| 369 |
if json_match:
|
| 370 |
data = json.loads(json_match.group(0))
|