webui / langchain /retrievers /merger_retriever.py
zhangyi617's picture
Upload folder using huggingface_hub
129cd69
import asyncio
from typing import List
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""
retrievers: List[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
async def amerge_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
)
)
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents