Spaces:
Runtime error
Runtime error
File size: 7,032 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import asyncio
import logging
from typing import List, Sequence
from langchain_core.documents import Document
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
from langchain.chains.llm import LLMChain
from langchain.llms.base import BaseLLM
from langchain.output_parsers.pydantic import PydanticOutputParser
logger = logging.getLogger(__name__)
class LineList(BaseModel):
"""List of lines."""
lines: List[str] = Field(description="Lines of text")
"""List of lines."""
class LineListOutputParser(PydanticOutputParser):
"""Output parser for a list of lines."""
def __init__(self) -> None:
super().__init__(pydantic_object=LineList)
def parse(self, text: str) -> LineList:
lines = text.strip().split("\n")
return LineList(lines=lines)
# Default prompt
DEFAULT_QUERY_PROMPT = PromptTemplate(
input_variables=["question"],
template="""You are an AI language model assistant. Your task is
to generate 3 different versions of the given user
question to retrieve relevant documents from a vector database.
By generating multiple perspectives on the user question,
your goal is to help the user overcome some of the limitations
of distance-based similarity search. Provide these alternative
questions separated by newlines. Original question: {question}""",
)
def _unique_documents(documents: Sequence[Document]) -> List[Document]:
return [doc for i, doc in enumerate(documents) if doc not in documents[:i]]
class MultiQueryRetriever(BaseRetriever):
"""Given a query, use an LLM to write a set of queries.
Retrieve docs for each query. Return the unique union of all retrieved docs.
"""
retriever: BaseRetriever
llm_chain: LLMChain
verbose: bool = True
parser_key: str = "lines"
include_original: bool = False
"""Whether to include the original query in the list of generated queries."""
@classmethod
def from_llm(
cls,
retriever: BaseRetriever,
llm: BaseLLM,
prompt: PromptTemplate = DEFAULT_QUERY_PROMPT,
parser_key: str = "lines",
include_original: bool = False,
) -> "MultiQueryRetriever":
"""Initialize from llm using default template.
Args:
retriever: retriever to query documents from
llm: llm for query generation using DEFAULT_QUERY_PROMPT
include_original: Whether to include the original query in the list of
generated queries.
Returns:
MultiQueryRetriever
"""
output_parser = LineListOutputParser()
llm_chain = LLMChain(llm=llm, prompt=prompt, output_parser=output_parser)
return cls(
retriever=retriever,
llm_chain=llm_chain,
parser_key=parser_key,
include_original=include_original,
)
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevant documents given a user query.
Args:
question: user query
Returns:
Unique union of relevant documents from all generated queries
"""
queries = await self.agenerate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = await self.aretrieve_documents(queries, run_manager)
return self.unique_union(documents)
async def agenerate_queries(
self, question: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[str]:
"""Generate queries based upon user input.
Args:
question: user query
Returns:
List of LLM generated queries that are similar to the user input
"""
response = await self.llm_chain.acall(
inputs={"question": question}, callbacks=run_manager.get_child()
)
lines = getattr(response["text"], self.parser_key, [])
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
async def aretrieve_documents(
self, queries: List[str], run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""Run all LLM generated queries.
Args:
queries: query list
Returns:
List of retrieved Documents
"""
document_lists = await asyncio.gather(
*(
self.retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child()
)
for query in queries
)
)
return [doc for docs in document_lists for doc in docs]
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""Get relevant documents given a user query.
Args:
question: user query
Returns:
Unique union of relevant documents from all generated queries
"""
queries = self.generate_queries(query, run_manager)
if self.include_original:
queries.append(query)
documents = self.retrieve_documents(queries, run_manager)
return self.unique_union(documents)
def generate_queries(
self, question: str, run_manager: CallbackManagerForRetrieverRun
) -> List[str]:
"""Generate queries based upon user input.
Args:
question: user query
Returns:
List of LLM generated queries that are similar to the user input
"""
response = self.llm_chain(
{"question": question}, callbacks=run_manager.get_child()
)
lines = getattr(response["text"], self.parser_key, [])
if self.verbose:
logger.info(f"Generated queries: {lines}")
return lines
def retrieve_documents(
self, queries: List[str], run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""Run all LLM generated queries.
Args:
queries: query list
Returns:
List of retrieved Documents
"""
documents = []
for query in queries:
docs = self.retriever.get_relevant_documents(
query, callbacks=run_manager.get_child()
)
documents.extend(docs)
return documents
def unique_union(self, documents: List[Document]) -> List[Document]:
"""Get unique Documents.
Args:
documents: List of retrieved Documents
Returns:
List of unique retrieved Documents
"""
return _unique_documents(documents)
|