import gradio as gr import spaces import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import threading model_to_use = "thesven/Llama3-8B-SFT-code_bagel-bnb-4bit" # Initialize global variables for the tokenizer and model tokenizer = None model = None @spaces.GPU def load_model(): global tokenizer, model model_name_or_path = model_to_use # BitsAndBytesConfig for loading the model in 4-bit precision bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype="bfloat16", ) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) model = AutoModelForCausalLM.from_pretrained( model_name_or_path, device_map="auto", trust_remote_code=True, quantization_config=bnb_config ) model.pad_token_id = model.config.eos_token_id return "Model loaded and ready!" def send_message(message, history): global tokenizer, model if tokenizer is None or model is None: return history # Return the existing history if the model is not loaded # Add the user's message to the history history.append(("User", message)) # Generate the model's response input_text = " ".join([msg for _, msg in history]) input_ids = tokenizer(input_text, return_tensors='pt').input_ids.cuda() output = model.generate(inputs=input_ids, max_new_tokens=50) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) # Add the model's response to the history history.append(("Bot", generated_text)) return history def initialize(): # Function to run the model loading in a separate thread threading.Thread(target=load_model).start() with gr.Blocks() as demo: gr.Markdown("# Chat with the Model") status_text = gr.Textbox(label="Status", value="Loading model, please wait...") send_button = gr.Button("Send", interactive=False) # Disable the send button initially chatbot = gr.Chatbot() message = gr.Textbox(label="Your Message") def enable_send_button(): send_button.interactive = True status_text.value = "Model loaded and ready!" demo.load(_js="initialize(); enable_send_button();") send_button.click(send_message, inputs=[message, chatbot], outputs=chatbot) initialize() # Start model initialization on app load demo.launch()