Miaoran000 commited on
Commit
1557ad2
·
1 Parent(s): e071b26

minor update for src/model_operations.py

Browse files
Files changed (1) hide show
  1. src/backend/model_operations.py +14 -5
src/backend/model_operations.py CHANGED
@@ -162,7 +162,7 @@ class SummaryGenerator:
162
  using_replicate_api = False
163
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
164
  using_pipeline = False
165
- pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b']
166
 
167
  for replicate_api_model in replicate_api_models:
168
  if replicate_api_model in self.model_id.lower():
@@ -375,12 +375,19 @@ class SummaryGenerator:
375
  model=self.model_id,
376
  model_kwargs={"torch_dtype": torch.bfloat16},
377
  device_map="auto",
 
378
  )
379
  else:
380
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
381
  print("Tokenizer loaded")
382
- self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
383
- print(self.local_model.device)
 
 
 
 
 
 
384
  print("Local model loaded")
385
 
386
 
@@ -394,6 +401,8 @@ class SummaryGenerator:
394
  outputs = self.local_pipeline(
395
  messages,
396
  max_new_tokens=250,
 
 
397
  )
398
  result = outputs[0]["generated_text"][-1]['content']
399
  print(result)
@@ -435,8 +444,8 @@ class SummaryGenerator:
435
  result = result.split("### Assistant:\n")[-1]
436
 
437
  else:
438
- print(prompt)
439
- print('-'*50)
440
  result = result.replace(prompt.strip(), '')
441
 
442
  print(result)
 
162
  using_replicate_api = False
163
  replicate_api_models = ['snowflake', 'llama-3.1-405b']
164
  using_pipeline = False
165
+ pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5']
166
 
167
  for replicate_api_model in replicate_api_models:
168
  if replicate_api_model in self.model_id.lower():
 
375
  model=self.model_id,
376
  model_kwargs={"torch_dtype": torch.bfloat16},
377
  device_map="auto",
378
+ trust_remote_code=True
379
  )
380
  else:
381
  self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
382
  print("Tokenizer loaded")
383
+ if 'jamba' in self.model_id.lower():
384
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
385
+ torch_dtype=torch.bfloat16,
386
+ attn_implementation="flash_attention_2",
387
+ device_map="auto")
388
+ else:
389
+ self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
390
+ # print(self.local_model.device)
391
  print("Local model loaded")
392
 
393
 
 
401
  outputs = self.local_pipeline(
402
  messages,
403
  max_new_tokens=250,
404
+ temperature=0.0,
405
+ do_sample=False
406
  )
407
  result = outputs[0]["generated_text"][-1]['content']
408
  print(result)
 
444
  result = result.split("### Assistant:\n")[-1]
445
 
446
  else:
447
+ # print(prompt)
448
+ # print('-'*50)
449
  result = result.replace(prompt.strip(), '')
450
 
451
  print(result)