Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import hashlib | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from sentence_transformers import SentenceTransformer | |
| from typing import List, Dict, Optional | |
| import threading | |
| import logging | |
| import warnings | |
| warnings.filterwarnings('ignore', category=FutureWarning) | |
| logging.getLogger('sentence_transformers').setLevel(logging.WARNING) | |
| class VectorDatabase: | |
| """Manage vector database for document embeddings using Qdrant Cloud.""" | |
| _embedding_model = None | |
| _embedding_model_name = None | |
| _embedding_model_lock = threading.Lock() | |
| def __init__(self, collection_name: str = "documents", persist_directory: str = None): | |
| """Initialize Qdrant Client (persist_directory is ignored for Cloud)""" | |
| qdrant_url = os.getenv("QDRANT_URL") | |
| qdrant_api_key = os.getenv("QDRANT_API_KEY") | |
| if not qdrant_url or not qdrant_api_key: | |
| raise ValueError("QDRANT_URL and QDRANT_API_KEY must be set in environment variables.") | |
| self.client = QdrantClient( | |
| url=qdrant_url, | |
| api_key=qdrant_api_key, | |
| timeout=60.0 | |
| ) | |
| self.collection_name = collection_name | |
| self.vector_size = 384 # Size for standard sentence-transformers (e.g. all-MiniLM-L6-v2) | |
| # Ensure collection exists | |
| self._ensure_collection() | |
| # Load embedding model | |
| self.embedding_model = self._get_or_create_embedding_model() | |
| def _ensure_collection(self): | |
| """Creates the collection in Qdrant if it doesn't exist.""" | |
| try: | |
| collections = self.client.get_collections().collections | |
| exists = any(c.name == self.collection_name for c in collections) | |
| if not exists: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=models.VectorParams( | |
| size=self.vector_size, | |
| distance=models.Distance.COSINE | |
| ) | |
| ) | |
| except Exception as e: | |
| print(f"Error checking/creating collection: {e}") | |
| def _get_or_create_embedding_model(cls): | |
| with cls._embedding_model_lock: | |
| # Assuming you set EMBEDDING_MODEL in your config, defaulting to MiniLM | |
| model_name = os.getenv("EMBEDDING_MODEL", "all-MiniLM-L6-v2") | |
| if cls._embedding_model is None or cls._embedding_model_name != model_name: | |
| import torch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| print(f"Loading embedding model on {device}...") | |
| cls._embedding_model = SentenceTransformer(model_name, device=device) | |
| cls._embedding_model_name = model_name | |
| return cls._embedding_model | |
| def _string_to_uuid(self, string_id: str) -> str: | |
| """Qdrant requires proper UUIDs. This hashes your custom string IDs into UUIDs.""" | |
| return str(uuid.UUID(hashlib.md5(string_id.encode()).hexdigest())) | |
| def add_documents(self, texts: List[str], metadatas: List[Dict], ids: List[str]): | |
| if not texts: | |
| return | |
| embeddings = self.embedding_model.encode(texts, show_progress_bar=False, batch_size=64).tolist() | |
| points = [] | |
| for i in range(len(texts)): | |
| payload = metadatas[i] if metadatas[i] else {} | |
| payload['text'] = texts[i] # Store actual text in payload for retrieval | |
| points.append(models.PointStruct( | |
| id=self._string_to_uuid(ids[i]), | |
| vector=embeddings[i], | |
| payload=payload | |
| )) | |
| # REMOVED self.client.upsert() | |
| # ADDED self.client.upload_points() with native auto-batching | |
| self.client.upload_points( | |
| collection_name=self.collection_name, | |
| points=points, | |
| batch_size=100, # Qdrant will automatically cut the payload into chunks of 100! | |
| wait=True # Ensures the upload finishes before returning to Flutter | |
| ) | |
| def query(self, query_text: str, n_results: int = 5, filter_dict: Optional[Dict] = None) -> Dict: | |
| # Check if collection is empty | |
| count = self.get_collection_count() | |
| if count == 0: | |
| return {"documents": [[]], "metadatas": [[]], "distances": [[]], "ids": [[]]} | |
| query_embedding = self.embedding_model.encode([query_text])[0].tolist() | |
| # Build Qdrant filter if provided | |
| qdrant_filter = None | |
| if filter_dict: | |
| conditions = [ | |
| models.FieldCondition(key=k, match=models.MatchValue(value=v)) | |
| for k, v in filter_dict.items() | |
| ] | |
| qdrant_filter = models.Filter(must=conditions) | |
| search_result = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| query_filter=qdrant_filter, | |
| limit=n_results | |
| ) | |
| # Format output to match exactly what your HybridRetriever expects (ChromaDB style) | |
| docs, metas, scores, ids = [], [], [], [] | |
| for hit in search_result: | |
| docs.append(hit.payload.get('text', '')) | |
| # Remove text from metadata so it mimics Chroma | |
| meta = {k: v for k, v in hit.payload.items() if k != 'text'} | |
| metas.append(meta) | |
| scores.append(hit.score) | |
| ids.append(str(hit.id)) | |
| return { | |
| "documents": [docs], | |
| "metadatas": [metas], | |
| "distances": [scores], # Note: Qdrant uses cosine similarity (higher is better), Chroma uses distance. | |
| "ids": [ids] | |
| } | |
| def get_collection_count(self) -> int: | |
| try: | |
| return self.client.count(collection_name=self.collection_name).count | |
| except Exception: | |
| return 0 |