Connor Sutton
commited on
Commit
·
09bee79
1
Parent(s):
fbe3579
added multivector and multiquery retrievers
Browse files
langchain-streamlit-demo/llm_resources.py
CHANGED
@@ -16,6 +16,11 @@ from langchain.schema import Document, BaseRetriever
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain.vectorstores import FAISS
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
|
20 |
from qagen import get_rag_qa_gen_chain
|
21 |
from summarize import get_rag_summarization_chain
|
@@ -153,6 +158,77 @@ def get_texts_and_retriever(
|
|
153 |
return texts, ensemble_retriever
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
class StreamHandler(BaseCallbackHandler):
|
157 |
def __init__(self, container, initial_text=""):
|
158 |
self.container = container
|
|
|
16 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
17 |
from langchain.vectorstores import FAISS
|
18 |
|
19 |
+
from langchain.retrievers.multi_query import MultiQueryRetriever
|
20 |
+
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
21 |
+
from langchain.storage import InMemoryStore
|
22 |
+
import uuid
|
23 |
+
|
24 |
from defaults import DEFAULT_CHUNK_SIZE, DEFAULT_CHUNK_OVERLAP, DEFAULT_RETRIEVER_K
|
25 |
from qagen import get_rag_qa_gen_chain
|
26 |
from summarize import get_rag_summarization_chain
|
|
|
158 |
return texts, ensemble_retriever
|
159 |
|
160 |
|
161 |
+
def get_texts_and_retriever2(
|
162 |
+
uploaded_file_bytes: bytes,
|
163 |
+
openai_api_key: str,
|
164 |
+
chunk_size: int = DEFAULT_CHUNK_SIZE,
|
165 |
+
chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
|
166 |
+
k: int = DEFAULT_RETRIEVER_K,
|
167 |
+
azure_kwargs: Optional[Dict[str, str]] = None,
|
168 |
+
use_azure: bool = False,
|
169 |
+
) -> Tuple[List[Document], BaseRetriever]:
|
170 |
+
with NamedTemporaryFile() as temp_file:
|
171 |
+
temp_file.write(uploaded_file_bytes)
|
172 |
+
temp_file.seek(0)
|
173 |
+
|
174 |
+
loader = PyPDFLoader(temp_file.name)
|
175 |
+
documents = loader.load()
|
176 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
177 |
+
chunk_size=10000,
|
178 |
+
chunk_overlap=0,
|
179 |
+
)
|
180 |
+
child_text_splitter = RecursiveCharacterTextSplitter(chunk_size=400)
|
181 |
+
|
182 |
+
texts = text_splitter.split_documents(documents)
|
183 |
+
id_key = "doc_id"
|
184 |
+
|
185 |
+
text_ids = [str(uuid.uuid4()) for _ in texts]
|
186 |
+
sub_texts = []
|
187 |
+
for i, text in enumerate(texts):
|
188 |
+
_id = text_ids[i]
|
189 |
+
_sub_texts = child_text_splitter.split_documents([text])
|
190 |
+
for _text in _sub_texts:
|
191 |
+
_text.metadata[id_key] = _id
|
192 |
+
sub_texts.extend(_sub_texts)
|
193 |
+
|
194 |
+
embeddings_kwargs = {"openai_api_key": openai_api_key}
|
195 |
+
if use_azure and azure_kwargs:
|
196 |
+
azure_kwargs["azure_endpoint"] = azure_kwargs.pop("openai_api_base")
|
197 |
+
embeddings_kwargs.update(azure_kwargs)
|
198 |
+
embeddings = AzureOpenAIEmbeddings(**embeddings_kwargs)
|
199 |
+
else:
|
200 |
+
embeddings = OpenAIEmbeddings(**embeddings_kwargs)
|
201 |
+
store = InMemoryStore()
|
202 |
+
|
203 |
+
# MultiVectorRetriever
|
204 |
+
multivectorstore = FAISS.from_documents(sub_texts, embeddings)
|
205 |
+
multivector_retriever = MultiVectorRetriever(
|
206 |
+
vectorstore=multivectorstore,
|
207 |
+
base_store=store,
|
208 |
+
id_key=id_key,
|
209 |
+
)
|
210 |
+
multivector_retriever.docstore.mset(list(zip(text_ids, texts)))
|
211 |
+
# multivector_retriever.k = k
|
212 |
+
|
213 |
+
multiquery_text_splitter = RecursiveCharacterTextSplitter(
|
214 |
+
chunk_size=chunk_size,
|
215 |
+
chunk_overlap=chunk_overlap,
|
216 |
+
)
|
217 |
+
# MultiQueryRetriever
|
218 |
+
multiquery_texts = multiquery_text_splitter.split_documents(documents)
|
219 |
+
multiquerystore = FAISS.from_documents(multiquery_texts, embeddings)
|
220 |
+
multiquery_retriever = MultiQueryRetriever.from_llm(
|
221 |
+
retriever=multiquerystore.as_retriever(search_kwargs={"k": k}),
|
222 |
+
llm=ChatOpenAI(),
|
223 |
+
)
|
224 |
+
|
225 |
+
ensemble_retriever = EnsembleRetriever(
|
226 |
+
retrievers=[multiquery_retriever, multivector_retriever],
|
227 |
+
weights=[0.5, 0.5],
|
228 |
+
)
|
229 |
+
return multiquery_texts, ensemble_retriever
|
230 |
+
|
231 |
+
|
232 |
class StreamHandler(BaseCallbackHandler):
|
233 |
def __init__(self, container, initial_text=""):
|
234 |
self.container = container
|
requirements.txt
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
anthropic==0.7.7
|
2 |
faiss-cpu==1.7.4
|
3 |
-
langchain==0.0.
|
4 |
langsmith==0.0.69
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==1.3.7
|
|
|
1 |
anthropic==0.7.7
|
2 |
faiss-cpu==1.7.4
|
3 |
+
langchain==0.0.346
|
4 |
langsmith==0.0.69
|
5 |
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
6 |
openai==1.3.7
|