Spaces:
Running
Running
import os | |
import sys | |
from typing import Any, Literal | |
from gradio import ChatMessage | |
from gradio.components.chatbot import Message | |
COMMUNITY_POSTFIX_URL = "/discussions" | |
DEBUG_MODE = False or os.environ.get("DEBUG_MODE") == "True" | |
models_config = { | |
"Apriel-Nemotron-15b-Thinker": { | |
"MODEL_DISPLAY_NAME": "Apriel-Nemotron-15b-Thinker", | |
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-Nemotron-15b-Thinker", | |
"MODEL_NAME": os.environ.get("MODEL_NAME_NEMO_15B"), | |
"VLLM_API_URL": os.environ.get("VLLM_API_URL_NEMO_15B"), | |
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"), | |
"REASONING": True | |
}, | |
"Apriel-5b": { | |
"MODEL_DISPLAY_NAME": "Apriel-5b", | |
"MODEL_HF_URL": "https://huggingface.co/ServiceNow-AI/Apriel-5B-Instruct", | |
"MODEL_NAME": os.environ.get("MODEL_NAME_5B"), | |
"VLLM_API_URL": os.environ.get("VLLM_API_URL_5B"), | |
"AUTH_TOKEN": os.environ.get("AUTH_TOKEN"), | |
"REASONING": False | |
} | |
} | |
def get_model_config(model_name: str) -> dict: | |
config = models_config.get(model_name) | |
if not config: | |
raise ValueError(f"Model {model_name} not found in models_config") | |
if not config.get("MODEL_NAME"): | |
raise ValueError(f"Model name not found in config for {model_name}") | |
if not config.get("VLLM_API_URL"): | |
raise ValueError(f"VLLM API URL not found in config for {model_name}") | |
return config | |
def log_message(message): | |
if DEBUG_MODE is True: | |
print(f"≫≫≫ {message}") | |
# Gradio 5.0.1 had issues with checking the message formats. 5.29.0 does not! | |
def check_format(messages: Any, type: Literal["messages", "tuples"] = "messages") -> None: | |
if not DEBUG_MODE: | |
return | |
if type == "messages": | |
all_valid = all( | |
isinstance(message, dict) | |
and "role" in message | |
and "content" in message | |
or isinstance(message, ChatMessage | Message) | |
for message in messages | |
) | |
if not all_valid: | |
# Display which message is not valid | |
for i, message in enumerate(messages): | |
if not (isinstance(message, dict) and | |
"role" in message and | |
"content" in message) and not isinstance(message, ChatMessage | Message): | |
print(f"_check_format() --> Invalid message at index {i}: {message}\n", file=sys.stderr) | |
break | |
raise Exception( | |
"Data incompatible with messages format. Each message should be a dictionary with 'role' and 'content' keys or a ChatMessage object." | |
) | |
# else: | |
# print("_check_format() --> All messages are valid.") | |
elif not all( | |
isinstance(message, (tuple, list)) and len(message) == 2 | |
for message in messages | |
): | |
raise Exception( | |
"Data incompatible with tuples format. Each message should be a list of length 2." | |
) | |