# pylint: disable=no-member, missing-function-docstring, missing-class-docstring, not-context-manager # type: ignore import json from typing import Any, Literal import gradio as gr # type:ignore from openai import OpenAI from loguru import logger import uuid from datetime import datetime Provider = Literal["openai", "anthropic", "gemini", "groq", "deepseek", "yi", "aliyun"] class Model: name: str provider: Provider enabled: bool def __init__(self, name: str, provider: Provider, enabled: bool): self.name = name self.provider = provider self.enabled = enabled models = [ Model("gpt-4o", "openai", True), Model("claude-3.5-sonnet", "anthropic", False), Model("gemini-1.5-pro", "gemini", False), Model("llama3.1-70b", "groq", True), Model("qwen2-72b", "aliyun", True), Model("deepseek-chat", "deepseek", False), Model("yi-large-rag", "yi", False), ] def prepare_openai_messages(user_msg: str, sys_msg: str, history: list[list[str]]): messages = list[Any]() if sys_msg: messages.append({"role": "system", "content": sys_msg}) for human, assistant in history: messages.append({"role": "user", "content": human}) messages.append({"role": "assistant", "content": assistant}) messages.append({"role": "user", "content": user_msg}) return messages def predict( user_msg: str, sys_msg: str, temperature: float, history: list[list[str]], model: Model, base_url: str, api_key: str, session_id: str, ): history = history or [] client = OpenAI(api_key=api_key, base_url=base_url) messages = prepare_openai_messages(user_msg, sys_msg, history) response = client.chat.completions.create( model=model.name, messages=messages, temperature=temperature, stream=True, max_tokens=4096, extra_headers={ "langfuse_trace_name": "lite-chat", "langfuse_session_id": f"lite-chat-{session_id}", "langfuse_trace_id": f"lite-chat-{session_id}-{model.name}", }, ) partial_msg = "" for chunk in response: if chunk.choices[0].delta.content is not None: partial_msg = partial_msg + chunk.choices[0].delta.content yield history + [[user_msg, partial_msg]] msg_input = gr.Text(show_label=False, placeholder="Ask AI ...", scale=8) submit_button = gr.Button("Submit") clear_button = gr.Button("Clear") system_msg_input = gr.Text(placeholder="System Prompt", show_label=False, scale=4) temperature_input = gr.Slider(0, 1.0, 0.1, step=0.1, label="Temperature") chat_history_markdown = gr.Markdown() model_sel = gr.CheckboxGroup( [model.name for model in models], show_label=False, value=[m.name for m in models if m.enabled], ) def update_chats(enabled_models: list[str]): result: list[Any] = [] for model in models: if model.name in enabled_models: result.append(gr.Column(visible=True)) else: result.append(gr.Column(visible=False)) return result css = """ .app.svelte-182fdeq.svelte-182fdeq { max-width: 100%; } """ def llm_chat( message: str, system_message: str, temperature: float, history: list[list[str]], model_name: str, enabled_models: list[str], base_url: str, api_key: str, session_id: str, request: gr.Request, ): if model_name in enabled_models: assert request.request user = request.request.client assert user logger.info("chat with model {} from {}", model_name, user.host) model = next((model for model in models if model.name == model_name), None) if not model: yield history + [[message, "Unknown model"]] return for ret in predict(message, system_message, temperature, history, model, base_url, api_key, session_id): yield ret else: return history with gr.Blocks(css=css) as demo: with gr.Accordion("Settings", open=False): load_config_textbox = gr.Text(visible=False) dump_config_textbox = gr.Text(visible=False) def dump_config(url: str, key: str): logger.info("dump config") conf = {"url": url, "key": key} return json.dumps(conf, indent=2) def load_config(config_str: str): logger.info("loaded config") config: dict[str, str] = json.loads(config_str) if config_str else {} return [config.get("url", ""), config.get("key", "")] with gr.Tab("Config"): url_textbox = gr.Text("https://api.openai.com/v1", label="LiteLLM base url") key_textbox = gr.Text("", label="LiteLLM key") gr.Button("Save Config").click( # type: ignore dump_config, inputs=[ url_textbox, key_textbox, ], outputs=[dump_config_textbox], ) demo.load( # type: ignore None, inputs=None, outputs=[load_config_textbox], js="()=>{conf = localStorage.getItem('config'); return conf;}", ) load_config_textbox.change( # type: ignore load_config, inputs=[load_config_textbox], outputs=[url_textbox, key_textbox], ) dump_config_textbox.change( # type: ignore None, inputs=[dump_config_textbox], js="(v)=>{localStorage.setItem('config', v)}", ) session_id_state = gr.State(datetime.now().strftime("%Y%m%d%H%M%S")) chats: list[Any] = [] bots: list[Any] = [] with gr.Row(): for m in models: with gr.Column(visible=m.enabled) as col: title = gr.Markdown(f"**{m.name}**") bot = gr.Chatbot(height="800px") bots.append(bot) model_name_state = gr.State(m.name) submit_button.click( # type: ignore fn=llm_chat, inputs=[ msg_input, system_msg_input, temperature_input, bot, model_name_state, model_sel, url_textbox, key_textbox, session_id_state, ], outputs=[bot], concurrency_limit=None, ) chats.append(col) with gr.Row(): msg_input.render() submit_button.render() clear_button.render() with gr.Row(): system_msg_input.render() temperature_input.render() model_sel.render() with gr.Accordion("Chat History", open=True): chat_history_markdown.render() demo.load( # type: ignore None, inputs=None, outputs=[system_msg_input], js="()=>{return localStorage.getItem('system_message'); }", ) system_msg_input.change( # type: ignore None, inputs=[system_msg_input], js="(v)=>{localStorage.setItem('system_message', v)}", ) def clear_chat(history, *all_bots): logger.info("Clear and save history called") new_history = history for content, model in zip(all_bots, models): if content: new_history += f"\n\n{format_history_content(content, model.name)}" new_history += "---" return [new_history] + [[] for _ in all_bots] def format_history_content(content, model_name): formatted = f"## {model_name}\n\n" for user, assistant in content: formatted += f"**User:** {user}\n\n**Assistant:** {assistant}\n\n" return formatted clear_button.click( fn=clear_chat, inputs=[chat_history_markdown, *bots], outputs=[chat_history_markdown, *bots], ) submit_button.click(lambda: "", outputs=msg_input) # type: ignore gr.on( # type: ignore [demo.load, model_sel.change], # type: ignore update_chats, inputs=[model_sel], outputs=chats, ) demo.launch(favicon_path="favicon.ico") # type: ignore