chatbot-test / app.py
brunner56's picture
add litellm integration
44d1f22
from __future__ import annotations
import os
from typing import Generator
import gradio as gr
from litellm import completion
from litellm import model_list
from litellm.utils import get_valid_models
# Create static directory if it doesn't exist
os.makedirs("static", exist_ok=True)
def get_available_models(
provider: str,
api_key: str | None = None,
) -> list[str]:
"""Get available models from LiteLLM for the specified provider"""
try:
if api_key:
os.environ[f"{provider.upper()}_API_KEY"] = api_key
try:
# Try to get models from API
models = model_list(provider)
return [model["id"] for model in models]
except Exception:
# Fallback to LiteLLM's valid models for the provider
valid_models = get_valid_models()
provider_models = [
model.split("/")[-1] if "/" in model else model
for model in valid_models
if model.startswith(f"{provider}/") or model.startswith(provider)
]
return provider_models if provider_models else ["gpt-3.5-turbo"]
return ["gpt-3.5-turbo"] # Default fallback
except Exception as e:
print(f"Error fetching models: {e!s}")
return ["gpt-3.5-turbo"] # Fallback on error
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
provider: str,
model: str,
api_key: str,
) -> Generator[str, None, None]:
"""Generate chat responses using the specified model and provider"""
messages = [{"role": "system", "content": system_message}]
for user_msg, assistant_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if assistant_msg:
messages.append({"role": "assistant", "content": assistant_msg})
messages.append({"role": "user", "content": message})
response = ""
# Set API key if provided
if api_key:
os.environ[f"{provider.upper()}_API_KEY"] = api_key
try:
# Construct full model name if needed
model_name = model if "/" in model else f"{provider}/{model}"
for chunk in completion(
model=model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
):
token = chunk.choices[0].delta.content
if token:
response += token
yield response
except Exception as e:
yield f"Error: {e!s}"
def update_model_list(provider: str, api_key: str) -> gr.Dropdown:
"""Update the model dropdown based on provider and API key"""
models = get_available_models(provider, api_key)
return gr.Dropdown(choices=models, value=models[0] if models else None)
def clear_click() -> None:
"""Clear the chat history"""
return None
def clear_input() -> str:
"""Clear the input textbox"""
return ""
# Get available providers from LiteLLM
valid_models = get_valid_models()
providers = sorted({model.split("/")[0] for model in valid_models if "/" in model})
# Create the chat interface with enhanced styling
with gr.Blocks(
css="static/styles.css",
title="AI Chat Assistant",
theme=gr.themes.Soft(
primary_hue="blue",
secondary_hue="blue",
neutral_hue="slate",
radius_size=gr.themes.sizes.radius_sm,
),
) as demo:
with gr.Column(elem_classes="chat-container"):
chatbot = gr.Chatbot(
label="Chat History",
bubble_full_width=False,
show_label=False,
elem_classes=["chat-history"],
height=500,
)
msg = gr.Textbox(
label="Type your message",
placeholder="Enter your message here...",
show_label=False,
container=False,
scale=8,
)
with gr.Row():
submit = gr.Button("Send", variant="primary", scale=1)
clear = gr.Button("Clear", variant="secondary", scale=1)
with gr.Accordion("Model Settings", open=True, elem_classes="additional-inputs"):
with gr.Row():
provider = gr.Dropdown(
choices=providers,
value=providers[0] if providers else "openai",
label="Provider",
info="Select the AI provider",
)
api_key = gr.Textbox(
value="",
label="API Key",
info="Enter your API key",
type="password",
)
model = gr.Dropdown(
choices=get_available_models(providers[0] if providers else "openai"),
value="gpt-3.5-turbo",
label="Model",
info="Select the model to use",
)
# Update model list when provider or API key changes
provider.change(
update_model_list,
inputs=[provider, api_key],
outputs=model,
)
api_key.change(
update_model_list,
inputs=[provider, api_key],
outputs=model,
)
with gr.Accordion("Chat Settings", open=False, elem_classes="additional-inputs"):
system_message = gr.Textbox(
value="You are a friendly and helpful AI assistant.",
label="System Message",
info="Set the AI's personality and behavior",
)
with gr.Row():
with gr.Column():
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Higher values make responses more creative but less focused",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
info="Controls response diversity",
)
with gr.Column():
max_tokens = gr.Slider(
minimum=1,
maximum=327670,
value=512,
step=1,
label="Max Tokens",
info="Maximum length of the response",
)
# Set up chat functionality
msg_submit_trigger = msg.submit(
respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key],
[chatbot],
api_name="chat",
)
submit_click_trigger = submit.click(
respond,
[msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key],
[chatbot],
api_name="chat",
)
clear.click(clear_click, None, chatbot, queue=False)
# Clear input after sending
msg_submit_trigger.then(clear_input, None, msg)
submit_click_trigger.then(clear_input, None, msg)
if __name__ == "__main__":
demo.launch(
share=True,
server_name="0.0.0.0",
server_port=7860,
show_api=False,
favicon_path="🤖",
allowed_paths=["static"],
)