""" modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py """ from functools import partial from typing import Callable from typing import Iterable from typing import Optional from langchain.schema import Document from langchain.tools import Tool from langchain_core.callbacks.manager import Callbacks from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import Field from langchain_core.retrievers import BaseRetriever class RetrieverInput(BaseModel): """Input to the retriever.""" query: str = Field(description="query to look up in retriever") def _get_relevant_documents( query: str, retriever: BaseRetriever, format_docs: Callable[[Iterable[Document]], str], callbacks: Callbacks = None, ) -> str: docs = retriever.get_relevant_documents(query, callbacks=callbacks) return format_docs(docs) async def _aget_relevant_documents( query: str, retriever: BaseRetriever, format_docs: Callable[[Iterable[Document]], str], callbacks: Callbacks = None, ) -> str: docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) return format_docs(docs) def get_retriever_tool( retriever: BaseRetriever, name: str, description: str, format_docs: Callable[[Iterable[Document]], str], ) -> Tool: """Create a tool to do retrieval of documents. Args: retriever: The retriever to use for the retrieval name: The name for the tool. This will be passed to the language model, so should be unique and somewhat descriptive. description: The description for the tool. This will be passed to the language model, so should be descriptive. format_docs: A function to turn an iterable of docs into a string. Returns: Tool class to pass to an agent """ func = partial( _get_relevant_documents, retriever=retriever, format_docs=format_docs, ) afunc = partial( _aget_relevant_documents, retriever=retriever, format_docs=format_docs, ) return Tool( name=name, description=description, func=func, coroutine=afunc, args_schema=RetrieverInput, )