from .base_retriever import BaseRetriever from models import BaseModel from langchain.document_loaders import PyMuPDFLoader, DirectoryLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from chromadb import PersistentClient import os class ChromaRetriever(BaseRetriever): def __init__(self, pdf_dir: str, collection_name: str, split_args: dict, embed_model: BaseModel = None, refine_model: BaseModel = None,): ''' pdf_dir: directory containing pdfs for vector database collection_name: name of collection to be used (if collection exists, it will be loaded, otherwise it will be created) split_args: dictionary of arguments for text splitter ("size": size of chunks, "overlap": overlap between chunks) embed_model: model to embed text chunks (if not provided, will use chroma's default embeddings) example: from models import GPT4Model dir = "papers" retriever = ChromaRetriever(dir, "pdfs", {"size": 2048, "overlap": 10}, embed_model=GPT4Model() ''' self.embed_model = embed_model if not os.path.exists("persist"): os.mkdir("persist") client = PersistentClient(path="persist") print(client.list_collections()) try: collection = client.get_collection(name=collection_name) except: print("Creating new collection...") print("Loading pdf papers into the vector database... ") pdf_loader = DirectoryLoader(pdf_dir, loader_cls=PyMuPDFLoader) docs = pdf_loader.load() text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"]) split_docs = text_splitter.split_documents(docs) texts = [doc.page_content for doc in split_docs] # TODO titles = [doc.metadata["title"] for doc in split_docs] collection = client.create_collection(name=collection_name) if embed_model is not None: embeddings = embed_model.embedding(texts) collection.add( embeddings=embeddings, documents=texts, ids=[str(i+1) for i in range(len(texts))], metadatas=[{"title": title} for title in titles] ) else: collection.add( documents=texts, ids=[str(i+1) for i in range(len(texts))], metadatas=[{"title": title} for title in titles] ) self.collection = collection print("Papers Loaded.") def retrieve(self, query: str, k: int = 5) -> list: ''' query: text string used to query the vector database k: number of text chunks to return returns: list of retrieved text chunks example: retriever.retrieve("how do sex chromosomes in rhesus monkeys influence proteome?", k=10) ''' if self.embed_model is not None: results = self.collection.query( query_embeddings=self.embed_model.embedding([query]), n_results=k, ) else: results = self.collection.query( query_texts=[query], n_results=k, ) return results['documents'][0], [result["title"] for result in results['metadatas'][0]]