|
import json |
|
from os import getenv |
|
from typing import Any |
|
from urllib.parse import urlencode |
|
|
|
import httpx |
|
|
|
from core.helper import ssrf_proxy |
|
from core.tools.entities.tool_bundle import ApiToolBundle |
|
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolProviderType |
|
from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError |
|
from core.tools.tool.tool import Tool |
|
|
|
API_TOOL_DEFAULT_TIMEOUT = ( |
|
int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), |
|
int(getenv("API_TOOL_DEFAULT_READ_TIMEOUT", "60")), |
|
) |
|
|
|
|
|
class ApiTool(Tool): |
|
api_bundle: ApiToolBundle |
|
|
|
""" |
|
Api tool |
|
""" |
|
|
|
def fork_tool_runtime(self, runtime: dict[str, Any]) -> "Tool": |
|
""" |
|
fork a new tool with meta data |
|
|
|
:param meta: the meta data of a tool call processing, tenant_id is required |
|
:return: the new tool |
|
""" |
|
return self.__class__( |
|
identity=self.identity.model_copy() if self.identity else None, |
|
parameters=self.parameters.copy() if self.parameters else None, |
|
description=self.description.model_copy() if self.description else None, |
|
api_bundle=self.api_bundle.model_copy() if self.api_bundle else None, |
|
runtime=Tool.Runtime(**runtime), |
|
) |
|
|
|
def validate_credentials( |
|
self, credentials: dict[str, Any], parameters: dict[str, Any], format_only: bool = False |
|
) -> str: |
|
""" |
|
validate the credentials for Api tool |
|
""" |
|
|
|
headers = self.assembling_request(parameters) |
|
|
|
if format_only: |
|
return "" |
|
|
|
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, parameters) |
|
|
|
return self.validate_and_parse_response(response) |
|
|
|
def tool_provider_type(self) -> ToolProviderType: |
|
return ToolProviderType.API |
|
|
|
def assembling_request(self, parameters: dict[str, Any]) -> dict[str, Any]: |
|
headers = {} |
|
credentials = self.runtime.credentials or {} |
|
|
|
if "auth_type" not in credentials: |
|
raise ToolProviderCredentialValidationError("Missing auth_type") |
|
|
|
if credentials["auth_type"] == "api_key": |
|
api_key_header = "api_key" |
|
|
|
if "api_key_header" in credentials: |
|
api_key_header = credentials["api_key_header"] |
|
|
|
if "api_key_value" not in credentials: |
|
raise ToolProviderCredentialValidationError("Missing api_key_value") |
|
elif not isinstance(credentials["api_key_value"], str): |
|
raise ToolProviderCredentialValidationError("api_key_value must be a string") |
|
|
|
if "api_key_header_prefix" in credentials: |
|
api_key_header_prefix = credentials["api_key_header_prefix"] |
|
if api_key_header_prefix == "basic" and credentials["api_key_value"]: |
|
credentials["api_key_value"] = f'Basic {credentials["api_key_value"]}' |
|
elif api_key_header_prefix == "bearer" and credentials["api_key_value"]: |
|
credentials["api_key_value"] = f'Bearer {credentials["api_key_value"]}' |
|
elif api_key_header_prefix == "custom": |
|
pass |
|
|
|
headers[api_key_header] = credentials["api_key_value"] |
|
|
|
needed_parameters = [parameter for parameter in self.api_bundle.parameters if parameter.required] |
|
for parameter in needed_parameters: |
|
if parameter.required and parameter.name not in parameters: |
|
raise ToolParameterValidationError(f"Missing required parameter {parameter.name}") |
|
|
|
if parameter.default is not None and parameter.name not in parameters: |
|
parameters[parameter.name] = parameter.default |
|
|
|
return headers |
|
|
|
def validate_and_parse_response(self, response: httpx.Response) -> str: |
|
""" |
|
validate the response |
|
""" |
|
if isinstance(response, httpx.Response): |
|
if response.status_code >= 400: |
|
raise ToolInvokeError(f"Request failed with status code {response.status_code} and {response.text}") |
|
if not response.content: |
|
return "Empty response from the tool, please check your parameters and try again." |
|
try: |
|
response = response.json() |
|
try: |
|
return json.dumps(response, ensure_ascii=False) |
|
except Exception as e: |
|
return json.dumps(response) |
|
except Exception as e: |
|
return response.text |
|
else: |
|
raise ValueError(f"Invalid response type {type(response)}") |
|
|
|
@staticmethod |
|
def get_parameter_value(parameter, parameters): |
|
if parameter["name"] in parameters: |
|
return parameters[parameter["name"]] |
|
elif parameter.get("required", False): |
|
raise ToolParameterValidationError(f"Missing required parameter {parameter['name']}") |
|
else: |
|
return (parameter.get("schema", {}) or {}).get("default", "") |
|
|
|
def do_http_request( |
|
self, url: str, method: str, headers: dict[str, Any], parameters: dict[str, Any] |
|
) -> httpx.Response: |
|
""" |
|
do http request depending on api bundle |
|
""" |
|
method = method.lower() |
|
|
|
params = {} |
|
path_params = {} |
|
body = {} |
|
cookies = {} |
|
|
|
|
|
for parameter in self.api_bundle.openapi.get("parameters", []): |
|
value = self.get_parameter_value(parameter, parameters) |
|
if parameter["in"] == "path": |
|
path_params[parameter["name"]] = value |
|
|
|
elif parameter["in"] == "query": |
|
if value != "": |
|
params[parameter["name"]] = value |
|
|
|
elif parameter["in"] == "cookie": |
|
cookies[parameter["name"]] = value |
|
|
|
elif parameter["in"] == "header": |
|
headers[parameter["name"]] = value |
|
|
|
|
|
if "requestBody" in self.api_bundle.openapi and self.api_bundle.openapi["requestBody"] is not None: |
|
|
|
if "content" in self.api_bundle.openapi["requestBody"]: |
|
for content_type in self.api_bundle.openapi["requestBody"]["content"]: |
|
headers["Content-Type"] = content_type |
|
body_schema = self.api_bundle.openapi["requestBody"]["content"][content_type]["schema"] |
|
required = body_schema.get("required", []) |
|
properties = body_schema.get("properties", {}) |
|
for name, property in properties.items(): |
|
if name in parameters: |
|
|
|
body[name] = self._convert_body_property_type(property, parameters[name]) |
|
elif name in required: |
|
raise ToolParameterValidationError( |
|
f"Missing required parameter {name} in operation {self.api_bundle.operation_id}" |
|
) |
|
elif "default" in property: |
|
body[name] = property["default"] |
|
else: |
|
body[name] = None |
|
break |
|
|
|
|
|
for name, value in path_params.items(): |
|
url = url.replace(f"{{{name}}}", f"{value}") |
|
|
|
|
|
if "Content-Type" in headers: |
|
if headers["Content-Type"] == "application/json": |
|
body = json.dumps(body) |
|
elif headers["Content-Type"] == "application/x-www-form-urlencoded": |
|
body = urlencode(body) |
|
else: |
|
body = body |
|
|
|
if method in {"get", "head", "post", "put", "delete", "patch"}: |
|
response = getattr(ssrf_proxy, method)( |
|
url, |
|
params=params, |
|
headers=headers, |
|
cookies=cookies, |
|
data=body, |
|
timeout=API_TOOL_DEFAULT_TIMEOUT, |
|
follow_redirects=True, |
|
) |
|
return response |
|
else: |
|
raise ValueError(f"Invalid http method {self.method}") |
|
|
|
def _convert_body_property_any_of( |
|
self, property: dict[str, Any], value: Any, any_of: list[dict[str, Any]], max_recursive=10 |
|
) -> Any: |
|
if max_recursive <= 0: |
|
raise Exception("Max recursion depth reached") |
|
for option in any_of or []: |
|
try: |
|
if "type" in option: |
|
|
|
if option["type"] == "integer" or option["type"] == "int": |
|
return int(value) |
|
elif option["type"] == "number": |
|
if "." in str(value): |
|
return float(value) |
|
else: |
|
return int(value) |
|
elif option["type"] == "string": |
|
return str(value) |
|
elif option["type"] == "boolean": |
|
if str(value).lower() in {"true", "1"}: |
|
return True |
|
elif str(value).lower() in {"false", "0"}: |
|
return False |
|
else: |
|
continue |
|
elif option["type"] == "null" and not value: |
|
return None |
|
else: |
|
continue |
|
elif "anyOf" in option and isinstance(option["anyOf"], list): |
|
|
|
return self._convert_body_property_any_of(property, value, option["anyOf"], max_recursive - 1) |
|
except ValueError: |
|
continue |
|
|
|
return value |
|
|
|
def _convert_body_property_type(self, property: dict[str, Any], value: Any) -> Any: |
|
try: |
|
if "type" in property: |
|
if property["type"] == "integer" or property["type"] == "int": |
|
return int(value) |
|
elif property["type"] == "number": |
|
|
|
if "." in str(value): |
|
return float(value) |
|
else: |
|
return int(value) |
|
elif property["type"] == "string": |
|
return str(value) |
|
elif property["type"] == "boolean": |
|
return bool(value) |
|
elif property["type"] == "null": |
|
if value is None: |
|
return None |
|
elif property["type"] == "object" or property["type"] == "array": |
|
if isinstance(value, str): |
|
try: |
|
|
|
|
|
value = value.replace("'", '"') |
|
return json.loads(value) |
|
except ValueError: |
|
return value |
|
elif isinstance(value, dict): |
|
return value |
|
else: |
|
return value |
|
else: |
|
raise ValueError(f"Invalid type {property['type']} for property {property}") |
|
elif "anyOf" in property and isinstance(property["anyOf"], list): |
|
return self._convert_body_property_any_of(property, value, property["anyOf"]) |
|
except ValueError as e: |
|
return value |
|
|
|
def _invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> ToolInvokeMessage | list[ToolInvokeMessage]: |
|
""" |
|
invoke http request |
|
""" |
|
|
|
headers = self.assembling_request(tool_parameters) |
|
|
|
|
|
response = self.do_http_request(self.api_bundle.server_url, self.api_bundle.method, headers, tool_parameters) |
|
|
|
|
|
response = self.validate_and_parse_response(response) |
|
|
|
|
|
return self.create_text_message(response) |
|
|