Apriel-Chat / app.py
bradnow's picture
Add themes
36ae3fb
import datetime
from openai import OpenAI
import gradio as gr
from theme import apriel
from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config
MODEL_TEMPERATURE = 0.8
BUTTON_WIDTH = 160
DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker"
# DEFAULT_MODEL_NAME = "Apriel-5b"
print(f"Gradio version: {gr.__version__}")
chat_start_count = 0
model_config = {}
openai_client = None
def setup_model(model_name, intial=False):
global model_config, openai_client
model_config = get_model_config(model_name)
log_message(f"update_model() --> Model config: {model_config}")
openai_client = OpenAI(
api_key=model_config.get('AUTH_TOKEN'),
base_url=model_config.get('VLLM_API_URL')
)
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1]
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>"
_description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}"
print(f"Switched to model {_model_hf_name}")
if intial:
return
else:
return _description
def chat_fn(message, history):
log_message(f"{'-' * 80}")
log_message(f"chat_fn() --> Message: {message}")
log_message(f"chat_fn() --> History: {history}")
# Check if the message is empty
if not message.strip():
gr.Warning("Please enter a message before sending.")
yield history
return
global chat_start_count
chat_start_count = chat_start_count + 1
print(
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}")
is_reasoning = model_config.get("REASONING")
# Remove any assistant messages with metadata from history for multiple turns
log_message(f"Initial History: {history}")
check_format(history, "messages")
history.append({"role": "user", "content": message})
log_message(f"History with user message: {history}")
check_format(history, "messages")
# Create the streaming response
try:
history_no_thoughts = [item for item in history if
not (isinstance(item, dict) and
item.get("role") == "assistant" and
isinstance(item.get("metadata"), dict) and
item.get("metadata", {}).get("title") is not None)]
log_message(f"Updated History: {history_no_thoughts}")
check_format(history_no_thoughts, "messages")
log_message(f"history_no_thoughts with user message: {history_no_thoughts}")
stream = openai_client.chat.completions.create(
model=model_config.get('MODEL_NAME'),
messages=history_no_thoughts,
temperature=MODEL_TEMPERATURE,
stream=True
)
except Exception as e:
print(f"Error: {e}")
yield [{"role": "assistant", "content": "😔 The model is unavailable at the moment. Please try again later."}]
return
if is_reasoning:
history.append(gr.ChatMessage(
role="assistant",
content="Thinking...",
metadata={"title": "🧠 Thought"}
))
log_message(f"History added thinking: {history}")
check_format(history, "messages")
else:
history.append(gr.ChatMessage(
role="assistant",
content="",
))
log_message(f"History added empty assistant: {history}")
check_format(history, "messages")
output = ""
completion_started = False
for chunk in stream:
# Extract the new content from the delta field
content = getattr(chunk.choices[0].delta, "content", "")
output += content
if is_reasoning:
parts = output.split("[BEGIN FINAL RESPONSE]")
if len(parts) > 1:
if parts[1].endswith("[END FINAL RESPONSE]"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "")
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"):
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "")
if parts[1].endswith("<|end|>"):
parts[1] = parts[1].replace("<|end|>", "")
history[-1 if not completion_started else -2] = gr.ChatMessage(
role="assistant",
content=parts[0],
metadata={"title": "🧠 Thought"}
)
if completion_started:
history[-1] = gr.ChatMessage(
role="assistant",
content=parts[1]
)
elif len(parts) > 1 and not completion_started:
completion_started = True
history.append(gr.ChatMessage(
role="assistant",
content=parts[1]
))
else:
if output.endswith("<|end|>"):
output = output.replace("<|end|>", "")
history[-1] = gr.ChatMessage(
role="assistant",
content=output
)
# log_message(f"Yielding messages: {history}")
yield history
log_message(f"Final History: {history}")
check_format(history, "messages")
title = None
description = None
# theme = gr.themes.Default(primary_hue="green")
# theme = gr.themes.Soft(primary_hue="gray", secondary_hue="slate", neutral_hue="slate",
# text_size=gr.themes.sizes.text_lg, font=[gr.themes.GoogleFont("Inconsolata"), "Arial", "sans-serif"])
# theme = gr.Theme.from_hub("earneleh/paris")
theme = apriel
with gr.Blocks(theme=theme) as demo:
gr.HTML("""
<style>
.html-container:has(.css-styles) {
padding: 0;
margin: 0;
}
.css-styles { height: 0; }
.model-message {
text-align: end;
}
.model-dropdown-container {
display: flex;
align-items: center;
gap: 10px;
padding: 0;
}
.chatbot {
max-height: 1400px;
}
@media (max-width: 800px) {
.responsive-row {
flex-direction: column;
}
.model-message {
text-align: start;
font-size: 10px !important;
}
.model-dropdown-container {
flex-direction: column;
align-items: flex-start;
}
.chatbot {
max-height: 850px;
}
}
@media (max-width: 400px) {
.responsive-row {
flex-direction: column;
}
.model-message {
text-align: start;
font-size: 10px !important;
}
.model-dropdown-container {
flex-direction: column;
align-items: flex-start;
}
.chatbot {
max-height: 400px;
}
}
""" + f"""
@media (min-width: 1024px) {{
.send-button-container, .clear-button-container {{
max-width: {BUTTON_WIDTH}px;
}}
}}
</style>
""", elem_classes="css-styles")
with gr.Row(variant="panel", elem_classes="responsive-row"):
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"):
model_dropdown = gr.Dropdown(
choices=[f"Model: {model}" for model in models_config.keys()],
value=f"Model: {DEFAULT_MODEL_NAME}",
label=None,
interactive=True,
container=False,
scale=0,
min_width=400
)
with gr.Column(scale=4, min_width=0):
description_html = gr.HTML(description, elem_classes="model-message")
chatbot = gr.Chatbot(
type="messages",
height="calc(100dvh - 280px)",
elem_classes="chatbot",
)
# chat_interface = gr.ChatInterface(
# chat_fn,
# description="",
# type="messages",
# chatbot=chatbot,
# fill_height=True,
# )
with gr.Row():
with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"):
user_input = gr.Textbox(
show_label=False,
placeholder="Type your message here and press Enter",
container=False,
)
with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20):
with gr.Row():
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"):
send_btn = gr.Button("Send", variant="primary")
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"):
clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary")
# on Enter: stream into the chatbot, then clear the textbox
user_input.submit(
fn=chat_fn,
inputs=[user_input, chatbot],
outputs=[chatbot]
)
user_input.submit(lambda: "", None, user_input, queue=False)
send_btn.click(
fn=chat_fn,
inputs=[user_input, chatbot],
outputs=[chatbot]
)
send_btn.click(lambda: "", None, user_input, queue=False)
# Ensure the model is reset to default on page reload
demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html])
def update_model_and_clear(model_name):
actual_model_name = model_name.replace("Model: ", "")
desc = setup_model(actual_model_name)
return desc, []
model_dropdown.change(
fn=update_model_and_clear,
inputs=[model_dropdown],
outputs=[description_html, chatbot]
)
demo.launch(ssr_mode=False, show_api=False)