martin-gorner's picture
fixed title
64919f8
import os
# Questions for Gradio
# - Chat share button is enabled by default but thrown an error when clicked.
# - How to add local images in HTML? (https://github.com/gradio-app/gradio/issues/884)
# - How to allow Chatbot to fill the vertical space? (https://github.com/gradio-app/gradio/issues/4001)
# TODO:
# - Add the 1MB models, keras/gemma_1.1_instruct_7b_en
# - Add retry button, for each model individually
# - Add ability to route a message to a single model only.
# - log_applied_layout_map: make it work for Llama3CausalLM and LlamaCausalLM (vicuna)
# - display context length
os.environ["KERAS_BACKEND"] = "jax"
import gradio as gr
from gradio import ChatMessage
import keras_hub
from chatstate import ChatState
from enum import Enum
from models import (
model_presets,
load_model,
model_labels,
preset_to_website_url,
get_appropriate_chat_template,
)
class TextRoute(Enum):
LEFT = 0
RIGHT = 1
BOTH = 2
model_labels_list = list(model_labels)
# load and warm up (compile) all the models
models = []
for preset in model_presets:
model = load_model(preset)
chat_template = get_appropriate_chat_template(preset)
chat_state = ChatState(model, "", chat_template)
prompt, response = chat_state.send_message("Hello")
print("model " + preset + " loaded and initialized.")
print("The model responded: " + response)
models.append(model)
# For local debugging
# model = keras_hub.models.Llama3CausalLM.from_preset(
# # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16"
# "../misc-code/ari_tiny_llama3"
# )
# models = [model, model, model, model, model]
def chat_turn_assistant(
message,
sel,
history,
system_message,
# max_tokens,
# temperature,
# top_p,
):
model = models[sel]
preset = model_presets[sel]
chat_template = get_appropriate_chat_template(preset)
chat_state = ChatState(model, system_message, chat_template)
for msg in history:
msg = ChatMessage(**msg)
if msg.role == "user":
chat_state.add_to_history_as_user(msg.content)
elif msg.role == "assistant":
chat_state.add_to_history_as_model(msg.content)
prompt, response = chat_state.send_message(message)
history.append(ChatMessage(role="assistant", content=response))
return history
def chat_turn_both_assistant(
message, sel1, sel2, history1, history2, system_message
):
return (
chat_turn_assistant(message, sel1, history1, system_message),
chat_turn_assistant(message, sel2, history2, system_message),
)
def chat_turn_user(message, history):
history.append(ChatMessage(role="user", content=message))
return history
def chat_turn_both_user(message, history1, history2):
return (
chat_turn_user(message, history1),
chat_turn_user(message, history2),
)
def bot_icon_select(model_name):
if "gemma" in model_name:
return "img/gemma.png"
elif "llama" in model_name:
return "img/meta.png"
elif "vicuna" in model_name:
return "img/vicuna.png"
elif "mistral" in model_name:
return "img/mistral.png"
# default
return "img/bot.png"
def instantiate_select_box(sel, model_labels):
return gr.Dropdown(
choices=[(name, i) for i, name in enumerate(model_labels)],
show_label=False,
value=sel,
info="<span style='color:black'>Selected model:</span> <a href='"
+ preset_to_website_url(model_presets[sel])
+ "'>"
+ preset_to_website_url(model_presets[sel])
+ "</a>",
)
def instantiate_chatbot(sel, key):
model_name = model_presets[sel]
return gr.Chatbot(
type="messages",
key=key,
show_label=False,
show_share_button=False,
show_copy_all_button=True,
avatar_images=("img/usr.png", bot_icon_select(model_name)),
)
def instantiate_arrow_button(route, text_route):
icons = {
TextRoute.LEFT: "img/arrowL.png",
TextRoute.RIGHT: "img/arrowR.png",
TextRoute.BOTH: "img/arrowRL.png",
}
button = gr.Button(
"",
size="sm",
scale=0,
min_width=40,
icon=icons[route],
)
button.click(lambda: route, outputs=[text_route])
return button
def instantiate_retry_button(route):
return gr.Button(
"",
size="sm",
scale=0,
min_width=40,
icon="img/retry.png",
)
def instantiate_trash_button():
return gr.Button(
"",
size="sm",
scale=0,
min_width=40,
icon="img/trash.png",
)
def instantiate_text_box():
return gr.Textbox(label="Your message:", submit_btn=True, key="msg")
def instantiate_additional_settings():
with gr.Accordion("Additional settings", open=False):
system_message = gr.Textbox(
label="Sytem prompt",
key="system_prompt",
value="You are a helpful assistant and your name is Eliza.",
)
return system_message
def retry_fn(history):
if len(history) >= 2:
msg = history.pop(-1) # assistant message
msg = history.pop(-1) # user message
return msg["content"], history
else:
return gr.skip(), gr.skip()
def retry_fn_both(history1, history2):
msg1, history1 = retry_fn(history1)
msg2, history2 = retry_fn(history2)
if isinstance(msg1, str) and isinstance(msg2, str):
if msg1 == msg2:
msg = msg1
else:
msg = msg1 + " / " + msg2
elif isinstance(msg1, str):
msg = msg1
elif isinstance(msg2, str):
msg = msg2
else:
msg = msg1
return msg, history1, history2
sel1 = instantiate_select_box(0, model_labels_list)
sel2 = instantiate_select_box(1, model_labels_list)
chatbot1 = instantiate_chatbot(sel1.value, "chat1")
chatbot2 = instantiate_chatbot(sel2.value, "chat2")
# to correctly align the left/right arrows
CSS = ".stick-to-the-right {align-items: end; justify-content: end}"
with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo:
# Where do messages go
text_route = gr.State(TextRoute.BOTH)
with gr.Row():
gr.Image(
"img/keras_logo_k.png",
width=80,
height=80,
min_width=80,
show_label=False,
show_download_button=False,
show_fullscreen_button=False,
show_share_button=False,
interactive=False,
scale=0,
container=False,
)
gr.HTML(
"<H2>Keras chatbot arena - running with JAX on TPU</H2>"
+ "All the models are loaded into the TPU memory. "
+ "You can call any of them and compare their answers. "
+ "The entire chat<br/>history is fed to the models at every submission. "
+ "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision."
)
with gr.Row():
sel1.render(),
sel2.render(),
with gr.Row():
chatbot1.render()
chatbot2.render()
@gr.render(inputs=text_route)
def render_text_area(route):
if route == TextRoute.BOTH:
with gr.Row():
msg = instantiate_text_box()
with gr.Column(scale=0, min_width=100):
with gr.Row():
instantiate_arrow_button(TextRoute.LEFT, text_route)
retry = instantiate_retry_button(route)
with gr.Row():
instantiate_arrow_button(TextRoute.RIGHT, text_route)
trash = instantiate_trash_button()
retry.click(
retry_fn_both,
inputs=[chatbot1, chatbot2],
outputs=[msg, chatbot1, chatbot2],
)
trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2])
elif route == TextRoute.LEFT:
with gr.Row():
with gr.Column(scale=1):
msg = instantiate_text_box()
with gr.Column(scale=1):
with gr.Row():
instantiate_arrow_button(TextRoute.RIGHT, text_route)
retry = instantiate_retry_button(route)
with gr.Row():
instantiate_arrow_button(TextRoute.BOTH, text_route)
trash = instantiate_trash_button()
retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1])
trash.click(lambda: ("", []), outputs=[msg, chatbot1])
elif route == TextRoute.RIGHT:
with gr.Row():
with gr.Column(scale=1, elem_classes="stick-to-the-right"):
with gr.Row(elem_classes="stick-to-the-right"):
retry = instantiate_retry_button(route)
instantiate_arrow_button(TextRoute.LEFT, text_route)
with gr.Row(elem_classes="stick-to-the-right"):
trash = instantiate_trash_button()
instantiate_arrow_button(TextRoute.BOTH, text_route)
with gr.Column(scale=1):
msg = instantiate_text_box()
retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2])
trash.click(lambda: ("", []), outputs=[msg, chatbot2])
system_message = instantiate_additional_settings()
# Route the submitted message to the left, right or both chatbots
if route == TextRoute.LEFT:
submission = msg.submit(
chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1]
).then(
chat_turn_assistant,
[msg, sel1, chatbot1, system_message],
outputs=[chatbot1],
)
elif route == TextRoute.RIGHT:
submission = msg.submit(
chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2]
).then(
chat_turn_assistant,
[msg, sel2, chatbot2, system_message],
outputs=[chatbot2],
)
elif route == TextRoute.BOTH:
submission = msg.submit(
chat_turn_both_user,
inputs=[msg, chatbot1, chatbot2],
outputs=[chatbot1, chatbot2],
).then(
chat_turn_both_assistant,
[msg, sel1, sel2, chatbot1, chatbot2, system_message],
outputs=[chatbot1, chatbot2],
)
# In all cases reset text box after submission
submission.then(lambda: "", outputs=msg)
sel1.select(
lambda sel: instantiate_chatbot(sel, "chat1"),
inputs=[sel1],
outputs=[chatbot1],
).then(
lambda sel: instantiate_select_box(sel, model_labels_list),
inputs=[sel1],
outputs=[sel1],
)
sel2.select(
lambda sel: instantiate_chatbot(sel, "chat2"),
inputs=[sel2],
outputs=[chatbot2],
).then(
lambda sel: instantiate_select_box(sel, model_labels_list),
inputs=[sel2],
outputs=[sel2],
)
if __name__ == "__main__":
demo.launch()