from __future__ import annotations import os from typing import Generator import gradio as gr from litellm import completion from litellm import model_list from litellm.utils import get_valid_models # Create static directory if it doesn't exist os.makedirs("static", exist_ok=True) def get_available_models( provider: str, api_key: str | None = None, ) -> list[str]: """Get available models from LiteLLM for the specified provider""" try: if api_key: os.environ[f"{provider.upper()}_API_KEY"] = api_key try: # Try to get models from API models = model_list(provider) return [model["id"] for model in models] except Exception: # Fallback to LiteLLM's valid models for the provider valid_models = get_valid_models() provider_models = [ model.split("/")[-1] if "/" in model else model for model in valid_models if model.startswith(f"{provider}/") or model.startswith(provider) ] return provider_models if provider_models else ["gpt-3.5-turbo"] return ["gpt-3.5-turbo"] # Default fallback except Exception as e: print(f"Error fetching models: {e!s}") return ["gpt-3.5-turbo"] # Fallback on error def respond( message: str, history: list[tuple[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, provider: str, model: str, api_key: str, ) -> Generator[str, None, None]: """Generate chat responses using the specified model and provider""" messages = [{"role": "system", "content": system_message}] for user_msg, assistant_msg in history: if user_msg: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) messages.append({"role": "user", "content": message}) response = "" # Set API key if provided if api_key: os.environ[f"{provider.upper()}_API_KEY"] = api_key try: # Construct full model name if needed model_name = model if "/" in model else f"{provider}/{model}" for chunk in completion( model=model_name, messages=messages, max_tokens=max_tokens, temperature=temperature, top_p=top_p, stream=True, ): token = chunk.choices[0].delta.content if token: response += token yield response except Exception as e: yield f"Error: {e!s}" def update_model_list(provider: str, api_key: str) -> gr.Dropdown: """Update the model dropdown based on provider and API key""" models = get_available_models(provider, api_key) return gr.Dropdown(choices=models, value=models[0] if models else None) def clear_click() -> None: """Clear the chat history""" return None def clear_input() -> str: """Clear the input textbox""" return "" # Get available providers from LiteLLM valid_models = get_valid_models() providers = sorted({model.split("/")[0] for model in valid_models if "/" in model}) # Create the chat interface with enhanced styling with gr.Blocks( css="static/styles.css", title="AI Chat Assistant", theme=gr.themes.Soft( primary_hue="blue", secondary_hue="blue", neutral_hue="slate", radius_size=gr.themes.sizes.radius_sm, ), ) as demo: with gr.Column(elem_classes="chat-container"): chatbot = gr.Chatbot( label="Chat History", bubble_full_width=False, show_label=False, elem_classes=["chat-history"], height=500, ) msg = gr.Textbox( label="Type your message", placeholder="Enter your message here...", show_label=False, container=False, scale=8, ) with gr.Row(): submit = gr.Button("Send", variant="primary", scale=1) clear = gr.Button("Clear", variant="secondary", scale=1) with gr.Accordion("Model Settings", open=True, elem_classes="additional-inputs"): with gr.Row(): provider = gr.Dropdown( choices=providers, value=providers[0] if providers else "openai", label="Provider", info="Select the AI provider", ) api_key = gr.Textbox( value="", label="API Key", info="Enter your API key", type="password", ) model = gr.Dropdown( choices=get_available_models(providers[0] if providers else "openai"), value="gpt-3.5-turbo", label="Model", info="Select the model to use", ) # Update model list when provider or API key changes provider.change( update_model_list, inputs=[provider, api_key], outputs=model, ) api_key.change( update_model_list, inputs=[provider, api_key], outputs=model, ) with gr.Accordion("Chat Settings", open=False, elem_classes="additional-inputs"): system_message = gr.Textbox( value="You are a friendly and helpful AI assistant.", label="System Message", info="Set the AI's personality and behavior", ) with gr.Row(): with gr.Column(): temperature = gr.Slider( minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Higher values make responses more creative but less focused", ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", info="Controls response diversity", ) with gr.Column(): max_tokens = gr.Slider( minimum=1, maximum=327670, value=512, step=1, label="Max Tokens", info="Maximum length of the response", ) # Set up chat functionality msg_submit_trigger = msg.submit( respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key], [chatbot], api_name="chat", ) submit_click_trigger = submit.click( respond, [msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key], [chatbot], api_name="chat", ) clear.click(clear_click, None, chatbot, queue=False) # Clear input after sending msg_submit_trigger.then(clear_input, None, msg) submit_click_trigger.then(clear_input, None, msg) if __name__ == "__main__": demo.launch( share=True, server_name="0.0.0.0", server_port=7860, show_api=False, favicon_path="🤖", allowed_paths=["static"], )