File size: 2,981 Bytes
b25bfc6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
TODO clean all this up
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
"""

from functools import partial
from typing import Optional

from langchain_core.callbacks.manager import Callbacks
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.retrievers import BaseRetriever
from langchain.tools import Tool


def get_retriever_tool(
    retriever,
    name,
    description,
    format_docs,
    *,
    document_prompt: Optional[BasePromptTemplate] = None,
    document_separator: str = "\n\n",
):

    class RetrieverInput(BaseModel):
        """Input to the retriever."""

        query: str = Field(description="query to look up in retriever")


    def _get_relevant_documents(
        query: str,
        retriever: BaseRetriever,
        document_prompt: BasePromptTemplate,
        document_separator: str,
        callbacks: Callbacks = None,
    ) -> str:
        docs = retriever.get_relevant_documents(query, callbacks=callbacks)
        return format_docs(docs)

    async def _aget_relevant_documents(
        query: str,
        retriever: BaseRetriever,
        document_prompt: BasePromptTemplate,
        document_separator: str,
        callbacks: Callbacks = None,
    ) -> str:
        docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
        return format_docs(docs)

    def create_retriever_tool(
        retriever: BaseRetriever,
        name: str,
        description: str,
        *,
        document_prompt: Optional[BasePromptTemplate] = None,
        document_separator: str = "\n\n",
    ) -> Tool:
        """Create a tool to do retrieval of documents.

        Args:
            retriever: The retriever to use for the retrieval
            name: The name for the tool. This will be passed to the language model,
                so should be unique and somewhat descriptive.
            description: The description for the tool. This will be passed to the language
                model, so should be descriptive.

        Returns:
            Tool class to pass to an agent
        """
        document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
        func = partial(
            _get_relevant_documents,
            retriever=retriever,
            document_prompt=document_prompt,
            document_separator=document_separator,
        )
        afunc = partial(
            _aget_relevant_documents,
            retriever=retriever,
            document_prompt=document_prompt,
            document_separator=document_separator,
        )
        return Tool(
            name=name,
            description=description,
            func=func,
            coroutine=afunc,
            args_schema=RetrieverInput,
        )


    return create_retriever_tool(
        retriever,
        name,
        description,
    )