""" TODO clean all this up modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py """ from functools import partial from typing import Optional from langchain_core.callbacks.manager import Callbacks from langchain_core.prompts import BasePromptTemplate, PromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.retrievers import BaseRetriever from langchain.tools import Tool def get_retriever_tool( retriever, name, description, format_docs, *, document_prompt: Optional[BasePromptTemplate] = None, document_separator: str = "\n\n", ): class RetrieverInput(BaseModel): """Input to the retriever.""" query: str = Field(description="query to look up in retriever") def _get_relevant_documents( query: str, retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, ) -> str: docs = retriever.get_relevant_documents(query, callbacks=callbacks) return format_docs(docs) async def _aget_relevant_documents( query: str, retriever: BaseRetriever, document_prompt: BasePromptTemplate, document_separator: str, callbacks: Callbacks = None, ) -> str: docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) return format_docs(docs) def create_retriever_tool( retriever: BaseRetriever, name: str, description: str, *, document_prompt: Optional[BasePromptTemplate] = None, document_separator: str = "\n\n", ) -> Tool: """Create a tool to do retrieval of documents. Args: retriever: The retriever to use for the retrieval name: The name for the tool. This will be passed to the language model, so should be unique and somewhat descriptive. description: The description for the tool. This will be passed to the language model, so should be descriptive. Returns: Tool class to pass to an agent """ document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") func = partial( _get_relevant_documents, retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, ) afunc = partial( _aget_relevant_documents, retriever=retriever, document_prompt=document_prompt, document_separator=document_separator, ) return Tool( name=name, description=description, func=func, coroutine=afunc, args_schema=RetrieverInput, ) return create_retriever_tool( retriever, name, description, )