Spaces:
Sleeping
Sleeping
Update src/demo.py
Browse files- src/demo.py +2 -2
src/demo.py
CHANGED
@@ -22,7 +22,7 @@ type2dataset = {
|
|
22 |
|
23 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN)
|
25 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN).eval()
|
26 |
|
27 |
# type2dataset = {}
|
28 |
|
@@ -35,7 +35,7 @@ def generate(input_text, sys_prompt) -> str:
|
|
35 |
'''
|
36 |
input_str = sys_prompt + input_text + " [/INST]"
|
37 |
|
38 |
-
input_ids = tokenizer(input_str, return_tensors="pt").input_ids
|
39 |
outputs = model.generate(input_ids, max_length=512)
|
40 |
|
41 |
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
|
22 |
|
23 |
model_id = "meta-llama/Llama-2-7b-chat-hf"
|
24 |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=TOKEN)
|
25 |
+
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, token=TOKEN, device_map="auto").eval()
|
26 |
|
27 |
# type2dataset = {}
|
28 |
|
|
|
35 |
'''
|
36 |
input_str = sys_prompt + input_text + " [/INST]"
|
37 |
|
38 |
+
input_ids = tokenizer(input_str, return_tensors="pt").input_ids.to('cuda')
|
39 |
outputs = model.generate(input_ids, max_length=512)
|
40 |
|
41 |
result = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
|