SwatGarg commited on
Commit
1346cad
1 Parent(s): 5553e22

Create retrieverV2.py

Browse files
Files changed (1) hide show
  1. retrieverV2.py +112 -0
retrieverV2.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_community.embeddings import HuggingFaceEmbeddings
2
+ from langchain.retrievers import ParentDocumentRetriever
3
+ from langchain.storage import InMemoryStore
4
+ from langchain_community.vectorstores import Chroma
5
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
6
+ from langchain_community.document_loaders import PyMuPDFLoader
7
+ import os
8
+
9
+ # Function to create embeddings
10
+ # def create_embeddings(text_chunks):
11
+ # embeddings = embeddings_model.encode(text_chunks, show_progress_bar=True)
12
+ # return embeddings
13
+
14
+ curr_dir = os.getcwd()
15
+ db_path = os.path.join(os.path.dirname(os.path.dirname(curr_dir)), 'src','vector_db','chroma_db')
16
+
17
+ def process_pdf_document(file_path_list):
18
+ '''
19
+ Process a PDF document and return the documents and text splitters
20
+
21
+ Args:
22
+ file_path (str): The path to the PDF document
23
+ parent_chunk_size (int): The size of the parent chunks
24
+ child_chunk_size (int): The size of the child chunks
25
+
26
+ Returns:
27
+ documents (list): The list of documents
28
+ parent_splitter (RecursiveCharacterTextSplitter): The text splitter for the parent documents
29
+ child_splitter (RecursiveCharacterTextSplitter): The text splitter for the child documents
30
+
31
+ '''
32
+ # # Load the PDF document
33
+ # loader = PyMuPDFLoader(file_path)
34
+ # documents = loader.load()
35
+
36
+ loaders = [PyMuPDFLoader(file_path) for file_path in file_path_list]
37
+
38
+ documents = []
39
+ for loader in loaders:
40
+ documents.extend(loader.load())
41
+
42
+ return documents
43
+
44
+
45
+ # Function to create the vectorstore
46
+ def create_vectorstore(embeddings_model="all-MiniLM-L6-v2"):
47
+ '''
48
+ Create the vectorstore and store for the documents
49
+
50
+ Args:
51
+ embeddings_model (HuggingFaceEmbeddings): The embeddings model
52
+ documents (list): The list of documents
53
+
54
+ Returns:
55
+ vectorstore (Chroma): The vectorstore
56
+ store (InMemoryStore): The store
57
+
58
+ '''
59
+
60
+ # Initialize the embedding model
61
+ embeddings_model = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
62
+
63
+ # # This text splitter is used to create the parent documents
64
+ # parent_splitter = RecursiveCharacterTextSplitter(chunk_size=2000)
65
+
66
+ # # This text splitter is used to create the child documents
67
+ # # It should create documents smaller than the parent
68
+ # child_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
69
+
70
+ # The vectorstore to use to index the child chunks
71
+ # vectorstore = Chroma(
72
+ # collection_name="split_parents", embedding_function=embeddings_model
73
+ # )
74
+ vectordb = Chroma(persist_directory=db_path,
75
+ embedding_function=embeddings_model)
76
+
77
+ # The storage layer for the parent documents
78
+ store = InMemoryStore()
79
+
80
+ return vectordb, store
81
+
82
+
83
+
84
+ def rag_retriever(vectorstore, store, documents, parent_splitter, child_splitter):
85
+ '''
86
+ Create the retriever for the RAG model
87
+
88
+ Args:
89
+ vectorstore (Chroma): The vectorstore
90
+ store (InMemoryStore): The store
91
+ parent_splitter (RecursiveCharacterTextSplitter): The text splitter for the parent documents
92
+ child_splitter (RecursiveCharacterTextSplitter): The text splitter for the child documents
93
+
94
+ Returns:
95
+ retriever (ParentDocumentRetriever): The retriever
96
+
97
+ '''
98
+
99
+ retriever = ParentDocumentRetriever(
100
+ vectorstore=vectorstore,
101
+ docstore=store,
102
+ child_splitter=child_splitter,
103
+ parent_splitter=parent_splitter,
104
+ # docs=documents
105
+ )
106
+
107
+ retriever.add_documents(documents)
108
+ # retriever = vectorstore.as_retriever()
109
+
110
+ return retriever
111
+
112
+