thesven's picture
update
29655fa
raw
history blame
2.46 kB
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import threading
model_to_use = "thesven/Llama3-8B-SFT-code_bagel-bnb-4bit"
# Initialize global variables for the tokenizer and model
tokenizer = None
model = None
@spaces.GPU
def load_model():
global tokenizer, model
model_name_or_path = model_to_use
# BitsAndBytesConfig for loading the model in 4-bit precision
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype="bfloat16",
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map="auto",
trust_remote_code=True,
quantization_config=bnb_config
)
model.pad_token_id = model.config.eos_token_id
return "Model loaded and ready!"
def send_message(message, history):
global tokenizer, model
if tokenizer is None or model is None:
return history # Return the existing history if the model is not loaded
# Add the user's message to the history
history.append(("User", message))
# Generate the model's response
input_text = " ".join([msg for _, msg in history])
input_ids = tokenizer(input_text, return_tensors='pt').input_ids.cuda()
output = model.generate(inputs=input_ids, max_new_tokens=50)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Add the model's response to the history
history.append(("Bot", generated_text))
return history
def initialize():
# Function to run the model loading in a separate thread
threading.Thread(target=load_model).start()
with gr.Blocks() as demo:
gr.Markdown("# Chat with the Model")
status_text = gr.Textbox(label="Status", value="Loading model, please wait...")
send_button = gr.Button("Send", interactive=False) # Disable the send button initially
chatbot = gr.Chatbot()
message = gr.Textbox(label="Your Message")
def enable_send_button():
send_button.interactive = True
status_text.value = "Model loaded and ready!"
demo.load(_js="initialize(); enable_send_button();")
send_button.click(send_message, inputs=[message, chatbot], outputs=chatbot)
initialize() # Start model initialization on app load
demo.launch()