| import time |
| import logging |
| import sys |
| import os |
| import base64 |
|
|
| import asyncio |
| from aiocache import cached |
| from typing import Any, Optional |
| import random |
| import json |
| import html |
| import inspect |
| import re |
| import ast |
|
|
| from uuid import uuid4 |
| from concurrent.futures import ThreadPoolExecutor |
|
|
|
|
| from fastapi import Request, HTTPException |
| from starlette.responses import Response, StreamingResponse |
|
|
|
|
| from open_webui.models.chats import Chats |
| from open_webui.models.users import Users |
| from open_webui.socket.main import ( |
| get_event_call, |
| get_event_emitter, |
| get_active_status_by_user_id, |
| ) |
| from open_webui.routers.tasks import ( |
| generate_queries, |
| generate_title, |
| generate_image_prompt, |
| generate_chat_tags, |
| ) |
| from open_webui.routers.retrieval import process_web_search, SearchForm |
| from open_webui.routers.images import image_generations, GenerateImageForm |
| from open_webui.routers.pipelines import ( |
| process_pipeline_inlet_filter, |
| process_pipeline_outlet_filter, |
| ) |
|
|
| from open_webui.utils.webhook import post_webhook |
|
|
|
|
| from open_webui.models.users import UserModel |
| from open_webui.models.functions import Functions |
| from open_webui.models.models import Models |
|
|
| from open_webui.retrieval.utils import get_sources_from_files |
|
|
|
|
| from open_webui.utils.chat import generate_chat_completion |
| from open_webui.utils.task import ( |
| get_task_model_id, |
| rag_template, |
| tools_function_calling_generation_template, |
| ) |
| from open_webui.utils.misc import ( |
| deep_update, |
| get_message_list, |
| add_or_update_system_message, |
| add_or_update_user_message, |
| get_last_user_message, |
| get_last_assistant_message, |
| prepend_to_first_user_message_content, |
| convert_logit_bias_input_to_json, |
| ) |
| from open_webui.utils.tools import get_tools |
| from open_webui.utils.plugin import load_function_module_by_id |
| from open_webui.utils.filter import ( |
| get_sorted_filter_ids, |
| process_filter_functions, |
| ) |
| from open_webui.utils.code_interpreter import execute_code_jupyter |
|
|
| from open_webui.tasks import create_task |
|
|
| from open_webui.config import ( |
| CACHE_DIR, |
| DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, |
| DEFAULT_CODE_INTERPRETER_PROMPT, |
| ) |
| from open_webui.env import ( |
| SRC_LOG_LEVELS, |
| GLOBAL_LOG_LEVEL, |
| BYPASS_MODEL_ACCESS_CONTROL, |
| ENABLE_REALTIME_CHAT_SAVE, |
| ) |
| from open_webui.constants import TASKS |
|
|
|
|
| logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL) |
| log = logging.getLogger(__name__) |
| log.setLevel(SRC_LOG_LEVELS["MAIN"]) |
|
|
|
|
| async def chat_completion_tools_handler( |
| request: Request, body: dict, extra_params: dict, user: UserModel, models, tools |
| ) -> tuple[dict, dict]: |
| async def get_content_from_response(response) -> Optional[str]: |
| content = None |
| if hasattr(response, "body_iterator"): |
| async for chunk in response.body_iterator: |
| data = json.loads(chunk.decode("utf-8")) |
| content = data["choices"][0]["message"]["content"] |
|
|
| |
| if response.background is not None: |
| await response.background() |
| else: |
| content = response["choices"][0]["message"]["content"] |
| return content |
|
|
| def get_tools_function_calling_payload(messages, task_model_id, content): |
| user_message = get_last_user_message(messages) |
| history = "\n".join( |
| f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\"" |
| for message in messages[::-1][:4] |
| ) |
|
|
| prompt = f"History:\n{history}\nQuery: {user_message}" |
|
|
| return { |
| "model": task_model_id, |
| "messages": [ |
| {"role": "system", "content": content}, |
| {"role": "user", "content": f"Query: {prompt}"}, |
| ], |
| "stream": False, |
| "metadata": {"task": str(TASKS.FUNCTION_CALLING)}, |
| } |
|
|
| event_caller = extra_params["__event_call__"] |
| metadata = extra_params["__metadata__"] |
|
|
| task_model_id = get_task_model_id( |
| body["model"], |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| skip_files = False |
| sources = [] |
|
|
| specs = [tool["spec"] for tool in tools.values()] |
| tools_specs = json.dumps(specs) |
|
|
| if request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE != "": |
| template = request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
| else: |
| template = DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE |
|
|
| tools_function_calling_prompt = tools_function_calling_generation_template( |
| template, tools_specs |
| ) |
| payload = get_tools_function_calling_payload( |
| body["messages"], task_model_id, tools_function_calling_prompt |
| ) |
|
|
| try: |
| response = await generate_chat_completion(request, form_data=payload, user=user) |
| log.debug(f"{response=}") |
| content = await get_content_from_response(response) |
| log.debug(f"{content=}") |
|
|
| if not content: |
| return body, {} |
|
|
| try: |
| content = content[content.find("{") : content.rfind("}") + 1] |
| if not content: |
| raise Exception("No JSON object found in the response") |
|
|
| result = json.loads(content) |
|
|
| async def tool_call_handler(tool_call): |
| nonlocal skip_files |
|
|
| log.debug(f"{tool_call=}") |
|
|
| tool_function_name = tool_call.get("name", None) |
| if tool_function_name not in tools: |
| return body, {} |
|
|
| tool_function_params = tool_call.get("parameters", {}) |
|
|
| try: |
| tool = tools[tool_function_name] |
|
|
| spec = tool.get("spec", {}) |
| allowed_params = ( |
| spec.get("parameters", {}).get("properties", {}).keys() |
| ) |
| tool_function_params = { |
| k: v |
| for k, v in tool_function_params.items() |
| if k in allowed_params |
| } |
|
|
| if tool.get("direct", False): |
| tool_result = await event_caller( |
| { |
| "type": "execute:tool", |
| "data": { |
| "id": str(uuid4()), |
| "name": tool_function_name, |
| "params": tool_function_params, |
| "server": tool.get("server", {}), |
| "session_id": metadata.get("session_id", None), |
| }, |
| } |
| ) |
| else: |
| tool_function = tool["callable"] |
| tool_result = await tool_function(**tool_function_params) |
|
|
| except Exception as e: |
| tool_result = str(e) |
|
|
| tool_result_files = [] |
| if isinstance(tool_result, list): |
| for item in tool_result: |
| |
| if isinstance(item, str) and item.startswith("data:"): |
| tool_result_files.append(item) |
| tool_result.remove(item) |
|
|
| if isinstance(tool_result, dict) or isinstance(tool_result, list): |
| tool_result = json.dumps(tool_result, indent=2) |
|
|
| if isinstance(tool_result, str): |
| tool = tools[tool_function_name] |
| tool_id = tool.get("tool_id", "") |
|
|
| tool_name = ( |
| f"{tool_id}/{tool_function_name}" |
| if tool_id |
| else f"{tool_function_name}" |
| ) |
| if tool.get("metadata", {}).get("citation", False) or tool.get( |
| "direct", False |
| ): |
| |
| sources.append( |
| { |
| "source": { |
| "name": (f"TOOL:{tool_name}"), |
| }, |
| "document": [tool_result], |
| "metadata": [{"source": (f"TOOL:{tool_name}")}], |
| } |
| ) |
| else: |
| |
| body["messages"] = add_or_update_user_message( |
| f"\nTool `{tool_name}` Output: {tool_result}", |
| body["messages"], |
| ) |
|
|
| if ( |
| tools[tool_function_name] |
| .get("metadata", {}) |
| .get("file_handler", False) |
| ): |
| skip_files = True |
|
|
| |
| if result.get("tool_calls"): |
| for tool_call in result.get("tool_calls"): |
| await tool_call_handler(tool_call) |
| else: |
| await tool_call_handler(result) |
|
|
| except Exception as e: |
| log.debug(f"Error: {e}") |
| content = None |
| except Exception as e: |
| log.debug(f"Error: {e}") |
| content = None |
|
|
| log.debug(f"tool_contexts: {sources}") |
|
|
| if skip_files and "files" in body.get("metadata", {}): |
| del body["metadata"]["files"] |
|
|
| return body, {"sources": sources} |
|
|
|
|
| async def chat_web_search_handler( |
| request: Request, form_data: dict, extra_params: dict, user |
| ): |
| event_emitter = extra_params["__event_emitter__"] |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": "Generating search query", |
| "done": False, |
| }, |
| } |
| ) |
|
|
| messages = form_data["messages"] |
| user_message = get_last_user_message(messages) |
|
|
| queries = [] |
| try: |
| res = await generate_queries( |
| request, |
| { |
| "model": form_data["model"], |
| "messages": messages, |
| "prompt": user_message, |
| "type": "web_search", |
| }, |
| user, |
| ) |
|
|
| response = res["choices"][0]["message"]["content"] |
|
|
| try: |
| bracket_start = response.find("{") |
| bracket_end = response.rfind("}") + 1 |
|
|
| if bracket_start == -1 or bracket_end == -1: |
| raise Exception("No JSON object found in the response") |
|
|
| response = response[bracket_start:bracket_end] |
| queries = json.loads(response) |
| queries = queries.get("queries", []) |
| except Exception as e: |
| queries = [response] |
|
|
| except Exception as e: |
| log.exception(e) |
| queries = [user_message] |
|
|
| if len(queries) == 0: |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": "No search query generated", |
| "done": True, |
| }, |
| } |
| ) |
| return form_data |
|
|
| all_results = [] |
|
|
| for searchQuery in queries: |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": 'Searching "{{searchQuery}}"', |
| "query": searchQuery, |
| "done": False, |
| }, |
| } |
| ) |
|
|
| try: |
| results = await process_web_search( |
| request, |
| SearchForm( |
| **{ |
| "query": searchQuery, |
| } |
| ), |
| user=user, |
| ) |
|
|
| if results: |
| all_results.append(results) |
| files = form_data.get("files", []) |
|
|
| if results.get("collection_names"): |
| for col_idx, collection_name in enumerate( |
| results.get("collection_names") |
| ): |
| files.append( |
| { |
| "collection_name": collection_name, |
| "name": searchQuery, |
| "type": "web_search", |
| "urls": [results["filenames"][col_idx]], |
| } |
| ) |
| elif results.get("docs"): |
| |
| docs = results["docs"] |
|
|
| if len(docs) == len(results["filenames"]): |
| |
| for doc_idx, doc in enumerate(docs): |
| files.append( |
| { |
| "docs": [doc], |
| "name": searchQuery, |
| "type": "web_search", |
| "urls": [results["filenames"][doc_idx]], |
| } |
| ) |
| else: |
| |
| |
| files.append( |
| { |
| "docs": results.get("docs", []), |
| "name": searchQuery, |
| "type": "web_search", |
| "urls": results["filenames"], |
| } |
| ) |
|
|
| form_data["files"] = files |
| except Exception as e: |
| log.exception(e) |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": 'Error searching "{{searchQuery}}"', |
| "query": searchQuery, |
| "done": True, |
| "error": True, |
| }, |
| } |
| ) |
|
|
| if all_results: |
| urls = [] |
| for results in all_results: |
| if "filenames" in results: |
| urls.extend(results["filenames"]) |
|
|
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": "Searched {{count}} sites", |
| "urls": urls, |
| "done": True, |
| }, |
| } |
| ) |
| else: |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "web_search", |
| "description": "No search results found", |
| "done": True, |
| "error": True, |
| }, |
| } |
| ) |
|
|
| return form_data |
|
|
|
|
| async def chat_image_generation_handler( |
| request: Request, form_data: dict, extra_params: dict, user |
| ): |
| __event_emitter__ = extra_params["__event_emitter__"] |
| await __event_emitter__( |
| { |
| "type": "status", |
| "data": {"description": "Generating an image", "done": False}, |
| } |
| ) |
|
|
| messages = form_data["messages"] |
| user_message = get_last_user_message(messages) |
|
|
| prompt = user_message |
| negative_prompt = "" |
|
|
| if request.app.state.config.ENABLE_IMAGE_PROMPT_GENERATION: |
| try: |
| res = await generate_image_prompt( |
| request, |
| { |
| "model": form_data["model"], |
| "messages": messages, |
| }, |
| user, |
| ) |
|
|
| response = res["choices"][0]["message"]["content"] |
|
|
| try: |
| bracket_start = response.find("{") |
| bracket_end = response.rfind("}") + 1 |
|
|
| if bracket_start == -1 or bracket_end == -1: |
| raise Exception("No JSON object found in the response") |
|
|
| response = response[bracket_start:bracket_end] |
| response = json.loads(response) |
| prompt = response.get("prompt", []) |
| except Exception as e: |
| prompt = user_message |
|
|
| except Exception as e: |
| log.exception(e) |
| prompt = user_message |
|
|
| system_message_content = "" |
|
|
| try: |
| images = await image_generations( |
| request=request, |
| form_data=GenerateImageForm(**{"prompt": prompt}), |
| user=user, |
| ) |
|
|
| await __event_emitter__( |
| { |
| "type": "status", |
| "data": {"description": "Generated an image", "done": True}, |
| } |
| ) |
|
|
| await __event_emitter__( |
| { |
| "type": "files", |
| "data": { |
| "files": [ |
| { |
| "type": "image", |
| "url": image["url"], |
| } |
| for image in images |
| ] |
| }, |
| } |
| ) |
|
|
| system_message_content = "<context>User is shown the generated image, tell the user that the image has been generated</context>" |
| except Exception as e: |
| log.exception(e) |
| await __event_emitter__( |
| { |
| "type": "status", |
| "data": { |
| "description": f"An error occurred while generating an image", |
| "done": True, |
| }, |
| } |
| ) |
|
|
| system_message_content = "<context>Unable to generate an image, tell the user that an error occurred</context>" |
|
|
| if system_message_content: |
| form_data["messages"] = add_or_update_system_message( |
| system_message_content, form_data["messages"] |
| ) |
|
|
| return form_data |
|
|
|
|
| async def chat_completion_files_handler( |
| request: Request, body: dict, user: UserModel |
| ) -> tuple[dict, dict[str, list]]: |
| sources = [] |
|
|
| if files := body.get("metadata", {}).get("files", None): |
| queries = [] |
| try: |
| queries_response = await generate_queries( |
| request, |
| { |
| "model": body["model"], |
| "messages": body["messages"], |
| "type": "retrieval", |
| }, |
| user, |
| ) |
| queries_response = queries_response["choices"][0]["message"]["content"] |
|
|
| try: |
| bracket_start = queries_response.find("{") |
| bracket_end = queries_response.rfind("}") + 1 |
|
|
| if bracket_start == -1 or bracket_end == -1: |
| raise Exception("No JSON object found in the response") |
|
|
| queries_response = queries_response[bracket_start:bracket_end] |
| queries_response = json.loads(queries_response) |
| except Exception as e: |
| queries_response = {"queries": [queries_response]} |
|
|
| queries = queries_response.get("queries", []) |
| except: |
| pass |
|
|
| if len(queries) == 0: |
| queries = [get_last_user_message(body["messages"])] |
|
|
| try: |
| |
| loop = asyncio.get_running_loop() |
| with ThreadPoolExecutor() as executor: |
| sources = await loop.run_in_executor( |
| executor, |
| lambda: get_sources_from_files( |
| request=request, |
| files=files, |
| queries=queries, |
| embedding_function=lambda query, prefix: request.app.state.EMBEDDING_FUNCTION( |
| query, prefix=prefix, user=user |
| ), |
| k=request.app.state.config.TOP_K, |
| reranking_function=request.app.state.rf, |
| k_reranker=request.app.state.config.TOP_K_RERANKER, |
| r=request.app.state.config.RELEVANCE_THRESHOLD, |
| hybrid_search=request.app.state.config.ENABLE_RAG_HYBRID_SEARCH, |
| full_context=request.app.state.config.RAG_FULL_CONTEXT, |
| ), |
| ) |
| except Exception as e: |
| log.exception(e) |
|
|
| log.debug(f"rag_contexts:sources: {sources}") |
|
|
| return body, {"sources": sources} |
|
|
|
|
| def apply_params_to_form_data(form_data, model): |
| params = form_data.pop("params", {}) |
| if model.get("ollama"): |
| form_data["options"] = params |
|
|
| if "format" in params: |
| form_data["format"] = params["format"] |
|
|
| if "keep_alive" in params: |
| form_data["keep_alive"] = params["keep_alive"] |
| else: |
| if "seed" in params and params["seed"] is not None: |
| form_data["seed"] = params["seed"] |
|
|
| if "stop" in params and params["stop"] is not None: |
| form_data["stop"] = params["stop"] |
|
|
| if "temperature" in params and params["temperature"] is not None: |
| form_data["temperature"] = params["temperature"] |
|
|
| if "max_tokens" in params and params["max_tokens"] is not None: |
| form_data["max_tokens"] = params["max_tokens"] |
|
|
| if "top_p" in params and params["top_p"] is not None: |
| form_data["top_p"] = params["top_p"] |
|
|
| if "frequency_penalty" in params and params["frequency_penalty"] is not None: |
| form_data["frequency_penalty"] = params["frequency_penalty"] |
|
|
| if "reasoning_effort" in params and params["reasoning_effort"] is not None: |
| form_data["reasoning_effort"] = params["reasoning_effort"] |
|
|
| if "logit_bias" in params and params["logit_bias"] is not None: |
| try: |
| form_data["logit_bias"] = json.loads( |
| convert_logit_bias_input_to_json(params["logit_bias"]) |
| ) |
| except Exception as e: |
| print(f"Error parsing logit_bias: {e}") |
|
|
| return form_data |
|
|
|
|
| async def process_chat_payload(request, form_data, user, metadata, model): |
|
|
| form_data = apply_params_to_form_data(form_data, model) |
| log.debug(f"form_data: {form_data}") |
|
|
| event_emitter = get_event_emitter(metadata) |
| event_call = get_event_call(metadata) |
|
|
| extra_params = { |
| "__event_emitter__": event_emitter, |
| "__event_call__": event_call, |
| "__user__": { |
| "id": user.id, |
| "email": user.email, |
| "name": user.name, |
| "role": user.role, |
| }, |
| "__metadata__": metadata, |
| "__request__": request, |
| "__model__": model, |
| } |
|
|
| |
| |
| if getattr(request.state, "direct", False) and hasattr(request.state, "model"): |
| models = { |
| request.state.model["id"]: request.state.model, |
| } |
| else: |
| models = request.app.state.MODELS |
|
|
| task_model_id = get_task_model_id( |
| form_data["model"], |
| request.app.state.config.TASK_MODEL, |
| request.app.state.config.TASK_MODEL_EXTERNAL, |
| models, |
| ) |
|
|
| events = [] |
| sources = [] |
|
|
| user_message = get_last_user_message(form_data["messages"]) |
| model_knowledge = model.get("info", {}).get("meta", {}).get("knowledge", False) |
|
|
| if model_knowledge: |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "knowledge_search", |
| "query": user_message, |
| "done": False, |
| }, |
| } |
| ) |
|
|
| knowledge_files = [] |
| for item in model_knowledge: |
| if item.get("collection_name"): |
| knowledge_files.append( |
| { |
| "id": item.get("collection_name"), |
| "name": item.get("name"), |
| "legacy": True, |
| } |
| ) |
| elif item.get("collection_names"): |
| knowledge_files.append( |
| { |
| "name": item.get("name"), |
| "type": "collection", |
| "collection_names": item.get("collection_names"), |
| "legacy": True, |
| } |
| ) |
| else: |
| knowledge_files.append(item) |
|
|
| files = form_data.get("files", []) |
| files.extend(knowledge_files) |
| form_data["files"] = files |
|
|
| variables = form_data.pop("variables", None) |
|
|
| |
| try: |
| form_data = await process_pipeline_inlet_filter( |
| request, form_data, user, models |
| ) |
| except Exception as e: |
| raise e |
|
|
| try: |
| filter_functions = [ |
| Functions.get_function_by_id(filter_id) |
| for filter_id in get_sorted_filter_ids(model) |
| ] |
|
|
| form_data, flags = await process_filter_functions( |
| request=request, |
| filter_functions=filter_functions, |
| filter_type="inlet", |
| form_data=form_data, |
| extra_params=extra_params, |
| ) |
| except Exception as e: |
| raise Exception(f"Error: {e}") |
|
|
| features = form_data.pop("features", None) |
| if features: |
| if "web_search" in features and features["web_search"]: |
| form_data = await chat_web_search_handler( |
| request, form_data, extra_params, user |
| ) |
|
|
| if "image_generation" in features and features["image_generation"]: |
| form_data = await chat_image_generation_handler( |
| request, form_data, extra_params, user |
| ) |
|
|
| if "code_interpreter" in features and features["code_interpreter"]: |
| form_data["messages"] = add_or_update_user_message( |
| ( |
| request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE |
| if request.app.state.config.CODE_INTERPRETER_PROMPT_TEMPLATE != "" |
| else DEFAULT_CODE_INTERPRETER_PROMPT |
| ), |
| form_data["messages"], |
| ) |
|
|
| tool_ids = form_data.pop("tool_ids", None) |
| files = form_data.pop("files", None) |
|
|
| |
| if files: |
| files = list({json.dumps(f, sort_keys=True): f for f in files}.values()) |
|
|
| metadata = { |
| **metadata, |
| "tool_ids": tool_ids, |
| "files": files, |
| } |
| form_data["metadata"] = metadata |
|
|
| |
| tool_ids = metadata.get("tool_ids", None) |
| |
| tool_servers = metadata.get("tool_servers", None) |
|
|
| log.debug(f"{tool_ids=}") |
| log.debug(f"{tool_servers=}") |
|
|
| tools_dict = {} |
|
|
| if tool_ids: |
| tools_dict = get_tools( |
| request, |
| tool_ids, |
| user, |
| { |
| **extra_params, |
| "__model__": models[task_model_id], |
| "__messages__": form_data["messages"], |
| "__files__": metadata.get("files", []), |
| }, |
| ) |
|
|
| if tool_servers: |
| for tool_server in tool_servers: |
| tool_specs = tool_server.pop("specs", []) |
|
|
| for tool in tool_specs: |
| tools_dict[tool["name"]] = { |
| "spec": tool, |
| "direct": True, |
| "server": tool_server, |
| } |
|
|
| if tools_dict: |
| if metadata.get("function_calling") == "native": |
| |
| metadata["tools"] = tools_dict |
| form_data["tools"] = [ |
| {"type": "function", "function": tool.get("spec", {})} |
| for tool in tools_dict.values() |
| ] |
| else: |
| |
| try: |
| form_data, flags = await chat_completion_tools_handler( |
| request, form_data, extra_params, user, models, tools_dict |
| ) |
| sources.extend(flags.get("sources", [])) |
|
|
| except Exception as e: |
| log.exception(e) |
|
|
| try: |
| form_data, flags = await chat_completion_files_handler(request, form_data, user) |
| sources.extend(flags.get("sources", [])) |
| except Exception as e: |
| log.exception(e) |
|
|
| |
| if len(sources) > 0: |
| context_string = "" |
| citation_idx = {} |
| for source in sources: |
| if "document" in source: |
| for doc_context, doc_meta in zip( |
| source["document"], source["metadata"] |
| ): |
| citation_id = ( |
| doc_meta.get("source", None) |
| or source.get("source", {}).get("id", None) |
| or "N/A" |
| ) |
| if citation_id not in citation_idx: |
| citation_idx[citation_id] = len(citation_idx) + 1 |
| context_string += f'<source id="{citation_idx[citation_id]}">{doc_context}</source>\n' |
|
|
| context_string = context_string.strip() |
| prompt = get_last_user_message(form_data["messages"]) |
|
|
| if prompt is None: |
| raise Exception("No user message found") |
| if ( |
| request.app.state.config.RELEVANCE_THRESHOLD == 0 |
| and context_string.strip() == "" |
| ): |
| log.debug( |
| f"With a 0 relevancy threshold for RAG, the context cannot be empty" |
| ) |
|
|
| |
| |
| if model.get("owned_by") == "ollama": |
| form_data["messages"] = prepend_to_first_user_message_content( |
| rag_template( |
| request.app.state.config.RAG_TEMPLATE, context_string, prompt |
| ), |
| form_data["messages"], |
| ) |
| else: |
| form_data["messages"] = add_or_update_system_message( |
| rag_template( |
| request.app.state.config.RAG_TEMPLATE, context_string, prompt |
| ), |
| form_data["messages"], |
| ) |
|
|
| |
| sources = [source for source in sources if source.get("source", {}).get("name", "")] |
|
|
| if len(sources) > 0: |
| events.append({"sources": sources}) |
|
|
| if model_knowledge: |
| await event_emitter( |
| { |
| "type": "status", |
| "data": { |
| "action": "knowledge_search", |
| "query": user_message, |
| "done": True, |
| "hidden": True, |
| }, |
| } |
| ) |
|
|
| return form_data, metadata, events |
|
|
|
|
| async def process_chat_response( |
| request, response, form_data, user, metadata, model, events, tasks |
| ): |
| async def background_tasks_handler(): |
| message_map = Chats.get_messages_by_chat_id(metadata["chat_id"]) |
| message = message_map.get(metadata["message_id"]) if message_map else None |
|
|
| if message: |
| messages = get_message_list(message_map, message.get("id")) |
|
|
| if tasks and messages: |
| if TASKS.TITLE_GENERATION in tasks: |
| if tasks[TASKS.TITLE_GENERATION]: |
| res = await generate_title( |
| request, |
| { |
| "model": message["model"], |
| "messages": messages, |
| "chat_id": metadata["chat_id"], |
| }, |
| user, |
| ) |
|
|
| if res and isinstance(res, dict): |
| if len(res.get("choices", [])) == 1: |
| title_string = ( |
| res.get("choices", [])[0] |
| .get("message", {}) |
| .get("content", message.get("content", "New Chat")) |
| ) |
| else: |
| title_string = "" |
|
|
| title_string = title_string[ |
| title_string.find("{") : title_string.rfind("}") + 1 |
| ] |
|
|
| try: |
| title = json.loads(title_string).get( |
| "title", "New Chat" |
| ) |
| except Exception as e: |
| title = "" |
|
|
| if not title: |
| title = messages[0].get("content", "New Chat") |
|
|
| Chats.update_chat_title_by_id(metadata["chat_id"], title) |
|
|
| await event_emitter( |
| { |
| "type": "chat:title", |
| "data": title, |
| } |
| ) |
| elif len(messages) == 2: |
| title = messages[0].get("content", "New Chat") |
|
|
| Chats.update_chat_title_by_id(metadata["chat_id"], title) |
|
|
| await event_emitter( |
| { |
| "type": "chat:title", |
| "data": message.get("content", "New Chat"), |
| } |
| ) |
|
|
| if TASKS.TAGS_GENERATION in tasks and tasks[TASKS.TAGS_GENERATION]: |
| res = await generate_chat_tags( |
| request, |
| { |
| "model": message["model"], |
| "messages": messages, |
| "chat_id": metadata["chat_id"], |
| }, |
| user, |
| ) |
|
|
| if res and isinstance(res, dict): |
| if len(res.get("choices", [])) == 1: |
| tags_string = ( |
| res.get("choices", [])[0] |
| .get("message", {}) |
| .get("content", "") |
| ) |
| else: |
| tags_string = "" |
|
|
| tags_string = tags_string[ |
| tags_string.find("{") : tags_string.rfind("}") + 1 |
| ] |
|
|
| try: |
| tags = json.loads(tags_string).get("tags", []) |
| Chats.update_chat_tags_by_id( |
| metadata["chat_id"], tags, user |
| ) |
|
|
| await event_emitter( |
| { |
| "type": "chat:tags", |
| "data": tags, |
| } |
| ) |
| except Exception as e: |
| pass |
|
|
| event_emitter = None |
| event_caller = None |
| if ( |
| "session_id" in metadata |
| and metadata["session_id"] |
| and "chat_id" in metadata |
| and metadata["chat_id"] |
| and "message_id" in metadata |
| and metadata["message_id"] |
| ): |
| event_emitter = get_event_emitter(metadata) |
| event_caller = get_event_call(metadata) |
|
|
| |
| if not isinstance(response, StreamingResponse): |
| if event_emitter: |
| if "error" in response: |
| error = response["error"].get("detail", response["error"]) |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "error": {"content": error}, |
| }, |
| ) |
|
|
| if "selected_model_id" in response: |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "selectedModelId": response["selected_model_id"], |
| }, |
| ) |
|
|
| choices = response.get("choices", []) |
| if choices and choices[0].get("message", {}).get("content"): |
| content = response["choices"][0]["message"]["content"] |
|
|
| if content: |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": response, |
| } |
| ) |
|
|
| title = Chats.get_chat_title_by_id(metadata["chat_id"]) |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "done": True, |
| "content": content, |
| "title": title, |
| }, |
| } |
| ) |
|
|
| |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "content": content, |
| }, |
| ) |
|
|
| |
| if get_active_status_by_user_id(user.id) is None: |
| webhook_url = Users.get_user_webhook_url_by_id(user.id) |
| if webhook_url: |
| post_webhook( |
| request.app.state.WEBUI_NAME, |
| webhook_url, |
| f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", |
| { |
| "action": "chat", |
| "message": content, |
| "title": title, |
| "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", |
| }, |
| ) |
|
|
| await background_tasks_handler() |
|
|
| return response |
| else: |
| return response |
|
|
| |
| if not any( |
| content_type in response.headers["Content-Type"] |
| for content_type in ["text/event-stream", "application/x-ndjson"] |
| ): |
| return response |
|
|
| extra_params = { |
| "__event_emitter__": event_emitter, |
| "__event_call__": event_caller, |
| "__user__": { |
| "id": user.id, |
| "email": user.email, |
| "name": user.name, |
| "role": user.role, |
| }, |
| "__metadata__": metadata, |
| "__request__": request, |
| "__model__": model, |
| } |
| filter_functions = [ |
| Functions.get_function_by_id(filter_id) |
| for filter_id in get_sorted_filter_ids(model) |
| ] |
|
|
| |
| if event_emitter and event_caller: |
| task_id = str(uuid4()) |
| model_id = form_data.get("model", "") |
|
|
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "model": model_id, |
| }, |
| ) |
|
|
| def split_content_and_whitespace(content): |
| content_stripped = content.rstrip() |
| original_whitespace = ( |
| content[len(content_stripped) :] |
| if len(content) > len(content_stripped) |
| else "" |
| ) |
| return content_stripped, original_whitespace |
|
|
| def is_opening_code_block(content): |
| backtick_segments = content.split("```") |
| |
| return len(backtick_segments) > 1 and len(backtick_segments) % 2 == 0 |
|
|
| |
| async def post_response_handler(response, events): |
| def serialize_content_blocks(content_blocks, raw=False): |
| content = "" |
|
|
| for block in content_blocks: |
| if block["type"] == "text": |
| content = f"{content}{block['content'].strip()}\n" |
| elif block["type"] == "tool_calls": |
| attributes = block.get("attributes", {}) |
|
|
| tool_calls = block.get("content", []) |
| results = block.get("results", []) |
|
|
| if results: |
|
|
| tool_calls_display_content = "" |
| for tool_call in tool_calls: |
|
|
| tool_call_id = tool_call.get("id", "") |
| tool_name = tool_call.get("function", {}).get( |
| "name", "" |
| ) |
| tool_arguments = tool_call.get("function", {}).get( |
| "arguments", "" |
| ) |
|
|
| tool_result = None |
| tool_result_files = None |
| for result in results: |
| if tool_call_id == result.get("tool_call_id", ""): |
| tool_result = result.get("content", None) |
| tool_result_files = result.get("files", None) |
| break |
|
|
| if tool_result: |
| tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="true" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}" result="{html.escape(json.dumps(tool_result))}" files="{html.escape(json.dumps(tool_result_files)) if tool_result_files else ""}">\n<summary>Tool Executed</summary>\n</details>\n' |
| else: |
| tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>' |
|
|
| if not raw: |
| content = f"{content}\n{tool_calls_display_content}\n\n" |
| else: |
| tool_calls_display_content = "" |
|
|
| for tool_call in tool_calls: |
| tool_call_id = tool_call.get("id", "") |
| tool_name = tool_call.get("function", {}).get( |
| "name", "" |
| ) |
| tool_arguments = tool_call.get("function", {}).get( |
| "arguments", "" |
| ) |
|
|
| tool_calls_display_content = f'{tool_calls_display_content}\n<details type="tool_calls" done="false" id="{tool_call_id}" name="{tool_name}" arguments="{html.escape(json.dumps(tool_arguments))}">\n<summary>Executing...</summary>\n</details>' |
|
|
| if not raw: |
| content = f"{content}\n{tool_calls_display_content}\n\n" |
|
|
| elif block["type"] == "reasoning": |
| reasoning_display_content = "\n".join( |
| (f"> {line}" if not line.startswith(">") else line) |
| for line in block["content"].splitlines() |
| ) |
|
|
| reasoning_duration = block.get("duration", None) |
|
|
| if reasoning_duration is not None: |
| if raw: |
| content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' |
| else: |
| content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n' |
| else: |
| if raw: |
| content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n' |
| else: |
| content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n' |
|
|
| elif block["type"] == "code_interpreter": |
| attributes = block.get("attributes", {}) |
| output = block.get("output", None) |
| lang = attributes.get("lang", "") |
|
|
| content_stripped, original_whitespace = ( |
| split_content_and_whitespace(content) |
| ) |
| if is_opening_code_block(content_stripped): |
| |
| content = ( |
| content_stripped.rstrip("`").rstrip() |
| + original_whitespace |
| ) |
| else: |
| |
| content = content_stripped + original_whitespace |
|
|
| if output: |
| output = html.escape(json.dumps(output)) |
|
|
| if raw: |
| content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n```output\n{output}\n```\n' |
| else: |
| content = f'{content}\n<details type="code_interpreter" done="true" output="{output}">\n<summary>Analyzed</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' |
| else: |
| if raw: |
| content = f'{content}\n<code_interpreter type="code" lang="{lang}">\n{block["content"]}\n</code_interpreter>\n' |
| else: |
| content = f'{content}\n<details type="code_interpreter" done="false">\n<summary>Analyzing...</summary>\n```{lang}\n{block["content"]}\n```\n</details>\n' |
|
|
| else: |
| block_content = str(block["content"]).strip() |
| content = f"{content}{block['type']}: {block_content}\n" |
|
|
| return content.strip() |
|
|
| def convert_content_blocks_to_messages(content_blocks): |
| messages = [] |
|
|
| temp_blocks = [] |
| for idx, block in enumerate(content_blocks): |
| if block["type"] == "tool_calls": |
| messages.append( |
| { |
| "role": "assistant", |
| "content": serialize_content_blocks(temp_blocks), |
| "tool_calls": block.get("content"), |
| } |
| ) |
|
|
| results = block.get("results", []) |
|
|
| for result in results: |
| messages.append( |
| { |
| "role": "tool", |
| "tool_call_id": result["tool_call_id"], |
| "content": result["content"], |
| } |
| ) |
| temp_blocks = [] |
| else: |
| temp_blocks.append(block) |
|
|
| if temp_blocks: |
| content = serialize_content_blocks(temp_blocks) |
| if content: |
| messages.append( |
| { |
| "role": "assistant", |
| "content": content, |
| } |
| ) |
|
|
| return messages |
|
|
| def tag_content_handler(content_type, tags, content, content_blocks): |
| end_flag = False |
|
|
| def extract_attributes(tag_content): |
| """Extract attributes from a tag if they exist.""" |
| attributes = {} |
| if not tag_content: |
| return attributes |
| |
| matches = re.findall(r'(\w+)\s*=\s*"([^"]+)"', tag_content) |
| for key, value in matches: |
| attributes[key] = value |
| return attributes |
|
|
| if content_blocks[-1]["type"] == "text": |
| for start_tag, end_tag in tags: |
| |
| start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>" |
| match = re.search(start_tag_pattern, content) |
| if match: |
| attr_content = ( |
| match.group(1) if match.group(1) else "" |
| ) |
| attributes = extract_attributes( |
| attr_content |
| ) |
|
|
| |
| before_tag = content[ |
| : match.start() |
| ] |
| after_tag = content[ |
| match.end() : |
| ] |
|
|
| |
| content_blocks[-1]["content"] = content_blocks[-1][ |
| "content" |
| ].replace(match.group(0) + after_tag, "") |
|
|
| if before_tag: |
| content_blocks[-1]["content"] = before_tag |
|
|
| if not content_blocks[-1]["content"]: |
| content_blocks.pop() |
|
|
| |
| content_blocks.append( |
| { |
| "type": content_type, |
| "start_tag": start_tag, |
| "end_tag": end_tag, |
| "attributes": attributes, |
| "content": "", |
| "started_at": time.time(), |
| } |
| ) |
|
|
| if after_tag: |
| content_blocks[-1]["content"] = after_tag |
|
|
| break |
| elif content_blocks[-1]["type"] == content_type: |
| start_tag = content_blocks[-1]["start_tag"] |
| end_tag = content_blocks[-1]["end_tag"] |
| |
| end_tag_pattern = rf"<{re.escape(end_tag)}>" |
|
|
| |
| if re.search(end_tag_pattern, content): |
| end_flag = True |
|
|
| block_content = content_blocks[-1]["content"] |
| |
| start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>" |
| block_content = re.sub( |
| start_tag_pattern, "", block_content |
| ).strip() |
|
|
| end_tag_regex = re.compile(end_tag_pattern, re.DOTALL) |
| split_content = end_tag_regex.split(block_content, maxsplit=1) |
|
|
| |
| block_content = ( |
| split_content[0].strip() if split_content else "" |
| ) |
|
|
| |
| leftover_content = ( |
| split_content[1].strip() if len(split_content) > 1 else "" |
| ) |
|
|
| if block_content: |
| content_blocks[-1]["content"] = block_content |
| content_blocks[-1]["ended_at"] = time.time() |
| content_blocks[-1]["duration"] = int( |
| content_blocks[-1]["ended_at"] |
| - content_blocks[-1]["started_at"] |
| ) |
|
|
| |
| if content_type != "code_interpreter": |
| if leftover_content: |
|
|
| content_blocks.append( |
| { |
| "type": "text", |
| "content": leftover_content, |
| } |
| ) |
| else: |
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| else: |
| |
| content_blocks.pop() |
|
|
| if leftover_content: |
| content_blocks.append( |
| { |
| "type": "text", |
| "content": leftover_content, |
| } |
| ) |
| else: |
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| |
| content = re.sub( |
| rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>", |
| "", |
| content, |
| flags=re.DOTALL, |
| ) |
|
|
| return content, content_blocks, end_flag |
|
|
| message = Chats.get_message_by_id_and_message_id( |
| metadata["chat_id"], metadata["message_id"] |
| ) |
|
|
| tool_calls = [] |
|
|
| last_assistant_message = None |
| try: |
| if form_data["messages"][-1]["role"] == "assistant": |
| last_assistant_message = get_last_assistant_message( |
| form_data["messages"] |
| ) |
| except Exception as e: |
| pass |
|
|
| content = ( |
| message.get("content", "") |
| if message |
| else last_assistant_message if last_assistant_message else "" |
| ) |
|
|
| content_blocks = [ |
| { |
| "type": "text", |
| "content": content, |
| } |
| ] |
|
|
| |
| DETECT_REASONING = True |
| DETECT_SOLUTION = True |
| DETECT_CODE_INTERPRETER = metadata.get("features", {}).get( |
| "code_interpreter", False |
| ) |
|
|
| reasoning_tags = [ |
| ("think", "/think"), |
| ("thinking", "/thinking"), |
| ("reason", "/reason"), |
| ("reasoning", "/reasoning"), |
| ("thought", "/thought"), |
| ("Thought", "/Thought"), |
| ("|begin_of_thought|", "|end_of_thought|"), |
| ] |
|
|
| code_interpreter_tags = [("code_interpreter", "/code_interpreter")] |
|
|
| solution_tags = [("|begin_of_solution|", "|end_of_solution|")] |
|
|
| try: |
| for event in events: |
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": event, |
| } |
| ) |
|
|
| |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| **event, |
| }, |
| ) |
|
|
| async def stream_body_handler(response): |
| nonlocal content |
| nonlocal content_blocks |
|
|
| response_tool_calls = [] |
|
|
| async for line in response.body_iterator: |
| line = line.decode("utf-8") if isinstance(line, bytes) else line |
| data = line |
|
|
| |
| if not data.strip(): |
| continue |
|
|
| |
| if not data.startswith("data:"): |
| continue |
|
|
| |
| data = data[len("data:") :].strip() |
|
|
| try: |
| data = json.loads(data) |
|
|
| data, _ = await process_filter_functions( |
| request=request, |
| filter_functions=filter_functions, |
| filter_type="stream", |
| form_data=data, |
| extra_params=extra_params, |
| ) |
|
|
| if data: |
| if "event" in data: |
| await event_emitter(data.get("event", {})) |
|
|
| if "selected_model_id" in data: |
| model_id = data["selected_model_id"] |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "selectedModelId": model_id, |
| }, |
| ) |
| else: |
| choices = data.get("choices", []) |
| if not choices: |
| error = data.get("error", {}) |
| if error: |
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "error": error, |
| }, |
| } |
| ) |
| usage = data.get("usage", {}) |
| if usage: |
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "usage": usage, |
| }, |
| } |
| ) |
| continue |
|
|
| delta = choices[0].get("delta", {}) |
| delta_tool_calls = delta.get("tool_calls", None) |
|
|
| if delta_tool_calls: |
| for delta_tool_call in delta_tool_calls: |
| tool_call_index = delta_tool_call.get( |
| "index" |
| ) |
|
|
| if tool_call_index is not None: |
| |
| current_response_tool_call = None |
| for ( |
| response_tool_call |
| ) in response_tool_calls: |
| if ( |
| response_tool_call.get("index") |
| == tool_call_index |
| ): |
| current_response_tool_call = ( |
| response_tool_call |
| ) |
| break |
|
|
| if current_response_tool_call is None: |
| |
| delta_tool_call.setdefault( |
| "function", {} |
| ) |
| delta_tool_call[ |
| "function" |
| ].setdefault("name", "") |
| delta_tool_call[ |
| "function" |
| ].setdefault("arguments", "") |
| response_tool_calls.append( |
| delta_tool_call |
| ) |
| else: |
| |
| delta_name = delta_tool_call.get( |
| "function", {} |
| ).get("name") |
| delta_arguments = ( |
| delta_tool_call.get( |
| "function", {} |
| ).get("arguments") |
| ) |
|
|
| if delta_name: |
| current_response_tool_call[ |
| "function" |
| ]["name"] += delta_name |
|
|
| if delta_arguments: |
| current_response_tool_call[ |
| "function" |
| ][ |
| "arguments" |
| ] += delta_arguments |
|
|
| value = delta.get("content") |
|
|
| reasoning_content = delta.get( |
| "reasoning_content" |
| ) or delta.get("reasoning") |
| if reasoning_content: |
| if ( |
| not content_blocks |
| or content_blocks[-1]["type"] != "reasoning" |
| ): |
| reasoning_block = { |
| "type": "reasoning", |
| "start_tag": "think", |
| "end_tag": "/think", |
| "attributes": { |
| "type": "reasoning_content" |
| }, |
| "content": "", |
| "started_at": time.time(), |
| } |
| content_blocks.append(reasoning_block) |
| else: |
| reasoning_block = content_blocks[-1] |
|
|
| reasoning_block["content"] += reasoning_content |
|
|
| data = { |
| "content": serialize_content_blocks( |
| content_blocks |
| ) |
| } |
|
|
| if value: |
| if ( |
| content_blocks |
| and content_blocks[-1]["type"] |
| == "reasoning" |
| and content_blocks[-1] |
| .get("attributes", {}) |
| .get("type") |
| == "reasoning_content" |
| ): |
| reasoning_block = content_blocks[-1] |
| reasoning_block["ended_at"] = time.time() |
| reasoning_block["duration"] = int( |
| reasoning_block["ended_at"] |
| - reasoning_block["started_at"] |
| ) |
|
|
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| content = f"{content}{value}" |
| if not content_blocks: |
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| content_blocks[-1]["content"] = ( |
| content_blocks[-1]["content"] + value |
| ) |
|
|
| if DETECT_REASONING: |
| content, content_blocks, _ = ( |
| tag_content_handler( |
| "reasoning", |
| reasoning_tags, |
| content, |
| content_blocks, |
| ) |
| ) |
|
|
| if DETECT_CODE_INTERPRETER: |
| content, content_blocks, end = ( |
| tag_content_handler( |
| "code_interpreter", |
| code_interpreter_tags, |
| content, |
| content_blocks, |
| ) |
| ) |
|
|
| if end: |
| break |
|
|
| if DETECT_SOLUTION: |
| content, content_blocks, _ = ( |
| tag_content_handler( |
| "solution", |
| solution_tags, |
| content, |
| content_blocks, |
| ) |
| ) |
|
|
| if ENABLE_REALTIME_CHAT_SAVE: |
| |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "content": serialize_content_blocks( |
| content_blocks |
| ), |
| }, |
| ) |
| else: |
| data = { |
| "content": serialize_content_blocks( |
| content_blocks |
| ), |
| } |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": data, |
| } |
| ) |
| except Exception as e: |
| done = "data: [DONE]" in line |
| if done: |
| pass |
| else: |
| log.debug("Error: ", e) |
| continue |
|
|
| if content_blocks: |
| |
| if content_blocks[-1]["type"] == "text": |
| content_blocks[-1]["content"] = content_blocks[-1][ |
| "content" |
| ].strip() |
|
|
| if not content_blocks[-1]["content"]: |
| content_blocks.pop() |
|
|
| if not content_blocks: |
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| if response_tool_calls: |
| tool_calls.append(response_tool_calls) |
|
|
| if response.background: |
| await response.background() |
|
|
| await stream_body_handler(response) |
|
|
| MAX_TOOL_CALL_RETRIES = 10 |
| tool_call_retries = 0 |
|
|
| while len(tool_calls) > 0 and tool_call_retries < MAX_TOOL_CALL_RETRIES: |
| tool_call_retries += 1 |
|
|
| response_tool_calls = tool_calls.pop(0) |
|
|
| content_blocks.append( |
| { |
| "type": "tool_calls", |
| "content": response_tool_calls, |
| } |
| ) |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| } |
| ) |
|
|
| tools = metadata.get("tools", {}) |
|
|
| results = [] |
| for tool_call in response_tool_calls: |
| tool_call_id = tool_call.get("id", "") |
| tool_name = tool_call.get("function", {}).get("name", "") |
|
|
| tool_function_params = {} |
| try: |
| |
| tool_function_params = ast.literal_eval( |
| tool_call.get("function", {}).get("arguments", "{}") |
| ) |
| except Exception as e: |
| log.debug(e) |
| |
| try: |
| tool_function_params = json.loads( |
| tool_call.get("function", {}).get("arguments", "{}") |
| ) |
| except Exception as e: |
| log.debug( |
| f"Error parsing tool call arguments: {tool_call.get('function', {}).get('arguments', '{}')}" |
| ) |
|
|
| tool_result = None |
|
|
| if tool_name in tools: |
| tool = tools[tool_name] |
| spec = tool.get("spec", {}) |
|
|
| try: |
| allowed_params = ( |
| spec.get("parameters", {}) |
| .get("properties", {}) |
| .keys() |
| ) |
|
|
| tool_function_params = { |
| k: v |
| for k, v in tool_function_params.items() |
| if k in allowed_params |
| } |
|
|
| if tool.get("direct", False): |
| tool_result = await event_caller( |
| { |
| "type": "execute:tool", |
| "data": { |
| "id": str(uuid4()), |
| "name": tool_name, |
| "params": tool_function_params, |
| "server": tool.get("server", {}), |
| "session_id": metadata.get( |
| "session_id", None |
| ), |
| }, |
| } |
| ) |
|
|
| else: |
| tool_function = tool["callable"] |
| tool_result = await tool_function( |
| **tool_function_params |
| ) |
|
|
| except Exception as e: |
| tool_result = str(e) |
|
|
| tool_result_files = [] |
| if isinstance(tool_result, list): |
| for item in tool_result: |
| |
| if isinstance(item, str) and item.startswith("data:"): |
| tool_result_files.append(item) |
| tool_result.remove(item) |
|
|
| if isinstance(tool_result, dict) or isinstance( |
| tool_result, list |
| ): |
| tool_result = json.dumps(tool_result, indent=2) |
|
|
| results.append( |
| { |
| "tool_call_id": tool_call_id, |
| "content": tool_result, |
| **( |
| {"files": tool_result_files} |
| if tool_result_files |
| else {} |
| ), |
| } |
| ) |
|
|
| content_blocks[-1]["results"] = results |
|
|
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| } |
| ) |
|
|
| try: |
| res = await generate_chat_completion( |
| request, |
| { |
| "model": model_id, |
| "stream": True, |
| "tools": form_data["tools"], |
| "messages": [ |
| *form_data["messages"], |
| *convert_content_blocks_to_messages(content_blocks), |
| ], |
| }, |
| user, |
| ) |
|
|
| if isinstance(res, StreamingResponse): |
| await stream_body_handler(res) |
| else: |
| break |
| except Exception as e: |
| log.debug(e) |
| break |
|
|
| if DETECT_CODE_INTERPRETER: |
| MAX_RETRIES = 5 |
| retries = 0 |
|
|
| while ( |
| content_blocks[-1]["type"] == "code_interpreter" |
| and retries < MAX_RETRIES |
| ): |
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| } |
| ) |
|
|
| retries += 1 |
| log.debug(f"Attempt count: {retries}") |
|
|
| output = "" |
| try: |
| if content_blocks[-1]["attributes"].get("type") == "code": |
| code = content_blocks[-1]["content"] |
|
|
| if ( |
| request.app.state.config.CODE_INTERPRETER_ENGINE |
| == "pyodide" |
| ): |
| output = await event_caller( |
| { |
| "type": "execute:python", |
| "data": { |
| "id": str(uuid4()), |
| "code": code, |
| "session_id": metadata.get( |
| "session_id", None |
| ), |
| }, |
| } |
| ) |
| elif ( |
| request.app.state.config.CODE_INTERPRETER_ENGINE |
| == "jupyter" |
| ): |
| output = await execute_code_jupyter( |
| request.app.state.config.CODE_INTERPRETER_JUPYTER_URL, |
| code, |
| ( |
| request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_TOKEN |
| if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH |
| == "token" |
| else None |
| ), |
| ( |
| request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH_PASSWORD |
| if request.app.state.config.CODE_INTERPRETER_JUPYTER_AUTH |
| == "password" |
| else None |
| ), |
| request.app.state.config.CODE_INTERPRETER_JUPYTER_TIMEOUT, |
| ) |
| else: |
| output = { |
| "stdout": "Code interpreter engine not configured." |
| } |
|
|
| log.debug(f"Code interpreter output: {output}") |
|
|
| if isinstance(output, dict): |
| stdout = output.get("stdout", "") |
|
|
| if isinstance(stdout, str): |
| stdoutLines = stdout.split("\n") |
| for idx, line in enumerate(stdoutLines): |
| if "data:image/png;base64" in line: |
| id = str(uuid4()) |
|
|
| |
| os.makedirs( |
| os.path.join(CACHE_DIR, "images"), |
| exist_ok=True, |
| ) |
|
|
| image_path = os.path.join( |
| CACHE_DIR, |
| f"images/{id}.png", |
| ) |
|
|
| with open(image_path, "wb") as f: |
| f.write( |
| base64.b64decode( |
| line.split(",")[1] |
| ) |
| ) |
|
|
| stdoutLines[idx] = ( |
| f"" |
| ) |
|
|
| output["stdout"] = "\n".join(stdoutLines) |
|
|
| result = output.get("result", "") |
|
|
| if isinstance(result, str): |
| resultLines = result.split("\n") |
| for idx, line in enumerate(resultLines): |
| if "data:image/png;base64" in line: |
| id = str(uuid4()) |
|
|
| |
| os.makedirs( |
| os.path.join(CACHE_DIR, "images"), |
| exist_ok=True, |
| ) |
|
|
| image_path = os.path.join( |
| CACHE_DIR, |
| f"images/{id}.png", |
| ) |
|
|
| with open(image_path, "wb") as f: |
| f.write( |
| base64.b64decode( |
| line.split(",")[1] |
| ) |
| ) |
|
|
| resultLines[idx] = ( |
| f"" |
| ) |
|
|
| output["result"] = "\n".join(resultLines) |
| except Exception as e: |
| output = str(e) |
|
|
| content_blocks[-1]["output"] = output |
|
|
| content_blocks.append( |
| { |
| "type": "text", |
| "content": "", |
| } |
| ) |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| } |
| ) |
|
|
| try: |
| res = await generate_chat_completion( |
| request, |
| { |
| "model": model_id, |
| "stream": True, |
| "messages": [ |
| *form_data["messages"], |
| { |
| "role": "assistant", |
| "content": serialize_content_blocks( |
| content_blocks, raw=True |
| ), |
| }, |
| ], |
| }, |
| user, |
| ) |
|
|
| if isinstance(res, StreamingResponse): |
| await stream_body_handler(res) |
| else: |
| break |
| except Exception as e: |
| log.debug(e) |
| break |
|
|
| title = Chats.get_chat_title_by_id(metadata["chat_id"]) |
| data = { |
| "done": True, |
| "content": serialize_content_blocks(content_blocks), |
| "title": title, |
| } |
|
|
| if not ENABLE_REALTIME_CHAT_SAVE: |
| |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| ) |
|
|
| |
| if get_active_status_by_user_id(user.id) is None: |
| webhook_url = Users.get_user_webhook_url_by_id(user.id) |
| if webhook_url: |
| post_webhook( |
| request.app.state.WEBUI_NAME, |
| webhook_url, |
| f"{title} - {request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}\n\n{content}", |
| { |
| "action": "chat", |
| "message": content, |
| "title": title, |
| "url": f"{request.app.state.config.WEBUI_URL}/c/{metadata['chat_id']}", |
| }, |
| ) |
|
|
| await event_emitter( |
| { |
| "type": "chat:completion", |
| "data": data, |
| } |
| ) |
|
|
| await background_tasks_handler() |
| except asyncio.CancelledError: |
| log.warning("Task was cancelled!") |
| await event_emitter({"type": "task-cancelled"}) |
|
|
| if not ENABLE_REALTIME_CHAT_SAVE: |
| |
| Chats.upsert_message_to_chat_by_id_and_message_id( |
| metadata["chat_id"], |
| metadata["message_id"], |
| { |
| "content": serialize_content_blocks(content_blocks), |
| }, |
| ) |
|
|
| if response.background is not None: |
| await response.background() |
|
|
| |
| task_id, _ = create_task( |
| post_response_handler(response, events), id=metadata["chat_id"] |
| ) |
| return {"status": True, "task_id": task_id} |
|
|
| else: |
| |
| async def stream_wrapper(original_generator, events): |
| def wrap_item(item): |
| return f"data: {item}\n\n" |
|
|
| for event in events: |
| event, _ = await process_filter_functions( |
| request=request, |
| filter_functions=filter_functions, |
| filter_type="stream", |
| form_data=event, |
| extra_params=extra_params, |
| ) |
|
|
| if event: |
| yield wrap_item(json.dumps(event)) |
|
|
| async for data in original_generator: |
| data, _ = await process_filter_functions( |
| request=request, |
| filter_functions=filter_functions, |
| filter_type="stream", |
| form_data=data, |
| extra_params=extra_params, |
| ) |
|
|
| if data: |
| yield data |
|
|
| return StreamingResponse( |
| stream_wrapper(response.body_iterator, events), |
| headers=dict(response.headers), |
| background=response.background, |
| ) |
|
|