Sandaruth commited on
Commit
fb208af
·
1 Parent(s): 2ffda8f

multi query

Browse files
Files changed (2) hide show
  1. MultiQueryRetriever.py +216 -0
  2. Retrieval.py +1 -2
MultiQueryRetriever.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import logging
3
+ from typing import List, Optional, Sequence
4
+
5
+ from langchain_core.callbacks import (
6
+ AsyncCallbackManagerForRetrieverRun,
7
+ CallbackManagerForRetrieverRun,
8
+ )
9
+ from langchain_core.documents import Document
10
+ from langchain_core.language_models import BaseLanguageModel
11
+ from langchain_core.output_parsers import BaseOutputParser
12
+ from langchain_core.prompts.prompt import PromptTemplate
13
+ from langchain_core.retrievers import BaseRetriever
14
+
15
+ from langchain.chains.llm import LLMChain
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class LineListOutputParser(BaseOutputParser[List[str]]):
21
+ """Output parser for a list of lines."""
22
+
23
+ def parse(self, text: str) -> List[str]:
24
+ lines = text.strip().split("\n")
25
+ return lines
26
+
27
+
28
+ # Default prompt
29
+ DEFAULT_QUERY_PROMPT = PromptTemplate(
30
+ input_variables=["question"],
31
+ template="""You are an AI language model assistant. Your task is
32
+ to generate 3 different versions of the given user
33
+ question to retrieve relevant documents from a vector database.
34
+ By generating multiple perspectives on the user question,
35
+ your goal is to help the user overcome some of the limitations
36
+ of distance-based similarity search. Provide these alternative
37
+ questions separated by newlines. Original question: {question}""",
38
+ )
39
+
40
+
41
+ def _unique_documents(documents: Sequence[Document]) -> List[Document]:
42
+ return [doc for i, doc in enumerate(documents) if doc not in documents[:i]][:4]
43
+
44
+
45
+ class MultiQueryRetriever(BaseRetriever):
46
+ """Given a query, use an LLM to write a set of queries.
47
+
48
+ Retrieve docs for each query. Return the unique union of all retrieved docs.
49
+ """
50
+
51
+ retriever: BaseRetriever
52
+ llm_chain: LLMChain
53
+ verbose: bool = True
54
+ parser_key: str = "lines"
55
+ """DEPRECATED. parser_key is no longer used and should not be specified."""
56
+ include_original: bool = False
57
+ """Whether to include the original query in the list of generated queries."""
58
+
59
+ @classmethod
60
+ def from_llm(
61
+ cls,
62
+ retriever: BaseRetriever,
63
+ llm: BaseLanguageModel,
64
+ prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
65
+ parser_key: Optional[str] = None,
66
+ include_original: bool = False,
67
+ ) -> "MultiQueryRetriever":
68
+ """Initialize from llm using default template.
69
+
70
+ Args:
71
+ retriever: retriever to query documents from
72
+ llm: llm for query generation using DEFAULT_QUERY_PROMPT
73
+ include_original: Whether to include the original query in the list of
74
+ generated queries.
75
+
76
+ Returns:
77
+ MultiQueryRetriever
78
+ """
79
+ output_parser = LineListOutputParser()
80
+ llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
81
+ return cls(
82
+ retriever=retriever,
83
+ llm_chain=llm_chain,
84
+ include_original=include_original,
85
+ )
86
+
87
+ async def _aget_relevant_documents(
88
+ self,
89
+ query: str,
90
+ *,
91
+ run_manager: AsyncCallbackManagerForRetrieverRun,
92
+ ) -> List[Document]:
93
+ """Get relevant documents given a user query.
94
+
95
+ Args:
96
+ question: user query
97
+
98
+ Returns:
99
+ Unique union of relevant documents from all generated queries
100
+ """
101
+ queries = await self.agenerate_queries(query, run_manager)
102
+ if self.include_original:
103
+ queries.append(query)
104
+ documents = await self.aretrieve_documents(queries, run_manager)
105
+ return self.unique_union(documents)
106
+
107
+ async def agenerate_queries(
108
+ self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
109
+ ) -> List[str]:
110
+ """Generate queries based upon user input.
111
+
112
+ Args:
113
+ question: user query
114
+
115
+ Returns:
116
+ List of LLM generated queries that are similar to the user input
117
+ """
118
+ response = await self.llm_chain.acall(
119
+ inputs={"question": question}, callbacks=run_manager.get_child()
120
+ )
121
+ lines = response["text"]
122
+ if self.verbose:
123
+ logger.info(f"Generated queries: {lines}")
124
+ return lines
125
+
126
+ async def aretrieve_documents(
127
+ self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
128
+ ) -> List[Document]:
129
+ """Run all LLM generated queries.
130
+
131
+ Args:
132
+ queries: query list
133
+
134
+ Returns:
135
+ List of retrieved Documents
136
+ """
137
+ document_lists = await asyncio.gather(
138
+ *(
139
+ self.retriever.aget_relevant_documents(
140
+ query, callbacks=run_manager.get_child()
141
+ )
142
+ for query in queries
143
+ )
144
+ )
145
+ return [doc for docs in document_lists for doc in docs]
146
+
147
+ def _get_relevant_documents(
148
+ self,
149
+ query: str,
150
+ *,
151
+ run_manager: CallbackManagerForRetrieverRun,
152
+ ) -> List[Document]:
153
+ """Get relevant documents given a user query.
154
+
155
+ Args:
156
+ question: user query
157
+
158
+ Returns:
159
+ Unique union of relevant documents from all generated queries
160
+ """
161
+ queries = self.generate_queries(query, run_manager)
162
+ if self.include_original:
163
+ queries.append(query)
164
+ documents = self.retrieve_documents(queries, run_manager)
165
+ return self.unique_union(documents)
166
+
167
+ def generate_queries(
168
+ self, question: str, run_manager: CallbackManagerForRetrieverRun
169
+ ) -> List[str]:
170
+ """Generate queries based upon user input.
171
+
172
+ Args:
173
+ question: user query
174
+
175
+ Returns:
176
+ List of LLM generated queries that are similar to the user input
177
+ """
178
+ response = self.llm_chain(
179
+ {"question": question}, callbacks=run_manager.get_child()
180
+ )
181
+ lines = response["text"]
182
+ if self.verbose:
183
+ logger.info(f"Generated queries: {lines}")
184
+ return lines
185
+
186
+ def retrieve_documents(
187
+ self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
188
+ ) -> List[Document]:
189
+ """Run all LLM generated queries.
190
+
191
+ Args:
192
+ queries: query list
193
+
194
+ Returns:
195
+ List of retrieved Documents
196
+ """
197
+ documents = []
198
+ for query in queries:
199
+ docs = self.retriever.get_relevant_documents(
200
+ query, callbacks=run_manager.get_child()
201
+ )
202
+ documents.extend(docs)
203
+ print("retrieve documents--", len(documents))
204
+ return documents
205
+
206
+ def unique_union(self, documents: List[Document]) -> List[Document]:
207
+ """Get unique Documents.
208
+
209
+ Args:
210
+ documents: List of retrieved Documents
211
+
212
+ Returns:
213
+ List of unique retrieved Documents
214
+ """
215
+ print("unique union--", len(documents))
216
+ return _unique_documents(documents)
Retrieval.py CHANGED
@@ -16,8 +16,7 @@ bsic_chain = RetrievalQA.from_chain_type(
16
 
17
 
18
 
19
- from langchain.retrievers.multi_query import MultiQueryRetriever
20
- # from kk import MultiQueryRetriever
21
 
22
  retriever_from_llm = MultiQueryRetriever.from_llm(
23
  retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),
 
16
 
17
 
18
 
19
+ from MultiQueryRetriever import MultiQueryRetriever
 
20
 
21
  retriever_from_llm = MultiQueryRetriever.from_llm(
22
  retriever=vectorstore.as_retriever(search_kwargs={"k": 3}),