| """ |
| ChromaDB vector storage interface. |
| |
| This module provides a clean interface to ChromaDB for storing and retrieving |
| document chunks with their embeddings and metadata. |
| """ |
|
|
| import chromadb |
| from typing import List, Optional |
| import numpy as np |
| import json |
| from datetime import datetime |
| from src.config.settings import get_settings, get_collection_name_for_model, EMBEDDING_MODELS |
| from src.utils.logging import get_logger |
| from src.ingestion.models import Chunk |
|
|
| logger = get_logger(__name__) |
|
|
|
|
| class VectorStore: |
| """ChromaDB interface for vector storage.""" |
|
|
| def __init__(self, embedding_model: Optional[str] = None): |
| """ |
| Initialize vector store with settings from configuration. |
| |
| Args: |
| embedding_model: Optional embedding model ID. If provided, uses model-specific collection. |
| """ |
| settings = get_settings() |
| self.persist_dir = settings.chroma_persist_dir |
| self._base_collection_name = settings.chroma_collection_name |
| self._embedding_model = embedding_model or settings.embedding_model |
|
|
| |
| self.collection_name = get_collection_name_for_model( |
| self._embedding_model, |
| self._base_collection_name |
| ) |
|
|
| self._client = None |
| self._collection = None |
|
|
| @property |
| def client(self): |
| """ |
| Lazy initialize ChromaDB client. |
| |
| Returns: |
| chromadb.Client: ChromaDB client instance |
| """ |
| if self._client is None: |
| logger.info(f"Initializing ChromaDB client: {self.persist_dir}") |
| self._client = chromadb.PersistentClient(path=self.persist_dir) |
| logger.debug(f"ChromaDB client initialized") |
| return self._client |
|
|
| def get_collection(self): |
| """ |
| Get or create the collection. |
| |
| Returns: |
| chromadb.Collection: Collection instance |
| """ |
| if self._collection is None: |
| self._collection = self.client.get_or_create_collection( |
| name=self.collection_name, |
| metadata={"description": "Hierarchical PDF chunks with embeddings"} |
| ) |
| logger.info(f"Collection loaded: {self.collection_name}") |
| return self._collection |
|
|
| def add_chunks(self, chunks: List[Chunk], embeddings: np.ndarray): |
| """ |
| Add chunks with embeddings to ChromaDB. |
| |
| Args: |
| chunks: List of chunks to store |
| embeddings: Numpy array of embeddings (num_chunks x embedding_dim) |
| """ |
| if len(chunks) != len(embeddings): |
| raise ValueError(f"Number of chunks ({len(chunks)}) != number of embeddings ({len(embeddings)})") |
|
|
| collection = self.get_collection() |
|
|
| |
| ids = [str(chunk.chunk_id) for chunk in chunks] |
| documents = [chunk.text for chunk in chunks] |
| metadatas = [self._prepare_metadata(chunk) for chunk in chunks] |
|
|
| logger.info(f"Adding {len(chunks)} chunks to ChromaDB") |
|
|
| |
| collection.add( |
| ids=ids, |
| embeddings=embeddings.tolist(), |
| documents=documents, |
| metadatas=metadatas |
| ) |
|
|
| logger.info(f"Successfully added {len(chunks)} chunks") |
|
|
| def _prepare_metadata(self, chunk: Chunk) -> dict: |
| """ |
| Prepare metadata for ChromaDB storage. |
| |
| ChromaDB metadata can only contain: str, int, float, bool. |
| Lists must be JSON-encoded. |
| |
| Args: |
| chunk: Chunk to extract metadata from |
| |
| Returns: |
| dict: Metadata dictionary |
| """ |
| return { |
| "chunk_id": str(chunk.chunk_id), |
| "document_id": str(chunk.document_id), |
| "parent_id": str(chunk.parent_id) if chunk.parent_id else "", |
| "chunk_type": chunk.chunk_type, |
| "token_count": chunk.token_count, |
| "chunk_index": chunk.chunk_index, |
| "page_numbers": json.dumps(chunk.page_numbers), |
| "start_char": chunk.start_char, |
| "end_char": chunk.end_char, |
| "file_hash": chunk.file_hash, |
| "filename": chunk.filename, |
| } |
|
|
| def document_exists(self, file_hash: str) -> bool: |
| """ |
| Check if document with given hash already exists. |
| |
| Args: |
| file_hash: SHA256 hash of document |
| |
| Returns: |
| bool: True if document exists |
| """ |
| collection = self.get_collection() |
|
|
| try: |
| |
| results = collection.get( |
| where={"file_hash": file_hash}, |
| limit=1 |
| ) |
| exists = len(results['ids']) > 0 |
| if exists: |
| logger.debug(f"Document with hash {file_hash[:8]}... already exists") |
| return exists |
| except Exception as e: |
| |
| logger.debug(f"Document check failed: {e}") |
| return False |
|
|
| def get_chunk(self, chunk_id: str) -> Optional[dict]: |
| """ |
| Retrieve a specific chunk by ID. |
| |
| Args: |
| chunk_id: UUID of chunk to retrieve |
| |
| Returns: |
| Optional[dict]: Chunk data or None if not found |
| """ |
| collection = self.get_collection() |
|
|
| try: |
| results = collection.get( |
| ids=[chunk_id], |
| include=["documents", "metadatas", "embeddings"] |
| ) |
|
|
| if len(results['ids']) > 0: |
| return { |
| "id": results['ids'][0], |
| "document": results['documents'][0], |
| "metadata": results['metadatas'][0], |
| "embedding": results['embeddings'][0] if results['embeddings'] else None |
| } |
| return None |
| except Exception as e: |
| logger.error(f"Failed to retrieve chunk {chunk_id}: {e}") |
| return None |
|
|
| def delete_document(self, document_id: str): |
| """ |
| Delete all chunks for a document. |
| |
| Args: |
| document_id: UUID of document to delete |
| """ |
| collection = self.get_collection() |
|
|
| try: |
| collection.delete( |
| where={"document_id": document_id} |
| ) |
| logger.info(f"Deleted all chunks for document: {document_id}") |
| except Exception as e: |
| logger.error(f"Failed to delete document {document_id}: {e}") |
| raise |
|
|
| def get_collection_stats(self) -> dict: |
| """ |
| Get statistics about the collection. |
| |
| Returns: |
| dict: Collection statistics |
| """ |
| collection = self.get_collection() |
|
|
| try: |
| count = collection.count() |
| return { |
| "name": self.collection_name, |
| "total_chunks": count, |
| "persist_dir": self.persist_dir, |
| "embedding_model": self._embedding_model, |
| } |
| except Exception as e: |
| logger.error(f"Failed to get collection stats: {e}") |
| return {} |
|
|
| def list_all_collections(self) -> List[dict]: |
| """ |
| List all available collections with their stats. |
| |
| Returns: |
| List[dict]: List of collection info dictionaries |
| """ |
| collections = [] |
| settings = get_settings() |
|
|
| for model_id, model_config in EMBEDDING_MODELS.items(): |
| collection_name = get_collection_name_for_model( |
| model_id, |
| self._base_collection_name |
| ) |
| try: |
| coll = self.client.get_collection(name=collection_name) |
| count = coll.count() |
| collections.append({ |
| "collection_name": collection_name, |
| "embedding_model": model_id, |
| "model_name": model_config.get("name", model_id), |
| "dimensions": model_config.get("dimensions"), |
| "total_chunks": count, |
| "is_active": model_id == self._embedding_model, |
| }) |
| except Exception: |
| |
| collections.append({ |
| "collection_name": collection_name, |
| "embedding_model": model_id, |
| "model_name": model_config.get("name", model_id), |
| "dimensions": model_config.get("dimensions"), |
| "total_chunks": 0, |
| "is_active": model_id == self._embedding_model, |
| }) |
|
|
| return collections |
|
|
| def switch_collection(self, embedding_model: str): |
| """ |
| Switch to a different collection based on embedding model. |
| |
| Args: |
| embedding_model: Embedding model ID to switch to |
| """ |
| self._embedding_model = embedding_model |
| self.collection_name = get_collection_name_for_model( |
| embedding_model, |
| self._base_collection_name |
| ) |
| self._collection = None |
| logger.info(f"Switched to collection: {self.collection_name}") |
|
|
| def query( |
| self, |
| query_embedding: np.ndarray, |
| top_k: int = 10, |
| filter_filenames: Optional[List[str]] = None, |
| ) -> dict: |
| """ |
| Query the collection with an embedding. |
| |
| Args: |
| query_embedding: Query embedding vector |
| top_k: Number of results to return |
| filter_filenames: Optional list of filenames to filter results |
| |
| Returns: |
| dict: Query results with ids, documents, metadatas, and distances |
| """ |
| collection = self.get_collection() |
|
|
| try: |
| |
| where_clause = None |
| if filter_filenames: |
| if len(filter_filenames) == 1: |
| where_clause = {"filename": filter_filenames[0]} |
| else: |
| where_clause = {"filename": {"$in": filter_filenames}} |
|
|
| results = collection.query( |
| query_embeddings=[query_embedding.tolist()], |
| n_results=top_k, |
| include=["documents", "metadatas", "distances"], |
| where=where_clause, |
| ) |
| return results |
| except Exception as e: |
| logger.error(f"Query failed: {e}") |
| return {"ids": [], "documents": [], "metadatas": [], "distances": []} |
|
|