Spaces:
Running
Running
| import random | |
| from collections.abc import Mapping | |
| from uuid import uuid4 | |
| from openai import OpenAI | |
| import gradio as gr | |
| import base64 | |
| import mimetypes | |
| import copy | |
| import os | |
| from theme import apriel | |
| from utils import COMMUNITY_POSTFIX_URL, get_model_config, check_format, models_config, \ | |
| logged_event_handler, DEBUG_MODE, DEBUG_MODEL, log_debug, log_info, log_error, log_warning | |
| from log_chat import log_chat | |
| MODEL_TEMPERATURE = 0.8 | |
| BUTTON_WIDTH = 160 | |
| DEFAULT_OPT_OUT_VALUE = DEBUG_MODE | |
| # If DEBUG_MODEL is True, use an alternative model (without reasoning) for testing | |
| DEFAULT_MODEL_NAME = "Apriel-1.5-15B-thinker" if not DEBUG_MODEL else "Apriel-1.5-15B-thinker" # "Apriel-5b" | |
| BUTTON_ENABLED = gr.update(interactive=True) | |
| BUTTON_DISABLED = gr.update(interactive=False) | |
| INPUT_ENABLED = gr.update(interactive=True) | |
| INPUT_DISABLED = gr.update(interactive=False) | |
| DROPDOWN_ENABLED = gr.update(interactive=True) | |
| DROPDOWN_DISABLED = gr.update(interactive=False) | |
| SEND_BUTTON_ENABLED = gr.update(interactive=True, visible=True) | |
| SEND_BUTTON_DISABLED = gr.update(interactive=True, visible=False) | |
| STOP_BUTTON_ENABLED = gr.update(interactive=True, visible=True) | |
| STOP_BUTTON_DISABLED = gr.update(interactive=True, visible=False) | |
| chat_start_count = 0 | |
| model_config = {} | |
| openai_client = None | |
| USE_RANDOM_ENDPOINT = False | |
| endpoint_rotation_count = 0 | |
| def app_loaded(state, request: gr.Request): | |
| message_html = setup_model(DEFAULT_MODEL_NAME, intial=False) | |
| state['session'] = request.session_hash if request else uuid4().hex | |
| log_debug(f"app_loaded() --> Session: {state['session']}") | |
| return state, message_html | |
| def update_model_and_clear_chat(model_name): | |
| actual_model_name = model_name.replace("Model: ", "") | |
| desc = setup_model(actual_model_name) | |
| return desc, [] | |
| def setup_model(model_key, intial=False): | |
| global model_config, openai_client, endpoint_rotation_count | |
| model_config = get_model_config(model_key) | |
| log_debug(f"update_model() --> Model config: {model_config}") | |
| url_list = (model_config.get('VLLM_API_URL_LIST') or "").split(",") | |
| if USE_RANDOM_ENDPOINT: | |
| base_url = random.choice(url_list) if len(url_list) > 0 else model_config.get('VLLM_API_URL') | |
| else: | |
| base_url = url_list[endpoint_rotation_count % len(url_list)] | |
| endpoint_rotation_count += 1 | |
| openai_client = OpenAI( | |
| api_key=model_config.get('AUTH_TOKEN'), | |
| base_url=base_url | |
| ) | |
| model_config['base_url'] = base_url | |
| log_debug(f"Switched to model {model_key} using endpoint {base_url}") | |
| _model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1] | |
| _link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>" | |
| _description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}" | |
| if intial: | |
| return | |
| else: | |
| return _description | |
| def chat_started(): | |
| # outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn | |
| return (DROPDOWN_DISABLED, gr.update(value="", interactive=False), | |
| SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED) | |
| def chat_finished(): | |
| # outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn | |
| return DROPDOWN_ENABLED, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED | |
| def stop_chat(state): | |
| state["stop_flag"] = True | |
| gr.Info("Chat stopped") | |
| return state | |
| def toggle_opt_out(state, checkbox): | |
| state["opt_out"] = checkbox | |
| return state | |
| def run_chat_inference(history, message, state): | |
| global chat_start_count | |
| state["is_streaming"] = True | |
| state["stop_flag"] = False | |
| error = None | |
| model_name = model_config.get('MODEL_NAME') | |
| # Reinitialize the OpenAI client with a random endpoint from the list | |
| setup_model(model_config.get('MODEL_KEY')) | |
| log_info("Using model {model_name} with endpoint {model_config.get('base_url')}") | |
| if len(history) == 0: | |
| state["chat_id"] = uuid4().hex | |
| if openai_client is None: | |
| log_info("Client UI is stale, letting user know to refresh the page") | |
| gr.Warning("Client UI is stale, please refresh the page") | |
| return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| # files will be the newly added files from the user | |
| files = [] | |
| # outputs: model_dropdown, user_input, send_btn, stop_btn, clear_btn, session_state | |
| log_debug(f"{'-' * 80}") | |
| log_debug(f"chat_fn() --> Message: {message}") | |
| log_debug(f"chat_fn() --> History: {history}") | |
| # We have multimodal input in this case | |
| if isinstance(message, Mapping): | |
| files = message.get("files") or [] | |
| message = message.get("text") or "" | |
| log_debug(f"chat_fn() --> Message (text only): {message}") | |
| log_debug(f"chat_fn() --> Files: {files}") | |
| # Validate that any uploaded files are images | |
| if len(files) > 0: | |
| invalid_files = [] | |
| for path in files: | |
| try: | |
| mime, _ = mimetypes.guess_type(path) | |
| mime = mime or "" | |
| if not mime.startswith("image/"): | |
| invalid_files.append((os.path.basename(path), mime or "unknown")) | |
| except Exception as e: | |
| log_error(f"Failed to inspect file '{path}': {e}") | |
| invalid_files.append((os.path.basename(path), "unknown")) | |
| if invalid_files: | |
| msg = "Only image files are allowed. Invalid uploads: " + \ | |
| ", ".join([f"{p} (type: {m})" for p, m in invalid_files]) | |
| log_warning(msg) | |
| gr.Warning(msg) | |
| yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| try: | |
| # Check if the message is empty | |
| if not message.strip() and len(files) == 0: | |
| gr.Info("Please enter a message before sending") | |
| yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| chat_start_count = chat_start_count + 1 | |
| user_messages_count = sum(1 for item in history if isinstance(item, dict) and item.get("role") == "user") | |
| log_info(f"chat_start_count: {chat_start_count}, turns: {user_messages_count}, model: {model_name}") | |
| is_reasoning = model_config.get("REASONING") | |
| # Remove any assistant messages with metadata from history for multiple turns | |
| log_debug(f"Initial History: {history}") | |
| check_format(history, "messages") | |
| # Build UI history: add text (if any) and per-file image placeholders {"path": ...} | |
| # Build API parts separately later to avoid Gradio issues with arrays in content | |
| if len(files) == 0: | |
| history.append({"role": "user", "content": message}) | |
| else: | |
| if message.strip(): | |
| history.append({"role": "user", "content": message}) | |
| for path in files: | |
| history.append({"role": "user", "content": {"path": path}}) | |
| log_debug(f"History with user message: {history}") | |
| check_format(history, "messages") | |
| # Create the streaming response | |
| try: | |
| history_no_thoughts = [item for item in history if | |
| not (isinstance(item, dict) and | |
| item.get("role") == "assistant" and | |
| isinstance(item.get("metadata"), dict) and | |
| item.get("metadata", {}).get("title") is not None)] | |
| log_debug(f"Updated History: {history_no_thoughts}") | |
| check_format(history_no_thoughts, "messages") | |
| log_debug(f"history_no_thoughts with user message: {history_no_thoughts}") | |
| # Build API-specific messages: | |
| # - Convert any UI image placeholders {"path": ...} to image_url parts | |
| # - Convert any user string content that is a valid file path to image_url parts | |
| # - Coalesce consecutive image paths into a single image-only user message | |
| api_messages = [] | |
| image_parts_buffer = [] | |
| def flush_image_buffer(): | |
| if len(image_parts_buffer) > 0: | |
| api_messages.append({"role": "user", "content": list(image_parts_buffer)}) | |
| image_parts_buffer.clear() | |
| def to_image_part(path: str): | |
| try: | |
| mime, _ = mimetypes.guess_type(path) | |
| mime = mime or "application/octet-stream" | |
| with open(path, "rb") as f: | |
| b64 = base64.b64encode(f.read()).decode("utf-8") | |
| data_url = f"data:{mime};base64,{b64}" | |
| return {"type": "image_url", "image_url": {"url": data_url}} | |
| except Exception as e: | |
| log_error(f"Failed to load file '{path}': {e}") | |
| return None | |
| def normalize_msg(msg): | |
| # Returns (role, content, as_dict) where as_dict is a message dict suitable to pass through when unmodified | |
| if isinstance(msg, dict): | |
| return msg.get("role"), msg.get("content"), msg | |
| # Gradio ChatMessage-like object | |
| role = getattr(msg, "role", None) | |
| content = getattr(msg, "content", None) | |
| if role is not None: | |
| return role, content, {"role": role, "content": content} | |
| return None, None, msg | |
| for m in copy.deepcopy(history_no_thoughts): | |
| role, content, as_dict = normalize_msg(m) | |
| # Unknown structure: pass through | |
| if role is None: | |
| flush_image_buffer() | |
| api_messages.append(as_dict) | |
| continue | |
| # Assistant messages pass through as-is | |
| if role == "assistant": | |
| flush_image_buffer() | |
| api_messages.append(as_dict) | |
| continue | |
| # Only user messages have potential image paths to convert | |
| if role == "user": | |
| # Case A: {'path': ...} | |
| if isinstance(content, dict) and isinstance(content.get("path"), str): | |
| p = content["path"] | |
| part = to_image_part(p) if os.path.isfile(p) else None | |
| if part: | |
| image_parts_buffer.append(part) | |
| else: | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": str(content)}) | |
| continue | |
| # Case B: string or tuple content that may be a file path | |
| if isinstance(content, str): | |
| if os.path.isfile(content): | |
| part = to_image_part(content) | |
| if part: | |
| image_parts_buffer.append(part) | |
| continue | |
| # Not a file path: pass through as text | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": content}) | |
| continue | |
| if isinstance(content, tuple): | |
| # Common case: a single-element tuple containing a path string | |
| tuple_items = list(content) | |
| tmp_parts = [] | |
| text_accum = [] | |
| for item in tuple_items: | |
| if isinstance(item, str) and os.path.isfile(item): | |
| part = to_image_part(item) | |
| if part: | |
| tmp_parts.append(part) | |
| else: | |
| text_accum.append(item) | |
| else: | |
| text_accum.append(str(item)) | |
| if tmp_parts: | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": tmp_parts}) | |
| if not text_accum: | |
| continue | |
| if text_accum: | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": "\n".join(text_accum)}) | |
| continue | |
| # Case C: list content | |
| if isinstance(content, list): | |
| # If it's already a list of parts, let it pass through | |
| all_dicts = all(isinstance(c, dict) for c in content) | |
| if all_dicts: | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": content}) | |
| continue | |
| # It might be a list of strings (paths/text). Convert string paths to image parts, others to text parts | |
| tmp_parts = [] | |
| text_accum = [] | |
| def flush_text_accum(): | |
| if text_accum: | |
| api_messages.append({"role": "user", "content": "\n".join(text_accum)}) | |
| text_accum.clear() | |
| for item in content: | |
| if isinstance(item, str) and os.path.isfile(item): | |
| part = to_image_part(item) | |
| if part: | |
| tmp_parts.append(part) | |
| else: | |
| text_accum.append(item) | |
| else: | |
| text_accum.append(str(item)) | |
| if tmp_parts: | |
| flush_image_buffer() | |
| api_messages.append({"role": "user", "content": tmp_parts}) | |
| if text_accum: | |
| flush_text_accum() | |
| continue | |
| # Fallback: pass through | |
| flush_image_buffer() | |
| api_messages.append(as_dict) | |
| continue | |
| # Other roles | |
| flush_image_buffer() | |
| api_messages.append(as_dict) | |
| # Flush any trailing images | |
| flush_image_buffer() | |
| log_debug(f"sending api_messages to model {model_name}: {api_messages}") | |
| stream = openai_client.chat.completions.create( | |
| model=model_name, | |
| messages=api_messages, | |
| temperature=MODEL_TEMPERATURE, | |
| stream=True | |
| ) | |
| except Exception as e: | |
| log_error(f"Error:\n\t{e}\n\tInference failed for model {model_name} and endpoint {model_config['base_url']}") | |
| error = str(e) | |
| yield ([{"role": "assistant", | |
| "content": "😔 The model is unavailable at the moment. Please try again later."}], | |
| INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state) | |
| if state["opt_out"] is not True: | |
| log_chat(chat_id=state["chat_id"], | |
| session_id=state["session"], | |
| model_name=model_name, | |
| prompt=message, | |
| history=history, | |
| info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, | |
| "stopped": True, "error": str(e)}, | |
| ) | |
| else: | |
| log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") | |
| return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| if is_reasoning: | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content="Thinking...", | |
| metadata={"title": "🧠 Thought"} | |
| )) | |
| log_debug(f"History added thinking: {history}") | |
| check_format(history, "messages") | |
| else: | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content="", | |
| )) | |
| log_debug(f"History added empty assistant: {history}") | |
| check_format(history, "messages") | |
| output = "" | |
| completion_started = False | |
| for chunk in stream: | |
| if state["stop_flag"]: | |
| log_debug(f"chat_fn() --> Stopping streaming...") | |
| break # Exit the loop if the stop flag is set | |
| # Extract the new content from the delta field | |
| content = getattr(chunk.choices[0].delta, "content", "") or "" | |
| reasoning_content = getattr(chunk.choices[0].delta, "reasoning_content", "") or "" | |
| output += reasoning_content + content | |
| if is_reasoning: | |
| parts = output.split("[BEGIN FINAL RESPONSE]") | |
| if len(parts) > 1: | |
| if parts[1].endswith("[END FINAL RESPONSE]"): | |
| parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") | |
| if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): | |
| parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") | |
| if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>\n"): | |
| parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>\n", "") | |
| if parts[1].endswith("<|end|>"): | |
| parts[1] = parts[1].replace("<|end|>", "") | |
| if parts[1].endswith("<|end|>\n"): | |
| parts[1] = parts[1].replace("<|end|>\n", "") | |
| history[-1 if not completion_started else -2] = gr.ChatMessage( | |
| role="assistant", | |
| content=parts[0], | |
| metadata={"title": "🧠 Thought"} | |
| ) | |
| if completion_started: | |
| history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content=parts[1] | |
| ) | |
| elif len(parts) > 1 and not completion_started: | |
| completion_started = True | |
| history.append(gr.ChatMessage( | |
| role="assistant", | |
| content=parts[1] | |
| )) | |
| else: | |
| if output.endswith("<|end|>"): | |
| output = output.replace("<|end|>", "") | |
| if output.endswith("<|end|>\n"): | |
| output = output.replace("<|end|>\n", "") | |
| history[-1] = gr.ChatMessage( | |
| role="assistant", | |
| content=output | |
| ) | |
| # log_message(f"Yielding messages: {history}") | |
| yield history, INPUT_DISABLED, SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED, state | |
| log_debug(f"Final History: {history}") | |
| check_format(history, "messages") | |
| yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| finally: | |
| if error is None: | |
| log_debug(f"chat_fn() --> Finished streaming. {chat_start_count} chats started.") | |
| if state["opt_out"] is not True: | |
| log_chat(chat_id=state["chat_id"], | |
| session_id=state["session"], | |
| model_name=model_name, | |
| prompt=message, | |
| history=history, | |
| info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, | |
| "stopped": state["stop_flag"]}, | |
| ) | |
| else: | |
| log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") | |
| state["is_streaming"] = False | |
| state["stop_flag"] = False | |
| return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state | |
| log_info(f"Gradio version: {gr.__version__}") | |
| title = None | |
| description = None | |
| theme = apriel | |
| with open('styles.css', 'r') as f: | |
| custom_css = f.read() | |
| with gr.Blocks(theme=theme, css=custom_css) as demo: | |
| session_state = gr.State(value={ | |
| "is_streaming": False, | |
| "stop_flag": False, | |
| "chat_id": None, | |
| "session": None, | |
| "opt_out": DEFAULT_OPT_OUT_VALUE, | |
| }) # Store session state as a dictionary | |
| gr.HTML(f""" | |
| <style> | |
| @media (min-width: 1024px) {{ | |
| .send-button-container, .clear-button-container {{ | |
| max-width: {BUTTON_WIDTH}px; | |
| }} | |
| }} | |
| </style> | |
| """, elem_classes="css-styles") | |
| with gr.Row(variant="panel", elem_classes="responsive-row"): | |
| with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"): | |
| model_dropdown = gr.Dropdown( | |
| choices=[f"Model: {model}" for model in models_config.keys()], | |
| value=f"Model: {DEFAULT_MODEL_NAME}", | |
| label=None, | |
| interactive=True, | |
| container=False, | |
| scale=0, | |
| min_width=400 | |
| ) | |
| with gr.Column(scale=4, min_width=0): | |
| feedback_message_html = gr.HTML(description, elem_classes="model-message") | |
| chatbot = gr.Chatbot( | |
| type="messages", | |
| height="calc(100dvh - 310px)", | |
| elem_classes="chatbot", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"): | |
| with gr.Row(): | |
| user_input = gr.MultimodalTextbox( | |
| interactive=True, | |
| container=False, | |
| file_count="multiple", | |
| placeholder="Type your message here and press Enter or upload file...", | |
| show_label=False, | |
| sources=["upload"] | |
| ) | |
| # Original text-only input | |
| # user_input = gr.Textbox( | |
| # show_label=False, | |
| # placeholder="Type your message here and press Enter", | |
| # container=False | |
| # ) | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20): | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"): | |
| send_btn = gr.Button("Send", variant="primary", elem_classes="control-button") | |
| stop_btn = gr.Button("Stop", variant="cancel", elem_classes="control-button", visible=False) | |
| with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"): | |
| clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary", elem_classes="control-button") | |
| with gr.Row(): | |
| with gr.Column(min_width=400, elem_classes="opt-out-container"): | |
| with gr.Row(): | |
| gr.HTML( | |
| "We may use your chats to improve our AI. You may opt out if you don’t want your conversations saved.", | |
| elem_classes="opt-out-message") | |
| with gr.Row(): | |
| opt_out_checkbox = gr.Checkbox( | |
| label="Don’t save my chat history for improvements or training", | |
| value=DEFAULT_OPT_OUT_VALUE, | |
| elem_classes="opt-out-checkbox", | |
| interactive=True, | |
| container=False | |
| ) | |
| gr.on( | |
| triggers=[send_btn.click, user_input.submit], | |
| fn=run_chat_inference, # this generator streams results. do not use logged_event_handler wrapper | |
| inputs=[chatbot, user_input, session_state], | |
| outputs=[chatbot, user_input, send_btn, stop_btn, clear_btn, session_state], | |
| concurrency_limit=4, | |
| api_name=False | |
| ).then( | |
| fn=chat_finished, inputs=None, outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], queue=False) | |
| # In parallel, disable or update the UI controls | |
| gr.on( | |
| triggers=[send_btn.click, user_input.submit], | |
| fn=chat_started, | |
| inputs=None, | |
| outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], | |
| queue=False, | |
| show_progress='hidden', | |
| api_name=False | |
| ) | |
| stop_btn.click( | |
| fn=stop_chat, | |
| inputs=[session_state], | |
| outputs=[session_state], | |
| api_name=False | |
| ) | |
| opt_out_checkbox.change(fn=toggle_opt_out, inputs=[session_state, opt_out_checkbox], outputs=[session_state]) | |
| # Ensure the model is reset to default on page reload | |
| demo.load( | |
| fn=logged_event_handler( | |
| log_msg="Browser session started", | |
| event_handler=app_loaded | |
| ), | |
| inputs=[session_state], | |
| outputs=[session_state, feedback_message_html], | |
| queue=True, | |
| api_name=False | |
| ) | |
| model_dropdown.change( | |
| fn=update_model_and_clear_chat, | |
| inputs=[model_dropdown], | |
| outputs=[feedback_message_html, chatbot], | |
| api_name=False | |
| ) | |
| demo.queue(default_concurrency_limit=2).launch(ssr_mode=False, show_api=False) | |
| log_info("Gradio app launched") | |