Spaces:
Sleeping
Sleeping
| import asyncio | |
| import logging | |
| from typing import List, Optional, Sequence | |
| from langchain_core.callbacks import ( | |
| AsyncCallbackManagerForRetrieverRun, | |
| CallbackManagerForRetrieverRun, | |
| ) | |
| from langchain_core.documents import Document | |
| from langchain_core.language_models import BaseLanguageModel | |
| from langchain_core.output_parsers import BaseOutputParser | |
| from langchain_core.prompts.prompt import PromptTemplate | |
| from langchain_core.retrievers import BaseRetriever | |
| from langchain.chains.llm import LLMChain | |
| logger = logging.getLogger(__name__) | |
| class LineListOutputParser(BaseOutputParser[List[str]]): | |
| """Output parser for a list of lines.""" | |
| def parse(self, text: str) -> List[str]: | |
| lines = text.strip().split("\n") | |
| return lines | |
| # Default prompt | |
| DEFAULT_QUERY_PROMPT = PromptTemplate( | |
| input_variables=["question"], | |
| template="""You are an AI language model assistant. Your task is | |
| to generate 3 different versions of the given user | |
| question to retrieve relevant documents from a vector database. | |
| By generating multiple perspectives on the user question, | |
| your goal is to help the user overcome some of the limitations | |
| of distance-based similarity search. Provide these alternative | |
| questions separated by newlines. Original question: {question}""", | |
| ) | |
| def _unique_documents(documents: Sequence[Document]) -> List[Document]: | |
| return [doc for i, doc in enumerate(documents) if doc not in documents[:i]][:4] | |
| class MultiQueryRetriever(BaseRetriever): | |
| """Given a query, use an LLM to write a set of queries. | |
| Retrieve docs for each query. Return the unique union of all retrieved docs. | |
| """ | |
| retriever: BaseRetriever | |
| llm_chain: LLMChain | |
| verbose: bool = True | |
| parser_key: str = "lines" | |
| """DEPRECATED. parser_key is no longer used and should not be specified.""" | |
| include_original: bool = False | |
| """Whether to include the original query in the list of generated queries.""" | |
| def from_llm( | |
| cls, | |
| retriever: BaseRetriever, | |
| llm: BaseLanguageModel, | |
| prompt: PromptTemplate = DEFAULT_QUERY_PROMPT, | |
| parser_key: Optional[str] = None, | |
| include_original: bool = False, | |
| ) -> "MultiQueryRetriever": | |
| """Initialize from llm using default template. | |
| Args: | |
| retriever: retriever to query documents from | |
| llm: llm for query generation using DEFAULT_QUERY_PROMPT | |
| include_original: Whether to include the original query in the list of | |
| generated queries. | |
| Returns: | |
| MultiQueryRetriever | |
| """ | |
| output_parser = LineListOutputParser() | |
| llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser) | |
| return cls( | |
| retriever=retriever, | |
| llm_chain=llm_chain, | |
| include_original=include_original, | |
| ) | |
| async def _aget_relevant_documents( | |
| self, | |
| query: str, | |
| *, | |
| run_manager: AsyncCallbackManagerForRetrieverRun, | |
| ) -> List[Document]: | |
| """Get relevant documents given a user query. | |
| Args: | |
| question: user query | |
| Returns: | |
| Unique union of relevant documents from all generated queries | |
| """ | |
| queries = await self.agenerate_queries(query, run_manager) | |
| if self.include_original: | |
| queries.append(query) | |
| documents = await self.aretrieve_documents(queries, run_manager) | |
| return self.unique_union(documents) | |
| async def agenerate_queries( | |
| self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun | |
| ) -> List[str]: | |
| """Generate queries based upon user input. | |
| Args: | |
| question: user query | |
| Returns: | |
| List of LLM generated queries that are similar to the user input | |
| """ | |
| response = await self.llm_chain.acall( | |
| inputs={"question": question}, callbacks=run_manager.get_child() | |
| ) | |
| lines = response["text"] | |
| if self.verbose: | |
| logger.info(f"Generated queries: {lines}") | |
| return lines | |
| async def aretrieve_documents( | |
| self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| """Run all LLM generated queries. | |
| Args: | |
| queries: query list | |
| Returns: | |
| List of retrieved Documents | |
| """ | |
| document_lists = await asyncio.gather( | |
| *( | |
| self.retriever.aget_relevant_documents( | |
| query, callbacks=run_manager.get_child() | |
| ) | |
| for query in queries | |
| ) | |
| ) | |
| return [doc for docs in document_lists for doc in docs] | |
| def _get_relevant_documents( | |
| self, | |
| query: str, | |
| *, | |
| run_manager: CallbackManagerForRetrieverRun, | |
| ) -> List[Document]: | |
| """Get relevant documents given a user query. | |
| Args: | |
| question: user query | |
| Returns: | |
| Unique union of relevant documents from all generated queries | |
| """ | |
| queries = self.generate_queries(query, run_manager) | |
| if self.include_original: | |
| queries.append(query) | |
| documents = self.retrieve_documents(queries, run_manager) | |
| return self.unique_union(documents) | |
| def generate_queries( | |
| self, question: str, run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[str]: | |
| """Generate queries based upon user input. | |
| Args: | |
| question: user query | |
| Returns: | |
| List of LLM generated queries that are similar to the user input | |
| """ | |
| response = self.llm_chain( | |
| {"question": question}, callbacks=run_manager.get_child() | |
| ) | |
| lines = response["text"] | |
| if self.verbose: | |
| logger.info(f"Generated queries: {lines}") | |
| return lines | |
| def retrieve_documents( | |
| self, queries: List[str], run_manager: CallbackManagerForRetrieverRun | |
| ) -> List[Document]: | |
| """Run all LLM generated queries. | |
| Args: | |
| queries: query list | |
| Returns: | |
| List of retrieved Documents | |
| """ | |
| documents = [] | |
| for query in queries: | |
| docs = self.retriever.get_relevant_documents( | |
| query, callbacks=run_manager.get_child() | |
| ) | |
| documents.extend(docs) | |
| print("retrieve documents--", len(documents)) | |
| return documents | |
| def unique_union(self, documents: List[Document]) -> List[Document]: | |
| """Get unique Documents. | |
| Args: | |
| documents: List of retrieved Documents | |
| Returns: | |
| List of unique retrieved Documents | |
| """ | |
| print("unique union--", len(documents)) | |
| return _unique_documents(documents) | |