Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
import os | |
os.environ["KERAS_BACKEND"] = "jax" | |
import gradio as gr | |
from gradio import ChatMessage | |
import keras_hub | |
from chatstate import ChatState | |
from models import ( | |
model_presets, | |
load_model, | |
model_labels, | |
preset_to_website_url, | |
get_appropriate_chat_template, | |
) | |
model_labels_list = list(model_labels) | |
# lod a warm up (compile) all the models | |
models = [] | |
for preset in model_presets: | |
model = load_model(preset) | |
chat_template = get_appropriate_chat_template(preset) | |
chat_state = ChatState(model, "", chat_template) | |
prompt, response = chat_state.send_message("Hello") | |
print("model " + preset + "loaded and initialized.") | |
print("The model responded: " + response) | |
models.append(model) | |
# For local debugging | |
# model = keras_hub.models.Llama3CausalLM.from_preset( | |
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16" | |
# ) | |
# models = [model, model] | |
def chat_turn_assistant_1( | |
model, | |
message, | |
history, | |
system_message, | |
preset, | |
# max_tokens, | |
# temperature, | |
# top_p, | |
): | |
chat_template = get_appropriate_chat_template(preset) | |
chat_state = ChatState(model, system_message, chat_template) | |
for msg in history: | |
msg = ChatMessage(**msg) | |
if msg.role == "user": | |
chat_state.add_to_history_as_user(msg.content) | |
elif msg.role == "assistant": | |
chat_state.add_to_history_as_model(msg.content) | |
prompt, response = chat_state.send_message(message) | |
history.append(ChatMessage(role="assistant", content=response)) | |
return history | |
def chat_turn_assistant( | |
message, | |
sel1, | |
history1, | |
sel2, | |
history2, | |
system_message, | |
# max_tokens, | |
# temperature, | |
# top_p, | |
): | |
history1 = chat_turn_assistant_1( | |
models[sel1], message, history1, system_message, model_presets[sel1] | |
) | |
history2 = chat_turn_assistant_1( | |
models[sel2], message, history2, system_message, model_presets[sel2] | |
) | |
return "", history1, history2 | |
def chat_turn_user_1(message, history): | |
history.append(ChatMessage(role="user", content=message)) | |
return history | |
def chat_turn_user(message, history1, history2): | |
history1 = chat_turn_user_1(message, history1) | |
history2 = chat_turn_user_1(message, history2) | |
return "", history1, history2 | |
def bot_icon_select(model_name): | |
if "gemma" in model_name: | |
return "img/gemma.png" | |
elif "llama" in model_name: | |
return "img/llama.png" | |
elif "vicuna" in model_name: | |
return "img/vicuna.png" | |
elif "mistral" in model_name: | |
return "img/mistral.png" | |
# default | |
return "img/bot.png" | |
def instantiate_chatbots(sel1, sel2): | |
model_name1 = model_presets[sel1] | |
chatbot1 = gr.Chatbot( | |
type="messages", | |
show_label=False, | |
avatar_images=("img/usr.png", bot_icon_select(model_name1)), | |
) | |
model_name2 = model_presets[sel2] | |
chatbot2 = gr.Chatbot( | |
type="messages", | |
show_label=False, | |
avatar_images=("img/usr.png", bot_icon_select(model_name2)), | |
) | |
return chatbot1, chatbot2 | |
def instantiate_select_boxes(sel1, sel2, model_labels): | |
sel1 = gr.Dropdown( | |
choices=[(name, i) for i, name in enumerate(model_labels)], | |
show_label=False, | |
info="<span style='color:black'>Selected model 1:</span> " | |
+ "<a href='" | |
+ preset_to_website_url(model_presets[sel1]) | |
+ "'>" | |
+ preset_to_website_url(model_presets[sel1]) | |
+ "</a>", | |
value=sel1, | |
) | |
sel2 = gr.Dropdown( | |
choices=[(name, i) for i, name in enumerate(model_labels)], | |
show_label=False, | |
info="<span style='color:black'>Selected model 2:</span> " | |
+ "<a href='" | |
+ preset_to_website_url(model_presets[sel2]) | |
+ "'>" | |
+ preset_to_website_url(model_presets[sel2]) | |
+ "</a>", | |
value=sel2, | |
) | |
return sel1, sel2 | |
def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels): | |
chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2) | |
sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels) | |
return sel1, chatbot1, sel2, chatbot2 | |
with gr.Blocks(fill_width=True, title="Keras demo") as demo: | |
with gr.Row(): | |
gr.Image( | |
"img/keras_logo_k.png", | |
width=80, | |
height=80, | |
min_width=80, | |
show_label=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
interactive=False, | |
scale=0.01, | |
container=False, | |
) | |
gr.HTML( | |
"<H2> Battle of the Keras chatbots on TPU</H2>" | |
+ "All the models are loaded into the TPU memory. " | |
+ "You can call them at will and compare their answers. <br/>" | |
+ "The entire chat history is fed to the models at every submission." | |
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).", | |
) | |
with gr.Row(): | |
sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list) | |
with gr.Row(): | |
chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value) | |
msg = gr.Textbox( | |
label="Your message:", | |
) | |
with gr.Row(): | |
gr.ClearButton([msg, chatbot1, chatbot2]) | |
with gr.Accordion("Additional settings", open=False): | |
system_message = gr.Textbox( | |
label="Sytem prompt", | |
value="You are a helpful assistant and your name is Eliza.", | |
) | |
sel1.select( | |
lambda sel1, sel2: instantiate_chatbots_and_select_boxes( | |
sel1, sel2, model_labels_list | |
), | |
inputs=[sel1, sel2], | |
outputs=[sel1, chatbot1, sel2, chatbot2], | |
) | |
sel2.select( | |
lambda sel1, sel2: instantiate_chatbots_and_select_boxes( | |
sel1, sel2, model_labels_list | |
), | |
inputs=[sel1, sel2], | |
outputs=[sel1, chatbot1, sel2, chatbot2], | |
) | |
msg.submit( | |
chat_turn_user, | |
inputs=[msg, chatbot1, chatbot2], | |
outputs=[msg, chatbot1, chatbot2], | |
).then( | |
chat_turn_assistant, | |
[msg, sel1, chatbot1, sel2, chatbot2, system_message], | |
outputs=[msg, chatbot1, chatbot2], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |