Spaces:
Running
Running
""" | |
TODO clean all this up | |
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py | |
""" | |
from functools import partial | |
from typing import Optional | |
from langchain_core.callbacks.manager import Callbacks | |
from langchain_core.prompts import BasePromptTemplate, PromptTemplate | |
from langchain_core.pydantic_v1 import BaseModel, Field | |
from langchain_core.retrievers import BaseRetriever | |
from langchain.tools import Tool | |
def get_retriever_tool( | |
retriever, | |
name, | |
description, | |
format_docs, | |
*, | |
document_prompt: Optional[BasePromptTemplate] = None, | |
document_separator: str = "\n\n", | |
): | |
class RetrieverInput(BaseModel): | |
"""Input to the retriever.""" | |
query: str = Field(description="query to look up in retriever") | |
def _get_relevant_documents( | |
query: str, | |
retriever: BaseRetriever, | |
document_prompt: BasePromptTemplate, | |
document_separator: str, | |
callbacks: Callbacks = None, | |
) -> str: | |
docs = retriever.get_relevant_documents(query, callbacks=callbacks) | |
return format_docs(docs) | |
async def _aget_relevant_documents( | |
query: str, | |
retriever: BaseRetriever, | |
document_prompt: BasePromptTemplate, | |
document_separator: str, | |
callbacks: Callbacks = None, | |
) -> str: | |
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks) | |
return format_docs(docs) | |
def create_retriever_tool( | |
retriever: BaseRetriever, | |
name: str, | |
description: str, | |
*, | |
document_prompt: Optional[BasePromptTemplate] = None, | |
document_separator: str = "\n\n", | |
) -> Tool: | |
"""Create a tool to do retrieval of documents. | |
Args: | |
retriever: The retriever to use for the retrieval | |
name: The name for the tool. This will be passed to the language model, | |
so should be unique and somewhat descriptive. | |
description: The description for the tool. This will be passed to the language | |
model, so should be descriptive. | |
Returns: | |
Tool class to pass to an agent | |
""" | |
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}") | |
func = partial( | |
_get_relevant_documents, | |
retriever=retriever, | |
document_prompt=document_prompt, | |
document_separator=document_separator, | |
) | |
afunc = partial( | |
_aget_relevant_documents, | |
retriever=retriever, | |
document_prompt=document_prompt, | |
document_separator=document_separator, | |
) | |
return Tool( | |
name=name, | |
description=description, | |
func=func, | |
coroutine=afunc, | |
args_schema=RetrieverInput, | |
) | |
return create_retriever_tool( | |
retriever, | |
name, | |
description, | |
) | |