Miaoran000 commited on
Commit
3cf286c
1 Parent(s): 2aa9a75
Files changed (1) hide show
  1. src/backend/model_operations.py +30 -13
src/backend/model_operations.py CHANGED
@@ -13,7 +13,7 @@ from sentence_transformers import CrossEncoder
13
  import litellm
14
  # from litellm import completion
15
  from tqdm import tqdm
16
- from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
17
  # from accelerate import PartialState
18
  # from accelerate.inference import prepare_pippy
19
  import torch
@@ -272,9 +272,8 @@ class SummaryGenerator:
272
  # Using HF API or download checkpoints
273
  elif self.local_model is None:
274
  try: # try use HuggingFace API
275
-
276
  response = litellm.completion(
277
- model='command-r-plus' if 'command' in self.model else self.model,
278
  messages=[{"role": "system", "content": system_prompt},
279
  {"role": "user", "content": user_prompt}],
280
  temperature=0.0,
@@ -286,7 +285,7 @@ class SummaryGenerator:
286
  except: # fail to call api. run it locally.
287
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
288
  print("Tokenizer loaded")
289
- self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
290
  print("Local model loaded")
291
 
292
  # Using local model
@@ -294,15 +293,33 @@ class SummaryGenerator:
294
  messages=[
295
  {"role": "system", "content": system_prompt}, # gemma-1.1 does not accept system role
296
  {"role": "user", "content": user_prompt}
297
- ],
298
- prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
299
- print(prompt)
300
- input_ids = self.tokenizer(prompt, return_tensors="pt").to('cuda')
301
- with torch.no_grad():
302
- outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
303
- result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
304
- result = result.replace(prompt[0], '')
305
- print(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
  return result
307
 
308
  def _compute_avg_length(self):
 
13
  import litellm
14
  # from litellm import completion
15
  from tqdm import tqdm
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig, pipeline
17
  # from accelerate import PartialState
18
  # from accelerate.inference import prepare_pippy
19
  import torch
 
272
  # Using HF API or download checkpoints
273
  elif self.local_model is None:
274
  try: # try use HuggingFace API
 
275
  response = litellm.completion(
276
+ model='command-r-plus' if 'command' in self.model_id else self.model_id,
277
  messages=[{"role": "system", "content": system_prompt},
278
  {"role": "user", "content": user_prompt}],
279
  temperature=0.0,
 
285
  except: # fail to call api. run it locally.
286
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_id, trust_remote_code=True)
287
  print("Tokenizer loaded")
288
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto", cache_dir='/home/paperspace/cache')
289
  print("Local model loaded")
290
 
291
  # Using local model
 
293
  messages=[
294
  {"role": "system", "content": system_prompt}, # gemma-1.1 does not accept system role
295
  {"role": "user", "content": user_prompt}
296
+ ]
297
+ try: # some models support pipeline
298
+ pipe = pipeline(
299
+ "text-generation",
300
+ model=self.local_model,
301
+ tokenizer=self.tokenizer,
302
+ )
303
+
304
+ generation_args = {
305
+ "max_new_tokens": 250,
306
+ "return_full_text": False,
307
+ "temperature": 0.0,
308
+ "do_sample": False,
309
+ }
310
+
311
+ output = pipe(messages, **generation_args)
312
+ result = output[0]['generated_text']
313
+ print(result)
314
+ except:
315
+ prompt = self.tokenizer.apply_chat_template(messages,add_generation_prompt=True, tokenize=False)
316
+ print(prompt)
317
+ input_ids = self.tokenizer(prompt, return_tensors="pt").to('cuda')
318
+ with torch.no_grad():
319
+ outputs = self.local_model.generate(**input_ids, max_new_tokens=250, do_sample=True, temperature=0.01, pad_token_id=self.tokenizer.eos_token_id)
320
+ result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
321
+ result = result.replace(prompt[0], '')
322
+ print(result)
323
  return result
324
 
325
  def _compute_avg_length(self):