Miaoran000 commited on
Commit
2aa9a75
1 Parent(s): 6632750
src/backend/model_operations.py CHANGED
@@ -19,7 +19,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
19
  import torch
20
  import cohere
21
  from openai import OpenAI
22
-
23
 
24
  import src.backend.util as util
25
  import src.envs as envs
@@ -131,6 +131,10 @@ class SummaryGenerator:
131
  wait_time = 200
132
  print(f"Model is loading, wait for {wait_time}")
133
  time.sleep(wait_time)
 
 
 
 
134
  else:
135
  print(f"Error at index {index}: {e}")
136
  _summary = ""
@@ -161,8 +165,16 @@ class SummaryGenerator:
161
 
162
  def generate_summary(self, system_prompt: str, user_prompt: str):
163
  # Using Together AI API
164
- if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API
165
- suffix = "completions" if ('mixtral' in self.model_id.lower() or 'base' in self.model_id.lower()) else "chat/completions"
 
 
 
 
 
 
 
 
166
  url = f"https://api.together.xyz/v1/{suffix}"
167
 
168
  payload = {
@@ -170,15 +182,17 @@ class SummaryGenerator:
170
  # "max_tokens": 4096,
171
  'max_new_tokens': 250,
172
  "temperature": 0.0,
173
- 'repetition_penalty': 1.1 if 'mixtral' in self.model_id.lower() else 1
174
  }
175
- if 'mixtral' in self.model_id.lower():
176
- # payload['prompt'] = user_prompt
177
- # payload['prompt'] = "Write a summary of the following passage:\nPassage:\n" + user_prompt.split('Passage:\n')[-1] + '\n\nSummary:'
178
- payload['prompt'] = 'You must stick to the passage provided. Provide a concise summary of the following passage, covering the core pieces of information described:\nPassage:\n' + user_prompt.split('Passage:\n')[-1] + '\n\nSummary:'
179
- print(payload)
180
- else:
181
- payload['messages'] = [{"role": "system", "content": system_prompt},
 
 
182
  {"role": "user", "content": user_prompt}]
183
  headers = {
184
  "accept": "application/json",
@@ -216,8 +230,47 @@ class SummaryGenerator:
216
  print(result)
217
  return result
218
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
  # Using HF API or download checkpoints
220
- if self.local_model is None:
221
  try: # try use HuggingFace API
222
 
223
  response = litellm.completion(
@@ -229,6 +282,7 @@ class SummaryGenerator:
229
  api_base=self.api_base,
230
  )
231
  result = response['choices'][0]['message']['content']
 
232
  except: # fail to call api. run it locally.
233
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
234
  print("Tokenizer loaded")
@@ -249,8 +303,7 @@ class SummaryGenerator:
249
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
250
  result = result.replace(prompt[0], '')
251
  print(result)
252
-
253
- return result
254
 
255
  def _compute_avg_length(self):
256
  """
 
19
  import torch
20
  import cohere
21
  from openai import OpenAI
22
+ import google.generativeai as genai
23
 
24
  import src.backend.util as util
25
  import src.envs as envs
 
131
  wait_time = 200
132
  print(f"Model is loading, wait for {wait_time}")
133
  time.sleep(wait_time)
134
+ elif '429 Resource has been exhausted' in str(e): # for gemini models
135
+ wait_time = 60
136
+ print(f"Quota has reached, wait for {wait_time}")
137
+ time.sleep(wait_time)
138
  else:
139
  print(f"Error at index {index}: {e}")
140
  _summary = ""
 
165
 
166
  def generate_summary(self, system_prompt: str, user_prompt: str):
167
  # Using Together AI API
168
+ using_together_api = False
169
+ together_ai_api_models = ['mixtral', 'dbrx', 'wizardlm', 'llama-3']
170
+ for together_ai_api_model in together_ai_api_models:
171
+ if together_ai_api_model in self.model_id.lower():
172
+ using_together_api = True
173
+ break
174
+ # if 'mixtral' in self.model_id.lower() or 'dbrx' in self.model_id.lower() or 'wizardlm' in self.model_id.lower(): # For mixtral and dbrx models, use Together AI API
175
+ if using_together_api:
176
+ # suffix = "completions" if ('mixtral' in self.model_id.lower() or 'base' in self.model_id.lower()) else "chat/completions"
177
+ suffix = "chat/completions"
178
  url = f"https://api.together.xyz/v1/{suffix}"
179
 
180
  payload = {
 
182
  # "max_tokens": 4096,
183
  'max_new_tokens': 250,
184
  "temperature": 0.0,
185
+ # 'repetition_penalty': 1.1 if 'mixtral' in self.model_id.lower() else 1
186
  }
187
+ # if 'mixtral' in self.model_id.lower():
188
+ # # payload['prompt'] = user_prompt
189
+ # # payload['prompt'] = "Write a summary of the following passage:\nPassage:\n" + user_prompt.split('Passage:\n')[-1] + '\n\nSummary:'
190
+ # payload['prompt'] = 'You must stick to the passage provided. Provide a concise summary of the following passage, covering the core pieces of information described:\nPassage:\n' + user_prompt.split('Passage:\n')[-1] + '\n\nSummary:'
191
+ # print(payload)
192
+ # else:
193
+ # payload['messages'] = [{"role": "system", "content": system_prompt},
194
+ # {"role": "user", "content": user_prompt}]
195
+ payload['messages'] = [{"role": "system", "content": system_prompt},
196
  {"role": "user", "content": user_prompt}]
197
  headers = {
198
  "accept": "application/json",
 
230
  print(result)
231
  return result
232
 
233
+ # Using Google AI API for Gemini models
234
+ elif 'gemini' in self.model_id.lower():
235
+ genai.configure(api_key=os.getenv('GOOGLE_AI_API_KEY'))
236
+ generation_config = {
237
+ "temperature": 0,
238
+ "top_p": 0.95, # cannot change
239
+ "top_k": 0,
240
+ "max_output_tokens": 250,
241
+ # "response_mime_type": "application/json",
242
+ }
243
+ safety_settings = [
244
+ {
245
+ "category": "HARM_CATEGORY_HARASSMENT",
246
+ "threshold": "BLOCK_NONE"
247
+ },
248
+ {
249
+ "category": "HARM_CATEGORY_HATE_SPEECH",
250
+ "threshold": "BLOCK_NONE"
251
+ },
252
+ {
253
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
254
+ "threshold": "BLOCK_NONE"
255
+ },
256
+ {
257
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
258
+ "threshold": "BLOCK_NONE"
259
+ },
260
+ ]
261
+ model = genai.GenerativeModel(model_name="gemini-1.5-pro-latest" if "gemini-1.5-pro" in self.model_id.lower() else self.model_id.lower().split('google/')[-1],
262
+ generation_config=generation_config,
263
+ system_instruction=system_prompt,
264
+ safety_settings=safety_settings)
265
+ convo = model.start_chat(history=[])
266
+ convo.send_message(user_prompt)
267
+ # print(convo.last)
268
+ result = convo.last.text
269
+ print(result)
270
+ return result
271
+
272
  # Using HF API or download checkpoints
273
+ elif self.local_model is None:
274
  try: # try use HuggingFace API
275
 
276
  response = litellm.completion(
 
282
  api_base=self.api_base,
283
  )
284
  result = response['choices'][0]['message']['content']
285
+ return result
286
  except: # fail to call api. run it locally.
287
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
288
  print("Tokenizer loaded")
 
303
  result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
304
  result = result.replace(prompt[0], '')
305
  print(result)
306
+ return result
 
307
 
308
  def _compute_avg_length(self):
309
  """
src/backend/run_eval_suite.py CHANGED
@@ -48,7 +48,8 @@ def run_evaluation(eval_request: EvalRequest, batch_size, device,
48
  batch_size, device, no_cache, limit, write_out=True,
49
  output_base_path='logs')
50
  results = evaluator.evaluate()
51
- evaluator.write_results()
 
52
  except Exception as e:
53
  logging.error(f"Error during evaluation: {e}")
54
  raise
 
48
  batch_size, device, no_cache, limit, write_out=True,
49
  output_base_path='logs')
50
  results = evaluator.evaluate()
51
+ if write_results:
52
+ evaluator.write_results()
53
  except Exception as e:
54
  logging.error(f"Error during evaluation: {e}")
55
  raise