Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Miaoran000
commited on
Commit
·
1557ad2
1
Parent(s):
e071b26
minor update for src/model_operations.py
Browse files
src/backend/model_operations.py
CHANGED
@@ -162,7 +162,7 @@ class SummaryGenerator:
|
|
162 |
using_replicate_api = False
|
163 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
164 |
using_pipeline = False
|
165 |
-
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b']
|
166 |
|
167 |
for replicate_api_model in replicate_api_models:
|
168 |
if replicate_api_model in self.model_id.lower():
|
@@ -375,12 +375,19 @@ class SummaryGenerator:
|
|
375 |
model=self.model_id,
|
376 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
377 |
device_map="auto",
|
|
|
378 |
)
|
379 |
else:
|
380 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
381 |
print("Tokenizer loaded")
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
print("Local model loaded")
|
385 |
|
386 |
|
@@ -394,6 +401,8 @@ class SummaryGenerator:
|
|
394 |
outputs = self.local_pipeline(
|
395 |
messages,
|
396 |
max_new_tokens=250,
|
|
|
|
|
397 |
)
|
398 |
result = outputs[0]["generated_text"][-1]['content']
|
399 |
print(result)
|
@@ -435,8 +444,8 @@ class SummaryGenerator:
|
|
435 |
result = result.split("### Assistant:\n")[-1]
|
436 |
|
437 |
else:
|
438 |
-
print(prompt)
|
439 |
-
print('-'*50)
|
440 |
result = result.replace(prompt.strip(), '')
|
441 |
|
442 |
print(result)
|
|
|
162 |
using_replicate_api = False
|
163 |
replicate_api_models = ['snowflake', 'llama-3.1-405b']
|
164 |
using_pipeline = False
|
165 |
+
pipeline_models = ['llama-3.1', 'phi-3-mini','falcon-7b', 'phi-3.5']
|
166 |
|
167 |
for replicate_api_model in replicate_api_models:
|
168 |
if replicate_api_model in self.model_id.lower():
|
|
|
375 |
model=self.model_id,
|
376 |
model_kwargs={"torch_dtype": torch.bfloat16},
|
377 |
device_map="auto",
|
378 |
+
trust_remote_code=True
|
379 |
)
|
380 |
else:
|
381 |
self.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf" if 'openelm' in self.model_id.lower() else self.model_id, trust_remote_code=True)
|
382 |
print("Tokenizer loaded")
|
383 |
+
if 'jamba' in self.model_id.lower():
|
384 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id,
|
385 |
+
torch_dtype=torch.bfloat16,
|
386 |
+
attn_implementation="flash_attention_2",
|
387 |
+
device_map="auto")
|
388 |
+
else:
|
389 |
+
self.local_model = AutoModelForCausalLM.from_pretrained(self.model_id, trust_remote_code=True, device_map="auto", torch_dtype="auto")
|
390 |
+
# print(self.local_model.device)
|
391 |
print("Local model loaded")
|
392 |
|
393 |
|
|
|
401 |
outputs = self.local_pipeline(
|
402 |
messages,
|
403 |
max_new_tokens=250,
|
404 |
+
temperature=0.0,
|
405 |
+
do_sample=False
|
406 |
)
|
407 |
result = outputs[0]["generated_text"][-1]['content']
|
408 |
print(result)
|
|
|
444 |
result = result.split("### Assistant:\n")[-1]
|
445 |
|
446 |
else:
|
447 |
+
# print(prompt)
|
448 |
+
# print('-'*50)
|
449 |
result = result.replace(prompt.strip(), '')
|
450 |
|
451 |
print(result)
|