Spaces:
Runtime error
Runtime error
import asyncio | |
from typing import List | |
from langchain_core.documents import Document | |
from langchain_core.retrievers import BaseRetriever | |
from langchain.callbacks.manager import ( | |
AsyncCallbackManagerForRetrieverRun, | |
CallbackManagerForRetrieverRun, | |
) | |
class MergerRetriever(BaseRetriever): | |
"""Retriever that merges the results of multiple retrievers.""" | |
retrievers: List[BaseRetriever] | |
"""A list of retrievers to merge.""" | |
def _get_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: CallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
""" | |
Get the relevant documents for a given query. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of relevant documents. | |
""" | |
# Merge the results of the retrievers. | |
merged_documents = self.merge_documents(query, run_manager) | |
return merged_documents | |
async def _aget_relevant_documents( | |
self, | |
query: str, | |
*, | |
run_manager: AsyncCallbackManagerForRetrieverRun, | |
) -> List[Document]: | |
""" | |
Asynchronously get the relevant documents for a given query. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of relevant documents. | |
""" | |
# Merge the results of the retrievers. | |
merged_documents = await self.amerge_documents(query, run_manager) | |
return merged_documents | |
def merge_documents( | |
self, query: str, run_manager: CallbackManagerForRetrieverRun | |
) -> List[Document]: | |
""" | |
Merge the results of the retrievers. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of merged documents. | |
""" | |
# Get the results of all retrievers. | |
retriever_docs = [ | |
retriever.get_relevant_documents( | |
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1)) | |
) | |
for i, retriever in enumerate(self.retrievers) | |
] | |
# Merge the results of the retrievers. | |
merged_documents = [] | |
max_docs = max(len(docs) for docs in retriever_docs) | |
for i in range(max_docs): | |
for retriever, doc in zip(self.retrievers, retriever_docs): | |
if i < len(doc): | |
merged_documents.append(doc[i]) | |
return merged_documents | |
async def amerge_documents( | |
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun | |
) -> List[Document]: | |
""" | |
Asynchronously merge the results of the retrievers. | |
Args: | |
query: The query to search for. | |
Returns: | |
A list of merged documents. | |
""" | |
# Get the results of all retrievers. | |
retriever_docs = await asyncio.gather( | |
*( | |
retriever.aget_relevant_documents( | |
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1)) | |
) | |
for i, retriever in enumerate(self.retrievers) | |
) | |
) | |
# Merge the results of the retrievers. | |
merged_documents = [] | |
max_docs = max(len(docs) for docs in retriever_docs) | |
for i in range(max_docs): | |
for retriever, doc in zip(self.retrievers, retriever_docs): | |
if i < len(doc): | |
merged_documents.append(doc[i]) | |
return merged_documents | |