",
+ )
+
+ scheme, token = parts
+ if scheme.lower() != "bearer":
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid authentication scheme",
+ )
+
+ # Check if SCIM is enabled
+ scim_enabled = getattr(request.app.state, "SCIM_ENABLED", False)
+ log.info(
+ f"SCIM auth check - raw SCIM_ENABLED: {scim_enabled}, type: {type(scim_enabled)}"
+ )
+ # Handle both PersistentConfig and direct value
+ if hasattr(scim_enabled, "value"):
+ scim_enabled = scim_enabled.value
+ log.info(f"SCIM enabled status after conversion: {scim_enabled}")
+ if not scim_enabled:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="SCIM is not enabled",
+ )
+
+ # Verify the SCIM token
+ scim_token = getattr(request.app.state, "SCIM_TOKEN", None)
+ # Handle both PersistentConfig and direct value
+ if hasattr(scim_token, "value"):
+ scim_token = scim_token.value
+ log.debug(f"SCIM token configured: {bool(scim_token)}")
+ if not scim_token or token != scim_token:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Invalid SCIM token",
+ )
+
+ return True
+ except HTTPException:
+ # Re-raise HTTP exceptions as-is
+ raise
+ except Exception as e:
+ log.error(f"SCIM authentication error: {e}")
+ import traceback
+
+ log.error(f"Traceback: {traceback.format_exc()}")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Authentication failed",
+ )
+
+
+def user_to_scim(user: UserModel, request: Request) -> SCIMUser:
+ """Convert internal User model to SCIM User"""
+ # Parse display name into name components
+ name_parts = user.name.split(" ", 1) if user.name else ["", ""]
+ given_name = name_parts[0] if name_parts else ""
+ family_name = name_parts[1] if len(name_parts) > 1 else ""
+
+ # Get user's groups
+ user_groups = Groups.get_groups_by_member_id(user.id)
+ groups = [
+ {
+ "value": group.id,
+ "display": group.name,
+ "$ref": f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
+ "type": "direct",
+ }
+ for group in user_groups
+ ]
+
+ return SCIMUser(
+ id=user.id,
+ userName=user.email,
+ name=SCIMName(
+ formatted=user.name,
+ givenName=given_name,
+ familyName=family_name,
+ ),
+ displayName=user.name,
+ emails=[SCIMEmail(value=user.email)],
+ active=user.role != "pending",
+ photos=(
+ [SCIMPhoto(value=user.profile_image_url)]
+ if user.profile_image_url
+ else None
+ ),
+ groups=groups if groups else None,
+ meta=SCIMMeta(
+ resourceType=SCIM_RESOURCE_TYPE_USER,
+ created=datetime.fromtimestamp(
+ user.created_at, tz=timezone.utc
+ ).isoformat(),
+ lastModified=datetime.fromtimestamp(
+ user.updated_at, tz=timezone.utc
+ ).isoformat(),
+ location=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
+ ),
+ )
+
+
+def group_to_scim(group: GroupModel, request: Request) -> SCIMGroup:
+ """Convert internal Group model to SCIM Group"""
+ members = []
+ for user_id in group.user_ids:
+ user = Users.get_user_by_id(user_id)
+ if user:
+ members.append(
+ SCIMGroupMember(
+ value=user.id,
+ ref=f"{request.base_url}api/v1/scim/v2/Users/{user.id}",
+ display=user.name,
+ )
+ )
+
+ return SCIMGroup(
+ id=group.id,
+ displayName=group.name,
+ members=members,
+ meta=SCIMMeta(
+ resourceType=SCIM_RESOURCE_TYPE_GROUP,
+ created=datetime.fromtimestamp(
+ group.created_at, tz=timezone.utc
+ ).isoformat(),
+ lastModified=datetime.fromtimestamp(
+ group.updated_at, tz=timezone.utc
+ ).isoformat(),
+ location=f"{request.base_url}api/v1/scim/v2/Groups/{group.id}",
+ ),
+ )
+
+
+# SCIM Service Provider Config
+@router.get("/ServiceProviderConfig")
+async def get_service_provider_config():
+ """Get SCIM Service Provider Configuration"""
+ return {
+ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"],
+ "patch": {"supported": True},
+ "bulk": {"supported": False, "maxOperations": 1000, "maxPayloadSize": 1048576},
+ "filter": {"supported": True, "maxResults": 200},
+ "changePassword": {"supported": False},
+ "sort": {"supported": False},
+ "etag": {"supported": False},
+ "authenticationSchemes": [
+ {
+ "type": "oauthbearertoken",
+ "name": "OAuth Bearer Token",
+ "description": "Authentication using OAuth 2.0 Bearer Token",
+ }
+ ],
+ }
+
+
+# SCIM Resource Types
+@router.get("/ResourceTypes")
+async def get_resource_types(request: Request):
+ """Get SCIM Resource Types"""
+ return [
+ {
+ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
+ "id": "User",
+ "name": "User",
+ "endpoint": "/Users",
+ "schema": SCIM_USER_SCHEMA,
+ "meta": {
+ "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/User",
+ "resourceType": "ResourceType",
+ },
+ },
+ {
+ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:ResourceType"],
+ "id": "Group",
+ "name": "Group",
+ "endpoint": "/Groups",
+ "schema": SCIM_GROUP_SCHEMA,
+ "meta": {
+ "location": f"{request.base_url}api/v1/scim/v2/ResourceTypes/Group",
+ "resourceType": "ResourceType",
+ },
+ },
+ ]
+
+
+# SCIM Schemas
+@router.get("/Schemas")
+async def get_schemas():
+ """Get SCIM Schemas"""
+ return [
+ {
+ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
+ "id": SCIM_USER_SCHEMA,
+ "name": "User",
+ "description": "User Account",
+ "attributes": [
+ {
+ "name": "userName",
+ "type": "string",
+ "required": True,
+ "uniqueness": "server",
+ },
+ {"name": "displayName", "type": "string", "required": True},
+ {
+ "name": "emails",
+ "type": "complex",
+ "multiValued": True,
+ "required": True,
+ },
+ {"name": "active", "type": "boolean", "required": False},
+ ],
+ },
+ {
+ "schemas": ["urn:ietf:params:scim:schemas:core:2.0:Schema"],
+ "id": SCIM_GROUP_SCHEMA,
+ "name": "Group",
+ "description": "Group",
+ "attributes": [
+ {"name": "displayName", "type": "string", "required": True},
+ {
+ "name": "members",
+ "type": "complex",
+ "multiValued": True,
+ "required": False,
+ },
+ ],
+ },
+ ]
+
+
+# Users endpoints
+@router.get("/Users", response_model=SCIMListResponse)
+async def get_users(
+ request: Request,
+ startIndex: int = Query(1, ge=1),
+ count: int = Query(20, ge=1, le=100),
+ filter: Optional[str] = None,
+ _: bool = Depends(get_scim_auth),
+):
+ """List SCIM Users"""
+ skip = startIndex - 1
+ limit = count
+
+ # Get users from database
+ if filter:
+ # Simple filter parsing - supports userName eq "email"
+ # In production, you'd want a more robust filter parser
+ if "userName eq" in filter:
+ email = filter.split('"')[1]
+ user = Users.get_user_by_email(email)
+ users_list = [user] if user else []
+ total = 1 if user else 0
+ else:
+ response = Users.get_users(skip=skip, limit=limit)
+ users_list = response["users"]
+ total = response["total"]
+ else:
+ response = Users.get_users(skip=skip, limit=limit)
+ users_list = response["users"]
+ total = response["total"]
+
+ # Convert to SCIM format
+ scim_users = [user_to_scim(user, request) for user in users_list]
+
+ return SCIMListResponse(
+ totalResults=total,
+ itemsPerPage=len(scim_users),
+ startIndex=startIndex,
+ Resources=scim_users,
+ )
+
+
+@router.get("/Users/{user_id}", response_model=SCIMUser)
+async def get_user(
+ user_id: str,
+ request: Request,
+ _: bool = Depends(get_scim_auth),
+):
+ """Get SCIM User by ID"""
+ user = Users.get_user_by_id(user_id)
+ if not user:
+ return scim_error(
+ status_code=status.HTTP_404_NOT_FOUND, detail=f"User {user_id} not found"
+ )
+
+ return user_to_scim(user, request)
+
+
+@router.post("/Users", response_model=SCIMUser, status_code=status.HTTP_201_CREATED)
+async def create_user(
+ request: Request,
+ user_data: SCIMUserCreateRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Create SCIM User"""
+ # Check if user already exists
+ existing_user = Users.get_user_by_email(user_data.userName)
+ if existing_user:
+ raise HTTPException(
+ status_code=status.HTTP_409_CONFLICT,
+ detail=f"User with email {user_data.userName} already exists",
+ )
+
+ # Create user
+ user_id = str(uuid.uuid4())
+ email = user_data.emails[0].value if user_data.emails else user_data.userName
+
+ # Parse name if provided
+ name = user_data.displayName
+ if user_data.name:
+ if user_data.name.formatted:
+ name = user_data.name.formatted
+ elif user_data.name.givenName or user_data.name.familyName:
+ name = f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
+
+ # Get profile image if provided
+ profile_image = "/user.png"
+ if user_data.photos and len(user_data.photos) > 0:
+ profile_image = user_data.photos[0].value
+
+ # Create user
+ new_user = Users.insert_new_user(
+ id=user_id,
+ name=name,
+ email=email,
+ profile_image_url=profile_image,
+ role="user" if user_data.active else "pending",
+ )
+
+ if not new_user:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to create user",
+ )
+
+ return user_to_scim(new_user, request)
+
+
+@router.put("/Users/{user_id}", response_model=SCIMUser)
+async def update_user(
+ user_id: str,
+ request: Request,
+ user_data: SCIMUserUpdateRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Update SCIM User (full update)"""
+ user = Users.get_user_by_id(user_id)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User {user_id} not found",
+ )
+
+ # Build update dict
+ update_data = {}
+
+ if user_data.userName:
+ update_data["email"] = user_data.userName
+
+ if user_data.displayName:
+ update_data["name"] = user_data.displayName
+ elif user_data.name:
+ if user_data.name.formatted:
+ update_data["name"] = user_data.name.formatted
+ elif user_data.name.givenName or user_data.name.familyName:
+ update_data["name"] = (
+ f"{user_data.name.givenName or ''} {user_data.name.familyName or ''}".strip()
+ )
+
+ if user_data.emails and len(user_data.emails) > 0:
+ update_data["email"] = user_data.emails[0].value
+
+ if user_data.active is not None:
+ update_data["role"] = "user" if user_data.active else "pending"
+
+ if user_data.photos and len(user_data.photos) > 0:
+ update_data["profile_image_url"] = user_data.photos[0].value
+
+ # Update user
+ updated_user = Users.update_user_by_id(user_id, update_data)
+ if not updated_user:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to update user",
+ )
+
+ return user_to_scim(updated_user, request)
+
+
+@router.patch("/Users/{user_id}", response_model=SCIMUser)
+async def patch_user(
+ user_id: str,
+ request: Request,
+ patch_data: SCIMPatchRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Update SCIM User (partial update)"""
+ user = Users.get_user_by_id(user_id)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User {user_id} not found",
+ )
+
+ update_data = {}
+
+ for operation in patch_data.Operations:
+ op = operation.op.lower()
+ path = operation.path
+ value = operation.value
+
+ if op == "replace":
+ if path == "active":
+ update_data["role"] = "user" if value else "pending"
+ elif path == "userName":
+ update_data["email"] = value
+ elif path == "displayName":
+ update_data["name"] = value
+ elif path == "emails[primary eq true].value":
+ update_data["email"] = value
+ elif path == "name.formatted":
+ update_data["name"] = value
+
+ # Update user
+ if update_data:
+ updated_user = Users.update_user_by_id(user_id, update_data)
+ if not updated_user:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to update user",
+ )
+ else:
+ updated_user = user
+
+ return user_to_scim(updated_user, request)
+
+
+@router.delete("/Users/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
+async def delete_user(
+ user_id: str,
+ request: Request,
+ _: bool = Depends(get_scim_auth),
+):
+ """Delete SCIM User"""
+ user = Users.get_user_by_id(user_id)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"User {user_id} not found",
+ )
+
+ success = Users.delete_user_by_id(user_id)
+ if not success:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to delete user",
+ )
+
+ return None
+
+
+# Groups endpoints
+@router.get("/Groups", response_model=SCIMListResponse)
+async def get_groups(
+ request: Request,
+ startIndex: int = Query(1, ge=1),
+ count: int = Query(20, ge=1, le=100),
+ filter: Optional[str] = None,
+ _: bool = Depends(get_scim_auth),
+):
+ """List SCIM Groups"""
+ # Get all groups
+ groups_list = Groups.get_groups()
+
+ # Apply pagination
+ total = len(groups_list)
+ start = startIndex - 1
+ end = start + count
+ paginated_groups = groups_list[start:end]
+
+ # Convert to SCIM format
+ scim_groups = [group_to_scim(group, request) for group in paginated_groups]
+
+ return SCIMListResponse(
+ totalResults=total,
+ itemsPerPage=len(scim_groups),
+ startIndex=startIndex,
+ Resources=scim_groups,
+ )
+
+
+@router.get("/Groups/{group_id}", response_model=SCIMGroup)
+async def get_group(
+ group_id: str,
+ request: Request,
+ _: bool = Depends(get_scim_auth),
+):
+ """Get SCIM Group by ID"""
+ group = Groups.get_group_by_id(group_id)
+ if not group:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Group {group_id} not found",
+ )
+
+ return group_to_scim(group, request)
+
+
+@router.post("/Groups", response_model=SCIMGroup, status_code=status.HTTP_201_CREATED)
+async def create_group(
+ request: Request,
+ group_data: SCIMGroupCreateRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Create SCIM Group"""
+ # Extract member IDs
+ member_ids = []
+ if group_data.members:
+ for member in group_data.members:
+ member_ids.append(member.value)
+
+ # Create group
+ from open_webui.models.groups import GroupForm
+
+ form = GroupForm(
+ name=group_data.displayName,
+ description="",
+ )
+
+ # Need to get the creating user's ID - we'll use the first admin
+ admin_user = Users.get_super_admin_user()
+ if not admin_user:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="No admin user found",
+ )
+
+ new_group = Groups.insert_new_group(admin_user.id, form)
+ if not new_group:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to create group",
+ )
+
+ # Add members if provided
+ if member_ids:
+ from open_webui.models.groups import GroupUpdateForm
+
+ update_form = GroupUpdateForm(
+ name=new_group.name,
+ description=new_group.description,
+ user_ids=member_ids,
+ )
+ Groups.update_group_by_id(new_group.id, update_form)
+ new_group = Groups.get_group_by_id(new_group.id)
+
+ return group_to_scim(new_group, request)
+
+
+@router.put("/Groups/{group_id}", response_model=SCIMGroup)
+async def update_group(
+ group_id: str,
+ request: Request,
+ group_data: SCIMGroupUpdateRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Update SCIM Group (full update)"""
+ group = Groups.get_group_by_id(group_id)
+ if not group:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Group {group_id} not found",
+ )
+
+ # Build update form
+ from open_webui.models.groups import GroupUpdateForm
+
+ update_form = GroupUpdateForm(
+ name=group_data.displayName if group_data.displayName else group.name,
+ description=group.description,
+ )
+
+ # Handle members if provided
+ if group_data.members is not None:
+ member_ids = [member.value for member in group_data.members]
+ update_form.user_ids = member_ids
+
+ # Update group
+ updated_group = Groups.update_group_by_id(group_id, update_form)
+ if not updated_group:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to update group",
+ )
+
+ return group_to_scim(updated_group, request)
+
+
+@router.patch("/Groups/{group_id}", response_model=SCIMGroup)
+async def patch_group(
+ group_id: str,
+ request: Request,
+ patch_data: SCIMPatchRequest,
+ _: bool = Depends(get_scim_auth),
+):
+ """Update SCIM Group (partial update)"""
+ group = Groups.get_group_by_id(group_id)
+ if not group:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Group {group_id} not found",
+ )
+
+ from open_webui.models.groups import GroupUpdateForm
+
+ update_form = GroupUpdateForm(
+ name=group.name,
+ description=group.description,
+ user_ids=group.user_ids.copy() if group.user_ids else [],
+ )
+
+ for operation in patch_data.Operations:
+ op = operation.op.lower()
+ path = operation.path
+ value = operation.value
+
+ if op == "replace":
+ if path == "displayName":
+ update_form.name = value
+ elif path == "members":
+ # Replace all members
+ update_form.user_ids = [member["value"] for member in value]
+ elif op == "add":
+ if path == "members":
+ # Add members
+ if isinstance(value, list):
+ for member in value:
+ if isinstance(member, dict) and "value" in member:
+ if member["value"] not in update_form.user_ids:
+ update_form.user_ids.append(member["value"])
+ elif op == "remove":
+ if path and path.startswith("members[value eq"):
+ # Remove specific member
+ member_id = path.split('"')[1]
+ if member_id in update_form.user_ids:
+ update_form.user_ids.remove(member_id)
+
+ # Update group
+ updated_group = Groups.update_group_by_id(group_id, update_form)
+ if not updated_group:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to update group",
+ )
+
+ return group_to_scim(updated_group, request)
+
+
+@router.delete("/Groups/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
+async def delete_group(
+ group_id: str,
+ request: Request,
+ _: bool = Depends(get_scim_auth),
+):
+ """Delete SCIM Group"""
+ group = Groups.get_group_by_id(group_id)
+ if not group:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail=f"Group {group_id} not found",
+ )
+
+ success = Groups.delete_group_by_id(group_id)
+ if not success:
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Failed to delete group",
+ )
+
+ return None
diff --git a/backend/open_webui/routers/tasks.py b/backend/open_webui/routers/tasks.py
new file mode 100644
index 0000000000000000000000000000000000000000..7585466f69c1663d4fd6bb7cdf2a2abeabad35e6
--- /dev/null
+++ b/backend/open_webui/routers/tasks.py
@@ -0,0 +1,744 @@
+from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
+from fastapi.responses import JSONResponse, RedirectResponse
+
+from pydantic import BaseModel
+from typing import Optional
+import logging
+import re
+
+from open_webui.utils.chat import generate_chat_completion
+from open_webui.utils.task import (
+ title_generation_template,
+ follow_up_generation_template,
+ query_generation_template,
+ image_prompt_generation_template,
+ autocomplete_generation_template,
+ tags_generation_template,
+ emoji_generation_template,
+ moa_response_generation_template,
+)
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.constants import TASKS
+
+from open_webui.routers.pipelines import process_pipeline_inlet_filter
+
+from open_webui.utils.task import get_task_model_id
+
+from open_webui.config import (
+ DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE,
+ DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE,
+)
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+router = APIRouter()
+
+
+##################################
+#
+# Task Endpoints
+#
+##################################
+
+
+@router.get("/config")
+async def get_task_config(request: Request, user=Depends(get_verified_user)):
+ return {
+ "TASK_MODEL": request.app.state.config.TASK_MODEL,
+ "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
+ "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+ "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
+ "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
+ "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
+ "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
+ "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
+ "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+ "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+ }
+
+
+class TaskConfigForm(BaseModel):
+ TASK_MODEL: Optional[str]
+ TASK_MODEL_EXTERNAL: Optional[str]
+ ENABLE_TITLE_GENERATION: bool
+ TITLE_GENERATION_PROMPT_TEMPLATE: str
+ IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE: str
+ ENABLE_AUTOCOMPLETE_GENERATION: bool
+ AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH: int
+ TAGS_GENERATION_PROMPT_TEMPLATE: str
+ FOLLOW_UP_GENERATION_PROMPT_TEMPLATE: str
+ ENABLE_FOLLOW_UP_GENERATION: bool
+ ENABLE_TAGS_GENERATION: bool
+ ENABLE_SEARCH_QUERY_GENERATION: bool
+ ENABLE_RETRIEVAL_QUERY_GENERATION: bool
+ QUERY_GENERATION_PROMPT_TEMPLATE: str
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
+
+
+@router.post("/config/update")
+async def update_task_config(
+ request: Request, form_data: TaskConfigForm, user=Depends(get_admin_user)
+):
+ request.app.state.config.TASK_MODEL = form_data.TASK_MODEL
+ request.app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
+ request.app.state.config.ENABLE_TITLE_GENERATION = form_data.ENABLE_TITLE_GENERATION
+ request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
+ form_data.TITLE_GENERATION_PROMPT_TEMPLATE
+ )
+
+ request.app.state.config.ENABLE_FOLLOW_UP_GENERATION = (
+ form_data.ENABLE_FOLLOW_UP_GENERATION
+ )
+ request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE = (
+ form_data.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+ )
+
+ request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE = (
+ form_data.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+ )
+
+ request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION = (
+ form_data.ENABLE_AUTOCOMPLETE_GENERATION
+ )
+ request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH = (
+ form_data.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
+ )
+
+ request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = (
+ form_data.TAGS_GENERATION_PROMPT_TEMPLATE
+ )
+ request.app.state.config.ENABLE_TAGS_GENERATION = form_data.ENABLE_TAGS_GENERATION
+ request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION = (
+ form_data.ENABLE_SEARCH_QUERY_GENERATION
+ )
+ request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION = (
+ form_data.ENABLE_RETRIEVAL_QUERY_GENERATION
+ )
+
+ request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE = (
+ form_data.QUERY_GENERATION_PROMPT_TEMPLATE
+ )
+ request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
+ form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+ )
+
+ return {
+ "TASK_MODEL": request.app.state.config.TASK_MODEL,
+ "TASK_MODEL_EXTERNAL": request.app.state.config.TASK_MODEL_EXTERNAL,
+ "ENABLE_TITLE_GENERATION": request.app.state.config.ENABLE_TITLE_GENERATION,
+ "TITLE_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
+ "IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE": request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_AUTOCOMPLETE_GENERATION": request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION,
+ "AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH": request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH,
+ "TAGS_GENERATION_PROMPT_TEMPLATE": request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_TAGS_GENERATION": request.app.state.config.ENABLE_TAGS_GENERATION,
+ "ENABLE_FOLLOW_UP_GENERATION": request.app.state.config.ENABLE_FOLLOW_UP_GENERATION,
+ "FOLLOW_UP_GENERATION_PROMPT_TEMPLATE": request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE,
+ "ENABLE_SEARCH_QUERY_GENERATION": request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION,
+ "ENABLE_RETRIEVAL_QUERY_GENERATION": request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION,
+ "QUERY_GENERATION_PROMPT_TEMPLATE": request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE,
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": request.app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+ }
+
+
+@router.post("/title/completions")
+async def generate_title(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if not request.app.state.config.ENABLE_TITLE_GENERATION:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"detail": "Title generation is disabled"},
+ )
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating chat title using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_TITLE_GENERATION_PROMPT_TEMPLATE
+
+ content = title_generation_template(template, form_data["messages"], user)
+
+ max_tokens = (
+ models[task_model_id].get("info", {}).get("params", {}).get("max_tokens", 1000)
+ )
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ **(
+ {"max_tokens": max_tokens}
+ if models[task_model_id].get("owned_by") == "ollama"
+ else {
+ "max_completion_tokens": max_tokens,
+ }
+ ),
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.TITLE_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error("Exception occurred", exc_info=True)
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
+@router.post("/follow_up/completions")
+async def generate_follow_ups(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if not request.app.state.config.ENABLE_FOLLOW_UP_GENERATION:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"detail": "Follow-up generation is disabled"},
+ )
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating chat title using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_FOLLOW_UP_GENERATION_PROMPT_TEMPLATE
+
+ content = follow_up_generation_template(template, form_data["messages"], user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.FOLLOW_UP_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error("Exception occurred", exc_info=True)
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
+@router.post("/tags/completions")
+async def generate_chat_tags(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if not request.app.state.config.ENABLE_TAGS_GENERATION:
+ return JSONResponse(
+ status_code=status.HTTP_200_OK,
+ content={"detail": "Tags generation is disabled"},
+ )
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating chat tags using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_TAGS_GENERATION_PROMPT_TEMPLATE
+
+ content = tags_generation_template(template, form_data["messages"], user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.TAGS_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error(f"Error generating chat completion: {e}")
+ return JSONResponse(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
+@router.post("/image_prompt/completions")
+async def generate_image_prompt(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating image prompt using model {task_model_id} for user {user.email} "
+ )
+
+ if request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE != "":
+ template = request.app.state.config.IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_IMAGE_PROMPT_GENERATION_PROMPT_TEMPLATE
+
+ content = image_prompt_generation_template(template, form_data["messages"], user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.IMAGE_PROMPT_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error("Exception occurred", exc_info=True)
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
+@router.post("/queries/completions")
+async def generate_queries(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ type = form_data.get("type")
+ if type == "web_search":
+ if not request.app.state.config.ENABLE_SEARCH_QUERY_GENERATION:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Search query generation is disabled",
+ )
+ elif type == "retrieval":
+ if not request.app.state.config.ENABLE_RETRIEVAL_QUERY_GENERATION:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Query generation is disabled",
+ )
+
+ if getattr(request.state, "cached_queries", None):
+ log.info(f"Reusing cached queries: {request.state.cached_queries}")
+ return request.state.cached_queries
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating {type} queries using model {task_model_id} for user {user.email}"
+ )
+
+ if (request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE).strip() != "":
+ template = request.app.state.config.QUERY_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE
+
+ content = query_generation_template(template, form_data["messages"], user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.QUERY_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
+
+
+@router.post("/auto/completions")
+async def generate_autocompletion(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+ if not request.app.state.config.ENABLE_AUTOCOMPLETE_GENERATION:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Autocompletion generation is disabled",
+ )
+
+ type = form_data.get("type")
+ prompt = form_data.get("prompt")
+ messages = form_data.get("messages")
+
+ if request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH > 0:
+ if (
+ len(prompt)
+ > request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
+ )
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(
+ f"generating autocompletion using model {task_model_id} for user {user.email}"
+ )
+
+ if (request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE).strip() != "":
+ template = request.app.state.config.AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
+ else:
+ template = DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE
+
+ content = autocomplete_generation_template(template, prompt, messages, type, user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.AUTOCOMPLETE_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ log.error(f"Error generating chat completion: {e}")
+ return JSONResponse(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ content={"detail": "An internal error has occurred."},
+ )
+
+
+@router.post("/emoji/completions")
+async def generate_emoji(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ # Check if the user has a custom task model
+ # If the user has a custom task model, use that model
+ task_model_id = get_task_model_id(
+ model_id,
+ request.app.state.config.TASK_MODEL,
+ request.app.state.config.TASK_MODEL_EXTERNAL,
+ models,
+ )
+
+ log.debug(f"generating emoji using model {task_model_id} for user {user.email} ")
+
+ template = DEFAULT_EMOJI_GENERATION_PROMPT_TEMPLATE
+
+ content = emoji_generation_template(template, form_data["prompt"], user)
+
+ payload = {
+ "model": task_model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": False,
+ **(
+ {"max_tokens": 4}
+ if models[task_model_id].get("owned_by") == "ollama"
+ else {
+ "max_completion_tokens": 4,
+ }
+ ),
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "task": str(TASKS.EMOJI_GENERATION),
+ "task_body": form_data,
+ "chat_id": form_data.get("chat_id", None),
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
+
+
+@router.post("/moa/completions")
+async def generate_moa_response(
+ request: Request, form_data: dict, user=Depends(get_verified_user)
+):
+
+ if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
+ models = {
+ request.state.model["id"]: request.state.model,
+ }
+ else:
+ models = request.app.state.MODELS
+
+ model_id = form_data["model"]
+
+ if model_id not in models:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail="Model not found",
+ )
+
+ template = DEFAULT_MOA_GENERATION_PROMPT_TEMPLATE
+
+ content = moa_response_generation_template(
+ template,
+ form_data["prompt"],
+ form_data["responses"],
+ )
+
+ payload = {
+ "model": model_id,
+ "messages": [{"role": "user", "content": content}],
+ "stream": form_data.get("stream", False),
+ "metadata": {
+ **(request.state.metadata if hasattr(request.state, "metadata") else {}),
+ "chat_id": form_data.get("chat_id", None),
+ "task": str(TASKS.MOA_RESPONSE_GENERATION),
+ "task_body": form_data,
+ },
+ }
+
+ # Process the payload through the pipeline
+ try:
+ payload = await process_pipeline_inlet_filter(request, payload, user, models)
+ except Exception as e:
+ raise e
+
+ try:
+ return await generate_chat_completion(request, form_data=payload, user=user)
+ except Exception as e:
+ return JSONResponse(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ content={"detail": str(e)},
+ )
diff --git a/backend/open_webui/routers/tools.py b/backend/open_webui/routers/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..2fa3f6abf615e4da86d8174fe23951c9c66e3242
--- /dev/null
+++ b/backend/open_webui/routers/tools.py
@@ -0,0 +1,632 @@
+import logging
+from pathlib import Path
+from typing import Optional
+import time
+import re
+import aiohttp
+from open_webui.models.groups import Groups
+from pydantic import BaseModel, HttpUrl
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+
+
+from open_webui.models.oauth_sessions import OAuthSessions
+from open_webui.models.tools import (
+ ToolForm,
+ ToolModel,
+ ToolResponse,
+ ToolUserResponse,
+ Tools,
+)
+from open_webui.utils.plugin import (
+ load_tool_module_by_id,
+ replace_imports,
+ get_tool_module_from_cache,
+)
+from open_webui.utils.tools import get_tool_specs
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.access_control import has_access, has_permission
+from open_webui.utils.tools import get_tool_servers
+
+from open_webui.env import SRC_LOG_LEVELS
+from open_webui.config import CACHE_DIR, BYPASS_ADMIN_ACCESS_CONTROL
+from open_webui.constants import ERROR_MESSAGES
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+
+router = APIRouter()
+
+
+def get_tool_module(request, tool_id, load_from_db=True):
+ """
+ Get the tool module by its ID.
+ """
+ tool_module, _ = get_tool_module_from_cache(request, tool_id, load_from_db)
+ return tool_module
+
+
+############################
+# GetTools
+############################
+
+
+@router.get("/", response_model=list[ToolUserResponse])
+async def get_tools(request: Request, user=Depends(get_verified_user)):
+ tools = []
+
+ # Local Tools
+ for tool in Tools.get_tools():
+ tool_module = get_tool_module(request, tool.id)
+ tools.append(
+ ToolUserResponse(
+ **{
+ **tool.model_dump(),
+ "has_user_valves": hasattr(tool_module, "UserValves"),
+ }
+ )
+ )
+
+ # OpenAPI Tool Servers
+ for server in await get_tool_servers(request):
+ tools.append(
+ ToolUserResponse(
+ **{
+ "id": f"server:{server.get('id')}",
+ "user_id": f"server:{server.get('id')}",
+ "name": server.get("openapi", {})
+ .get("info", {})
+ .get("title", "Tool Server"),
+ "meta": {
+ "description": server.get("openapi", {})
+ .get("info", {})
+ .get("description", ""),
+ },
+ "access_control": request.app.state.config.TOOL_SERVER_CONNECTIONS[
+ server.get("idx", 0)
+ ]
+ .get("config", {})
+ .get("access_control", None),
+ "updated_at": int(time.time()),
+ "created_at": int(time.time()),
+ }
+ )
+ )
+
+ # MCP Tool Servers
+ for server in request.app.state.config.TOOL_SERVER_CONNECTIONS:
+ if server.get("type", "openapi") == "mcp":
+ server_id = server.get("info", {}).get("id")
+ auth_type = server.get("auth_type", "none")
+
+ session_token = None
+ if auth_type == "oauth_2.1":
+ splits = server_id.split(":")
+ server_id = splits[-1] if len(splits) > 1 else server_id
+
+ session_token = (
+ await request.app.state.oauth_client_manager.get_oauth_token(
+ user.id, f"mcp:{server_id}"
+ )
+ )
+
+ tools.append(
+ ToolUserResponse(
+ **{
+ "id": f"server:mcp:{server.get('info', {}).get('id')}",
+ "user_id": f"server:mcp:{server.get('info', {}).get('id')}",
+ "name": server.get("info", {}).get("name", "MCP Tool Server"),
+ "meta": {
+ "description": server.get("info", {}).get(
+ "description", ""
+ ),
+ },
+ "access_control": server.get("config", {}).get(
+ "access_control", None
+ ),
+ "updated_at": int(time.time()),
+ "created_at": int(time.time()),
+ **(
+ {
+ "authenticated": session_token is not None,
+ }
+ if auth_type == "oauth_2.1"
+ else {}
+ ),
+ }
+ )
+ )
+
+ if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
+ # Admin can see all tools
+ return tools
+ else:
+ user_group_ids = {group.id for group in Groups.get_groups_by_member_id(user.id)}
+ tools = [
+ tool
+ for tool in tools
+ if tool.user_id == user.id
+ or has_access(user.id, "read", tool.access_control, user_group_ids)
+ ]
+ return tools
+
+
+############################
+# GetToolList
+############################
+
+
+@router.get("/list", response_model=list[ToolUserResponse])
+async def get_tool_list(user=Depends(get_verified_user)):
+ if user.role == "admin" and BYPASS_ADMIN_ACCESS_CONTROL:
+ tools = Tools.get_tools()
+ else:
+ tools = Tools.get_tools_by_user_id(user.id, "write")
+ return tools
+
+
+############################
+# LoadFunctionFromLink
+############################
+
+
+class LoadUrlForm(BaseModel):
+ url: HttpUrl
+
+
+def github_url_to_raw_url(url: str) -> str:
+ # Handle 'tree' (folder) URLs (add main.py at the end)
+ m1 = re.match(r"https://github\.com/([^/]+)/([^/]+)/tree/([^/]+)/(.*)", url)
+ if m1:
+ org, repo, branch, path = m1.groups()
+ return f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path.rstrip('/')}/main.py"
+
+ # Handle 'blob' (file) URLs
+ m2 = re.match(r"https://github\.com/([^/]+)/([^/]+)/blob/([^/]+)/(.*)", url)
+ if m2:
+ org, repo, branch, path = m2.groups()
+ return (
+ f"https://raw.githubusercontent.com/{org}/{repo}/refs/heads/{branch}/{path}"
+ )
+
+ # No match; return as-is
+ return url
+
+
+@router.post("/load/url", response_model=Optional[dict])
+async def load_tool_from_url(
+ request: Request, form_data: LoadUrlForm, user=Depends(get_admin_user)
+):
+ # NOTE: This is NOT a SSRF vulnerability:
+ # This endpoint is admin-only (see get_admin_user), meant for *trusted* internal use,
+ # and does NOT accept untrusted user input. Access is enforced by authentication.
+
+ url = str(form_data.url)
+ if not url:
+ raise HTTPException(status_code=400, detail="Please enter a valid URL")
+
+ url = github_url_to_raw_url(url)
+ url_parts = url.rstrip("/").split("/")
+
+ file_name = url_parts[-1]
+ tool_name = (
+ file_name[:-3]
+ if (
+ file_name.endswith(".py")
+ and (not file_name.startswith(("main.py", "index.py", "__init__.py")))
+ )
+ else url_parts[-2] if len(url_parts) > 1 else "function"
+ )
+
+ try:
+ async with aiohttp.ClientSession(trust_env=True) as session:
+ async with session.get(
+ url, headers={"Content-Type": "application/json"}
+ ) as resp:
+ if resp.status != 200:
+ raise HTTPException(
+ status_code=resp.status, detail="Failed to fetch the tool"
+ )
+ data = await resp.text()
+ if not data:
+ raise HTTPException(
+ status_code=400, detail="No data received from the URL"
+ )
+ return {
+ "name": tool_name,
+ "content": data,
+ }
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Error importing tool: {e}")
+
+
+############################
+# ExportTools
+############################
+
+
+@router.get("/export", response_model=list[ToolModel])
+async def export_tools(user=Depends(get_admin_user)):
+ tools = Tools.get_tools()
+ return tools
+
+
+############################
+# CreateNewTools
+############################
+
+
+@router.post("/create", response_model=Optional[ToolResponse])
+async def create_new_tools(
+ request: Request,
+ form_data: ToolForm,
+ user=Depends(get_verified_user),
+):
+ if user.role != "admin" and not has_permission(
+ user.id, "workspace.tools", request.app.state.config.USER_PERMISSIONS
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
+ )
+
+ if not form_data.id.isidentifier():
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail="Only alphanumeric characters and underscores are allowed in the id",
+ )
+
+ form_data.id = form_data.id.lower()
+
+ tools = Tools.get_tool_by_id(form_data.id)
+ if tools is None:
+ try:
+ form_data.content = replace_imports(form_data.content)
+ tool_module, frontmatter = load_tool_module_by_id(
+ form_data.id, content=form_data.content
+ )
+ form_data.meta.manifest = frontmatter
+
+ TOOLS = request.app.state.TOOLS
+ TOOLS[form_data.id] = tool_module
+
+ specs = get_tool_specs(TOOLS[form_data.id])
+ tools = Tools.insert_new_tool(user.id, form_data, specs)
+
+ tool_cache_dir = CACHE_DIR / "tools" / form_data.id
+ tool_cache_dir.mkdir(parents=True, exist_ok=True)
+
+ if tools:
+ return tools
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error creating tools"),
+ )
+ except Exception as e:
+ log.exception(f"Failed to load the tool by id {form_data.id}: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.ID_TAKEN,
+ )
+
+
+############################
+# GetToolsById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[ToolModel])
+async def get_tools_by_id(id: str, user=Depends(get_verified_user)):
+ tools = Tools.get_tool_by_id(id)
+
+ if tools:
+ if (
+ user.role == "admin"
+ or tools.user_id == user.id
+ or has_access(user.id, "read", tools.access_control)
+ ):
+ return tools
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# UpdateToolsById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[ToolModel])
+async def update_tools_by_id(
+ request: Request,
+ id: str,
+ form_data: ToolForm,
+ user=Depends(get_verified_user),
+):
+ tools = Tools.get_tool_by_id(id)
+ if not tools:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+ # Is the user the original creator, in a group with write access, or an admin
+ if (
+ tools.user_id != user.id
+ and not has_access(user.id, "write", tools.access_control)
+ and user.role != "admin"
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
+ )
+
+ try:
+ form_data.content = replace_imports(form_data.content)
+ tool_module, frontmatter = load_tool_module_by_id(id, content=form_data.content)
+ form_data.meta.manifest = frontmatter
+
+ TOOLS = request.app.state.TOOLS
+ TOOLS[id] = tool_module
+
+ specs = get_tool_specs(TOOLS[id])
+
+ updated = {
+ **form_data.model_dump(exclude={"id"}),
+ "specs": specs,
+ }
+
+ log.debug(updated)
+ tools = Tools.update_tool_by_id(id, updated)
+
+ if tools:
+ return tools
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT("Error updating tools"),
+ )
+
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+
+
+############################
+# DeleteToolsById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_tools_by_id(
+ request: Request, id: str, user=Depends(get_verified_user)
+):
+ tools = Tools.get_tool_by_id(id)
+ if not tools:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+ if (
+ tools.user_id != user.id
+ and not has_access(user.id, "write", tools.access_control)
+ and user.role != "admin"
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.UNAUTHORIZED,
+ )
+
+ result = Tools.delete_tool_by_id(id)
+ if result:
+ TOOLS = request.app.state.TOOLS
+ if id in TOOLS:
+ del TOOLS[id]
+
+ return result
+
+
+############################
+# GetToolValves
+############################
+
+
+@router.get("/id/{id}/valves", response_model=Optional[dict])
+async def get_tools_valves_by_id(id: str, user=Depends(get_verified_user)):
+ tools = Tools.get_tool_by_id(id)
+ if tools:
+ try:
+ valves = Tools.get_tool_valves_by_id(id)
+ return valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# GetToolValvesSpec
+############################
+
+
+@router.get("/id/{id}/valves/spec", response_model=Optional[dict])
+async def get_tools_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_verified_user)
+):
+ tools = Tools.get_tool_by_id(id)
+ if tools:
+ if id in request.app.state.TOOLS:
+ tools_module = request.app.state.TOOLS[id]
+ else:
+ tools_module, _ = load_tool_module_by_id(id)
+ request.app.state.TOOLS[id] = tools_module
+
+ if hasattr(tools_module, "Valves"):
+ Valves = tools_module.Valves
+ return Valves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+############################
+# UpdateToolValves
+############################
+
+
+@router.post("/id/{id}/valves/update", response_model=Optional[dict])
+async def update_tools_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+ tools = Tools.get_tool_by_id(id)
+ if not tools:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+ if (
+ tools.user_id != user.id
+ and not has_access(user.id, "write", tools.access_control)
+ and user.role != "admin"
+ ):
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ )
+
+ if id in request.app.state.TOOLS:
+ tools_module = request.app.state.TOOLS[id]
+ else:
+ tools_module, _ = load_tool_module_by_id(id)
+ request.app.state.TOOLS[id] = tools_module
+
+ if not hasattr(tools_module, "Valves"):
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ Valves = tools_module.Valves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ valves = Valves(**form_data)
+ valves_dict = valves.model_dump(exclude_unset=True)
+ Tools.update_tool_valves_by_id(id, valves_dict)
+ return valves_dict
+ except Exception as e:
+ log.exception(f"Failed to update tool valves by id {id}: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+
+
+############################
+# ToolUserValves
+############################
+
+
+@router.get("/id/{id}/valves/user", response_model=Optional[dict])
+async def get_tools_user_valves_by_id(id: str, user=Depends(get_verified_user)):
+ tools = Tools.get_tool_by_id(id)
+ if tools:
+ try:
+ user_valves = Tools.get_user_valves_by_id_and_user_id(id, user.id)
+ return user_valves
+ except Exception as e:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.get("/id/{id}/valves/user/spec", response_model=Optional[dict])
+async def get_tools_user_valves_spec_by_id(
+ request: Request, id: str, user=Depends(get_verified_user)
+):
+ tools = Tools.get_tool_by_id(id)
+ if tools:
+ if id in request.app.state.TOOLS:
+ tools_module = request.app.state.TOOLS[id]
+ else:
+ tools_module, _ = load_tool_module_by_id(id)
+ request.app.state.TOOLS[id] = tools_module
+
+ if hasattr(tools_module, "UserValves"):
+ UserValves = tools_module.UserValves
+ return UserValves.schema()
+ return None
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+
+
+@router.post("/id/{id}/valves/user/update", response_model=Optional[dict])
+async def update_tools_user_valves_by_id(
+ request: Request, id: str, form_data: dict, user=Depends(get_verified_user)
+):
+ tools = Tools.get_tool_by_id(id)
+
+ if tools:
+ if id in request.app.state.TOOLS:
+ tools_module = request.app.state.TOOLS[id]
+ else:
+ tools_module, _ = load_tool_module_by_id(id)
+ request.app.state.TOOLS[id] = tools_module
+
+ if hasattr(tools_module, "UserValves"):
+ UserValves = tools_module.UserValves
+
+ try:
+ form_data = {k: v for k, v in form_data.items() if v is not None}
+ user_valves = UserValves(**form_data)
+ user_valves_dict = user_valves.model_dump(exclude_unset=True)
+ Tools.update_user_valves_by_id_and_user_id(
+ id, user.id, user_valves_dict
+ )
+ return user_valves_dict
+ except Exception as e:
+ log.exception(f"Failed to update user valves by id {id}: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(str(e)),
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.NOT_FOUND,
+ )
diff --git a/backend/open_webui/routers/users.py b/backend/open_webui/routers/users.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ee3f9f88c00d9cc6f249f2f802a4a84927aeb7b
--- /dev/null
+++ b/backend/open_webui/routers/users.py
@@ -0,0 +1,550 @@
+import logging
+from typing import Optional
+import base64
+import io
+
+
+from fastapi import APIRouter, Depends, HTTPException, Request, status
+from fastapi.responses import Response, StreamingResponse, FileResponse
+from pydantic import BaseModel
+
+
+from open_webui.models.auths import Auths
+from open_webui.models.oauth_sessions import OAuthSessions
+
+from open_webui.models.groups import Groups
+from open_webui.models.chats import Chats
+from open_webui.models.users import (
+ UserModel,
+ UserListResponse,
+ UserInfoListResponse,
+ UserIdNameListResponse,
+ UserRoleUpdateForm,
+ Users,
+ UserSettings,
+ UserUpdateForm,
+)
+
+
+from open_webui.socket.main import (
+ get_active_status_by_user_id,
+ get_active_user_ids,
+ get_user_active_status,
+)
+from open_webui.constants import ERROR_MESSAGES
+from open_webui.env import SRC_LOG_LEVELS, STATIC_DIR
+
+
+from open_webui.utils.auth import get_admin_user, get_password_hash, get_verified_user
+from open_webui.utils.access_control import get_permissions, has_permission
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+router = APIRouter()
+
+
+############################
+# GetActiveUsers
+############################
+
+
+@router.get("/active")
+async def get_active_users(
+ user=Depends(get_verified_user),
+):
+ """
+ Get a list of active users.
+ """
+ return {
+ "user_ids": get_active_user_ids(),
+ }
+
+
+############################
+# GetUsers
+############################
+
+
+PAGE_ITEM_COUNT = 30
+
+
+@router.get("/", response_model=UserListResponse)
+async def get_users(
+ query: Optional[str] = None,
+ order_by: Optional[str] = None,
+ direction: Optional[str] = None,
+ page: Optional[int] = 1,
+ user=Depends(get_admin_user),
+):
+ limit = PAGE_ITEM_COUNT
+
+ page = max(1, page)
+ skip = (page - 1) * limit
+
+ filter = {}
+ if query:
+ filter["query"] = query
+ if order_by:
+ filter["order_by"] = order_by
+ if direction:
+ filter["direction"] = direction
+
+ return Users.get_users(filter=filter, skip=skip, limit=limit)
+
+
+@router.get("/all", response_model=UserInfoListResponse)
+async def get_all_users(
+ user=Depends(get_admin_user),
+):
+ return Users.get_users()
+
+
+@router.get("/search", response_model=UserIdNameListResponse)
+async def search_users(
+ query: Optional[str] = None,
+ user=Depends(get_verified_user),
+):
+ limit = PAGE_ITEM_COUNT
+
+ page = 1 # Always return the first page for search
+ skip = (page - 1) * limit
+
+ filter = {}
+ if query:
+ filter["query"] = query
+
+ return Users.get_users(filter=filter, skip=skip, limit=limit)
+
+
+############################
+# User Groups
+############################
+
+
+@router.get("/groups")
+async def get_user_groups(user=Depends(get_verified_user)):
+ return Groups.get_groups_by_member_id(user.id)
+
+
+############################
+# User Permissions
+############################
+
+
+@router.get("/permissions")
+async def get_user_permissisions(request: Request, user=Depends(get_verified_user)):
+ user_permissions = get_permissions(
+ user.id, request.app.state.config.USER_PERMISSIONS
+ )
+
+ return user_permissions
+
+
+############################
+# User Default Permissions
+############################
+class WorkspacePermissions(BaseModel):
+ models: bool = False
+ knowledge: bool = False
+ prompts: bool = False
+ tools: bool = False
+
+
+class SharingPermissions(BaseModel):
+ public_models: bool = True
+ public_knowledge: bool = True
+ public_prompts: bool = True
+ public_tools: bool = True
+ public_notes: bool = True
+
+
+class ChatPermissions(BaseModel):
+ controls: bool = True
+ valves: bool = True
+ system_prompt: bool = True
+ params: bool = True
+ file_upload: bool = True
+ delete: bool = True
+ delete_message: bool = True
+ continue_response: bool = True
+ regenerate_response: bool = True
+ rate_response: bool = True
+ edit: bool = True
+ share: bool = True
+ export: bool = True
+ stt: bool = True
+ tts: bool = True
+ call: bool = True
+ multiple_models: bool = True
+ temporary: bool = True
+ temporary_enforced: bool = False
+
+
+class FeaturesPermissions(BaseModel):
+ direct_tool_servers: bool = False
+ web_search: bool = True
+ image_generation: bool = True
+ code_interpreter: bool = True
+ notes: bool = True
+
+
+class UserPermissions(BaseModel):
+ workspace: WorkspacePermissions
+ sharing: SharingPermissions
+ chat: ChatPermissions
+ features: FeaturesPermissions
+
+
+@router.get("/default/permissions", response_model=UserPermissions)
+async def get_default_user_permissions(request: Request, user=Depends(get_admin_user)):
+ return {
+ "workspace": WorkspacePermissions(
+ **request.app.state.config.USER_PERMISSIONS.get("workspace", {})
+ ),
+ "sharing": SharingPermissions(
+ **request.app.state.config.USER_PERMISSIONS.get("sharing", {})
+ ),
+ "chat": ChatPermissions(
+ **request.app.state.config.USER_PERMISSIONS.get("chat", {})
+ ),
+ "features": FeaturesPermissions(
+ **request.app.state.config.USER_PERMISSIONS.get("features", {})
+ ),
+ }
+
+
+@router.post("/default/permissions")
+async def update_default_user_permissions(
+ request: Request, form_data: UserPermissions, user=Depends(get_admin_user)
+):
+ request.app.state.config.USER_PERMISSIONS = form_data.model_dump()
+ return request.app.state.config.USER_PERMISSIONS
+
+
+############################
+# GetUserSettingsBySessionUser
+############################
+
+
+@router.get("/user/settings", response_model=Optional[UserSettings])
+async def get_user_settings_by_session_user(user=Depends(get_verified_user)):
+ user = Users.get_user_by_id(user.id)
+ if user:
+ return user.settings
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# UpdateUserSettingsBySessionUser
+############################
+
+
+@router.post("/user/settings/update", response_model=UserSettings)
+async def update_user_settings_by_session_user(
+ request: Request, form_data: UserSettings, user=Depends(get_verified_user)
+):
+ updated_user_settings = form_data.model_dump()
+ if (
+ user.role != "admin"
+ and "toolServers" in updated_user_settings.get("ui").keys()
+ and not has_permission(
+ user.id,
+ "features.direct_tool_servers",
+ request.app.state.config.USER_PERMISSIONS,
+ )
+ ):
+ # If the user is not an admin and does not have permission to use tool servers, remove the key
+ updated_user_settings["ui"].pop("toolServers", None)
+
+ user = Users.update_user_settings_by_id(user.id, updated_user_settings)
+ if user:
+ return user.settings
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# GetUserInfoBySessionUser
+############################
+
+
+@router.get("/user/info", response_model=Optional[dict])
+async def get_user_info_by_session_user(user=Depends(get_verified_user)):
+ user = Users.get_user_by_id(user.id)
+ if user:
+ return user.info
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# UpdateUserInfoBySessionUser
+############################
+
+
+@router.post("/user/info/update", response_model=Optional[dict])
+async def update_user_info_by_session_user(
+ form_data: dict, user=Depends(get_verified_user)
+):
+ user = Users.get_user_by_id(user.id)
+ if user:
+ if user.info is None:
+ user.info = {}
+
+ user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
+ if user:
+ return user.info
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# GetUserById
+############################
+
+
+class UserResponse(BaseModel):
+ name: str
+ profile_image_url: str
+ active: Optional[bool] = None
+
+
+@router.get("/{user_id}", response_model=UserResponse)
+async def get_user_by_id(user_id: str, user=Depends(get_verified_user)):
+ # Check if user_id is a shared chat
+ # If it is, get the user_id from the chat
+ if user_id.startswith("shared-"):
+ chat_id = user_id.replace("shared-", "")
+ chat = Chats.get_chat_by_id(chat_id)
+ if chat:
+ user_id = chat.user_id
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+ user = Users.get_user_by_id(user_id)
+
+ if user:
+ return UserResponse(
+ **{
+ "name": user.name,
+ "profile_image_url": user.profile_image_url,
+ "active": get_active_status_by_user_id(user_id),
+ }
+ )
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+@router.get("/{user_id}/oauth/sessions")
+async def get_user_oauth_sessions_by_id(user_id: str, user=Depends(get_admin_user)):
+ sessions = OAuthSessions.get_sessions_by_user_id(user_id)
+ if sessions and len(sessions) > 0:
+ return sessions
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# GetUserProfileImageById
+############################
+
+
+@router.get("/{user_id}/profile/image")
+async def get_user_profile_image_by_id(user_id: str, user=Depends(get_verified_user)):
+ user = Users.get_user_by_id(user_id)
+ if user:
+ if user.profile_image_url:
+ # check if it's url or base64
+ if user.profile_image_url.startswith("http"):
+ return Response(
+ status_code=status.HTTP_302_FOUND,
+ headers={"Location": user.profile_image_url},
+ )
+ elif user.profile_image_url.startswith("data:image"):
+ try:
+ header, base64_data = user.profile_image_url.split(",", 1)
+ image_data = base64.b64decode(base64_data)
+ image_buffer = io.BytesIO(image_data)
+
+ return StreamingResponse(
+ image_buffer,
+ media_type="image/png",
+ headers={"Content-Disposition": "inline; filename=image.png"},
+ )
+ except Exception as e:
+ pass
+ return FileResponse(f"{STATIC_DIR}/user.png")
+ else:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# GetUserActiveStatusById
+############################
+
+
+@router.get("/{user_id}/active", response_model=dict)
+async def get_user_active_status_by_id(user_id: str, user=Depends(get_verified_user)):
+ return {
+ "active": get_user_active_status(user_id),
+ }
+
+
+############################
+# UpdateUserById
+############################
+
+
+@router.post("/{user_id}/update", response_model=Optional[UserModel])
+async def update_user_by_id(
+ user_id: str,
+ form_data: UserUpdateForm,
+ session_user=Depends(get_admin_user),
+):
+ # Prevent modification of the primary admin user by other admins
+ try:
+ first_user = Users.get_first_user()
+ if first_user:
+ if user_id == first_user.id:
+ if session_user.id != user_id:
+ # If the user trying to update is the primary admin, and they are not the primary admin themselves
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+
+ if form_data.role != "admin":
+ # If the primary admin is trying to change their own role, prevent it
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+
+ except Exception as e:
+ log.error(f"Error checking primary admin status: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Could not verify primary admin status.",
+ )
+
+ user = Users.get_user_by_id(user_id)
+
+ if user:
+ if form_data.email.lower() != user.email:
+ email_user = Users.get_user_by_email(form_data.email.lower())
+ if email_user:
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.EMAIL_TAKEN,
+ )
+
+ if form_data.password:
+ hashed = get_password_hash(form_data.password)
+ log.debug(f"hashed: {hashed}")
+ Auths.update_user_password_by_id(user_id, hashed)
+
+ Auths.update_email_by_id(user_id, form_data.email.lower())
+ updated_user = Users.update_user_by_id(
+ user_id,
+ {
+ "role": form_data.role,
+ "name": form_data.name,
+ "email": form_data.email.lower(),
+ "profile_image_url": form_data.profile_image_url,
+ },
+ )
+
+ if updated_user:
+ return updated_user
+
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DEFAULT(),
+ )
+
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.USER_NOT_FOUND,
+ )
+
+
+############################
+# DeleteUserById
+############################
+
+
+@router.delete("/{user_id}", response_model=bool)
+async def delete_user_by_id(user_id: str, user=Depends(get_admin_user)):
+ # Prevent deletion of the primary admin user
+ try:
+ first_user = Users.get_first_user()
+ if first_user and user_id == first_user.id:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+ except Exception as e:
+ log.error(f"Error checking primary admin status: {e}")
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail="Could not verify primary admin status.",
+ )
+
+ if user.id != user_id:
+ result = Auths.delete_auth_by_id(user_id)
+
+ if result:
+ return True
+
+ raise HTTPException(
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+ detail=ERROR_MESSAGES.DELETE_USER_ERROR,
+ )
+
+ # Prevent self-deletion
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+ )
+
+
+############################
+# GetUserGroupsById
+############################
+
+
+@router.get("/{user_id}/groups")
+async def get_user_groups_by_id(user_id: str, user=Depends(get_admin_user)):
+ return Groups.get_groups_by_member_id(user_id)
diff --git a/backend/open_webui/routers/utils.py b/backend/open_webui/routers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e6768a6716295c027a35e2b254137321f914348
--- /dev/null
+++ b/backend/open_webui/routers/utils.py
@@ -0,0 +1,135 @@
+import black
+import logging
+import markdown
+
+from open_webui.models.chats import ChatTitleMessagesForm
+from open_webui.config import DATA_DIR, ENABLE_ADMIN_EXPORT
+from open_webui.constants import ERROR_MESSAGES
+from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
+from pydantic import BaseModel
+from starlette.responses import FileResponse
+
+
+from open_webui.utils.misc import get_gravatar_url
+from open_webui.utils.pdf_generator import PDFGenerator
+from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.code_interpreter import execute_code_jupyter
+from open_webui.env import SRC_LOG_LEVELS
+
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MAIN"])
+
+router = APIRouter()
+
+
+@router.get("/gravatar")
+async def get_gravatar(email: str, user=Depends(get_verified_user)):
+ return get_gravatar_url(email)
+
+
+class CodeForm(BaseModel):
+ code: str
+
+
+@router.post("/code/format")
+async def format_code(form_data: CodeForm, user=Depends(get_admin_user)):
+ try:
+ formatted_code = black.format_str(form_data.code, mode=black.Mode())
+ return {"code": formatted_code}
+ except black.NothingChanged:
+ return {"code": form_data.code}
+ except Exception as e:
+ raise HTTPException(status_code=400, detail=str(e))
+
+
+@router.post("/code/execute")
+async def execute_code(
+ request: Request, form_data: CodeForm, user=Depends(get_verified_user)
+):
+ if request.app.state.config.CODE_EXECUTION_ENGINE == "jupyter":
+ output = await execute_code_jupyter(
+ request.app.state.config.CODE_EXECUTION_JUPYTER_URL,
+ form_data.code,
+ (
+ request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_TOKEN
+ if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "token"
+ else None
+ ),
+ (
+ request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH_PASSWORD
+ if request.app.state.config.CODE_EXECUTION_JUPYTER_AUTH == "password"
+ else None
+ ),
+ request.app.state.config.CODE_EXECUTION_JUPYTER_TIMEOUT,
+ )
+
+ return output
+ else:
+ raise HTTPException(
+ status_code=400,
+ detail="Code execution engine not supported",
+ )
+
+
+class MarkdownForm(BaseModel):
+ md: str
+
+
+@router.post("/markdown")
+async def get_html_from_markdown(
+ form_data: MarkdownForm, user=Depends(get_verified_user)
+):
+ return {"html": markdown.markdown(form_data.md)}
+
+
+class ChatForm(BaseModel):
+ title: str
+ messages: list[dict]
+
+
+@router.post("/pdf")
+async def download_chat_as_pdf(
+ form_data: ChatTitleMessagesForm, user=Depends(get_verified_user)
+):
+ try:
+ pdf_bytes = PDFGenerator(form_data).generate_chat_pdf()
+
+ return Response(
+ content=pdf_bytes,
+ media_type="application/pdf",
+ headers={"Content-Disposition": "attachment;filename=chat.pdf"},
+ )
+ except Exception as e:
+ log.exception(f"Error generating PDF: {e}")
+ raise HTTPException(status_code=400, detail=str(e))
+
+
+@router.get("/db/download")
+async def download_db(user=Depends(get_admin_user)):
+ if not ENABLE_ADMIN_EXPORT:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+ )
+ from open_webui.internal.db import engine
+
+ if engine.name != "sqlite":
+ raise HTTPException(
+ status_code=status.HTTP_400_BAD_REQUEST,
+ detail=ERROR_MESSAGES.DB_NOT_SQLITE,
+ )
+ return FileResponse(
+ engine.url.database,
+ media_type="application/octet-stream",
+ filename="webui.db",
+ )
+
+
+@router.get("/litellm/config")
+async def download_litellm_config_yaml(user=Depends(get_admin_user)):
+ return FileResponse(
+ f"{DATA_DIR}/litellm/config.yaml",
+ media_type="application/octet-stream",
+ filename="config.yaml",
+ )
diff --git a/backend/open_webui/services/meilisearch_service.py b/backend/open_webui/services/meilisearch_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aaa90951a8c147a50ca896a3102faaeb3d5779a
--- /dev/null
+++ b/backend/open_webui/services/meilisearch_service.py
@@ -0,0 +1,303 @@
+import meilisearch
+import os
+import logging
+from typing import Optional, List, Dict, Any
+
+from open_webui.config import MEILISEARCH_URL, MEILISEARCH_API_KEY
+
+log = logging.getLogger(__name__)
+
+# Global variable to hold the MeiliSearchService instance
+meilisearch_service_instance: Optional["MeiliSearchService"] = None
+
+class MeiliSearchService:
+ def __init__(self, url: str, api_key: str):
+ self.client = meilisearch.Client(url, api_key)
+ self.index = self.client.index("messages")
+ self.index.update(primary_key='id')
+ log.info("MeiliSearch client initialized and primary key set.")
+
+ def setup_index(self):
+ """Sets up the index settings."""
+ settings = {
+ "rankingRules": ["words", "typo", "proximity", "attribute", "sort", "exactness"],
+ "searchableAttributes": ["content"],
+ "filterableAttributes": ["chatId", "userId", "tags", "folderId", "archived", "pinned", "shared", "role", "timestamp"],
+ "sortableAttributes": ["timestamp"],
+ "typoTolerance": {
+ "enabled": True,
+ "minWordSizeForTypos": {
+ "oneTypo": 4,
+ "twoTypos": 8
+ }
+ }
+ }
+ self.index.update_settings(settings)
+ log.info("MeiliSearch index settings configured - searching content only.")
+
+ def get_branch_path(self, message_id: str, messages_dict: Dict[str, Any]) -> List[str]:
+ """Builds an array of message IDs from the root to the given message."""
+ path = []
+ current_id = message_id
+ while current_id:
+ path.append(current_id)
+ message = messages_dict.get(current_id)
+ if not message:
+ break
+ current_id = message.get("parent")
+ path.reverse()
+ return path
+
+ def extract_messages_from_chat(self, chat_model) -> List[Dict[str, Any]]:
+ """Flatten chat history and prepare messages for indexing."""
+ if not chat_model or not chat_model.chat:
+ log.warning(f"extract_messages_from_chat: chat_model or chat_model.chat is missing for chat_id {chat_model.id}")
+ return []
+
+ messages_dict = chat_model.chat.get("history", {}).get("messages", {})
+ if not messages_dict:
+ log.warning(f"extract_messages_from_chat: messages_dict is empty for chat_id {chat_model.id}")
+ return []
+
+ documents = []
+ for message_id, message in messages_dict.items():
+ if not message.get("content"):
+ continue
+
+ branch_path = self.get_branch_path(message_id, messages_dict)
+
+ # Get timestamp and validate it
+ msg_timestamp = message.get("timestamp", chat_model.updated_at)
+
+ # Auto-correct millisecond timestamps
+ if msg_timestamp > 4102444800:
+ msg_timestamp = msg_timestamp // 1000
+
+ doc = {
+ "id": f"{chat_model.id}-{message_id}",
+ "chatId": chat_model.id,
+ "messageId": message_id,
+ "userId": chat_model.user_id,
+ "chatTitle": chat_model.title,
+ "content": message["content"],
+ "role": message.get("role"),
+ "timestamp": msg_timestamp,
+ "parentId": message.get("parent"),
+ "branchPath": branch_path,
+ "tags": chat_model.meta.get("tags", []),
+ "folderId": chat_model.folder_id,
+ "archived": chat_model.archived,
+ "pinned": chat_model.pinned,
+ "shared": bool(chat_model.share_id),
+ }
+ documents.append(doc)
+ return documents
+
+ def index_chat(self, chat_id: str, user_id: str):
+ """Indexes a single chat and its messages."""
+ from open_webui.models.chats import Chats
+
+ log.info(f"Attempting to index chat_id: {chat_id} for user_id: {user_id}")
+ chat_model = Chats.get_chat_by_id(chat_id)
+ if not chat_model:
+ log.error(f"index_chat: Chat model with id {chat_id} not found in database.")
+ return
+
+ if chat_model.user_id != user_id:
+ log.error(f"index_chat: User {user_id} does not have access to chat {chat_id}.")
+ return
+
+ log.info(f"index_chat: Found chat '{chat_model.title}' for indexing.")
+ documents = self.extract_messages_from_chat(chat_model)
+
+ if documents:
+ log.info(f"index_chat: Extracted {len(documents)} documents to index.")
+ log.debug(f"index_chat: Sample document: {documents[0]}")
+ task = self.index.add_documents(documents)
+ log.info(f"Successfully queued {len(documents)} documents for chat_id: {chat_id} for indexing. Task UID: {task.task_uid}")
+ else:
+ log.warning(f"index_chat: No documents were extracted from chat_id: {chat_id}")
+
+ def delete_chat_from_index(self, chat_id: str):
+ """Deletes all messages of a chat from the index."""
+ log.info(f"Attempting to delete chat_id: {chat_id} from index")
+ try:
+ # Search for all documents with this chatId and delete them by ID
+ # We need to fetch all documents, so use a high limit
+ search_results = self.index.search("", {
+ "filter": f'chatId = "{chat_id}"',
+ "limit": 10000,
+ "attributesToRetrieve": ["id"]
+ })
+
+ doc_ids = [hit["id"] for hit in search_results.get("hits", [])]
+
+ if doc_ids:
+ log.info(f"Found {len(doc_ids)} documents to delete for chat_id: {chat_id}")
+ self.index.delete_documents(doc_ids)
+ log.info(f"Successfully queued deletion of {len(doc_ids)} documents for chat_id: {chat_id}")
+ else:
+ log.info(f"No documents found for chat_id: {chat_id}")
+ except Exception as e:
+ log.error(f"Failed to delete chat {chat_id} from index: {e}")
+
+ def delete_message_from_index(self, chat_id: str, message_id: str):
+ """Deletes a specific message from the index."""
+ doc_id = f"{chat_id}-{message_id}"
+ log.info(f"Attempting to delete message document: {doc_id} from index")
+ try:
+ self.index.delete_document(doc_id)
+ log.info(f"Successfully queued deletion for document: {doc_id}")
+ except Exception as e:
+ log.error(f"Failed to delete message document {doc_id}: {e}")
+
+ def delete_messages_from_index(self, chat_id: str, message_ids: list[str]):
+ """Deletes multiple messages from the index."""
+ if not message_ids:
+ return
+ doc_ids = [f"{chat_id}-{msg_id}" for msg_id in message_ids]
+ log.info(f"Attempting to delete {len(doc_ids)} message documents from index")
+ try:
+ self.index.delete_documents(doc_ids)
+ log.info(f"Successfully queued deletion for {len(doc_ids)} documents")
+ except Exception as e:
+ log.error(f"Failed to delete message documents: {e}")
+
+ def search_messages(
+ self,
+ query: str,
+ user_id: str,
+ chat_id: Optional[str] = None,
+ page: int = 1,
+ limit: int = 60,
+ sort_by: str = "relevance",
+ filters: Dict[str, Any] = None
+ ) -> Dict[str, Any]:
+ """Searches for messages with security filters and sorting."""
+ search_params = {
+ "limit": limit,
+ "offset": (page - 1) * limit,
+ "attributesToHighlight": ["content"],
+ "highlightPreTag": "",
+ "highlightPostTag": "",
+ "showMatchesPosition": True,
+ "showRankingScore": True,
+ }
+
+ # Build filter conditions
+ filter_conditions = [f"userId = {user_id}"]
+
+ if chat_id:
+ filter_conditions.append(f"chatId = {chat_id}")
+
+ if filters:
+ if filters.get('role'):
+ filter_conditions.append(f"role = {filters['role']}")
+
+ if filters.get('timestamp_gte'):
+ filter_conditions.append(f"timestamp >= {filters['timestamp_gte']}")
+
+ if filters.get('tags'):
+ tag_filters = " OR ".join([f"tags = {tag}" for tag in filters['tags']])
+ filter_conditions.append(f"({tag_filters})")
+
+ if filters.get('folder_id'):
+ filter_conditions.append(f"folderId = {filters['folder_id']}")
+
+ search_params["filter"] = " AND ".join(filter_conditions)
+ results = self.index.search(query, search_params)
+
+ # Client-side sort by date if requested
+ if sort_by == "date" and results.get('hits'):
+ results['hits'] = sorted(results['hits'], key=lambda x: x.get('timestamp', 0), reverse=True)
+
+ return results
+
+ def reindex_all_chats_for_user(self, user_id: str):
+ """Re-indexes all chats for a specific user, removing orphaned documents."""
+ from open_webui.models.chats import Chats
+
+ log.info(f"Re-indexing all chats for user {user_id}")
+
+ # Step 1: Delete all existing documents for this user to ensure clean sync
+ log.info(f"Deleting all existing documents for user {user_id}")
+ deleted_total = 0
+ try:
+ # Paginate through deletions to handle large datasets efficiently
+ batch_size = 10000
+ offset = 0
+
+ while True:
+ search_results = self.index.search("", {
+ "filter": f'userId = "{user_id}"',
+ "limit": batch_size,
+ "offset": offset,
+ "attributesToRetrieve": ["id"]
+ })
+
+ old_doc_ids = [hit["id"] for hit in search_results.get("hits", [])]
+
+ if not old_doc_ids:
+ break # No more documents to delete
+
+ self.index.delete_documents(old_doc_ids)
+ deleted_total += len(old_doc_ids)
+ log.info(f"Deleted batch of {len(old_doc_ids)} documents (total: {deleted_total})")
+
+ # If we got fewer results than batch_size, we're done
+ if len(old_doc_ids) < batch_size:
+ break
+
+ if deleted_total > 0:
+ log.info(f"Successfully queued deletion of {deleted_total} old documents for user {user_id}")
+ else:
+ log.info(f"No existing documents found for user {user_id}")
+ except Exception as e:
+ log.error(f"Failed to delete existing documents for user {user_id}: {e}")
+
+ # Step 2: Re-index all current chats
+ all_user_chats = Chats.get_chats_by_user_id(user_id)
+
+ total_chats = len(all_user_chats)
+ indexed_count = 0
+
+ for chat in all_user_chats:
+ try:
+ self.index_chat(chat.id, chat.user_id)
+ indexed_count += 1
+ log.info(f"Indexed chat {chat.id} ({indexed_count}/{total_chats})")
+ except Exception as e:
+ log.error(f"Failed to index chat {chat.id}: {e}")
+
+ log.info(f"Re-indexing complete: deleted {deleted_total} old documents, indexed {indexed_count}/{total_chats} chats")
+ return {"total": total_chats, "indexed": indexed_count, "deleted": deleted_total}
+
+ def parse_query_with_filters(self, raw_query: str) -> (str, Dict[str, Any]):
+ """Parses special syntax like tag: and folder: from the query."""
+ # TODO: Implement query parsing
+ return raw_query, {}
+
+
+def initialize_meilisearch_service():
+ """Initializes the MeiliSearch service and sets up the index."""
+ global meilisearch_service_instance
+ if MEILISEARCH_URL and MEILISEARCH_API_KEY:
+ try:
+ meilisearch_service_instance = MeiliSearchService(url=MEILISEARCH_URL, api_key=MEILISEARCH_API_KEY)
+ meilisearch_service_instance.setup_index()
+ log.info("MeiliSearch service initialized successfully.")
+ except Exception as e:
+ log.error(f"Failed to initialize MeiliSearch service: {e}")
+ meilisearch_service_instance = None
+ else:
+ log.warning("MeiliSearch URL or API key not configured. Search will be disabled.")
+
+def shutdown_meilisearch_service():
+ """Placeholder for any shutdown logic."""
+ global meilisearch_service_instance
+ meilisearch_service_instance = None
+ log.info("MeiliSearch service shut down.")
+
+def get_meilisearch_service() -> Optional[MeiliSearchService]:
+ """Returns the singleton instance of the MeiliSearch service."""
+ return meilisearch_service_instance
\ No newline at end of file
diff --git a/backend/open_webui/socket/main.py b/backend/open_webui/socket/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..818a57807f83e26de5d13635cc4c32c23b7393ff
--- /dev/null
+++ b/backend/open_webui/socket/main.py
@@ -0,0 +1,794 @@
+import asyncio
+import random
+
+import socketio
+import logging
+import sys
+import time
+from typing import Dict, Set
+from redis import asyncio as aioredis
+import pycrdt as Y
+
+from open_webui.models.users import Users, UserNameResponse
+from open_webui.models.channels import Channels
+from open_webui.models.chats import Chats
+from open_webui.models.notes import Notes, NoteUpdateForm
+from open_webui.utils.redis import (
+ get_sentinels_from_env,
+ get_sentinel_url_from_env,
+)
+
+from open_webui.config import (
+ CORS_ALLOW_ORIGIN,
+)
+
+from open_webui.env import (
+ VERSION,
+ ENABLE_WEBSOCKET_SUPPORT,
+ WEBSOCKET_MANAGER,
+ WEBSOCKET_REDIS_URL,
+ WEBSOCKET_REDIS_CLUSTER,
+ WEBSOCKET_REDIS_LOCK_TIMEOUT,
+ WEBSOCKET_SENTINEL_PORT,
+ WEBSOCKET_SENTINEL_HOSTS,
+ REDIS_KEY_PREFIX,
+)
+from open_webui.utils.auth import decode_token
+from open_webui.socket.utils import RedisDict, RedisLock, YdocManager
+from open_webui.tasks import create_task, stop_item_tasks
+from open_webui.utils.redis import get_redis_connection
+from open_webui.utils.access_control import has_access, get_users_with_access
+
+
+from open_webui.env import (
+ GLOBAL_LOG_LEVEL,
+ SRC_LOG_LEVELS,
+)
+
+
+logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["SOCKET"])
+
+
+REDIS = None
+
+# Configure CORS for Socket.IO
+SOCKETIO_CORS_ORIGINS = "*" if CORS_ALLOW_ORIGIN == ["*"] else CORS_ALLOW_ORIGIN
+
+if WEBSOCKET_MANAGER == "redis":
+ if WEBSOCKET_SENTINEL_HOSTS:
+ mgr = socketio.AsyncRedisManager(
+ get_sentinel_url_from_env(
+ WEBSOCKET_REDIS_URL, WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+ )
+ )
+ else:
+ mgr = socketio.AsyncRedisManager(WEBSOCKET_REDIS_URL)
+ sio = socketio.AsyncServer(
+ cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
+ async_mode="asgi",
+ transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
+ allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
+ always_connect=True,
+ client_manager=mgr,
+ )
+else:
+ sio = socketio.AsyncServer(
+ cors_allowed_origins=SOCKETIO_CORS_ORIGINS,
+ async_mode="asgi",
+ transports=(["websocket"] if ENABLE_WEBSOCKET_SUPPORT else ["polling"]),
+ allow_upgrades=ENABLE_WEBSOCKET_SUPPORT,
+ always_connect=True,
+ )
+
+
+# Timeout duration in seconds
+TIMEOUT_DURATION = 3
+
+# Dictionary to maintain the user pool
+
+if WEBSOCKET_MANAGER == "redis":
+ log.debug("Using Redis to manage websockets.")
+ REDIS = get_redis_connection(
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=get_sentinels_from_env(
+ WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+ ),
+ redis_cluster=WEBSOCKET_REDIS_CLUSTER,
+ async_mode=True,
+ )
+
+ redis_sentinels = get_sentinels_from_env(
+ WEBSOCKET_SENTINEL_HOSTS, WEBSOCKET_SENTINEL_PORT
+ )
+ SESSION_POOL = RedisDict(
+ f"{REDIS_KEY_PREFIX}:session_pool",
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=redis_sentinels,
+ redis_cluster=WEBSOCKET_REDIS_CLUSTER,
+ )
+ USER_POOL = RedisDict(
+ f"{REDIS_KEY_PREFIX}:user_pool",
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=redis_sentinels,
+ redis_cluster=WEBSOCKET_REDIS_CLUSTER,
+ )
+ USAGE_POOL = RedisDict(
+ f"{REDIS_KEY_PREFIX}:usage_pool",
+ redis_url=WEBSOCKET_REDIS_URL,
+ redis_sentinels=redis_sentinels,
+ redis_cluster=WEBSOCKET_REDIS_CLUSTER,
+ )
+
+ clean_up_lock = RedisLock(
+ redis_url=WEBSOCKET_REDIS_URL,
+ lock_name=f"{REDIS_KEY_PREFIX}:usage_cleanup_lock",
+ timeout_secs=WEBSOCKET_REDIS_LOCK_TIMEOUT,
+ redis_sentinels=redis_sentinels,
+ redis_cluster=WEBSOCKET_REDIS_CLUSTER,
+ )
+ aquire_func = clean_up_lock.aquire_lock
+ renew_func = clean_up_lock.renew_lock
+ release_func = clean_up_lock.release_lock
+else:
+ SESSION_POOL = {}
+ USER_POOL = {}
+ USAGE_POOL = {}
+
+ aquire_func = release_func = renew_func = lambda: True
+
+
+YDOC_MANAGER = YdocManager(
+ redis=REDIS,
+ redis_key_prefix=f"{REDIS_KEY_PREFIX}:ydoc:documents",
+)
+
+
+async def periodic_usage_pool_cleanup():
+ max_retries = 2
+ retry_delay = random.uniform(
+ WEBSOCKET_REDIS_LOCK_TIMEOUT / 2, WEBSOCKET_REDIS_LOCK_TIMEOUT
+ )
+ for attempt in range(max_retries + 1):
+ if aquire_func():
+ break
+ else:
+ if attempt < max_retries:
+ log.debug(
+ f"Cleanup lock already exists. Retry {attempt + 1} after {retry_delay}s..."
+ )
+ await asyncio.sleep(retry_delay)
+ else:
+ log.warning(
+ "Failed to acquire cleanup lock after retries. Skipping cleanup."
+ )
+ return
+
+ log.debug("Running periodic_cleanup")
+ try:
+ while True:
+ if not renew_func():
+ log.error(f"Unable to renew cleanup lock. Exiting usage pool cleanup.")
+ raise Exception("Unable to renew usage pool cleanup lock.")
+
+ now = int(time.time())
+ send_usage = False
+ for model_id, connections in list(USAGE_POOL.items()):
+ # Creating a list of sids to remove if they have timed out
+ expired_sids = [
+ sid
+ for sid, details in connections.items()
+ if now - details["updated_at"] > TIMEOUT_DURATION
+ ]
+
+ for sid in expired_sids:
+ del connections[sid]
+
+ if not connections:
+ log.debug(f"Cleaning up model {model_id} from usage pool")
+ del USAGE_POOL[model_id]
+ else:
+ USAGE_POOL[model_id] = connections
+
+ send_usage = True
+ await asyncio.sleep(TIMEOUT_DURATION)
+ finally:
+ release_func()
+
+
+app = socketio.ASGIApp(
+ sio,
+ socketio_path="/ws/socket.io",
+)
+
+
+def get_models_in_use():
+ # List models that are currently in use
+ models_in_use = list(USAGE_POOL.keys())
+ return models_in_use
+
+
+def get_active_user_ids():
+ """Get the list of active user IDs."""
+ return list(USER_POOL.keys())
+
+
+def get_user_active_status(user_id):
+ """Check if a user is currently active."""
+ return user_id in USER_POOL
+
+
+def get_user_id_from_session_pool(sid):
+ user = SESSION_POOL.get(sid)
+ if user:
+ return user["id"]
+ return None
+
+
+def get_session_ids_from_room(room):
+ """Get all session IDs from a specific room."""
+ active_session_ids = sio.manager.get_participants(
+ namespace="/",
+ room=room,
+ )
+ return [session_id[0] for session_id in active_session_ids]
+
+
+def get_user_ids_from_room(room):
+ active_session_ids = get_session_ids_from_room(room)
+
+ active_user_ids = list(
+ set([SESSION_POOL.get(session_id)["id"] for session_id in active_session_ids])
+ )
+ return active_user_ids
+
+
+def get_active_status_by_user_id(user_id):
+ if user_id in USER_POOL:
+ return True
+ return False
+
+
+@sio.on("usage")
+async def usage(sid, data):
+ if sid in SESSION_POOL:
+ model_id = data["model"]
+ # Record the timestamp for the last update
+ current_time = int(time.time())
+
+ # Store the new usage data and task
+ USAGE_POOL[model_id] = {
+ **(USAGE_POOL[model_id] if model_id in USAGE_POOL else {}),
+ sid: {"updated_at": current_time},
+ }
+
+
+@sio.event
+async def connect(sid, environ, auth):
+ user = None
+ if auth and "token" in auth:
+ data = decode_token(auth["token"])
+
+ if data is not None and "id" in data:
+ user = Users.get_user_by_id(data["id"])
+
+ if user:
+ SESSION_POOL[sid] = user.model_dump(
+ exclude=["date_of_birth", "bio", "gender"]
+ )
+ if user.id in USER_POOL:
+ USER_POOL[user.id] = USER_POOL[user.id] + [sid]
+ else:
+ USER_POOL[user.id] = [sid]
+
+
+@sio.on("user-join")
+async def user_join(sid, data):
+
+ auth = data["auth"] if "auth" in data else None
+ if not auth or "token" not in auth:
+ return
+
+ data = decode_token(auth["token"])
+ if data is None or "id" not in data:
+ return
+
+ user = Users.get_user_by_id(data["id"])
+ if not user:
+ return
+
+ SESSION_POOL[sid] = user.model_dump(exclude=["date_of_birth", "bio", "gender"])
+ if user.id in USER_POOL:
+ USER_POOL[user.id] = USER_POOL[user.id] + [sid]
+ else:
+ USER_POOL[user.id] = [sid]
+
+ # Join all the channels
+ channels = Channels.get_channels_by_user_id(user.id)
+ log.debug(f"{channels=}")
+ for channel in channels:
+ await sio.enter_room(sid, f"channel:{channel.id}")
+ return {"id": user.id, "name": user.name}
+
+
+@sio.on("join-channels")
+async def join_channel(sid, data):
+ auth = data["auth"] if "auth" in data else None
+ if not auth or "token" not in auth:
+ return
+
+ data = decode_token(auth["token"])
+ if data is None or "id" not in data:
+ return
+
+ user = Users.get_user_by_id(data["id"])
+ if not user:
+ return
+
+ # Join all the channels
+ channels = Channels.get_channels_by_user_id(user.id)
+ log.debug(f"{channels=}")
+ for channel in channels:
+ await sio.enter_room(sid, f"channel:{channel.id}")
+
+
+@sio.on("join-note")
+async def join_note(sid, data):
+ auth = data["auth"] if "auth" in data else None
+ if not auth or "token" not in auth:
+ return
+
+ token_data = decode_token(auth["token"])
+ if token_data is None or "id" not in token_data:
+ return
+
+ user = Users.get_user_by_id(token_data["id"])
+ if not user:
+ return
+
+ note = Notes.get_note_by_id(data["note_id"])
+ if not note:
+ log.error(f"Note {data['note_id']} not found for user {user.id}")
+ return
+
+ if (
+ user.role != "admin"
+ and user.id != note.user_id
+ and not has_access(user.id, type="read", access_control=note.access_control)
+ ):
+ log.error(f"User {user.id} does not have access to note {data['note_id']}")
+ return
+
+ log.debug(f"Joining note {note.id} for user {user.id}")
+ await sio.enter_room(sid, f"note:{note.id}")
+
+
+@sio.on("events:channel")
+async def channel_events(sid, data):
+ room = f"channel:{data['channel_id']}"
+ participants = sio.manager.get_participants(
+ namespace="/",
+ room=room,
+ )
+
+ sids = [sid for sid, _ in participants]
+ if sid not in sids:
+ return
+
+ event_data = data["data"]
+ event_type = event_data["type"]
+
+ if event_type == "typing":
+ await sio.emit(
+ "events:channel",
+ {
+ "channel_id": data["channel_id"],
+ "message_id": data.get("message_id", None),
+ "data": event_data,
+ "user": UserNameResponse(**SESSION_POOL[sid]).model_dump(),
+ },
+ room=room,
+ )
+
+
+@sio.on("ydoc:document:join")
+async def ydoc_document_join(sid, data):
+ """Handle user joining a document"""
+ user = SESSION_POOL.get(sid)
+
+ try:
+ document_id = data["document_id"]
+
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(
+ f"User {user.get('id')} does not have access to note {note_id}"
+ )
+ return
+
+ user_id = data.get("user_id", sid)
+ user_name = data.get("user_name", "Anonymous")
+ user_color = data.get("user_color", "#000000")
+
+ log.info(f"User {user_id} joining document {document_id}")
+ await YDOC_MANAGER.add_user(document_id=document_id, user_id=sid)
+
+ # Join Socket.IO room
+ await sio.enter_room(sid, f"doc_{document_id}")
+
+ active_session_ids = get_session_ids_from_room(f"doc_{document_id}")
+
+ # Get the Yjs document state
+ ydoc = Y.Doc()
+ updates = await YDOC_MANAGER.get_updates(document_id)
+ for update in updates:
+ ydoc.apply_update(bytes(update))
+
+ # Encode the entire document state as an update
+ state_update = ydoc.get_update()
+ await sio.emit(
+ "ydoc:document:state",
+ {
+ "document_id": document_id,
+ "state": list(state_update), # Convert bytes to list for JSON
+ "sessions": active_session_ids,
+ },
+ room=sid,
+ )
+
+ # Notify other users about the new user
+ await sio.emit(
+ "ydoc:user:joined",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "user_name": user_name,
+ "user_color": user_color,
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ log.info(f"User {user_id} successfully joined document {document_id}")
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_join: {e}")
+ await sio.emit("error", {"message": "Failed to join document"}, room=sid)
+
+
+async def document_save_handler(document_id, data, user):
+ if document_id.startswith("note:"):
+ note_id = document_id.split(":")[1]
+ note = Notes.get_note_by_id(note_id)
+ if not note:
+ log.error(f"Note {note_id} not found")
+ return
+
+ if (
+ user.get("role") != "admin"
+ and user.get("id") != note.user_id
+ and not has_access(
+ user.get("id"), type="read", access_control=note.access_control
+ )
+ ):
+ log.error(f"User {user.get('id')} does not have access to note {note_id}")
+ return
+
+ Notes.update_note_by_id(note_id, NoteUpdateForm(data=data))
+
+
+@sio.on("ydoc:document:state")
+async def yjs_document_state(sid, data):
+ """Send the current state of the Yjs document to the user"""
+ try:
+ document_id = data["document_id"]
+ room = f"doc_{document_id}"
+
+ active_session_ids = get_session_ids_from_room(room)
+
+ if sid not in active_session_ids:
+ log.warning(f"Session {sid} not in room {room}. Cannot send state.")
+ return
+
+ if not await YDOC_MANAGER.document_exists(document_id):
+ log.warning(f"Document {document_id} not found")
+ return
+
+ # Get the Yjs document state
+ ydoc = Y.Doc()
+ updates = await YDOC_MANAGER.get_updates(document_id)
+ for update in updates:
+ ydoc.apply_update(bytes(update))
+
+ # Encode the entire document state as an update
+ state_update = ydoc.get_update()
+
+ await sio.emit(
+ "ydoc:document:state",
+ {
+ "document_id": document_id,
+ "state": list(state_update), # Convert bytes to list for JSON
+ "sessions": active_session_ids,
+ },
+ room=sid,
+ )
+ except Exception as e:
+ log.error(f"Error in yjs_document_state: {e}")
+
+
+@sio.on("ydoc:document:update")
+async def yjs_document_update(sid, data):
+ """Handle Yjs document updates"""
+ try:
+ document_id = data["document_id"]
+
+ try:
+ await stop_item_tasks(REDIS, document_id)
+ except:
+ pass
+
+ user_id = data.get("user_id", sid)
+
+ update = data["update"] # List of bytes from frontend
+
+ await YDOC_MANAGER.append_to_updates(
+ document_id=document_id,
+ update=update, # Convert list of bytes to bytes
+ )
+
+ # Broadcast update to all other users in the document
+ await sio.emit(
+ "ydoc:document:update",
+ {
+ "document_id": document_id,
+ "user_id": user_id,
+ "update": update,
+ "socket_id": sid, # Add socket_id to match frontend filtering
+ },
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ async def debounced_save():
+ await asyncio.sleep(0.5)
+ await document_save_handler(
+ document_id, data.get("data", {}), SESSION_POOL.get(sid)
+ )
+
+ if data.get("data"):
+ await create_task(REDIS, debounced_save(), document_id)
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_update: {e}")
+
+
+@sio.on("ydoc:document:leave")
+async def yjs_document_leave(sid, data):
+ """Handle user leaving a document"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+
+ log.info(f"User {user_id} leaving document {document_id}")
+
+ # Remove user from the document
+ await YDOC_MANAGER.remove_user(document_id=document_id, user_id=sid)
+
+ # Leave Socket.IO room
+ await sio.leave_room(sid, f"doc_{document_id}")
+
+ # Notify other users
+ await sio.emit(
+ "ydoc:user:left",
+ {"document_id": document_id, "user_id": user_id},
+ room=f"doc_{document_id}",
+ )
+
+ if (
+ await YDOC_MANAGER.document_exists(document_id)
+ and len(await YDOC_MANAGER.get_users(document_id)) == 0
+ ):
+ log.info(f"Cleaning up document {document_id} as no users are left")
+ await YDOC_MANAGER.clear_document(document_id)
+
+ except Exception as e:
+ log.error(f"Error in yjs_document_leave: {e}")
+
+
+@sio.on("ydoc:awareness:update")
+async def yjs_awareness_update(sid, data):
+ """Handle awareness updates (cursors, selections, etc.)"""
+ try:
+ document_id = data["document_id"]
+ user_id = data.get("user_id", sid)
+ update = data["update"]
+
+ # Broadcast awareness update to all other users in the document
+ await sio.emit(
+ "ydoc:awareness:update",
+ {"document_id": document_id, "user_id": user_id, "update": update},
+ room=f"doc_{document_id}",
+ skip_sid=sid,
+ )
+
+ except Exception as e:
+ log.error(f"Error in yjs_awareness_update: {e}")
+
+
+@sio.event
+async def disconnect(sid):
+ if sid in SESSION_POOL:
+ user = SESSION_POOL[sid]
+ del SESSION_POOL[sid]
+
+ user_id = user["id"]
+ USER_POOL[user_id] = [_sid for _sid in USER_POOL[user_id] if _sid != sid]
+
+ if len(USER_POOL[user_id]) == 0:
+ del USER_POOL[user_id]
+
+ await YDOC_MANAGER.remove_user_from_all_documents(sid)
+ else:
+ pass
+ # print(f"Unknown session ID {sid} disconnected")
+
+
+def get_event_emitter(request_info, update_db=True):
+ async def __event_emitter__(event_data):
+ user_id = request_info["user_id"]
+
+ session_ids = list(
+ set(
+ USER_POOL.get(user_id, [])
+ + (
+ [request_info.get("session_id")]
+ if request_info.get("session_id")
+ else []
+ )
+ )
+ )
+
+ chat_id = request_info.get("chat_id", None)
+ message_id = request_info.get("message_id", None)
+
+ emit_tasks = [
+ sio.emit(
+ "events",
+ {
+ "chat_id": chat_id,
+ "message_id": message_id,
+ "data": event_data,
+ },
+ to=session_id,
+ )
+ for session_id in session_ids
+ ]
+
+ await asyncio.gather(*emit_tasks)
+ if (
+ update_db
+ and message_id
+ and not request_info.get("chat_id", "").startswith("local:")
+ ):
+ if "type" in event_data and event_data["type"] == "status":
+ Chats.add_message_status_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ event_data.get("data", {}),
+ )
+
+ if "type" in event_data and event_data["type"] == "message":
+ message = Chats.get_message_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ )
+
+ if message:
+ content = message.get("content", "")
+ content += event_data.get("data", {}).get("content", "")
+
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ {
+ "content": content,
+ },
+ )
+
+ if "type" in event_data and event_data["type"] == "replace":
+ content = event_data.get("data", {}).get("content", "")
+
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ {
+ "content": content,
+ },
+ )
+
+ if "type" in event_data and event_data["type"] == "embeds":
+ message = Chats.get_message_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ )
+
+ embeds = event_data.get("data", {}).get("embeds", [])
+ embeds.extend(message.get("embeds", []))
+
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ {
+ "embeds": embeds,
+ },
+ )
+
+ if "type" in event_data and event_data["type"] == "files":
+ message = Chats.get_message_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ )
+
+ files = event_data.get("data", {}).get("files", [])
+ files.extend(message.get("files", []))
+
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ {
+ "files": files,
+ },
+ )
+
+ if event_data.get("type") in ["source", "citation"]:
+ data = event_data.get("data", {})
+ if data.get("type") == None:
+ message = Chats.get_message_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ )
+
+ sources = message.get("sources", [])
+ sources.append(data)
+
+ Chats.upsert_message_to_chat_by_id_and_message_id(
+ request_info["chat_id"],
+ request_info["message_id"],
+ {
+ "sources": sources,
+ },
+ )
+
+ return __event_emitter__
+
+
+def get_event_call(request_info):
+ async def __event_caller__(event_data):
+ response = await sio.call(
+ "events",
+ {
+ "chat_id": request_info.get("chat_id", None),
+ "message_id": request_info.get("message_id", None),
+ "data": event_data,
+ },
+ to=request_info["session_id"],
+ )
+ return response
+
+ return __event_caller__
+
+
+get_event_caller = get_event_call
diff --git a/backend/open_webui/socket/utils.py b/backend/open_webui/socket/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..168d2fd88efbf2dad7bc8fa61000b8f35da71bcc
--- /dev/null
+++ b/backend/open_webui/socket/utils.py
@@ -0,0 +1,214 @@
+import json
+import uuid
+from open_webui.utils.redis import get_redis_connection
+from open_webui.env import REDIS_KEY_PREFIX
+from typing import Optional, List, Tuple
+import pycrdt as Y
+
+
+class RedisLock:
+ def __init__(
+ self,
+ redis_url,
+ lock_name,
+ timeout_secs,
+ redis_sentinels=[],
+ redis_cluster=False,
+ ):
+
+ self.lock_name = lock_name
+ self.lock_id = str(uuid.uuid4())
+ self.timeout_secs = timeout_secs
+ self.lock_obtained = False
+ self.redis = get_redis_connection(
+ redis_url,
+ redis_sentinels,
+ redis_cluster=redis_cluster,
+ decode_responses=True,
+ )
+
+ def aquire_lock(self):
+ # nx=True will only set this key if it _hasn't_ already been set
+ self.lock_obtained = self.redis.set(
+ self.lock_name, self.lock_id, nx=True, ex=self.timeout_secs
+ )
+ return self.lock_obtained
+
+ def renew_lock(self):
+ # xx=True will only set this key if it _has_ already been set
+ return self.redis.set(
+ self.lock_name, self.lock_id, xx=True, ex=self.timeout_secs
+ )
+
+ def release_lock(self):
+ lock_value = self.redis.get(self.lock_name)
+ if lock_value and lock_value == self.lock_id:
+ self.redis.delete(self.lock_name)
+
+
+class RedisDict:
+ def __init__(self, name, redis_url, redis_sentinels=[], redis_cluster=False):
+ self.name = name
+ self.redis = get_redis_connection(
+ redis_url,
+ redis_sentinels,
+ redis_cluster=redis_cluster,
+ decode_responses=True,
+ )
+
+ def __setitem__(self, key, value):
+ serialized_value = json.dumps(value)
+ self.redis.hset(self.name, key, serialized_value)
+
+ def __getitem__(self, key):
+ value = self.redis.hget(self.name, key)
+ if value is None:
+ raise KeyError(key)
+ return json.loads(value)
+
+ def __delitem__(self, key):
+ result = self.redis.hdel(self.name, key)
+ if result == 0:
+ raise KeyError(key)
+
+ def __contains__(self, key):
+ return self.redis.hexists(self.name, key)
+
+ def __len__(self):
+ return self.redis.hlen(self.name)
+
+ def keys(self):
+ return self.redis.hkeys(self.name)
+
+ def values(self):
+ return [json.loads(v) for v in self.redis.hvals(self.name)]
+
+ def items(self):
+ return [(k, json.loads(v)) for k, v in self.redis.hgetall(self.name).items()]
+
+ def get(self, key, default=None):
+ try:
+ return self[key]
+ except KeyError:
+ return default
+
+ def clear(self):
+ self.redis.delete(self.name)
+
+ def update(self, other=None, **kwargs):
+ if other is not None:
+ for k, v in other.items() if hasattr(other, "items") else other:
+ self[k] = v
+ for k, v in kwargs.items():
+ self[k] = v
+
+ def setdefault(self, key, default=None):
+ if key not in self:
+ self[key] = default
+ return self[key]
+
+
+class YdocManager:
+ def __init__(
+ self,
+ redis=None,
+ redis_key_prefix: str = f"{REDIS_KEY_PREFIX}:ydoc:documents",
+ ):
+ self._updates = {}
+ self._users = {}
+ self._redis = redis
+ self._redis_key_prefix = redis_key_prefix
+
+ async def append_to_updates(self, document_id: str, update: bytes):
+ document_id = document_id.replace(":", "_")
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ await self._redis.rpush(redis_key, json.dumps(list(update)))
+ else:
+ if document_id not in self._updates:
+ self._updates[document_id] = []
+ self._updates[document_id].append(update)
+
+ async def get_updates(self, document_id: str) -> List[bytes]:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ updates = await self._redis.lrange(redis_key, 0, -1)
+ return [bytes(json.loads(update)) for update in updates]
+ else:
+ return self._updates.get(document_id, [])
+
+ async def document_exists(self, document_id: str) -> bool:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ return await self._redis.exists(redis_key) > 0
+ else:
+ return document_id in self._updates
+
+ async def get_users(self, document_id: str) -> List[str]:
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ users = await self._redis.smembers(redis_key)
+ return list(users)
+ else:
+ return self._users.get(document_id, [])
+
+ async def add_user(self, document_id: str, user_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.sadd(redis_key, user_id)
+ else:
+ if document_id not in self._users:
+ self._users[document_id] = set()
+ self._users[document_id].add(user_id)
+
+ async def remove_user(self, document_id: str, user_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.srem(redis_key, user_id)
+ else:
+ if document_id in self._users and user_id in self._users[document_id]:
+ self._users[document_id].remove(user_id)
+
+ async def remove_user_from_all_documents(self, user_id: str):
+ if self._redis:
+ keys = await self._redis.keys(f"{self._redis_key_prefix}:*")
+ for key in keys:
+ if key.endswith(":users"):
+ await self._redis.srem(key, user_id)
+
+ document_id = key.split(":")[-2]
+ if len(await self.get_users(document_id)) == 0:
+ await self.clear_document(document_id)
+
+ else:
+ for document_id in list(self._users.keys()):
+ if user_id in self._users[document_id]:
+ self._users[document_id].remove(user_id)
+ if not self._users[document_id]:
+ del self._users[document_id]
+
+ await self.clear_document(document_id)
+
+ async def clear_document(self, document_id: str):
+ document_id = document_id.replace(":", "_")
+
+ if self._redis:
+ redis_key = f"{self._redis_key_prefix}:{document_id}:updates"
+ await self._redis.delete(redis_key)
+ redis_users_key = f"{self._redis_key_prefix}:{document_id}:users"
+ await self._redis.delete(redis_users_key)
+ else:
+ if document_id in self._updates:
+ del self._updates[document_id]
+ if document_id in self._users:
+ del self._users[document_id]
diff --git a/backend/open_webui/static/apple-touch-icon.png b/backend/open_webui/static/apple-touch-icon.png
new file mode 100644
index 0000000000000000000000000000000000000000..9807373436540a5b80ae43960cd3cb86f31eec4f
Binary files /dev/null and b/backend/open_webui/static/apple-touch-icon.png differ
diff --git a/backend/open_webui/static/assets/pdf-style.css b/backend/open_webui/static/assets/pdf-style.css
new file mode 100644
index 0000000000000000000000000000000000000000..8b4e8d23705d805b0bc3ffb6a2412405b818dc4d
--- /dev/null
+++ b/backend/open_webui/static/assets/pdf-style.css
@@ -0,0 +1,314 @@
+/* HTML and Body */
+@font-face {
+ font-family: 'NotoSans';
+ src: url('fonts/NotoSans-Variable.ttf');
+}
+
+@font-face {
+ font-family: 'NotoSansJP';
+ src: url('fonts/NotoSansJP-Variable.ttf');
+}
+
+@font-face {
+ font-family: 'NotoSansKR';
+ src: url('fonts/NotoSansKR-Variable.ttf');
+}
+
+@font-face {
+ font-family: 'NotoSansSC';
+ src: url('fonts/NotoSansSC-Variable.ttf');
+}
+
+@font-face {
+ font-family: 'NotoSansSC-Regular';
+ src: url('fonts/NotoSansSC-Regular.ttf');
+}
+
+html {
+ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', 'NotoSans', 'NotoSansJP', 'NotoSansKR',
+ 'NotoSansSC', 'Twemoji', 'STSong-Light', 'MSung-Light', 'HeiseiMin-W3', 'HYSMyeongJo-Medium',
+ Roboto, 'Helvetica Neue', Arial, sans-serif;
+ font-size: 14px; /* Default font size */
+ line-height: 1.5;
+}
+
+*,
+*::before,
+*::after {
+ box-sizing: inherit;
+}
+
+body {
+ margin: 0;
+ padding: 0;
+ background-color: #fff;
+ width: auto;
+}
+
+/* Typography */
+h1,
+h2,
+h3,
+h4,
+h5,
+h6 {
+ font-weight: 500;
+ margin: 0;
+}
+
+h1 {
+ font-size: 2.5rem;
+}
+
+h2 {
+ font-size: 2rem;
+}
+
+h3 {
+ font-size: 1.75rem;
+}
+
+h4 {
+ font-size: 1.5rem;
+}
+
+h5 {
+ font-size: 1.25rem;
+}
+
+h6 {
+ font-size: 1rem;
+}
+
+p {
+ margin-top: 0;
+ margin-bottom: 1rem;
+}
+
+/* Grid System */
+.container {
+ width: 100%;
+ padding-right: 15px;
+ padding-left: 15px;
+ margin-right: auto;
+ margin-left: auto;
+}
+
+/* Utilities */
+.text-center {
+ text-align: center;
+}
+
+/* Additional Text Utilities */
+.text-muted {
+ color: #6c757d; /* Muted text color */
+}
+
+/* Small Text */
+small {
+ font-size: 80%; /* Smaller font size relative to the base */
+ color: #6c757d; /* Lighter text color for secondary information */
+ margin-bottom: 0;
+ margin-top: 0;
+}
+
+/* Strong Element Styles */
+strong {
+ font-weight: bolder; /* Ensures the text is bold */
+ color: inherit; /* Inherits the color from its parent element */
+}
+
+/* link */
+a {
+ color: #007bff;
+ text-decoration: none;
+ background-color: transparent;
+}
+
+a:hover {
+ color: #0056b3;
+ text-decoration: underline;
+}
+
+/* General styles for lists */
+ol,
+ul,
+li {
+ padding-left: 40px; /* Increase padding to move bullet points to the right */
+ margin-left: 20px; /* Indent lists from the left */
+}
+
+/* Ordered list styles */
+ol {
+ list-style-type: decimal; /* Use numbers for ordered lists */
+ margin-bottom: 10px; /* Space after each list */
+}
+
+ol li {
+ margin-bottom: 0.5rem; /* Space between ordered list items */
+}
+
+/* Unordered list styles */
+ul {
+ list-style-type: disc; /* Use bullets for unordered lists */
+ margin-bottom: 10px; /* Space after each list */
+}
+
+ul li {
+ margin-bottom: 0.5rem; /* Space between unordered list items */
+}
+
+/* List item styles */
+li {
+ margin-bottom: 5px; /* Space between list items */
+ line-height: 1.5; /* Line height for better readability */
+}
+
+/* Nested lists */
+ol ol,
+ol ul,
+ul ol,
+ul ul {
+ padding-left: 20px;
+ margin-left: 30px; /* Further indent nested lists */
+ margin-bottom: 0; /* Remove extra margin at the bottom of nested lists */
+}
+
+/* Code blocks */
+pre {
+ background-color: #f4f4f4;
+ padding: 10px;
+ overflow-x: auto;
+ max-width: 100%; /* Ensure it doesn't overflow the page */
+ width: 80%; /* Set a specific width for a container-like appearance */
+ margin: 0 1em; /* Center the pre block */
+ box-sizing: border-box; /* Include padding in the width */
+ border: 1px solid #ccc; /* Optional: Add a border for better definition */
+ border-radius: 4px; /* Optional: Add rounded corners */
+}
+
+code {
+ font-family: 'Courier New', Courier, monospace;
+ background-color: #f4f4f4;
+ padding: 2px 4px;
+ border-radius: 4px;
+ box-sizing: border-box; /* Include padding in the width */
+}
+
+.message {
+ margin-top: 8px;
+ margin-bottom: 8px;
+ max-width: 100%;
+ overflow-wrap: break-word;
+}
+
+/* Table Styles */
+table {
+ width: 100%;
+ margin-bottom: 1rem;
+ color: #212529;
+ border-collapse: collapse; /* Removes the space between borders */
+}
+
+th,
+td {
+ margin: 0;
+ padding: 0.75rem;
+ vertical-align: top;
+ border-top: 1px solid #dee2e6;
+}
+
+thead th {
+ vertical-align: bottom;
+ border-bottom: 2px solid #dee2e6;
+}
+
+tbody + tbody {
+ border-top: 2px solid #dee2e6;
+}
+
+/* markdown-section styles */
+.markdown-section blockquote,
+.markdown-section h1,
+.markdown-section h2,
+.markdown-section h3,
+.markdown-section h4,
+.markdown-section h5,
+.markdown-section h6,
+.markdown-section p,
+.markdown-section pre,
+.markdown-section table,
+.markdown-section ul {
+ /* Give most block elements margin top and bottom */
+ margin-top: 1rem;
+}
+
+/* Remove top margin if it's the first child */
+.markdown-section blockquote:first-child,
+.markdown-section h1:first-child,
+.markdown-section h2:first-child,
+.markdown-section h3:first-child,
+.markdown-section h4:first-child,
+.markdown-section h5:first-child,
+.markdown-section h6:first-child,
+.markdown-section p:first-child,
+.markdown-section pre:first-child,
+.markdown-section table:first-child,
+.markdown-section ul:first-child {
+ margin-top: 0;
+}
+
+/* Remove top margin of