notebook-backend / utils /vector_db.py
mohhhhhit's picture
Update utils/vector_db.py
c4b5910 verified
import os
import uuid
import hashlib
from qdrant_client import QdrantClient
from qdrant_client.http import models
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Optional
import threading
import logging
import warnings
warnings.filterwarnings('ignore', category=FutureWarning)
logging.getLogger('sentence_transformers').setLevel(logging.WARNING)
class VectorDatabase:
"""Manage vector database for document embeddings using Qdrant Cloud."""
_embedding_model = None
_embedding_model_name = None
_embedding_model_lock = threading.Lock()
def __init__(self, collection_name: str = "documents", persist_directory: str = None):
"""Initialize Qdrant Client (persist_directory is ignored for Cloud)"""
qdrant_url = os.getenv("QDRANT_URL")
qdrant_api_key = os.getenv("QDRANT_API_KEY")
if not qdrant_url or not qdrant_api_key:
raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.")
self.client = QdrantClient(
url=qdrant_url,
api_key=qdrant_api_key,
timeout=60.0
)
self.collection_name = collection_name
self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2)
# Ensure collection exists
self._ensure_collection()
# Load embedding model
self.embedding_model = self._get_or_create_embedding_model()
def _ensure_collection(self):
"""Creates the collection in Qdrant if it doesn't exist."""
try:
collections = self.client.get_collections().collections
exists = any(c.name == self.collection_name for c in collections)
if not exists:
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.vector_size,
distance=models.Distance.COSINE
)
)
except Exception as e:
print(f"Error checking/creating collection: {e}")
@classmethod
def _get_or_create_embedding_model(cls):
with cls._embedding_model_lock:
# Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM
model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
if cls._embedding_model is None or cls._embedding_model_name != model_name:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Loading embedding model on {device}...")
cls._embedding_model = SentenceTransformer(model_name, device=device)
cls._embedding_model_name = model_name
return cls._embedding_model
def _string_to_uuid(self, string_id: str) -> str:
"""Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs."""
return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest()))
def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]):
if not texts:
return
embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist()
points = []
for i in range(len(texts)):
payload = metadatas[i] if metadatas[i] else {}
payload['text'] = texts[i] # Store actual text in payload for retrieval
points.append(models.PointStruct(
id=self._string_to_uuid(ids[i]),
vector=embeddings[i],
payload=payload
))
# REMOVED self.client.upsert()
# ADDED self.client.upload_points() with native auto-batching
self.client.upload_points(
collection_name=self.collection_name,
points=points,
batch_size=100, # Qdrant will automatically cut the payload into chunks of 100!
wait=True # Ensures the upload finishes before returning to Flutter
)
def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict:
# Check if collection is empty
count = self.get_collection_count()
if count == 0:
return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]}
query_embedding = self.embedding_model.encode([query_text])[0].tolist()
# Build Qdrant filter if provided
qdrant_filter = None
if filter_dict:
conditions = [
models.FieldCondition(key=k, match=models.MatchValue(value=v))
for k, v in filter_dict.items()
]
qdrant_filter = models.Filter(must=conditions)
search_result = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
query_filter=qdrant_filter,
limit=n_results
)
# Format output to match exactly what your HybridRetriever expects (ChromaDB style)
docs, metas, scores, ids = [], [], [], []
for hit in search_result:
docs.append(hit.payload.get('text', ''))
# Remove text from metadata so it mimics Chroma
meta = {k: v for k, v in hit.payload.items() if k != 'text'}
metas.append(meta)
scores.append(hit.score)
ids.append(str(hit.id))
return {
"documents": [docs],
"metadatas": [metas],
"distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance.
"ids": [ids]
}
def get_collection_count(self) -> int:
try:
return self.client.count(collection_name=self.collection_name).count
except Exception:
return 0