Connor Sutton commited on
Commit
09bee79
·
1 Parent(s): fbe3579

added multivector and multiquery retrievers

Browse files
langchain-streamlit-demo/llm_resources.py CHANGED
@@ -16,6 +16,11 @@ from langchain.schema import Document, BaseRetriever
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
 
 
 
 
 
 
19
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
20
  from qagen import get_rag_qa_gen_chain
21
  from summarize import get_rag_summarization_chain
@@ -153,6 +158,77 @@ def get_texts_and_retriever(
153
  return texts, ensemble_retriever
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  class StreamHandler(BaseCallbackHandler):
157
  def __init__(self, container, initial_text=""):
158
  self.container = container
 
16
  from langchain.text_splitter import RecursiveCharacterTextSplitter
17
  from langchain.vectorstores import FAISS
18
 
19
+ from langchain.retrievers.multi_query import MultiQueryRetriever
20
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
21
+ from langchain.storage import InMemoryStore
22
+ import uuid
23
+
24
  from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
25
  from qagen import get_rag_qa_gen_chain
26
  from summarize import get_rag_summarization_chain
 
158
  return texts, ensemble_retriever
159
 
160
 
161
+ def get_texts_and_retriever2(
162
+ uploaded_file_bytes: bytes,
163
+ openai_api_key: str,
164
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
165
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
166
+ k: int = DEFAULT_RETRIEVER_K,
167
+ azure_kwargs: Optional[Dict[str, str]] = None,
168
+ use_azure: bool = False,
169
+ ) -> Tuple[List[Document], BaseRetriever]:
170
+ with NamedTemporaryFile() as temp_file:
171
+ temp_file.write(uploaded_file_bytes)
172
+ temp_file.seek(0)
173
+
174
+ loader = PyPDFLoader(temp_file.name)
175
+ documents = loader.load()
176
+ text_splitter = RecursiveCharacterTextSplitter(
177
+ chunk_size=10000,
178
+ chunk_overlap=0,
179
+ )
180
+ child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
181
+
182
+ texts = text_splitter.split_documents(documents)
183
+ id_key = "doc_id"
184
+
185
+ text_ids = [str(uuid.uuid4()) for _ in texts]
186
+ sub_texts = []
187
+ for i, text in enumerate(texts):
188
+ _id = text_ids[i]
189
+ _sub_texts = child_text_splitter.split_documents([text])
190
+ for _text in _sub_texts:
191
+ _text.metadata[id_key] = _id
192
+ sub_texts.extend(_sub_texts)
193
+
194
+ embeddings_kwargs = {"openai_api_key": openai_api_key}
195
+ if use_azure and azure_kwargs:
196
+ azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
197
+ embeddings_kwargs.update(azure_kwargs)
198
+ embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
199
+ else:
200
+ embeddings = OpenAIEmbeddings(**embeddings_kwargs)
201
+ store = InMemoryStore()
202
+
203
+ # MultiVectorRetriever
204
+ multivectorstore = FAISS.from_documents(sub_texts, embeddings)
205
+ multivector_retriever = MultiVectorRetriever(
206
+ vectorstore=multivectorstore,
207
+ base_store=store,
208
+ id_key=id_key,
209
+ )
210
+ multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
211
+ # multivector_retriever.k = k
212
+
213
+ multiquery_text_splitter = RecursiveCharacterTextSplitter(
214
+ chunk_size=chunk_size,
215
+ chunk_overlap=chunk_overlap,
216
+ )
217
+ # MultiQueryRetriever
218
+ multiquery_texts = multiquery_text_splitter.split_documents(documents)
219
+ multiquerystore = FAISS.from_documents(multiquery_texts, embeddings)
220
+ multiquery_retriever = MultiQueryRetriever.from_llm(
221
+ retriever=multiquerystore.as_retriever(search_kwargs={"k": k}),
222
+ llm=ChatOpenAI(),
223
+ )
224
+
225
+ ensemble_retriever = EnsembleRetriever(
226
+ retrievers=[multiquery_retriever, multivector_retriever],
227
+ weights=[0.5, 0.5],
228
+ )
229
+ return multiquery_texts, ensemble_retriever
230
+
231
+
232
  class StreamHandler(BaseCallbackHandler):
233
  def __init__(self, container, initial_text=""):
234
  self.container = container
requirements.txt CHANGED
@@ -1,6 +1,6 @@
1
  anthropic==0.7.7
2
  faiss-cpu==1.7.4
3
- langchain==0.0.345
4
  langsmith==0.0.69
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==1.3.7
 
1
  anthropic==0.7.7
2
  faiss-cpu==1.7.4
3
+ langchain==0.0.346
4
  langsmith==0.0.69
5
  numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
6
  openai==1.3.7