kaikaidai commited on
Commit
c7a9dfe
1 Parent(s): 65ba9f3

Organise prompts

Browse files
Files changed (1) hide show
  1. gen_api_answer.py +46 -16
gen_api_answer.py CHANGED
@@ -6,6 +6,11 @@ import json
6
  import re
7
  import os
8
  import requests
 
 
 
 
 
9
 
10
  # Initialize clients
11
  anthropic_client = anthropic.Anthropic()
@@ -18,10 +23,6 @@ huggingface_client = OpenAI(
18
  api_key=hf_api_key
19
  )
20
 
21
- JUDGE_SYSTEM_PROMPT = """Please act as an impartial judge and evaluate based on the user's instruction. Your output format should strictly adhere to JSON as follows: {"feedback": "<write feedback>", "result": <numerical score>}. Ensure the output is valid JSON, without additional formatting or explanations."""
22
-
23
- ALTERNATIVE_JUDGE_SYSTEM_PROMPT = """Please act as an impartial judge and evaluate based on the user's instruction."""
24
-
25
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
26
  """Get response from OpenAI API"""
27
  try:
@@ -119,8 +120,8 @@ def get_cohere_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, m
119
  def get_model_response(
120
  model_name,
121
  model_info,
122
- prompt,
123
- use_alternative_prompt=False,
124
  max_tokens=500,
125
  temperature=0
126
  ):
@@ -131,33 +132,62 @@ def get_model_response(
131
  api_model = model_info["api_model"]
132
  organization = model_info["organization"]
133
 
134
- # Select the appropriate system prompt
135
- if use_alternative_prompt:
136
- system_prompt = ALTERNATIVE_JUDGE_SYSTEM_PROMPT
 
 
 
 
 
 
137
  else:
138
- system_prompt = JUDGE_SYSTEM_PROMPT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  try:
141
  if organization == "OpenAI":
142
  return get_openai_response(
143
- api_model, prompt, system_prompt, max_tokens, temperature
144
  )
145
  elif organization == "Anthropic":
146
  return get_anthropic_response(
147
- api_model, prompt, system_prompt, max_tokens, temperature
148
  )
149
  elif organization == "Prometheus":
150
  return get_hf_response(
151
- api_model, prompt, max_tokens
152
  )
153
  elif organization == "Cohere":
154
  return get_cohere_response(
155
- api_model, prompt, system_prompt, max_tokens, temperature
156
  )
157
  else:
158
  # All other organizations use Together API
159
  return get_together_response(
160
- api_model, prompt, system_prompt, max_tokens, temperature
161
  )
162
  except Exception as e:
163
  return f"Error with {organization} model {model_name}: {str(e)}"
@@ -185,7 +215,7 @@ def parse_model_response(response):
185
  print(f"Failed to parse response: {str(e)}")
186
  return "Error", f"Failed to parse response: {response}"
187
 
188
- def alternative_parse_model_response(output):
189
  try:
190
  print(f"Raw model response: {output}")
191
  output = output.strip()
 
6
  import re
7
  import os
8
  import requests
9
+ from prompts import (
10
+ JUDGE_SYSTEM_PROMPT,
11
+ PROMETHEUS_PROMPT,
12
+ PROMETHEUS_PROMPT_WITH_REFERENCE,
13
+ )
14
 
15
  # Initialize clients
16
  anthropic_client = anthropic.Anthropic()
 
23
  api_key=hf_api_key
24
  )
25
 
 
 
 
 
26
  def get_openai_response(model_name, prompt, system_prompt=JUDGE_SYSTEM_PROMPT, max_tokens=500, temperature=0):
27
  """Get response from OpenAI API"""
28
  try:
 
120
  def get_model_response(
121
  model_name,
122
  model_info,
123
+ prompt_data,
124
+ use_reference=False,
125
  max_tokens=500,
126
  temperature=0
127
  ):
 
132
  api_model = model_info["api_model"]
133
  organization = model_info["organization"]
134
 
135
+ # Determine if model is Prometheus
136
+ is_prometheus = (organization == "Prometheus")
137
+
138
+ # For non-Prometheus models, use the Judge system prompt
139
+ system_prompt = None if is_prometheus else JUDGE_SYSTEM_PROMPT
140
+
141
+ # Select the appropriate base prompt
142
+ if use_reference:
143
+ base_prompt = PROMETHEUS_PROMPT_WITH_REFERENCE
144
  else:
145
+ base_prompt = PROMETHEUS_PROMPT
146
+
147
+ # For non-Prometheus models, replace the specific instruction
148
+ if not is_prometheus:
149
+ base_prompt = base_prompt.replace(
150
+ '3. The output format should look as follows: "Feedback: (write a feedback for criteria) [RESULT] (an integer number between 1 and 5)"',
151
+ '3. Your output format should strictly adhere to JSON as follows: {{"feedback": "<write feedback>", "result": <numerical score>}}. Ensure the output is valid JSON, without additional formatting or explanations.'
152
+ )
153
+
154
+ try:
155
+ # Format the prompt with the provided data, only using available keys
156
+ final_prompt = base_prompt.format(
157
+ human_input=prompt_data['human_input'],
158
+ ai_response=prompt_data['ai_response'],
159
+ ground_truth_input=prompt_data.get('ground_truth_input', ''),
160
+ eval_criteria=prompt_data['eval_criteria'],
161
+ score1_desc=prompt_data['score1_desc'],
162
+ score2_desc=prompt_data['score2_desc'],
163
+ score3_desc=prompt_data['score3_desc'],
164
+ score4_desc=prompt_data['score4_desc'],
165
+ score5_desc=prompt_data['score5_desc']
166
+ )
167
+ except KeyError as e:
168
+ return f"Error formatting prompt: Missing required field {str(e)}"
169
 
170
  try:
171
  if organization == "OpenAI":
172
  return get_openai_response(
173
+ api_model, final_prompt, system_prompt, max_tokens, temperature
174
  )
175
  elif organization == "Anthropic":
176
  return get_anthropic_response(
177
+ api_model, final_prompt, system_prompt, max_tokens, temperature
178
  )
179
  elif organization == "Prometheus":
180
  return get_hf_response(
181
+ api_model, final_prompt, max_tokens
182
  )
183
  elif organization == "Cohere":
184
  return get_cohere_response(
185
+ api_model, final_prompt, system_prompt, max_tokens, temperature
186
  )
187
  else:
188
  # All other organizations use Together API
189
  return get_together_response(
190
+ api_model, final_prompt, system_prompt, max_tokens, temperature
191
  )
192
  except Exception as e:
193
  return f"Error with {organization} model {model_name}: {str(e)}"
 
215
  print(f"Failed to parse response: {str(e)}")
216
  return "Error", f"Failed to parse response: {response}"
217
 
218
+ def prometheus_parse_model_response(output):
219
  try:
220
  print(f"Raw model response: {output}")
221
  output = output.strip()