project / app /chroma.py
kabylake's picture
commit
7bd11ed
raw
history blame contribute delete
No virus
3.72 kB
import shutil
from pathlib import Path
from typing import List, Optional, Tuple
import tqdm
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_community.vectorstores import Chroma
from loguru import logger
from app.config.models.configs import Config
from app.parsers.splitter import Document
from app.utils import torch_device
class ChromaDenseVectorDB:
def __init__(self, persist_folder: str, config: Config):
self._persist_folder = persist_folder
self._config = config
logger.info(f"Embedding model config: {config}")
self._embeddings = SentenceTransformerEmbeddings(model_name=config.embeddings.embedding_model.model_name,
model_kwargs={"device": torch_device()})
self.batch_size = 200
self._retriever = None
self._vectordb = None
@property
def retriever(self):
if self._retriever is None:
self._retriever = self._load_retriever()
return self._retriever
@property
def vectordb(self):
if self._vectordb is None:
self._vectordb = Chroma(
persist_directory=self._persist_folder,
embedding_function=self._embeddings,
)
return self._vectordb
def generate_embeddings(
self,
docs: List[Document],
clear_persist_folder: bool = True,
):
if clear_persist_folder:
pf = Path(self._persist_folder)
if pf.exists() and pf.is_dir():
logger.warning(f"Deleting the content of: {pf}")
shutil.rmtree(pf)
logger.info("Generating and persisting the embeddings..")
vectordb = None
for group in tqdm.tqdm(
chunker(docs, size=self.batch_size),
total=int(len(docs) / self.batch_size),
):
ids = [d.metadata["document_id"] for d in group]
if vectordb is None:
vectordb = Chroma.from_documents(
documents=group,
embedding=self._embeddings,
ids=ids,
persist_directory=self._persist_folder,
)
else:
vectordb.add_texts(
texts=[doc.page_content for doc in group],
embedding=self._embeddings,
ids=ids,
metadatas=[doc.metadata for doc in group],
)
logger.info("Generated embeddings. Persisting...")
if vectordb is not None:
vectordb.persist()
def _load_retriever(self, **kwargs):
return self.vectordb.as_retriever(**kwargs)
def get_documents_by_id(self, document_ids: List[str]) -> List[Document]:
results = self.retriever.vectorstore.get(ids=document_ids, include=["metadatas", "documents"]) # type: ignore
docs = [
Document(page_content=d, metadata=m)
for d, m in zip(results["documents"], results["metadatas"])
]
return docs
def similarity_search_with_relevance_scores(
self, query: str, filter: Optional[dict]
) -> List[Tuple[Document, float]]:
if isinstance(filter, dict) and len(filter) > 1:
filter = {"$and": [{key: {"$eq": value}} for key, value in filter.items()]}
print("Filter = ", filter)
return self.retriever.vectorstore.similarity_search_with_relevance_scores(
query, k=self._config.semantic_search.max_k, filter=filter
)
def chunker(seq, size):
return (seq[pos: pos + size] for pos in range(0, len(seq), size))