QLORA-Phi2 / inference.py
wgetdd's picture
inference code
321de10
raw
history blame
No virus
849 Bytes
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer,pipeline
model_name = "trained_model/content/results/checkpoint-500"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Run text generation pipeline with our next model
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
def run_inference(prompt):
result = pipe(f"<s>[INST] {prompt} [/INST]")
return result[0]['generated_text']