msinghy's picture
Update app.py
02ebcf9 verified
raw
history blame
No virus
1.46 kB
import os
import torch
import gradio as gr
import transformers
import accelerate
import huggingface_hub
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
huggingface_hub.login(token = os.environ['HF_TOKEN'])
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
base_model_id = "google/gemma-7b"
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
#quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True,
token=True,
offload_folder="offload/",
)
tokenizer = AutoTokenizer.from_pretrained(base_model_id, add_bos_token=True, trust_remote_code=True)
ft_model = PeftModel.from_pretrained(base_model, "msinghy/gemma-7b-ft-80row-alpaca-correcting-mistakes", offload_folder="offload/")
def respond(query):
eval_prompt = "###Input: " + query + "\n\n###Output: "
model_input = tokenizer(eval_prompt, return_tensors="pt")#.to("cuda")
output = ft_model.generate(input_ids=model_input["input_ids"]#.to(device),
,attention_mask=model_input["attention_mask"],
max_new_tokens=500)
result = tokenizer.decode(output[0], skip_special_tokens=True).replace(eval_prompt, "")
return result
def chat_response(message, history):
return respond(message)
demo = gr.ChatInterface(chat_response)
demo.launch()