martin-gorner's picture
initial commit
b637f0b
raw
history blame
6.27 kB
import os
os.environ["KERAS_BACKEND"] = "jax"
import gradio as gr
from gradio import ChatMessage
import keras_hub
from chatstate import ChatState
from models import (
model_presets,
load_model,
model_labels,
preset_to_website_url,
get_appropriate_chat_template,
)
model_labels_list = list(model_labels)
# lod a warm up (compile) all the models
models = []
for preset in model_presets:
model = load_model(preset)
chat_template = get_appropriate_chat_template(preset)
chat_state = ChatState(model, "", chat_template)
prompt, response = chat_state.send_message("Hello")
print("model " + preset + "loaded and initialized.")
print("The model responded: " + response)
models.append(model)
# For local debugging
# model = keras_hub.models.Llama3CausalLM.from_preset(
# "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
# )
# models = [model, model]
def chat_turn_assistant_1(
model,
message,
history,
system_message,
preset,
# max_tokens,
# temperature,
# top_p,
):
chat_template = get_appropriate_chat_template(preset)
chat_state = ChatState(model, system_message, chat_template)
for msg in history:
msg = ChatMessage(**msg)
if msg.role == "user":
chat_state.add_to_history_as_user(msg.content)
elif msg.role == "assistant":
chat_state.add_to_history_as_model(msg.content)
prompt, response = chat_state.send_message(message)
history.append(ChatMessage(role="assistant", content=response))
return history
def chat_turn_assistant(
message,
sel1,
history1,
sel2,
history2,
system_message,
# max_tokens,
# temperature,
# top_p,
):
history1 = chat_turn_assistant_1(
models[sel1], message, history1, system_message, model_presets[sel1]
)
history2 = chat_turn_assistant_1(
models[sel2], message, history2, system_message, model_presets[sel2]
)
return "", history1, history2
def chat_turn_user_1(message, history):
history.append(ChatMessage(role="user", content=message))
return history
def chat_turn_user(message, history1, history2):
history1 = chat_turn_user_1(message, history1)
history2 = chat_turn_user_1(message, history2)
return "", history1, history2
def bot_icon_select(model_name):
if "gemma" in model_name:
return "img/gemma.png"
elif "llama" in model_name:
return "img/llama.png"
elif "vicuna" in model_name:
return "img/vicuna.png"
elif "mistral" in model_name:
return "img/mistral.png"
# default
return "img/bot.png"
def instantiate_chatbots(sel1, sel2):
model_name1 = model_presets[sel1]
chatbot1 = gr.Chatbot(
type="messages",
show_label=False,
avatar_images=("img/usr.png", bot_icon_select(model_name1)),
)
model_name2 = model_presets[sel2]
chatbot2 = gr.Chatbot(
type="messages",
show_label=False,
avatar_images=("img/usr.png", bot_icon_select(model_name2)),
)
return chatbot1, chatbot2
def instantiate_select_boxes(sel1, sel2, model_labels):
sel1 = gr.Dropdown(
choices=[(name, i) for i, name in enumerate(model_labels)],
show_label=False,
info="<span style='color:black'>Selected model 1:</span> "
+ "<a href='"
+ preset_to_website_url(model_presets[sel1])
+ "'>"
+ preset_to_website_url(model_presets[sel1])
+ "</a>",
value=sel1,
)
sel2 = gr.Dropdown(
choices=[(name, i) for i, name in enumerate(model_labels)],
show_label=False,
info="<span style='color:black'>Selected model 2:</span> "
+ "<a href='"
+ preset_to_website_url(model_presets[sel2])
+ "'>"
+ preset_to_website_url(model_presets[sel2])
+ "</a>",
value=sel2,
)
return sel1, sel2
def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels):
chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2)
sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels)
return sel1, chatbot1, sel2, chatbot2
with gr.Blocks(fill_width=True, title="Keras demo") as demo:
with gr.Row():
gr.Image(
"img/keras_logo_k.png",
width=80,
height=80,
min_width=80,
show_label=False,
show_download_button=False,
show_fullscreen_button=False,
interactive=False,
scale=0.01,
container=False,
)
gr.HTML(
"<H2> Battle of the Keras chatbots on TPU</H2>"
+ "All the models are loaded into the TPU memory. "
+ "You can call them at will and compare their answers. <br/>"
+ "The entire chat history is fed to the models at every submission."
+ "This demno is runnig on a Google TPU v5e 2x4 (8 cores).",
)
with gr.Row():
sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list)
with gr.Row():
chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value)
msg = gr.Textbox(
label="Your message:",
)
with gr.Row():
gr.ClearButton([msg, chatbot1, chatbot2])
with gr.Accordion("Additional settings", open=False):
system_message = gr.Textbox(
label="Sytem prompt",
value="You are a helpful assistant and your name is Eliza.",
)
sel1.select(
lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
sel1, sel2, model_labels_list
),
inputs=[sel1, sel2],
outputs=[sel1, chatbot1, sel2, chatbot2],
)
sel2.select(
lambda sel1, sel2: instantiate_chatbots_and_select_boxes(
sel1, sel2, model_labels_list
),
inputs=[sel1, sel2],
outputs=[sel1, chatbot1, sel2, chatbot2],
)
msg.submit(
chat_turn_user,
inputs=[msg, chatbot1, chatbot2],
outputs=[msg, chatbot1, chatbot2],
).then(
chat_turn_assistant,
[msg, sel1, chatbot1, sel2, chatbot2, system_message],
outputs=[msg, chatbot1, chatbot2],
)
if __name__ == "__main__":
demo.launch()