Spaces:
Runtime error
Runtime error
File size: 3,493 Bytes
129cd69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
import asyncio
from typing import List
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from langchain.callbacks.manager import (
AsyncCallbackManagerForRetrieverRun,
CallbackManagerForRetrieverRun,
)
class MergerRetriever(BaseRetriever):
"""Retriever that merges the results of multiple retrievers."""
retrievers: List[BaseRetriever]
"""A list of retrievers to merge."""
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = self.merge_documents(query, run_manager)
return merged_documents
async def _aget_relevant_documents(
self,
query: str,
*,
run_manager: AsyncCallbackManagerForRetrieverRun,
) -> List[Document]:
"""
Asynchronously get the relevant documents for a given query.
Args:
query: The query to search for.
Returns:
A list of relevant documents.
"""
# Merge the results of the retrievers.
merged_documents = await self.amerge_documents(query, run_manager)
return merged_documents
def merge_documents(
self, query: str, run_manager: CallbackManagerForRetrieverRun
) -> List[Document]:
"""
Merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = [
retriever.get_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
]
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
async def amerge_documents(
self, query: str, run_manager: AsyncCallbackManagerForRetrieverRun
) -> List[Document]:
"""
Asynchronously merge the results of the retrievers.
Args:
query: The query to search for.
Returns:
A list of merged documents.
"""
# Get the results of all retrievers.
retriever_docs = await asyncio.gather(
*(
retriever.aget_relevant_documents(
query, callbacks=run_manager.get_child("retriever_{}".format(i + 1))
)
for i, retriever in enumerate(self.retrievers)
)
)
# Merge the results of the retrievers.
merged_documents = []
max_docs = max(len(docs) for docs in retriever_docs)
for i in range(max_docs):
for retriever, doc in zip(self.retrievers, retriever_docs):
if i < len(doc):
merged_documents.append(doc[i])
return merged_documents
|