Spaces:
Sleeping
Sleeping
import asyncio | |
import logging | |
from typing import List, Optional, Sequence | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
) | |
from langchain_core.documents import Document | |
from langchain_core.language_models import BaseLanguageModel | |
from langchain_core.output_parsers import BaseOutputParser | |
from langchain_core.prompts.prompt import PromptTemplate | |
from langchain_core.retrievers import BaseRetriever | |
from langchain.chains.llm import LLMChain | |
logger = logging.getLogger(__name__) | |
class LineListOutputParser(BaseOutputParser[List[str]]): | |
"""Output parser for a list of lines.""" | |
def parse(self, text: str) -> List[str]: | |
lines = text.strip().split("\n") | |
return lines | |
# Default prompt | |
DEFAULT_QUERY_PROMPT = PromptTemplate( | |
input_variables=["question"], | |
template="""You are an AI language model assistant. Your task is | |
to generate 3 different versions of the given user | |
question to retrieve relevant documents from a vector database. | |
By generating multiple perspectives on the user question, | |
your goal is to help the user overcome some of the limitations | |
of distance-based similarity search. Provide these alternative | |
questions separated by newlines. Original question: {question}""", | |
) | |
def _unique_documents(documents: Sequence[Document]) -> List[Document]: | |
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]][:4] | |
class MultiQueryRetriever(BaseRetriever): | |
"""Given a query, use an LLM to write a set of queries. | |
Retrieve docs for each query. Return the unique union of all retrieved docs. | |
""" | |
retriever: BaseRetriever | |
llm_chain: LLMChain | |
verbose: bool = True | |
parser_key: str = "lines" | |
"""DEPRECATED. parser_key is no longer used and should not be specified.""" | |
include_original: bool = False | |
"""Whether to include the original query in the list of generated queries.""" | |
def from_llm( | |
cls, | |
retriever: BaseRetriever, | |
llm: BaseLanguageModel, | |
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, | |
parser_key: Optional[str] = None, | |
include_original: bool = False, | |
) -> "MultiQueryRetriever": | |
"""Initialize from llm using default template. | |
Args: | |
retriever: retriever to query documents from | |
llm: llm for query generation using DEFAULT_QUERY_PROMPT | |
include_original: Whether to include the original query in the list of | |
generated queries. | |
Returns: | |
MultiQueryRetriever | |
""" | |
output_parser = LineListOutputParser() | |
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) | |
return cls( | |
retriever=retriever, | |
llm_chain=llm_chain, | |
include_original=include_original, | |
) | |
async def _aget_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: AsyncCallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
"""Get relevant documents given a user query. | |
Args: | |
question: user query | |
Returns: | |
Unique union of relevant documents from all generated queries | |
""" | |
queries = await self.agenerate_queries(query, run_manager) | |
if self.include_original: | |
queries.append(query) | |
documents = await self.aretrieve_documents(queries, run_manager) | |
return self.unique_union(documents) | |
async def agenerate_queries( | |
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[str]: | |
"""Generate queries based upon user input. | |
Args: | |
question: user query | |
Returns: | |
List of LLM generated queries that are similar to the user input | |
""" | |
response = await self.llm_chain.acall( | |
inputs={"question": question}, callbacks=run_manager.get_child() | |
) | |
lines = response["text"] | |
if self.verbose: | |
logger.info(f"Generated queries: {lines}") | |
return lines | |
async def aretrieve_documents( | |
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Run all LLM generated queries. | |
Args: | |
queries: query list | |
Returns: | |
List of retrieved Documents | |
""" | |
document_lists = await asyncio.gather( | |
*( | |
self.retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
for query in queries | |
) | |
) | |
return [doc for docs in document_lists for doc in docs] | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
"""Get relevant documents given a user query. | |
Args: | |
question: user query | |
Returns: | |
Unique union of relevant documents from all generated queries | |
""" | |
queries = self.generate_queries(query, run_manager) | |
if self.include_original: | |
queries.append(query) | |
documents = self.retrieve_documents(queries, run_manager) | |
return self.unique_union(documents) | |
def generate_queries( | |
self, question: str, run_manager: CallbackManagerForRetrieverRun | |
) -> List[str]: | |
"""Generate queries based upon user input. | |
Args: | |
question: user query | |
Returns: | |
List of LLM generated queries that are similar to the user input | |
""" | |
response = self.llm_chain( | |
{"question": question}, callbacks=run_manager.get_child() | |
) | |
lines = response["text"] | |
if self.verbose: | |
logger.info(f"Generated queries: {lines}") | |
return lines | |
def retrieve_documents( | |
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
"""Run all LLM generated queries. | |
Args: | |
queries: query list | |
Returns: | |
List of retrieved Documents | |
""" | |
documents = [] | |
for query in queries: | |
docs = self.retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child() | |
) | |
documents.extend(docs) | |
print("retrieve documents--", len(documents)) | |
return documents | |
def unique_union(self, documents: List[Document]) -> List[Document]: | |
"""Get unique Documents. | |
Args: | |
documents: List of retrieved Documents | |
Returns: | |
List of unique retrieved Documents | |
""" | |
print("unique union--", len(documents)) | |
return _unique_documents(documents) | |