FlameF0X's picture
Update app.py
01bfc56 verified
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
# Initialize cache for models and tokenizers
model_cache = {}
tokenizer_cache = {}
def load_model_and_tokenizer(model_name):
"""Load model and tokenizer with caching to avoid reloading the same model"""
if model_name not in model_cache:
print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto",
torch_dtype=torch.float16
)
model_cache[model_name] = model
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Set pad token if missing
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Define a custom chat template if one is not available
if tokenizer.chat_template is None:
# Basic ChatML-style template
tokenizer.chat_template = "{% for message in messages %}\n{% if message['role'] == 'system' %}<|system|>\n{{ message['content'] }}\n{% elif message['role'] == 'user' %}<|user|>\n{{ message['content'] }}\n{% elif message['role'] == 'assistant' %}<|assistant|>\n{{ message['content'] }}\n{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}<|assistant|>\n{% endif %}"
tokenizer_cache[model_name] = tokenizer
return model_cache[model_name], tokenizer_cache[model_name]
# Define available models
available_models = [
"GoofyLM/BrainrotLM-Assistant-362M",
"GoofyLM/BrainrotLM2-Assistant-362M"
]
def respond(message, chat_history, model_choice, system_message, max_tokens, temperature, top_p):
# Load selected model and tokenizer
model, tokenizer = load_model_and_tokenizer(model_choice)
# Build conversation messages
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in chat_history:
messages.append({"role": "user", "content": user_msg})
if assistant_msg: # This might be None during streaming
messages.append({"role": "assistant", "content": assistant_msg})
# Add the current message
messages.append({"role": "user", "content": message})
# Format prompt using chat template
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Set up streaming
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Configure generation parameters
generation_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=(temperature > 0 or top_p < 1.0),
pad_token_id=tokenizer.pad_token_id
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield chat_history + [(message, partial_message)]
return chat_history + [(message, partial_message)]
# Create the Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# BrainrotLM Chat Interface")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(height=600)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Type your message here...",
lines=3,
show_label=False
)
submit = gr.Button("Send", variant="primary")
clear = gr.Button("Clear Conversation")
with gr.Column(scale=1):
model_dropdown = gr.Dropdown(
choices=available_models,
value=available_models[0],
label="Select Model"
)
system_message = gr.Textbox(
value="Your name is BrainrotLM, an AI assistant trained by GoofyLM.",
label="System message",
lines=4
)
max_tokens = gr.Slider(1, 512, value=144, label="Max new tokens")
temperature = gr.Slider(0.1, 2.0, value=0.67, label="Temperature")
top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p (nucleus sampling)")
# Set up event handlers
submit_event = msg.submit(
respond,
inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
outputs=chatbot
)
submit_click = submit.click(
respond,
inputs=[msg, chatbot, model_dropdown, system_message, max_tokens, temperature, top_p],
outputs=chatbot
)
# Clear message box after sending
submit_event.then(lambda: "", None, msg)
submit_click.then(lambda: "", None, msg)
# Clear conversation button
clear.click(lambda: None, None, chatbot)
if __name__ == "__main__":
demo.launch()