File size: 4,279 Bytes
4304c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from typing import Any

from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.entities.common_entities import I18nObject
from core.tools.entities.tool_entities import (
    ToolDescription,
    ToolIdentity,
    ToolInvokeMessage,
    ToolParameter,
    ToolProviderType,
)
from core.tools.tool.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.tools.tool.tool import Tool


class DatasetRetrieverTool(Tool):
    retrival_tool: DatasetRetrieverBaseTool

    @staticmethod
    def get_dataset_tools(tenant_id: str,

                          dataset_ids: list[str],

                          retrieve_config: DatasetRetrieveConfigEntity,

                          return_resource: bool,

                          invoke_from: InvokeFrom,

                          hit_callback: DatasetIndexToolCallbackHandler

                          ) -> list['DatasetRetrieverTool']:
        """

        get dataset tool

        """
        # check if retrieve_config is valid
        if dataset_ids is None or len(dataset_ids) == 0:
            return []
        if retrieve_config is None:
            return []

        feature = DatasetRetrieval()

        # save original retrieve strategy, and set retrieve strategy to SINGLE
        # Agent only support SINGLE mode
        original_retriever_mode = retrieve_config.retrieve_strategy
        retrieve_config.retrieve_strategy = DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE
        retrival_tools = feature.to_dataset_retriever_tool(
            tenant_id=tenant_id,
            dataset_ids=dataset_ids,
            retrieve_config=retrieve_config,
            return_resource=return_resource,
            invoke_from=invoke_from,
            hit_callback=hit_callback
        )
        # restore retrieve strategy
        retrieve_config.retrieve_strategy = original_retriever_mode

        # convert retrival tools to Tools
        tools = []
        for retrival_tool in retrival_tools:
            tool = DatasetRetrieverTool(
                retrival_tool=retrival_tool,
                identity=ToolIdentity(provider='', author='', name=retrival_tool.name, label=I18nObject(en_US='', zh_Hans='')),
                parameters=[],
                is_team_authorization=True,
                description=ToolDescription(
                    human=I18nObject(en_US='', zh_Hans=''),
                    llm=retrival_tool.description),
                runtime=DatasetRetrieverTool.Runtime()
            )

            tools.append(tool)

        return tools

    def get_runtime_parameters(self) -> list[ToolParameter]:
        return [
            ToolParameter(name='query',
                          label=I18nObject(en_US='', zh_Hans=''),
                          human_description=I18nObject(en_US='', zh_Hans=''),
                          type=ToolParameter.ToolParameterType.STRING,
                          form=ToolParameter.ToolParameterForm.LLM,
                          llm_description='Query for the dataset to be used to retrieve the dataset.',
                          required=True,
                          default=''),
        ]
    
    def tool_provider_type(self) -> ToolProviderType:
        return ToolProviderType.DATASET_RETRIEVAL

    def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]:
        """

        invoke dataset retriever tool

        """
        query = tool_parameters.get('query', None)
        if not query:
            return self.create_text_message(text='please input query')

        # invoke dataset retriever tool
        result = self.retrival_tool._run(query=query)

        return self.create_text_message(text=result)

    def validate_credentials(self, credentials: dict[str, Any], parameters: dict[str, Any]) -> None:
        """

        validate the credentials for dataset retriever tool

        """
        pass