File size: 1,018 Bytes
0e46a0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import torch, sys
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM

import warnings
warnings.filterwarnings("ignore")

access_token = sys.argv[2]
device = "xpu:0" if sys.argv[1] == "gpu" else "cpu:0"

tokenizer = AutoTokenizer.from_pretrained("./tokenizer/")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

base_model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b",
    token=access_token,
    low_cpu_mem_usage=True,
    return_dict=True,
    torch_dtype=torch.bfloat16,
)

model = PeftModel.from_pretrained(base_model, "adapter_model")
model = model.to(device)

print("Prompt:", " ".join(sys.argv[3:]))

inputs = tokenizer(" ".join(sys.argv[3:]), return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=200, 
                            do_sample=False, top_k=100,temperature=0.1, 
                            eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))