Spaces:
Configuration error
Configuration error
import shutil | |
from pathlib import Path | |
from typing import List, Optional, Tuple | |
import tqdm | |
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
from langchain_community.vectorstores import Chroma | |
from loguru import logger | |
from app.config.models.configs import Config | |
from app.parsers.splitter import Document | |
from app.utils import torch_device | |
class ChromaDenseVectorDB: | |
def __init__(self, persist_folder: str, config: Config): | |
self._persist_folder = persist_folder | |
self._config = config | |
logger.info(f"Embedding model config: {config}") | |
self._embeddings = SentenceTransformerEmbeddings(model_name=config.embeddings.embedding_model.model_name, | |
model_kwargs={"device": torch_device()}) | |
self.batch_size = 200 | |
self._retriever = None | |
self._vectordb = None | |
def retriever(self): | |
if self._retriever is None: | |
self._retriever = self._load_retriever() | |
return self._retriever | |
def vectordb(self): | |
if self._vectordb is None: | |
self._vectordb = Chroma( | |
persist_directory=self._persist_folder, | |
embedding_function=self._embeddings, | |
) | |
return self._vectordb | |
def generate_embeddings( | |
self, | |
docs: List[Document], | |
clear_persist_folder: bool = True, | |
): | |
if clear_persist_folder: | |
pf = Path(self._persist_folder) | |
if pf.exists() and pf.is_dir(): | |
logger.warning(f"Deleting the content of: {pf}") | |
shutil.rmtree(pf) | |
logger.info("Generating and persisting the embeddings..") | |
vectordb = None | |
for group in tqdm.tqdm( | |
chunker(docs, size=self.batch_size), | |
total=int(len(docs) / self.batch_size), | |
): | |
ids = [d.metadata["document_id"] for d in group] | |
if vectordb is None: | |
vectordb = Chroma.from_documents( | |
documents=group, | |
embedding=self._embeddings, | |
ids=ids, | |
persist_directory=self._persist_folder, | |
) | |
else: | |
vectordb.add_texts( | |
texts=[doc.page_content for doc in group], | |
embedding=self._embeddings, | |
ids=ids, | |
metadatas=[doc.metadata for doc in group], | |
) | |
logger.info("Generated embeddings. Persisting...") | |
if vectordb is not None: | |
vectordb.persist() | |
def _load_retriever(self, **kwargs): | |
return self.vectordb.as_retriever(**kwargs) | |
def get_documents_by_id(self, document_ids: List[str]) -> List[Document]: | |
results = self.retriever.vectorstore.get(ids=document_ids, include=["metadatas", "documents"]) # type: ignore | |
docs = [ | |
Document(page_content=d, metadata=m) | |
for d, m in zip(results["documents"], results["metadatas"]) | |
] | |
return docs | |
def similarity_search_with_relevance_scores( | |
self, query: str, filter: Optional[dict] | |
) -> List[Tuple[Document, float]]: | |
if isinstance(filter, dict) and len(filter) > 1: | |
filter = {"$and": [{key: {"$eq": value}} for key, value in filter.items()]} | |
print("Filter = ", filter) | |
return self.retriever.vectorstore.similarity_search_with_relevance_scores( | |
query, k=self._config.semantic_search.max_k, filter=filter | |
) | |
def chunker(seq, size): | |
return (seq[pos: pos + size] for pos in range(0, len(seq), size)) | |