Spaces:
Runtime error
Runtime error
import gradio as gr | |
import spaces | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import threading | |
model_to_use = "thesven/Llama3-8B-SFT-code_bagel-bnb-4bit" | |
# Initialize global variables for the tokenizer and model | |
tokenizer = None | |
model = None | |
def load_model(): | |
global tokenizer, model | |
model_name_or_path = model_to_use | |
# BitsAndBytesConfig for loading the model in 4-bit precision | |
bnb_config = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_quant_type="nf4", | |
bnb_4bit_compute_dtype="bfloat16", | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name_or_path, | |
device_map="auto", | |
trust_remote_code=True, | |
quantization_config=bnb_config | |
) | |
model.pad_token_id = model.config.eos_token_id | |
return "Model loaded and ready!" | |
def send_message(message, history): | |
global tokenizer, model | |
if tokenizer is None or model is None: | |
return history # Return the existing history if the model is not loaded | |
# Add the user's message to the history | |
history.append(("User", message)) | |
# Generate the model's response | |
input_text = " ".join([msg for _, msg in history]) | |
input_ids = tokenizer(input_text, return_tensors='pt').input_ids.cuda() | |
output = model.generate(inputs=input_ids, max_new_tokens=50) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
# Add the model's response to the history | |
history.append(("Bot", generated_text)) | |
return history | |
def initialize(): | |
# Function to run the model loading in a separate thread | |
threading.Thread(target=load_model).start() | |
with gr.Blocks() as demo: | |
gr.Markdown("# Chat with the Model") | |
status_text = gr.Textbox(label="Status", value="Loading model, please wait...") | |
send_button = gr.Button("Send", interactive=False) # Disable the send button initially | |
chatbot = gr.Chatbot() | |
message = gr.Textbox(label="Your Message") | |
def enable_send_button(): | |
send_button.interactive = True | |
status_text.value = "Model loaded and ready!" | |
demo.load(_js="initialize(); enable_send_button();") | |
send_button.click(send_message, inputs=[message, chatbot], outputs=chatbot) | |
initialize() # Start model initialization on app load | |
demo.launch() | |