SwatGarg commited on
Commit
63b46ac
1 Parent(s): eb5e96e

Create retriever.py

Browse files
Files changed (1) hide show
  1. retriever.py +120 -0
retriever.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.embeddings import HuggingFaceEmbeddings
2
+ from langchain.retrievers import ParentDocumentRetriever
3
+ from langchain.storage import InMemoryStore
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_community.document_loaders import PyMuPDFLoader
7
+ import os
8
+
9
+ # Function to create embeddings
10
+ # def create_embeddings(text_chunks):
11
+ # embeddings = embeddings_model.encode(text_chunks, show_progress_bar=True)
12
+ # return embeddings
13
+
14
+ curr_dir = os.getcwd()
15
+ db_path = os.path.join(os.path.dirname(os.path.dirname(curr_dir)), 'src','vector_db','chroma_db_v2')
16
+
17
+ def process_pdf_document(file_path, parent_chunk_size=2000, child_chunk_size=500):
18
+ '''
19
+ Process a PDF document and return the documents and text splitters
20
+
21
+ Args:
22
+ file_path (str): The path to the PDF document
23
+ parent_chunk_size (int): The size of the parent chunks
24
+ child_chunk_size (int): The size of the child chunks
25
+
26
+ Returns:
27
+ documents (list): The list of documents
28
+ parent_splitter (RecursiveCharacterTextSplitter): The text splitter for the parent documents
29
+ child_splitter (RecursiveCharacterTextSplitter): The text splitter for the child documents
30
+
31
+ '''
32
+ # Load the PDF document
33
+ loader = PyMuPDFLoader(file_path)
34
+ documents = loader.load()
35
+
36
+ # Initialize text splitters for parent and child documents
37
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=parent_chunk_size)
38
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=child_chunk_size)
39
+
40
+ return documents, parent_splitter, child_splitter
41
+
42
+
43
+ # Function to create the vectorstore
44
+ def create_vectorstore(embeddings_model="all-MiniLM-L6-v2"):
45
+ '''
46
+ Create the vectorstore and store for the documents
47
+
48
+ Args:
49
+ embeddings_model (HuggingFaceEmbeddings): The embeddings model
50
+ documents (list): The list of documents
51
+
52
+ Returns:
53
+ vectorstore (Chroma): The vectorstore
54
+ store (InMemoryStore): The store
55
+
56
+ '''
57
+
58
+ # Initialize the embedding model
59
+ embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
60
+
61
+ # This text splitter is used to create the parent documents
62
+ parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
63
+
64
+ # This text splitter is used to create the child documents
65
+ # It should create documents smaller than the parent
66
+ child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
67
+
68
+ # The vectorstore to use to index the child chunks
69
+ vectorstore = Chroma(
70
+ collection_name="split_parents", embedding_function=embeddings_model
71
+ )
72
+ vectordb = Chroma(persist_directory=db_path,
73
+ embedding_function=embeddings_model)
74
+ # The storage layer for the parent documents
75
+ store = InMemoryStore()
76
+
77
+ return vectordb, store
78
+
79
+
80
+
81
+ def rag_retriever(vectorstore):
82
+ '''
83
+ Create the retriever for the RAG model
84
+
85
+ Args:
86
+ vectorstore (Chroma): The vectorstore
87
+ store (InMemoryStore): The store
88
+ parent_splitter (RecursiveCharacterTextSplitter): The text splitter for the parent documents
89
+ child_splitter (RecursiveCharacterTextSplitter): The text splitter for the child documents
90
+
91
+ Returns:
92
+ retriever (ParentDocumentRetriever): The retriever
93
+
94
+ '''
95
+
96
+ # retriever = ParentDocumentRetriever(
97
+ # vectorstore=vectorstore,
98
+ # docstore=store,
99
+ # child_splitter=None,
100
+ # parent_splitter=None,
101
+ # docs=documents
102
+ # )
103
+
104
+ # retriever.add_documents(documents)
105
+ retriever = vectorstore.as_retriever()
106
+
107
+ return retriever
108
+
109
+
110
+
111
+
112
+ # def retrieve_context(query, top_k):
113
+
114
+ # # Retrieve the top k similar documents
115
+ # sub_docs = vectorstore.similarity_search(query, k=top_k, return_documents=True)
116
+
117
+ # # Get the context of the first document
118
+ # context = sub_docs[0].page_content
119
+
120
+ # return context