|
from .base_retriever import BaseRetriever |
|
from models import BaseModel |
|
|
|
from langchain.document_loaders import PyMuPDFLoader, DirectoryLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from chromadb import PersistentClient |
|
import os |
|
|
|
|
|
class ChromaRetriever(BaseRetriever): |
|
def __init__(self, |
|
pdf_dir: str, |
|
collection_name: str, |
|
split_args: dict, |
|
embed_model: BaseModel = None, |
|
refine_model: BaseModel = None,): |
|
''' |
|
pdf_dir: directory containing pdfs for vector database |
|
collection_name: name of collection to be used (if collection exists, it will be loaded, otherwise it will be created) |
|
split_args: dictionary of arguments for text splitter ("size": size of chunks, "overlap": overlap between chunks) |
|
embed_model: model to embed text chunks (if not provided, will use chroma's default embeddings) |
|
|
|
example: |
|
from models import GPT4Model |
|
dir = "papers" |
|
retriever = ChromaRetriever(dir, "pdfs", {"size": 2048, "overlap": 10}, embed_model=GPT4Model() |
|
''' |
|
self.embed_model = embed_model |
|
|
|
|
|
if not os.path.exists("persist"): |
|
os.mkdir("persist") |
|
client = PersistentClient(path="persist") |
|
|
|
try: |
|
collection = client.get_collection(name=collection_name) |
|
except: |
|
print("Creating new collection...") |
|
print("Loading pdf papers into the vector database... ") |
|
pdf_loader = DirectoryLoader(pdf_dir, loader_cls=PyMuPDFLoader) |
|
docs = pdf_loader.load() |
|
|
|
text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"]) |
|
texts = text_splitter.split_documents(docs) |
|
texts = [text.page_content for text in texts] |
|
|
|
collection = client.create_collection(name=collection_name) |
|
if embed_model is not None: |
|
embeddings = embed_model.embedding(texts) |
|
collection.add( |
|
embeddings=embeddings, |
|
documents=texts, |
|
ids=[str(i+1) for i in range(len(texts))] |
|
) |
|
else: |
|
collection.add( |
|
documents=texts, |
|
ids=[str(i+1) for i in range(len(texts))] |
|
) |
|
|
|
self.collection = collection |
|
print("Papers Loaded.") |
|
|
|
|
|
def retrieve(self, query: str, k: int = 5) -> list: |
|
''' |
|
query: text string used to query the vector database |
|
k: number of text chunks to return |
|
returns: list of retrieved text chunks |
|
|
|
example: |
|
retriever.retrieve("how do sex chromosomes in rhesus monkeys influence proteome?", k=10) |
|
''' |
|
|
|
if self.embed_model is not None: |
|
results = self.collection.query( |
|
query_embeddings=self.embed_model.embedding([query]), |
|
n_results=k, |
|
) |
|
else: |
|
results = self.collection.query( |
|
query_texts=[query], |
|
n_results=k, |
|
) |
|
return results['documents'][0] |
|
|