|
import gradio as gr |
|
import spaces |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
import threading |
|
|
|
model_to_use = "thesven/Llama3-8B-SFT-code_bagel-bnb-4bit" |
|
|
|
|
|
tokenizer = None |
|
model = None |
|
|
|
@spaces.GPU |
|
def load_model(): |
|
global tokenizer, model |
|
model_name_or_path = model_to_use |
|
|
|
|
|
bnb_config = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_compute_dtype="bfloat16", |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name_or_path, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
quantization_config=bnb_config |
|
) |
|
model.pad_token_id = model.config.eos_token_id |
|
|
|
return "Model loaded and ready!" |
|
|
|
def send_message(message, history): |
|
global tokenizer, model |
|
if tokenizer is None or model is None: |
|
return history |
|
|
|
|
|
history.append(("User", message)) |
|
|
|
|
|
input_text = " ".join([msg for _, msg in history]) |
|
input_ids = tokenizer(input_text, return_tensors='pt').input_ids.cuda() |
|
output = model.generate(inputs=input_ids, max_new_tokens=50) |
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
|
|
|
history.append(("Bot", generated_text)) |
|
|
|
return history |
|
|
|
def initialize(): |
|
|
|
threading.Thread(target=load_model).start() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Chat with the Model") |
|
|
|
status_text = gr.Textbox(label="Status", value="Loading model, please wait...") |
|
send_button = gr.Button("Send", interactive=False) |
|
|
|
chatbot = gr.Chatbot() |
|
message = gr.Textbox(label="Your Message") |
|
|
|
def enable_send_button(): |
|
send_button.interactive = True |
|
status_text.value = "Model loaded and ready!" |
|
|
|
demo.load(_js="initialize(); enable_send_button();") |
|
send_button.click(send_message, inputs=[message, chatbot], outputs=chatbot) |
|
|
|
initialize() |
|
|
|
demo.launch() |
|
|