import os os.environ["KERAS_BACKEND"] = "jax" import gradio as gr from gradio import ChatMessage import keras_hub from chatstate import ChatState from models import ( model_presets, load_model, model_labels, preset_to_website_url, get_appropriate_chat_template, ) model_labels_list = list(model_labels) # lod a warm up (compile) all the models models = [] for preset in model_presets: model = load_model(preset) chat_template = get_appropriate_chat_template(preset) chat_state = ChatState(model, "", chat_template) prompt, response = chat_state.send_message("Hello") print("model " + preset + "loaded and initialized.") print("The model responded: " + response) models.append(model) # For local debugging # model = keras_hub.models.Llama3CausalLM.from_preset( # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16" # ) # models = [model, model] def chat_turn_assistant_1( model, message, history, system_message, preset, # max_tokens, # temperature, # top_p, ): chat_template = get_appropriate_chat_template(preset) chat_state = ChatState(model, system_message, chat_template) for msg in history: msg = ChatMessage(**msg) if msg.role == "user": chat_state.add_to_history_as_user(msg.content) elif msg.role == "assistant": chat_state.add_to_history_as_model(msg.content) prompt, response = chat_state.send_message(message) history.append(ChatMessage(role="assistant", content=response)) return history def chat_turn_assistant( message, sel1, history1, sel2, history2, system_message, # max_tokens, # temperature, # top_p, ): history1 = chat_turn_assistant_1( models[sel1], message, history1, system_message, model_presets[sel1] ) history2 = chat_turn_assistant_1( models[sel2], message, history2, system_message, model_presets[sel2] ) return "", history1, history2 def chat_turn_user_1(message, history): history.append(ChatMessage(role="user", content=message)) return history def chat_turn_user(message, history1, history2): history1 = chat_turn_user_1(message, history1) history2 = chat_turn_user_1(message, history2) return "", history1, history2 def bot_icon_select(model_name): if "gemma" in model_name: return "img/gemma.png" elif "llama" in model_name: return "img/llama.png" elif "vicuna" in model_name: return "img/vicuna.png" elif "mistral" in model_name: return "img/mistral.png" # default return "img/bot.png" def instantiate_chatbots(sel1, sel2): model_name1 = model_presets[sel1] chatbot1 = gr.Chatbot( type="messages", show_label=False, avatar_images=("img/usr.png", bot_icon_select(model_name1)), ) model_name2 = model_presets[sel2] chatbot2 = gr.Chatbot( type="messages", show_label=False, avatar_images=("img/usr.png", bot_icon_select(model_name2)), ) return chatbot1, chatbot2 def instantiate_select_boxes(sel1, sel2, model_labels): sel1 = gr.Dropdown( choices=[(name, i) for i, name in enumerate(model_labels)], show_label=False, info="Selected model 1: " + "" + preset_to_website_url(model_presets[sel1]) + "", value=sel1, ) sel2 = gr.Dropdown( choices=[(name, i) for i, name in enumerate(model_labels)], show_label=False, info="Selected model 2: " + "" + preset_to_website_url(model_presets[sel2]) + "", value=sel2, ) return sel1, sel2 def instantiate_chatbots_and_select_boxes(sel1, sel2, model_labels): chatbot1, chatbot2 = instantiate_chatbots(sel1, sel2) sel1, sel2 = instantiate_select_boxes(sel1, sel2, model_labels) return sel1, chatbot1, sel2, chatbot2 with gr.Blocks(fill_width=True, title="Keras demo") as demo: with gr.Row(): gr.Image( "img/keras_logo_k.png", width=80, height=80, min_width=80, show_label=False, show_download_button=False, show_fullscreen_button=False, interactive=False, scale=0.01, container=False, ) gr.HTML( "

Battle of the Keras chatbots on TPU

" + "All the models are loaded into the TPU memory. " + "You can call them at will and compare their answers.
" + "The entire chat history is fed to the models at every submission." + "This demno is runnig on a Google TPU v5e 2x4 (8 cores).", ) with gr.Row(): sel1, sel2 = instantiate_select_boxes(0, 1, model_labels_list) with gr.Row(): chatbot1, chatbot2 = instantiate_chatbots(sel1.value, sel2.value) msg = gr.Textbox( label="Your message:", ) with gr.Row(): gr.ClearButton([msg, chatbot1, chatbot2]) with gr.Accordion("Additional settings", open=False): system_message = gr.Textbox( label="Sytem prompt", value="You are a helpful assistant and your name is Eliza.", ) sel1.select( lambda sel1, sel2: instantiate_chatbots_and_select_boxes( sel1, sel2, model_labels_list ), inputs=[sel1, sel2], outputs=[sel1, chatbot1, sel2, chatbot2], ) sel2.select( lambda sel1, sel2: instantiate_chatbots_and_select_boxes( sel1, sel2, model_labels_list ), inputs=[sel1, sel2], outputs=[sel1, chatbot1, sel2, chatbot2], ) msg.submit( chat_turn_user, inputs=[msg, chatbot1, chatbot2], outputs=[msg, chatbot1, chatbot2], ).then( chat_turn_assistant, [msg, sel1, chatbot1, sel2, chatbot2, system_message], outputs=[msg, chatbot1, chatbot2], ) if __name__ == "__main__": demo.launch()