File size: 1,018 Bytes
0e46a0b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch, sys
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
access_token = sys.argv[2]
device = "xpu:0" if sys.argv[1] == "gpu" else "cpu:0"
tokenizer = AutoTokenizer.from_pretrained("./tokenizer/")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=access_token,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(base_model, "adapter_model")
model = model.to(device)
print("Prompt:", " ".join(sys.argv[3:]))
inputs = tokenizer(" ".join(sys.argv[3:]), return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=200,
do_sample=False, top_k=100,temperature=0.1,
eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True)) |