QLORA-Phi2 / inference.py
wgetdd's picture
Update inference.py
e2cfe37
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoTokenizer,pipeline
model_name = "model2/"
# bnb_config = BitsAndBytesConfig(
# load_in_4bit=True,
# bnb_4bit_quant_type="nf4",
# bnb_4bit_compute_dtype=torch.float16,
# )
model = AutoModelForCausalLM.from_pretrained(
model_name,
# quantization_config=bnb_config,
trust_remote_code=True
)
model.config.use_cache = False
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Run text generation pipeline with our next model
pipe = pipeline(task="text-generation", model=model, tokenizer=tokenizer, max_length=200)
def run_inference(prompt):
result = pipe(f"<s>[INST] {prompt} [/INST]")
return result[0]['generated_text']