hf-legisqa / custom_tools.py
gabrielaltay's picture
agent update
b25bfc6
raw
history blame
2.98 kB
"""
TODO clean all this up
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
"""
from functools import partial
from typing import Optional
from langchain_core.callbacks.manager import Callbacks
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain.tools import Tool
def get_retriever_tool(
retriever,
name,
description,
format_docs,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
):
class RetrieverInput(BaseModel):
"""Input to the retriever."""
query: str = Field(description="query to look up in retriever")
def _get_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
return format_docs(docs)
async def _aget_relevant_documents(
query: str,
retriever: BaseRetriever,
document_prompt: BasePromptTemplate,
document_separator: str,
callbacks: Callbacks = None,
) -> str:
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
return format_docs(docs)
def create_retriever_tool(
retriever: BaseRetriever,
name: str,
description: str,
*,
document_prompt: Optional[BasePromptTemplate] = None,
document_separator: str = "\n\n",
) -> Tool:
"""Create a tool to do retrieval of documents.
Args:
retriever: The retriever to use for the retrieval
name: The name for the tool. This will be passed to the language model,
so should be unique and somewhat descriptive.
description: The description for the tool. This will be passed to the language
model, so should be descriptive.
Returns:
Tool class to pass to an agent
"""
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
func = partial(
_get_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
afunc = partial(
_aget_relevant_documents,
retriever=retriever,
document_prompt=document_prompt,
document_separator=document_separator,
)
return Tool(
name=name,
description=description,
func=func,
coroutine=afunc,
args_schema=RetrieverInput,
)
return create_retriever_tool(
retriever,
name,
description,
)