Spaces:
Running
Running
# src/vector_store_manager/chroma_manager.py | |
from langchain_chroma import Chroma # cite: embed_pipeline.py, query_pipeline.py | |
from langchain.schema import Document # cite: embed_pipeline.py | |
from config.settings import PERSIST_DIR, CHROMADB_COLLECTION_NAME # cite: embed_pipeline.py, query_pipeline.py | |
from src.embedding_generator.embedder import EmbeddingGenerator | |
import logging | |
from typing import List, Dict, Any | |
logger = logging.getLogger(__name__) | |
class ChromaManager: | |
""" | |
Manages interactions with the ChromaDB vector store. | |
""" | |
def __init__(self, embedding_generator: EmbeddingGenerator): | |
self.embedding_generator = embedding_generator | |
# --- Financial Ministry Adaptation --- | |
# TODO: Configure Chroma client to use a scalable backend (e.g., ClickHouse) | |
# instead of or in addition to persistent_directory for production. | |
# This might involve using chromadb.HttpClient or specific backend configurations. | |
# Handle connection errors and retries to the database backend. | |
# Implement authentication/authorization for ChromaDB access. | |
# ------------------------------------ | |
try: | |
# Initialize Chroma with the embedding function and persistence settings | |
# For production, you might replace persist_directory with client settings | |
# pointing to a ClickHouse backend. | |
self.vectordb = Chroma( | |
persist_directory=PERSIST_DIR, # cite: embed_pipeline.py, query_pipeline.py | |
collection_name=CHROMADB_COLLECTION_NAME, # cite: embed_pipeline.py, query_pipeline.py | |
embedding_function=self.embedding_generator.embedder # Use the Langchain embedder instance | |
) | |
logger.info(f"Initialized ChromaDB collection: '{CHROMADB_COLLECTION_NAME}' at '{PERSIST_DIR}'") | |
# You might want to check if the collection exists and its health | |
except Exception as e: | |
logger.critical(f"Failed to initialize ChromaDB: {e}") | |
raise e | |
def add_documents(self, chunks: List[Document]): | |
""" | |
Adds document chunks to the ChromaDB collection. | |
Args: | |
chunks: A list of Langchain Document chunks with metadata. | |
""" | |
# --- Financial Ministry Adaptation --- | |
# Implement error handling and retry logic for batch additions. | |
# Consider transactional behavior if adding large batches requires it. | |
# Log successful and failed additions. | |
# Ensure document IDs are managed consistently (e.g., based on source + chunk index or a stable hash). | |
# ------------------------------------ | |
try: | |
# Langchain's add_documents handles embedding internally using the provided embedding_function | |
# Ensure your chunks have unique IDs if you need to update/delete later. | |
# If IDs are not in metadata, Langchain/Chroma might generate them. | |
# For better control, you might generate IDs in document_processor and pass them here. | |
if not chunks: | |
logger.warning("No chunks to add to ChromaDB.") | |
return | |
# If chunks don't have IDs, generate them (simple example) | |
# In a real system, use stable IDs based on source data | |
# chunk_ids = [f"{chunk.metadata.get('source', 'unknown')}_{i}" for i, chunk in enumerate(chunks)] | |
# self.vectordb.add_documents(chunks, ids=chunk_ids) | |
self.vectordb.add_documents(chunks) # Langchain handles IDs if not provided | |
logger.info(f"Added {len(chunks)} chunks to ChromaDB.") | |
except Exception as e: | |
logger.error(f"Failed to add documents to ChromaDB: {e}") | |
# Implement retry logic or raise exception | |
def update_documents(self, ids: List[str], documents: List[str], metadatas: List[Dict[str, Any]]): | |
""" | |
Updates documents in the ChromaDB collection by ID. | |
Args: | |
ids: List of document IDs to update. | |
documents: List of new document content corresponding to IDs. | |
metadatas: List of new metadata dictionaries corresponding to IDs. | |
""" | |
# --- Financial Ministry Adaptation --- | |
# Implement error handling and retry logic. | |
# Validate that IDs exist before attempting to update. | |
# ------------------------------------ | |
try: | |
self.vectordb._collection.update( # Accessing the underlying collection for update/delete | |
ids=ids, | |
documents=documents, | |
metadatas=metadatas | |
) | |
logger.info(f"Updated documents with IDs: {ids}") | |
except Exception as e: | |
logger.error(f"Failed to update documents with IDs {ids}: {e}") | |
raise e | |
def delete_documents(self, ids: List[str] = None, where: Dict[str, Any] = None): | |
""" | |
Deletes documents from the ChromaDB collection by ID or metadata filter. | |
Args: | |
ids: List of document IDs to delete. | |
where: A dictionary for metadata filtering (e.g., {"source": "old_file.txt"}). | |
""" | |
# --- Financial Ministry Adaptation --- | |
# Implement error handling and retry logic. | |
# Add logging to record which documents were deleted and why (if using where). | |
# ------------------------------------ | |
try: | |
if ids: | |
self.vectordb._collection.delete(ids=ids) # Accessing the underlying collection | |
logger.info(f"Deleted documents with IDs: {ids}") | |
elif where: | |
self.vectordb._collection.delete(where=where) # Accessing the underlying collection | |
logger.info(f"Deleted documents matching metadata filter: {where}") | |
else: | |
logger.warning("Delete called without specifying ids or where filter.") | |
except Exception as e: | |
logger.error(f"Failed to delete documents (ids: {ids}, where: {where}): {e}") | |
raise e | |
def get_documents(self, ids: List[str] = None, where: Dict[str, Any] = None, | |
where_document: Dict[str, Any] = None, limit: int = None, | |
offset: int = None, include: List[str] = None) -> Dict[str, List[Any]]: | |
""" | |
Retrieves documents and their details from the ChromaDB collection. | |
Args: | |
ids: List of document IDs to retrieve. | |
where: Metadata filter. | |
where_document: Document content filter. | |
limit: Maximum number of results. | |
offset: Offset for pagination. | |
include: List of fields to include (e.g., ['metadatas', 'documents']). IDs are always included. | |
Returns: | |
A dictionary containing the retrieved data (ids, documents, metadatas, etc.). | |
""" | |
# --- Financial Ministry Adaptation --- | |
# Implement error handling and retry logic. | |
# Ensure sensitive metadata is handled appropriately if retrieved. | |
# ------------------------------------ | |
try: | |
# Default include to metadatas and documents if not specified | |
if include is None: | |
include = ['metadatas', 'documents'] # Default as per Chroma docs | |
results = self.vectordb._collection.get( # Accessing the underlying collection | |
ids=ids, | |
where=where, | |
where_document=where_document, | |
limit=limit, | |
offset=offset, | |
include=include | |
) | |
logger.debug(f"Retrieved {len(results.get('ids', []))} documents from ChromaDB.") | |
return results | |
except Exception as e: | |
logger.error(f"Failed to retrieve documents from ChromaDB: {e}") | |
raise e | |
def as_retriever(self, search_kwargs: Dict[str, Any] = None): | |
""" | |
Returns a Langchain Retriever instance for the Chroma collection. | |
Args: | |
search_kwargs: Arguments for the retriever (e.g., {"k": 5}). | |
Returns: | |
A Langchain Retriever. | |
""" | |
# --- Financial Ministry Adaptation --- | |
# Consider adding default search_kwargs here if not provided. | |
# Ensure the retriever uses the configured embedding function. | |
# ------------------------------------ | |
if search_kwargs is None: | |
search_kwargs = {} | |
# Langchain's .as_retriever method automatically uses the embedding_function | |
# provided during Chroma initialization. | |
return self.vectordb.as_retriever(search_kwargs=search_kwargs) # cite: query_pipeline.py |