msinghy's picture
Update app.py
8a05110 verified
raw
history blame
No virus
1.36 kB
import torch
import gradio as gr
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model_id = "google/gemma-7b"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
use_auth_token=True,
offload_folder="offload/",
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
ft_model = PeftModel.from_pretrained(base_model, "msinghy/gemma-7b-ft-80row-alpaca-correcting-mistakes", offload_folder="offload/")
def respond(query):
eval_prompt = "###Input: " + query + "\n\n###Output: "
model_input = tokenizer(eval_prompt, return_tensors="pt").to("cuda")
output = ft_model.generate(input_ids=model_input["input_ids"].to(device),
attention_mask=model_input["attention_mask"],
max_new_tokens=500)
result = tokenizer.decode(output[0], skip_special_tokens=True).replace(eval_prompt, "")
return result
def chat_response(message, history):
return respond(message)
demo = gr.ChatInterface(chat_response)
demo.launch()