File size: 10,633 Bytes
4304c6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from abc import ABC, abstractmethod
from typing import Any, Optional

from pydantic import BaseModel

from core.tools.entities.tool_entities import (
    ToolParameter,
    ToolProviderCredentials,
    ToolProviderIdentity,
    ToolProviderType,
)
from core.tools.entities.user_entities import UserToolProviderCredentials
from core.tools.errors import ToolNotFoundError, ToolParameterValidationError, ToolProviderCredentialValidationError
from core.tools.tool.tool import Tool


class ToolProviderController(BaseModel, ABC):
    identity: Optional[ToolProviderIdentity] = None
    tools: Optional[list[Tool]] = None
    credentials_schema: Optional[dict[str, ToolProviderCredentials]] = None

    def get_credentials_schema(self) -> dict[str, ToolProviderCredentials]:
        """

            returns the credentials schema of the provider



            :return: the credentials schema

        """
        return self.credentials_schema.copy()
    
    def user_get_credentials_schema(self) -> UserToolProviderCredentials:
        """

            returns the credentials schema of the provider, this method is used for user



            :return: the credentials schema

        """
        credentials = self.credentials_schema.copy()
        return UserToolProviderCredentials(credentials=credentials)

    @abstractmethod
    def get_tools(self) -> list[Tool]:
        """

            returns a list of tools that the provider can provide



            :return: list of tools

        """
        pass

    @abstractmethod
    def get_tool(self, tool_name: str) -> Tool:
        """

            returns a tool that the provider can provide



            :return: tool

        """
        pass

    def get_parameters(self, tool_name: str) -> list[ToolParameter]:
        """

            returns the parameters of the tool



            :param tool_name: the name of the tool, defined in `get_tools`

            :return: list of parameters

        """
        tool = next(filter(lambda x: x.identity.name == tool_name, self.get_tools()), None)
        if tool is None:
            raise ToolNotFoundError(f'tool {tool_name} not found')
        return tool.parameters

    @property
    def app_type(self) -> ToolProviderType:
        """

            returns the type of the provider



            :return: type of the provider

        """
        return ToolProviderType.BUILT_IN

    def validate_parameters(self, tool_id: int, tool_name: str, tool_parameters: dict[str, Any]) -> None:
        """

            validate the parameters of the tool and set the default value if needed



            :param tool_name: the name of the tool, defined in `get_tools`

            :param tool_parameters: the parameters of the tool

        """
        tool_parameters_schema = self.get_parameters(tool_name)
        
        tool_parameters_need_to_validate: dict[str, ToolParameter] = {}
        for parameter in tool_parameters_schema:
            tool_parameters_need_to_validate[parameter.name] = parameter

        for parameter in tool_parameters:
            if parameter not in tool_parameters_need_to_validate:
                raise ToolParameterValidationError(f'parameter {parameter} not found in tool {tool_name}')
            
            # check type
            parameter_schema = tool_parameters_need_to_validate[parameter]
            if parameter_schema.type == ToolParameter.ToolParameterType.STRING:
                if not isinstance(tool_parameters[parameter], str):
                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
            
            elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
                if not isinstance(tool_parameters[parameter], int | float):
                    raise ToolParameterValidationError(f'parameter {parameter} should be number')
                
                if parameter_schema.min is not None and tool_parameters[parameter] < parameter_schema.min:
                    raise ToolParameterValidationError(f'parameter {parameter} should be greater than {parameter_schema.min}')
                
                if parameter_schema.max is not None and tool_parameters[parameter] > parameter_schema.max:
                    raise ToolParameterValidationError(f'parameter {parameter} should be less than {parameter_schema.max}')
                
            elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
                if not isinstance(tool_parameters[parameter], bool):
                    raise ToolParameterValidationError(f'parameter {parameter} should be boolean')
                
            elif parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
                if not isinstance(tool_parameters[parameter], str):
                    raise ToolParameterValidationError(f'parameter {parameter} should be string')
                
                options = parameter_schema.options
                if not isinstance(options, list):
                    raise ToolParameterValidationError(f'parameter {parameter} options should be list')
                
                if tool_parameters[parameter] not in [x.value for x in options]:
                    raise ToolParameterValidationError(f'parameter {parameter} should be one of {options}')
                
            tool_parameters_need_to_validate.pop(parameter)

        for parameter in tool_parameters_need_to_validate:
            parameter_schema = tool_parameters_need_to_validate[parameter]
            if parameter_schema.required:
                raise ToolParameterValidationError(f'parameter {parameter} is required')
            
            # the parameter is not set currently, set the default value if needed
            if parameter_schema.default is not None:
                default_value = parameter_schema.default
                # parse default value into the correct type
                if parameter_schema.type == ToolParameter.ToolParameterType.STRING or \
                    parameter_schema.type == ToolParameter.ToolParameterType.SELECT:
                    default_value = str(default_value)
                elif parameter_schema.type == ToolParameter.ToolParameterType.NUMBER:
                    default_value = float(default_value)
                elif parameter_schema.type == ToolParameter.ToolParameterType.BOOLEAN:
                    default_value = bool(default_value)

                tool_parameters[parameter] = default_value

    def validate_credentials_format(self, credentials: dict[str, Any]) -> None:
        """

            validate the format of the credentials of the provider and set the default value if needed



            :param credentials: the credentials of the tool

        """
        credentials_schema = self.credentials_schema
        if credentials_schema is None:
            return
        
        credentials_need_to_validate: dict[str, ToolProviderCredentials] = {}
        for credential_name in credentials_schema:
            credentials_need_to_validate[credential_name] = credentials_schema[credential_name]

        for credential_name in credentials:
            if credential_name not in credentials_need_to_validate:
                raise ToolProviderCredentialValidationError(f'credential {credential_name} not found in provider {self.identity.name}')
            
            # check type
            credential_schema = credentials_need_to_validate[credential_name]
            if credential_schema == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
                credential_schema == ToolProviderCredentials.CredentialsType.TEXT_INPUT:
                if not isinstance(credentials[credential_name], str):
                    raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
            
            elif credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
                if not isinstance(credentials[credential_name], str):
                    raise ToolProviderCredentialValidationError(f'credential {credential_name} should be string')
                
                options = credential_schema.options
                if not isinstance(options, list):
                    raise ToolProviderCredentialValidationError(f'credential {credential_name} options should be list')
                
                if credentials[credential_name] not in [x.value for x in options]:
                    raise ToolProviderCredentialValidationError(f'credential {credential_name} should be one of {options}')
                
            credentials_need_to_validate.pop(credential_name)

        for credential_name in credentials_need_to_validate:
            credential_schema = credentials_need_to_validate[credential_name]
            if credential_schema.required:
                raise ToolProviderCredentialValidationError(f'credential {credential_name} is required')
            
            # the credential is not set currently, set the default value if needed
            if credential_schema.default is not None:
                default_value = credential_schema.default
                # parse default value into the correct type
                if credential_schema.type == ToolProviderCredentials.CredentialsType.SECRET_INPUT or \
                    credential_schema.type == ToolProviderCredentials.CredentialsType.TEXT_INPUT or \
                    credential_schema.type == ToolProviderCredentials.CredentialsType.SELECT:
                    default_value = str(default_value)

                credentials[credential_name] = default_value
    
    def validate_credentials(self, credentials: dict[str, Any]) -> None:
        """

            validate the credentials of the provider



            :param tool_name: the name of the tool, defined in `get_tools`

            :param credentials: the credentials of the tool

        """
        # validate credentials format
        self.validate_credentials_format(credentials)

        # validate credentials
        self._validate_credentials(credentials)

    @abstractmethod
    def _validate_credentials(self, credentials: dict[str, Any]) -> None:
        """

            validate the credentials of the provider



            :param tool_name: the name of the tool, defined in `get_tools`

            :param credentials: the credentials of the tool

        """
        pass