|
import logging |
|
from typing import Optional, Tuple |
|
from urllib.parse import urlparse |
|
|
|
import grpc |
|
from open_webui.config import ( |
|
QDRANT_API_KEY, |
|
QDRANT_GRPC_PORT, |
|
QDRANT_ON_DISK, |
|
QDRANT_PREFER_GRPC, |
|
QDRANT_URI, |
|
) |
|
from open_webui.env import SRC_LOG_LEVELS |
|
from open_webui.retrieval.vector.main import ( |
|
GetResult, |
|
SearchResult, |
|
VectorDBBase, |
|
VectorItem, |
|
) |
|
from qdrant_client import QdrantClient as Qclient |
|
from qdrant_client.http.exceptions import UnexpectedResponse |
|
from qdrant_client.http.models import PointStruct |
|
from qdrant_client.models import models |
|
|
|
NO_LIMIT = 999999999 |
|
|
|
log = logging.getLogger(__name__) |
|
log.setLevel(SRC_LOG_LEVELS["RAG"]) |
|
|
|
|
|
class QdrantClient(VectorDBBase): |
|
def __init__(self): |
|
self.collection_prefix = "open-webui" |
|
self.QDRANT_URI = QDRANT_URI |
|
self.QDRANT_API_KEY = QDRANT_API_KEY |
|
self.QDRANT_ON_DISK = QDRANT_ON_DISK |
|
self.PREFER_GRPC = QDRANT_PREFER_GRPC |
|
self.GRPC_PORT = QDRANT_GRPC_PORT |
|
|
|
if not self.QDRANT_URI: |
|
self.client = None |
|
return |
|
|
|
|
|
parsed = urlparse(self.QDRANT_URI) |
|
host = parsed.hostname or self.QDRANT_URI |
|
http_port = parsed.port or 6333 |
|
|
|
if self.PREFER_GRPC: |
|
self.client = Qclient( |
|
host=host, |
|
port=http_port, |
|
grpc_port=self.GRPC_PORT, |
|
prefer_grpc=self.PREFER_GRPC, |
|
api_key=self.QDRANT_API_KEY, |
|
) |
|
else: |
|
self.client = Qclient(url=self.QDRANT_URI, api_key=self.QDRANT_API_KEY) |
|
|
|
|
|
self.MEMORY_COLLECTION = f"{self.collection_prefix}_memories" |
|
self.KNOWLEDGE_COLLECTION = f"{self.collection_prefix}_knowledge" |
|
self.FILE_COLLECTION = f"{self.collection_prefix}_files" |
|
self.WEB_SEARCH_COLLECTION = f"{self.collection_prefix}_web-search" |
|
self.HASH_BASED_COLLECTION = f"{self.collection_prefix}_hash-based" |
|
|
|
def _result_to_get_result(self, points) -> GetResult: |
|
ids = [] |
|
documents = [] |
|
metadatas = [] |
|
|
|
for point in points: |
|
payload = point.payload |
|
ids.append(point.id) |
|
documents.append(payload["text"]) |
|
metadatas.append(payload["metadata"]) |
|
|
|
return GetResult( |
|
**{ |
|
"ids": [ids], |
|
"documents": [documents], |
|
"metadatas": [metadatas], |
|
} |
|
) |
|
|
|
def _get_collection_and_tenant_id(self, collection_name: str) -> Tuple[str, str]: |
|
""" |
|
Maps the traditional collection name to multi-tenant collection and tenant ID. |
|
|
|
Returns: |
|
tuple: (collection_name, tenant_id) |
|
""" |
|
|
|
tenant_id = collection_name |
|
|
|
if collection_name.startswith("user-memory-"): |
|
return self.MEMORY_COLLECTION, tenant_id |
|
|
|
|
|
elif collection_name.startswith("file-"): |
|
return self.FILE_COLLECTION, tenant_id |
|
|
|
|
|
elif collection_name.startswith("web-search-"): |
|
return self.WEB_SEARCH_COLLECTION, tenant_id |
|
|
|
|
|
elif len(collection_name) == 63 and all( |
|
c in "0123456789abcdef" for c in collection_name |
|
): |
|
return self.HASH_BASED_COLLECTION, tenant_id |
|
|
|
else: |
|
return self.KNOWLEDGE_COLLECTION, tenant_id |
|
|
|
def _extract_error_message(self, exception): |
|
""" |
|
Extract error message from either HTTP or gRPC exceptions |
|
|
|
Returns: |
|
tuple: (status_code, error_message) |
|
""" |
|
|
|
if isinstance(exception, UnexpectedResponse): |
|
try: |
|
error_data = exception.structured() |
|
error_msg = error_data.get("status", {}).get("error", "") |
|
return exception.status_code, error_msg |
|
except Exception as inner_e: |
|
log.error(f"Failed to parse HTTP error: {inner_e}") |
|
return exception.status_code, str(exception) |
|
|
|
|
|
elif isinstance(exception, grpc.RpcError): |
|
|
|
status_code = None |
|
if hasattr(exception, "code") and callable(exception.code): |
|
status_code = exception.code().value[0] |
|
|
|
|
|
error_msg = str(exception) |
|
if "details =" in error_msg: |
|
|
|
try: |
|
details_line = [ |
|
line.strip() |
|
for line in error_msg.split("\n") |
|
if "details =" in line |
|
][0] |
|
error_msg = details_line.split("details =")[1].strip(' "') |
|
except (IndexError, AttributeError): |
|
|
|
pass |
|
|
|
return status_code, error_msg |
|
|
|
|
|
return None, str(exception) |
|
|
|
def _is_collection_not_found_error(self, exception): |
|
""" |
|
Check if the exception is due to collection not found, supporting both HTTP and gRPC |
|
""" |
|
status_code, error_msg = self._extract_error_message(exception) |
|
|
|
|
|
if ( |
|
status_code == 404 |
|
and "Collection" in error_msg |
|
and "doesn't exist" in error_msg |
|
): |
|
return True |
|
|
|
|
|
if ( |
|
isinstance(exception, grpc.RpcError) |
|
and exception.code() == grpc.StatusCode.NOT_FOUND |
|
): |
|
return True |
|
|
|
return False |
|
|
|
def _is_dimension_mismatch_error(self, exception): |
|
""" |
|
Check if the exception is due to dimension mismatch, supporting both HTTP and gRPC |
|
""" |
|
status_code, error_msg = self._extract_error_message(exception) |
|
|
|
|
|
return ( |
|
"Vector dimension error" in error_msg |
|
or "dimensions mismatch" in error_msg |
|
or "invalid vector size" in error_msg |
|
) |
|
|
|
def _create_multi_tenant_collection_if_not_exists( |
|
self, mt_collection_name: str, dimension: int = 384 |
|
): |
|
""" |
|
Creates a collection with multi-tenancy configuration if it doesn't exist. |
|
Default dimension is set to 384 which corresponds to 'sentence-transformers/all-MiniLM-L6-v2'. |
|
When creating collections dynamically (insert/upsert), the actual vector dimensions will be used. |
|
""" |
|
try: |
|
|
|
self.client.create_collection( |
|
collection_name=mt_collection_name, |
|
vectors_config=models.VectorParams( |
|
size=dimension, |
|
distance=models.Distance.COSINE, |
|
on_disk=self.QDRANT_ON_DISK, |
|
), |
|
hnsw_config=models.HnswConfigDiff( |
|
payload_m=16, |
|
m=0, |
|
on_disk=self.QDRANT_ON_DISK, |
|
), |
|
) |
|
|
|
|
|
self.client.create_payload_index( |
|
collection_name=mt_collection_name, |
|
field_name="tenant_id", |
|
field_schema=models.KeywordIndexParams( |
|
type=models.KeywordIndexType.KEYWORD, |
|
is_tenant=True, |
|
on_disk=self.QDRANT_ON_DISK, |
|
), |
|
wait=True, |
|
) |
|
|
|
log.info( |
|
f"Multi-tenant collection {mt_collection_name} created with dimension {dimension}!" |
|
) |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
|
|
status_code, error_msg = self._extract_error_message(e) |
|
|
|
|
|
if (isinstance(e, UnexpectedResponse) and status_code == 409) or ( |
|
isinstance(e, grpc.RpcError) |
|
and e.code() == grpc.StatusCode.ALREADY_EXISTS |
|
): |
|
if "already exists" in error_msg: |
|
log.debug(f"Collection {mt_collection_name} already exists") |
|
return |
|
|
|
raise e |
|
except Exception as e: |
|
raise e |
|
|
|
def _create_points(self, items: list[VectorItem], tenant_id: str): |
|
""" |
|
Create point structs from vector items with tenant ID. |
|
""" |
|
return [ |
|
PointStruct( |
|
id=item["id"], |
|
vector=item["vector"], |
|
payload={ |
|
"text": item["text"], |
|
"metadata": item["metadata"], |
|
"tenant_id": tenant_id, |
|
}, |
|
) |
|
for item in items |
|
] |
|
|
|
def has_collection(self, collection_name: str) -> bool: |
|
""" |
|
Check if a logical collection exists by checking for any points with the tenant ID. |
|
""" |
|
if not self.client: |
|
return False |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
try: |
|
|
|
response = self.client.query_points( |
|
collection_name=mt_collection, |
|
query_filter=models.Filter(must=[tenant_filter]), |
|
limit=1, |
|
) |
|
|
|
|
|
return len(response.points) > 0 |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
if self._is_collection_not_found_error(e): |
|
log.debug(f"Collection {mt_collection} doesn't exist") |
|
return False |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unexpected Qdrant error: {error_msg}") |
|
return False |
|
except Exception as e: |
|
|
|
log.debug(f"Error checking collection {mt_collection}: {e}") |
|
return False |
|
|
|
def delete( |
|
self, |
|
collection_name: str, |
|
ids: Optional[list[str]] = None, |
|
filter: Optional[dict] = None, |
|
): |
|
""" |
|
Delete vectors by ID or filter from a collection with tenant isolation. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
must_conditions = [tenant_filter] |
|
should_conditions = [] |
|
|
|
if ids: |
|
for id_value in ids: |
|
should_conditions.append( |
|
models.FieldCondition( |
|
key="metadata.id", |
|
match=models.MatchValue(value=id_value), |
|
), |
|
) |
|
elif filter: |
|
for key, value in filter.items(): |
|
must_conditions.append( |
|
models.FieldCondition( |
|
key=f"metadata.{key}", |
|
match=models.MatchValue(value=value), |
|
), |
|
) |
|
|
|
try: |
|
|
|
update_result = self.client.delete( |
|
collection_name=mt_collection, |
|
points_selector=models.FilterSelector( |
|
filter=models.Filter(must=must_conditions, should=should_conditions) |
|
), |
|
) |
|
|
|
return update_result |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
if self._is_collection_not_found_error(e): |
|
log.debug( |
|
f"Collection {mt_collection} doesn't exist, nothing to delete" |
|
) |
|
return None |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unexpected Qdrant error: {error_msg}") |
|
raise |
|
except Exception as e: |
|
|
|
raise |
|
|
|
def search( |
|
self, collection_name: str, vectors: list[list[float | int]], limit: int |
|
) -> Optional[SearchResult]: |
|
""" |
|
Search for the nearest neighbor items based on the vectors with tenant isolation. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
dimension = len(vectors[0]) if vectors and len(vectors) > 0 else None |
|
|
|
try: |
|
|
|
|
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
|
|
collection_dim = self.client.get_collection( |
|
mt_collection |
|
).config.params.vectors.size |
|
|
|
if collection_dim != dimension: |
|
if collection_dim < dimension: |
|
vectors = [vector[:collection_dim] for vector in vectors] |
|
else: |
|
vectors = [ |
|
vector + [0] * (collection_dim - dimension) |
|
for vector in vectors |
|
] |
|
|
|
|
|
prefetch_query = models.Prefetch( |
|
filter=models.Filter(must=[tenant_filter]), |
|
limit=NO_LIMIT, |
|
) |
|
query_response = self.client.query_points( |
|
collection_name=mt_collection, |
|
query=vectors[0], |
|
prefetch=prefetch_query, |
|
limit=limit, |
|
) |
|
|
|
get_result = self._result_to_get_result(query_response.points) |
|
return SearchResult( |
|
ids=get_result.ids, |
|
documents=get_result.documents, |
|
metadatas=get_result.metadatas, |
|
|
|
distances=[ |
|
[(point.score + 1.0) / 2.0 for point in query_response.points] |
|
], |
|
) |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
if self._is_collection_not_found_error(e): |
|
log.debug( |
|
f"Collection {mt_collection} doesn't exist, search returns None" |
|
) |
|
return None |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unexpected Qdrant error during search: {error_msg}") |
|
raise |
|
except Exception as e: |
|
|
|
log.exception(f"Error searching collection '{collection_name}': {e}") |
|
return None |
|
|
|
def query(self, collection_name: str, filter: dict, limit: Optional[int] = None): |
|
""" |
|
Query points with filters and tenant isolation. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
if limit is None: |
|
limit = NO_LIMIT |
|
|
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
|
|
field_conditions = [] |
|
for key, value in filter.items(): |
|
field_conditions.append( |
|
models.FieldCondition( |
|
key=f"metadata.{key}", match=models.MatchValue(value=value) |
|
) |
|
) |
|
|
|
|
|
combined_filter = models.Filter(must=[tenant_filter, *field_conditions]) |
|
|
|
try: |
|
|
|
points = self.client.query_points( |
|
collection_name=mt_collection, |
|
query_filter=combined_filter, |
|
limit=limit, |
|
) |
|
|
|
return self._result_to_get_result(points.points) |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
if self._is_collection_not_found_error(e): |
|
log.debug( |
|
f"Collection {mt_collection} doesn't exist, query returns None" |
|
) |
|
return None |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unexpected Qdrant error during query: {error_msg}") |
|
raise |
|
except Exception as e: |
|
|
|
log.exception(f"Error querying collection '{collection_name}': {e}") |
|
return None |
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]: |
|
""" |
|
Get all items in a collection with tenant isolation. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
try: |
|
|
|
points = self.client.query_points( |
|
collection_name=mt_collection, |
|
query_filter=models.Filter(must=[tenant_filter]), |
|
limit=NO_LIMIT, |
|
) |
|
|
|
return self._result_to_get_result(points.points) |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
if self._is_collection_not_found_error(e): |
|
log.debug(f"Collection {mt_collection} doesn't exist, get returns None") |
|
return None |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unexpected Qdrant error during get: {error_msg}") |
|
raise |
|
except Exception as e: |
|
|
|
log.exception(f"Error getting collection '{collection_name}': {e}") |
|
return None |
|
|
|
def _handle_operation_with_error_retry( |
|
self, operation_name, mt_collection, points, dimension |
|
): |
|
""" |
|
Private helper to handle common error cases for insert and upsert operations. |
|
|
|
Args: |
|
operation_name: 'insert' or 'upsert' |
|
mt_collection: The multi-tenant collection name |
|
points: The vector points to insert/upsert |
|
dimension: The dimension of the vectors |
|
|
|
Returns: |
|
The operation result (for upsert) or None (for insert) |
|
""" |
|
try: |
|
if operation_name == "insert": |
|
self.client.upload_points(mt_collection, points) |
|
return None |
|
else: |
|
return self.client.upsert(mt_collection, points) |
|
except (UnexpectedResponse, grpc.RpcError) as e: |
|
|
|
if self._is_collection_not_found_error(e): |
|
log.info( |
|
f"Collection {mt_collection} doesn't exist. Creating it with dimension {dimension}." |
|
) |
|
|
|
self._create_multi_tenant_collection_if_not_exists( |
|
mt_collection_name=mt_collection, dimension=dimension |
|
) |
|
|
|
if operation_name == "insert": |
|
self.client.upload_points(mt_collection, points) |
|
return None |
|
else: |
|
return self.client.upsert(mt_collection, points) |
|
|
|
|
|
elif self._is_dimension_mismatch_error(e): |
|
|
|
mt_collection_info = self.client.get_collection(mt_collection) |
|
existing_size = mt_collection_info.config.params.vectors.size |
|
|
|
log.info( |
|
f"Dimension mismatch: Collection {mt_collection} expects {existing_size}, got {dimension}" |
|
) |
|
|
|
if existing_size < dimension: |
|
|
|
log.info( |
|
f"Truncating vectors from {dimension} to {existing_size} dimensions" |
|
) |
|
points = [ |
|
PointStruct( |
|
id=point.id, |
|
vector=point.vector[:existing_size], |
|
payload=point.payload, |
|
) |
|
for point in points |
|
] |
|
elif existing_size > dimension: |
|
|
|
log.info( |
|
f"Padding vectors from {dimension} to {existing_size} dimensions with zeros" |
|
) |
|
points = [ |
|
PointStruct( |
|
id=point.id, |
|
vector=point.vector |
|
+ [0] * (existing_size - len(point.vector)), |
|
payload=point.payload, |
|
) |
|
for point in points |
|
] |
|
|
|
if operation_name == "insert": |
|
self.client.upload_points(mt_collection, points) |
|
return None |
|
else: |
|
return self.client.upsert(mt_collection, points) |
|
else: |
|
|
|
_, error_msg = self._extract_error_message(e) |
|
log.warning(f"Unhandled Qdrant error: {error_msg}") |
|
raise |
|
except Exception as e: |
|
|
|
raise |
|
|
|
def insert(self, collection_name: str, items: list[VectorItem]): |
|
""" |
|
Insert items with tenant ID. |
|
""" |
|
if not self.client or not items: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
dimension = len(items[0]["vector"]) if items else None |
|
|
|
|
|
points = self._create_points(items, tenant_id) |
|
|
|
|
|
return self._handle_operation_with_error_retry( |
|
"insert", mt_collection, points, dimension |
|
) |
|
|
|
def upsert(self, collection_name: str, items: list[VectorItem]): |
|
""" |
|
Upsert items with tenant ID. |
|
""" |
|
if not self.client or not items: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
|
|
dimension = len(items[0]["vector"]) if items else None |
|
|
|
|
|
points = self._create_points(items, tenant_id) |
|
|
|
|
|
return self._handle_operation_with_error_retry( |
|
"upsert", mt_collection, points, dimension |
|
) |
|
|
|
def reset(self): |
|
""" |
|
Reset the database by deleting all collections. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
collection_names = self.client.get_collections().collections |
|
for collection_name in collection_names: |
|
if collection_name.name.startswith(self.collection_prefix): |
|
self.client.delete_collection(collection_name=collection_name.name) |
|
|
|
def delete_collection(self, collection_name: str): |
|
""" |
|
Delete a collection. |
|
""" |
|
if not self.client: |
|
return None |
|
|
|
|
|
mt_collection, tenant_id = self._get_collection_and_tenant_id(collection_name) |
|
|
|
tenant_filter = models.FieldCondition( |
|
key="tenant_id", match=models.MatchValue(value=tenant_id) |
|
) |
|
|
|
field_conditions = [tenant_filter] |
|
|
|
update_result = self.client.delete( |
|
collection_name=mt_collection, |
|
points_selector=models.FilterSelector( |
|
filter=models.Filter(must=field_conditions) |
|
), |
|
) |
|
|
|
if self.client.get_collection(mt_collection).points_count == 0: |
|
self.client.delete_collection(mt_collection) |
|
|
|
return update_result |
|
|