Spaces:
Running
on
TPU v5e
Running
on
TPU v5e
import os | |
# Questions for Gradio | |
# - Chat share button is enabled by default but thrown an error when clicked. | |
# - How to add local images in HTML? (https://github.com/gradio-app/gradio/issues/884) | |
# - How to allow Chatbot to fill the vertical space? (https://github.com/gradio-app/gradio/issues/4001) | |
# TODO: | |
# - Add the 1MB models, keras/gemma_1.1_instruct_7b_en | |
# - Add retry button, for each model individually | |
# - Add ability to route a message to a single model only. | |
# - log_applied_layout_map: make it work for Llama3CausalLM and LlamaCausalLM (vicuna) | |
# - display context length | |
os.environ["KERAS_BACKEND"] = "jax" | |
import gradio as gr | |
from gradio import ChatMessage | |
import keras_hub | |
from chatstate import ChatState | |
from enum import Enum | |
from models import ( | |
model_presets, | |
load_model, | |
model_labels, | |
preset_to_website_url, | |
get_appropriate_chat_template, | |
) | |
class TextRoute(Enum): | |
LEFT = 0 | |
RIGHT = 1 | |
BOTH = 2 | |
model_labels_list = list(model_labels) | |
# load and warm up (compile) all the models | |
models = [] | |
for preset in model_presets: | |
model = load_model(preset) | |
chat_template = get_appropriate_chat_template(preset) | |
chat_state = ChatState(model, "", chat_template) | |
prompt, response = chat_state.send_message("Hello") | |
print("model " + preset + " loaded and initialized.") | |
print("The model responded: " + response) | |
models.append(model) | |
# For local debugging | |
# model = keras_hub.models.Llama3CausalLM.from_preset( | |
# # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16" | |
# "../misc-code/ari_tiny_llama3" | |
# ) | |
# models = [model, model, model, model, model] | |
def chat_turn_assistant( | |
message, | |
sel, | |
history, | |
system_message, | |
# max_tokens, | |
# temperature, | |
# top_p, | |
): | |
model = models[sel] | |
preset = model_presets[sel] | |
chat_template = get_appropriate_chat_template(preset) | |
chat_state = ChatState(model, system_message, chat_template) | |
for msg in history: | |
msg = ChatMessage(**msg) | |
if msg.role == "user": | |
chat_state.add_to_history_as_user(msg.content) | |
elif msg.role == "assistant": | |
chat_state.add_to_history_as_model(msg.content) | |
prompt, response = chat_state.send_message(message) | |
history.append(ChatMessage(role="assistant", content=response)) | |
return history | |
def chat_turn_both_assistant( | |
message, sel1, sel2, history1, history2, system_message | |
): | |
return ( | |
chat_turn_assistant(message, sel1, history1, system_message), | |
chat_turn_assistant(message, sel2, history2, system_message), | |
) | |
def chat_turn_user(message, history): | |
history.append(ChatMessage(role="user", content=message)) | |
return history | |
def chat_turn_both_user(message, history1, history2): | |
return ( | |
chat_turn_user(message, history1), | |
chat_turn_user(message, history2), | |
) | |
def bot_icon_select(model_name): | |
if "gemma" in model_name: | |
return "img/gemma.png" | |
elif "llama" in model_name: | |
return "img/meta.png" | |
elif "vicuna" in model_name: | |
return "img/vicuna.png" | |
elif "mistral" in model_name: | |
return "img/mistral.png" | |
# default | |
return "img/bot.png" | |
def instantiate_select_box(sel, model_labels): | |
return gr.Dropdown( | |
choices=[(name, i) for i, name in enumerate(model_labels)], | |
show_label=False, | |
value=sel, | |
info="<span style='color:black'>Selected model:</span> <a href='" | |
+ preset_to_website_url(model_presets[sel]) | |
+ "'>" | |
+ preset_to_website_url(model_presets[sel]) | |
+ "</a>", | |
) | |
def instantiate_chatbot(sel, key): | |
model_name = model_presets[sel] | |
return gr.Chatbot( | |
type="messages", | |
key=key, | |
show_label=False, | |
show_share_button=False, | |
show_copy_all_button=True, | |
avatar_images=("img/usr.png", bot_icon_select(model_name)), | |
) | |
def instantiate_arrow_button(route, text_route): | |
icons = { | |
TextRoute.LEFT: "img/arrowL.png", | |
TextRoute.RIGHT: "img/arrowR.png", | |
TextRoute.BOTH: "img/arrowRL.png", | |
} | |
button = gr.Button( | |
"", | |
size="sm", | |
scale=0, | |
min_width=40, | |
icon=icons[route], | |
) | |
button.click(lambda: route, outputs=[text_route]) | |
return button | |
def instantiate_retry_button(route): | |
return gr.Button( | |
"", | |
size="sm", | |
scale=0, | |
min_width=40, | |
icon="img/retry.png", | |
) | |
def instantiate_trash_button(): | |
return gr.Button( | |
"", | |
size="sm", | |
scale=0, | |
min_width=40, | |
icon="img/trash.png", | |
) | |
def instantiate_text_box(): | |
return gr.Textbox(label="Your message:", submit_btn=True, key="msg") | |
def instantiate_additional_settings(): | |
with gr.Accordion("Additional settings", open=False): | |
system_message = gr.Textbox( | |
label="Sytem prompt", | |
key="system_prompt", | |
value="You are a helpful assistant and your name is Eliza.", | |
) | |
return system_message | |
def retry_fn(history): | |
if len(history) >= 2: | |
msg = history.pop(-1) # assistant message | |
msg = history.pop(-1) # user message | |
return msg["content"], history | |
else: | |
return gr.skip(), gr.skip() | |
def retry_fn_both(history1, history2): | |
msg1, history1 = retry_fn(history1) | |
msg2, history2 = retry_fn(history2) | |
if isinstance(msg1, str) and isinstance(msg2, str): | |
if msg1 == msg2: | |
msg = msg1 | |
else: | |
msg = msg1 + " / " + msg2 | |
elif isinstance(msg1, str): | |
msg = msg1 | |
elif isinstance(msg2, str): | |
msg = msg2 | |
else: | |
msg = msg1 | |
return msg, history1, history2 | |
sel1 = instantiate_select_box(0, model_labels_list) | |
sel2 = instantiate_select_box(1, model_labels_list) | |
chatbot1 = instantiate_chatbot(sel1.value, "chat1") | |
chatbot2 = instantiate_chatbot(sel2.value, "chat2") | |
# to correctly align the left/right arrows | |
CSS = ".stick-to-the-right {align-items: end; justify-content: end}" | |
with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo: | |
# Where do messages go | |
text_route = gr.State(TextRoute.BOTH) | |
with gr.Row(): | |
gr.Image( | |
"img/keras_logo_k.png", | |
width=80, | |
height=80, | |
min_width=80, | |
show_label=False, | |
show_download_button=False, | |
show_fullscreen_button=False, | |
show_share_button=False, | |
interactive=False, | |
scale=0, | |
container=False, | |
) | |
gr.HTML( | |
"<H2>Keras chatbot arena - running with JAX on TPU</H2>" | |
+ "All the models are loaded into the TPU memory. " | |
+ "You can call any of them and compare their answers. " | |
+ "The entire chat<br/>history is fed to the models at every submission. " | |
+ "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision." | |
) | |
with gr.Row(): | |
sel1.render(), | |
sel2.render(), | |
with gr.Row(): | |
chatbot1.render() | |
chatbot2.render() | |
def render_text_area(route): | |
if route == TextRoute.BOTH: | |
with gr.Row(): | |
msg = instantiate_text_box() | |
with gr.Column(scale=0, min_width=100): | |
with gr.Row(): | |
instantiate_arrow_button(TextRoute.LEFT, text_route) | |
retry = instantiate_retry_button(route) | |
with gr.Row(): | |
instantiate_arrow_button(TextRoute.RIGHT, text_route) | |
trash = instantiate_trash_button() | |
retry.click( | |
retry_fn_both, | |
inputs=[chatbot1, chatbot2], | |
outputs=[msg, chatbot1, chatbot2], | |
) | |
trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2]) | |
elif route == TextRoute.LEFT: | |
with gr.Row(): | |
with gr.Column(scale=1): | |
msg = instantiate_text_box() | |
with gr.Column(scale=1): | |
with gr.Row(): | |
instantiate_arrow_button(TextRoute.RIGHT, text_route) | |
retry = instantiate_retry_button(route) | |
with gr.Row(): | |
instantiate_arrow_button(TextRoute.BOTH, text_route) | |
trash = instantiate_trash_button() | |
retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1]) | |
trash.click(lambda: ("", []), outputs=[msg, chatbot1]) | |
elif route == TextRoute.RIGHT: | |
with gr.Row(): | |
with gr.Column(scale=1, elem_classes="stick-to-the-right"): | |
with gr.Row(elem_classes="stick-to-the-right"): | |
retry = instantiate_retry_button(route) | |
instantiate_arrow_button(TextRoute.LEFT, text_route) | |
with gr.Row(elem_classes="stick-to-the-right"): | |
trash = instantiate_trash_button() | |
instantiate_arrow_button(TextRoute.BOTH, text_route) | |
with gr.Column(scale=1): | |
msg = instantiate_text_box() | |
retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2]) | |
trash.click(lambda: ("", []), outputs=[msg, chatbot2]) | |
system_message = instantiate_additional_settings() | |
# Route the submitted message to the left, right or both chatbots | |
if route == TextRoute.LEFT: | |
submission = msg.submit( | |
chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1] | |
).then( | |
chat_turn_assistant, | |
[msg, sel1, chatbot1, system_message], | |
outputs=[chatbot1], | |
) | |
elif route == TextRoute.RIGHT: | |
submission = msg.submit( | |
chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2] | |
).then( | |
chat_turn_assistant, | |
[msg, sel2, chatbot2, system_message], | |
outputs=[chatbot2], | |
) | |
elif route == TextRoute.BOTH: | |
submission = msg.submit( | |
chat_turn_both_user, | |
inputs=[msg, chatbot1, chatbot2], | |
outputs=[chatbot1, chatbot2], | |
).then( | |
chat_turn_both_assistant, | |
[msg, sel1, sel2, chatbot1, chatbot2, system_message], | |
outputs=[chatbot1, chatbot2], | |
) | |
# In all cases reset text box after submission | |
submission.then(lambda: "", outputs=msg) | |
sel1.select( | |
lambda sel: instantiate_chatbot(sel, "chat1"), | |
inputs=[sel1], | |
outputs=[chatbot1], | |
).then( | |
lambda sel: instantiate_select_box(sel, model_labels_list), | |
inputs=[sel1], | |
outputs=[sel1], | |
) | |
sel2.select( | |
lambda sel: instantiate_chatbot(sel, "chat2"), | |
inputs=[sel2], | |
outputs=[chatbot2], | |
).then( | |
lambda sel: instantiate_select_box(sel, model_labels_list), | |
inputs=[sel2], | |
outputs=[sel2], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |