inkyant
add easy inference
0e46a0b
raw
history blame
No virus
1.02 kB
import torch, sys
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings("ignore")
access_token = sys.argv[2]
device = "xpu:0" if sys.argv[1] == "gpu" else "cpu:0"
tokenizer = AutoTokenizer.from_pretrained("./tokenizer/")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b",
token=access_token,
low_cpu_mem_usage=True,
return_dict=True,
torch_dtype=torch.bfloat16,
)
model = PeftModel.from_pretrained(base_model, "adapter_model")
model = model.to(device)
print("Prompt:", " ".join(sys.argv[3:]))
inputs = tokenizer(" ".join(sys.argv[3:]), return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=200,
do_sample=False, top_k=100,temperature=0.1,
eos_token_id=tokenizer.eos_token_id)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))