|
import chromadb |
|
from chromadb import Settings |
|
from chromadb.utils.batch_utils import create_batches |
|
|
|
from typing import Optional |
|
|
|
from open_webui.apps.retrieval.vector.main import VectorItem, SearchResult, GetResult |
|
from open_webui.config import ( |
|
CHROMA_DATA_PATH, |
|
CHROMA_HTTP_HOST, |
|
CHROMA_HTTP_PORT, |
|
CHROMA_HTTP_HEADERS, |
|
CHROMA_HTTP_SSL, |
|
CHROMA_TENANT, |
|
CHROMA_DATABASE, |
|
) |
|
|
|
|
|
class ChromaClient: |
|
def __init__(self): |
|
if CHROMA_HTTP_HOST != "": |
|
self.client = chromadb.HttpClient( |
|
host=CHROMA_HTTP_HOST, |
|
port=CHROMA_HTTP_PORT, |
|
headers=CHROMA_HTTP_HEADERS, |
|
ssl=CHROMA_HTTP_SSL, |
|
tenant=CHROMA_TENANT, |
|
database=CHROMA_DATABASE, |
|
settings=Settings(allow_reset=True, anonymized_telemetry=False), |
|
) |
|
else: |
|
self.client = chromadb.PersistentClient( |
|
path=CHROMA_DATA_PATH, |
|
settings=Settings(allow_reset=True, anonymized_telemetry=False), |
|
tenant=CHROMA_TENANT, |
|
database=CHROMA_DATABASE, |
|
) |
|
|
|
def has_collection(self, collection_name: str) -> bool: |
|
|
|
collections = self.client.list_collections() |
|
return collection_name in [collection.name for collection in collections] |
|
|
|
def delete_collection(self, collection_name: str): |
|
|
|
return self.client.delete_collection(name=collection_name) |
|
|
|
def search( |
|
self, collection_name: str, vectors: list[list[float | int]], limit: int |
|
) -> Optional[SearchResult]: |
|
|
|
try: |
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.query( |
|
query_embeddings=vectors, |
|
n_results=limit, |
|
) |
|
|
|
return SearchResult( |
|
**{ |
|
"ids": result["ids"], |
|
"distances": result["distances"], |
|
"documents": result["documents"], |
|
"metadatas": result["metadatas"], |
|
} |
|
) |
|
return None |
|
except Exception as e: |
|
return None |
|
|
|
def query( |
|
self, collection_name: str, filter: dict, limit: Optional[int] = None |
|
) -> Optional[GetResult]: |
|
|
|
try: |
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.get( |
|
where=filter, |
|
limit=limit, |
|
) |
|
|
|
return GetResult( |
|
**{ |
|
"ids": [result["ids"]], |
|
"documents": [result["documents"]], |
|
"metadatas": [result["metadatas"]], |
|
} |
|
) |
|
return None |
|
except Exception as e: |
|
print(e) |
|
return None |
|
|
|
def get(self, collection_name: str) -> Optional[GetResult]: |
|
|
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
result = collection.get() |
|
return GetResult( |
|
**{ |
|
"ids": [result["ids"]], |
|
"documents": [result["documents"]], |
|
"metadatas": [result["metadatas"]], |
|
} |
|
) |
|
return None |
|
|
|
def insert(self, collection_name: str, items: list[VectorItem]): |
|
|
|
collection = self.client.get_or_create_collection( |
|
name=collection_name, metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
ids = [item["id"] for item in items] |
|
documents = [item["text"] for item in items] |
|
embeddings = [item["vector"] for item in items] |
|
metadatas = [item["metadata"] for item in items] |
|
|
|
for batch in create_batches( |
|
api=self.client, |
|
documents=documents, |
|
embeddings=embeddings, |
|
ids=ids, |
|
metadatas=metadatas, |
|
): |
|
collection.add(*batch) |
|
|
|
def upsert(self, collection_name: str, items: list[VectorItem]): |
|
|
|
collection = self.client.get_or_create_collection( |
|
name=collection_name, metadata={"hnsw:space": "cosine"} |
|
) |
|
|
|
ids = [item["id"] for item in items] |
|
documents = [item["text"] for item in items] |
|
embeddings = [item["vector"] for item in items] |
|
metadatas = [item["metadata"] for item in items] |
|
|
|
collection.upsert( |
|
ids=ids, documents=documents, embeddings=embeddings, metadatas=metadatas |
|
) |
|
|
|
def delete( |
|
self, |
|
collection_name: str, |
|
ids: Optional[list[str]] = None, |
|
filter: Optional[dict] = None, |
|
): |
|
|
|
collection = self.client.get_collection(name=collection_name) |
|
if collection: |
|
if ids: |
|
collection.delete(ids=ids) |
|
elif filter: |
|
collection.delete(where=filter) |
|
|
|
def reset(self): |
|
|
|
return self.client.reset() |
|
|