Spaces:
Sleeping
Sleeping
from __future__ import annotations | |
import os | |
from typing import Generator | |
import gradio as gr | |
from litellm import completion | |
from litellm import model_list | |
from litellm.utils import get_valid_models | |
# Create static directory if it doesn't exist | |
os.makedirs("static", exist_ok=True) | |
def get_available_models( | |
provider: str, | |
api_key: str | None = None, | |
) -> list[str]: | |
"""Get available models from LiteLLM for the specified provider""" | |
try: | |
if api_key: | |
os.environ[f"{provider.upper()}_API_KEY"] = api_key | |
try: | |
# Try to get models from API | |
models = model_list(provider) | |
return [model["id"] for model in models] | |
except Exception: | |
# Fallback to LiteLLM's valid models for the provider | |
valid_models = get_valid_models() | |
provider_models = [ | |
model.split("/")[-1] if "/" in model else model | |
for model in valid_models | |
if model.startswith(f"{provider}/") or model.startswith(provider) | |
] | |
return provider_models if provider_models else ["gpt-3.5-turbo"] | |
return ["gpt-3.5-turbo"] # Default fallback | |
except Exception as e: | |
print(f"Error fetching models: {e!s}") | |
return ["gpt-3.5-turbo"] # Fallback on error | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
provider: str, | |
model: str, | |
api_key: str, | |
) -> Generator[str, None, None]: | |
"""Generate chat responses using the specified model and provider""" | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, assistant_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
messages.append({"role": "user", "content": message}) | |
response = "" | |
# Set API key if provided | |
if api_key: | |
os.environ[f"{provider.upper()}_API_KEY"] = api_key | |
try: | |
# Construct full model name if needed | |
model_name = model if "/" in model else f"{provider}/{model}" | |
for chunk in completion( | |
model=model_name, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
stream=True, | |
): | |
token = chunk.choices[0].delta.content | |
if token: | |
response += token | |
yield response | |
except Exception as e: | |
yield f"Error: {e!s}" | |
def update_model_list(provider: str, api_key: str) -> gr.Dropdown: | |
"""Update the model dropdown based on provider and API key""" | |
models = get_available_models(provider, api_key) | |
return gr.Dropdown(choices=models, value=models[0] if models else None) | |
def clear_click() -> None: | |
"""Clear the chat history""" | |
return None | |
def clear_input() -> str: | |
"""Clear the input textbox""" | |
return "" | |
# Get available providers from LiteLLM | |
valid_models = get_valid_models() | |
providers = sorted({model.split("/")[0] for model in valid_models if "/" in model}) | |
# Create the chat interface with enhanced styling | |
with gr.Blocks( | |
css="static/styles.css", | |
title="AI Chat Assistant", | |
theme=gr.themes.Soft( | |
primary_hue="blue", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
), | |
) as demo: | |
with gr.Column(elem_classes="chat-container"): | |
chatbot = gr.Chatbot( | |
label="Chat History", | |
bubble_full_width=False, | |
show_label=False, | |
elem_classes=["chat-history"], | |
height=500, | |
) | |
msg = gr.Textbox( | |
label="Type your message", | |
placeholder="Enter your message here...", | |
show_label=False, | |
container=False, | |
scale=8, | |
) | |
with gr.Row(): | |
submit = gr.Button("Send", variant="primary", scale=1) | |
clear = gr.Button("Clear", variant="secondary", scale=1) | |
with gr.Accordion("Model Settings", open=True, elem_classes="additional-inputs"): | |
with gr.Row(): | |
provider = gr.Dropdown( | |
choices=providers, | |
value=providers[0] if providers else "openai", | |
label="Provider", | |
info="Select the AI provider", | |
) | |
api_key = gr.Textbox( | |
value="", | |
label="API Key", | |
info="Enter your API key", | |
type="password", | |
) | |
model = gr.Dropdown( | |
choices=get_available_models(providers[0] if providers else "openai"), | |
value="gpt-3.5-turbo", | |
label="Model", | |
info="Select the model to use", | |
) | |
# Update model list when provider or API key changes | |
provider.change( | |
update_model_list, | |
inputs=[provider, api_key], | |
outputs=model, | |
) | |
api_key.change( | |
update_model_list, | |
inputs=[provider, api_key], | |
outputs=model, | |
) | |
with gr.Accordion("Chat Settings", open=False, elem_classes="additional-inputs"): | |
system_message = gr.Textbox( | |
value="You are a friendly and helpful AI assistant.", | |
label="System Message", | |
info="Set the AI's personality and behavior", | |
) | |
with gr.Row(): | |
with gr.Column(): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature", | |
info="Higher values make responses more creative but less focused", | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
info="Controls response diversity", | |
) | |
with gr.Column(): | |
max_tokens = gr.Slider( | |
minimum=1, | |
maximum=327670, | |
value=512, | |
step=1, | |
label="Max Tokens", | |
info="Maximum length of the response", | |
) | |
# Set up chat functionality | |
msg_submit_trigger = msg.submit( | |
respond, | |
[msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key], | |
[chatbot], | |
api_name="chat", | |
) | |
submit_click_trigger = submit.click( | |
respond, | |
[msg, chatbot, system_message, max_tokens, temperature, top_p, provider, model, api_key], | |
[chatbot], | |
api_name="chat", | |
) | |
clear.click(clear_click, None, chatbot, queue=False) | |
# Clear input after sending | |
msg_submit_trigger.then(clear_input, None, msg) | |
submit_click_trigger.then(clear_input, None, msg) | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_api=False, | |
favicon_path="🤖", | |
allowed_paths=["static"], | |
) | |