Spaces:
Running
Running
import datetime | |
from openai import OpenAI | |
import gradio as gr | |
from utils import COMMUNITY_POSTFIX_URL, get_model_config, log_message, check_format, models_config | |
print(f"Gradio version: {gr.__version__}") | |
DEFAULT_MODEL_NAME = "Apriel-Nemotron-15b-Thinker" | |
chat_start_count = 0 | |
model_config = None | |
client = None | |
def setup_model(model_name, intial=False): | |
global model_config, client | |
model_config = get_model_config(model_name) | |
log_message(f"update_model() --> Model config: {model_config}") | |
client = OpenAI( | |
api_key=model_config.get('AUTH_TOKEN'), | |
base_url=model_config.get('VLLM_API_URL') | |
) | |
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1] | |
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>" | |
_description = f"Please use the community section on this space to provide feedback! {_link}" | |
print(f"Switched to model {_model_hf_name}") | |
if intial: | |
return | |
else: | |
return _description | |
def chat_fn(message, history): | |
log_message(f"{'-' * 80}") | |
log_message(f"chat_fn() --> Message: {message}") | |
log_message(f"chat_fn() --> History: {history}") | |
global chat_start_count | |
chat_start_count = chat_start_count + 1 | |
print( | |
f"{datetime.datetime.now()}: chat_start_count: {chat_start_count}, turns: {int(len(history if history else []) / 3)}") | |
is_reasoning = model_config.get("REASONING") | |
# Remove any assistant messages with metadata from history for multiple turns | |
log_message(f"Original History: {history}") | |
check_format(history, "messages") | |
history = [item for item in history if | |
not (isinstance(item, dict) and | |
item.get("role") == "assistant" and | |
isinstance(item.get("metadata"), dict) and | |
item.get("metadata", {}).get("title") is not None)] | |
log_message(f"Updated History: {history}") | |
check_format(history, "messages") | |
history.append({"role": "user", "content": message}) | |
log_message(f"History with user message: {history}") | |
check_format(history, "messages") | |
# Create the streaming response | |
try: | |
stream = client.chat.completions.create( | |
model=model_config.get('MODEL_NAME'), | |
messages=history, | |
temperature=0.8, | |
stream=True | |
) | |
except Exception as e: | |
print(f"Error: {e}") | |
yield gr.ChatMessage( | |
role="assistant", | |
content="😔 The model is unavailable at the moment. Please try again later.", | |
) | |
return | |
if is_reasoning: | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content="Thinking...", | |
metadata={"title": "🧠 Thought"} | |
)) | |
log_message(f"History added thinking: {history}") | |
check_format(history, "messages") | |
output = "" | |
completion_started = False | |
for chunk in stream: | |
# Extract the new content from the delta field | |
content = getattr(chunk.choices[0].delta, "content", "") | |
output += content | |
if is_reasoning: | |
parts = output.split("[BEGIN FINAL RESPONSE]") | |
if len(parts) > 1: | |
if parts[1].endswith("[END FINAL RESPONSE]"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") | |
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): | |
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") | |
if parts[1].endswith("<|end|>"): | |
parts[1] = parts[1].replace("<|end|>", "") | |
history[-1 if not completion_started else -2] = gr.ChatMessage( | |
role="assistant", | |
content=parts[0], | |
metadata={"title": "🧠 Thought"} | |
) | |
if completion_started: | |
history[-1] = gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
) | |
elif len(parts) > 1 and not completion_started: | |
completion_started = True | |
history.append(gr.ChatMessage( | |
role="assistant", | |
content=parts[1] | |
)) | |
else: | |
if output.endswith("<|end|>"): | |
output = output.replace("<|end|>", "") | |
history[-1] = gr.ChatMessage( | |
role="assistant", | |
content=output | |
) | |
# only yield the most recent assistant messages | |
messages_to_yield = history[-1:] if not completion_started else history[-2:] | |
# check_format(messages_to_yield, "messages") | |
# log_message(f"Yielding messages: {messages_to_yield}") | |
yield messages_to_yield | |
log_message(f"Final History: {history}") | |
check_format(history, "messages") | |
title = None | |
description = None | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="green")) as demo: | |
gr.HTML(""" | |
<style> | |
.model-message { | |
text-align: end; | |
} | |
.model-dropdown-container { | |
display: flex; | |
align-items: center; | |
gap: 10px; | |
padding: 0; | |
} | |
.chatbot { | |
max-height: 1400px; | |
} | |
@media (max-width: 800px) { | |
.responsive-row { | |
flex-direction: column; | |
} | |
.model-message { | |
text-align: start; | |
} | |
.model-dropdown-container { | |
flex-direction: column; | |
align-items: flex-start; | |
} | |
.chatbot { | |
max-height: 900px; | |
} | |
} | |
""") | |
with gr.Row(variant="panel", elem_classes="responsive-row"): | |
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"): | |
model_dropdown = gr.Dropdown( | |
choices=[f"Model: {model}" for model in models_config.keys()], | |
value=f"Model: {DEFAULT_MODEL_NAME}", | |
label=None, | |
interactive=True, | |
container=False, | |
scale=0, | |
min_width=400 | |
) | |
with gr.Column(scale=4, min_width=0): | |
description_html = gr.HTML(description, elem_classes="model-message") | |
chatbot = gr.Chatbot( | |
type="messages", | |
height="calc(100dvh - 280px)", | |
elem_classes="chatbot", | |
) | |
chat_interface = gr.ChatInterface( | |
chat_fn, | |
description="", | |
type="messages", | |
chatbot=chatbot, | |
fill_height=True, | |
) | |
# Add this line to ensure the model is reset to default on page reload | |
demo.load(lambda: setup_model(DEFAULT_MODEL_NAME, intial=False), [], [description_html]) | |
def update_model_and_clear(model_name): | |
# Remove the "Model: " prefix to get the actual model name | |
actual_model_name = model_name.replace("Model: ", "") | |
desc = setup_model(actual_model_name) | |
chatbot.clear() # Critical line | |
return desc | |
model_dropdown.change( | |
fn=update_model_and_clear, | |
inputs=[model_dropdown], | |
outputs=[description_html] | |
) | |
demo.launch(ssr_mode=False) | |