import os import re from typing import List, Optional import json import requests from jsonschema import RefResolver from pydantic import BaseModel, ValidationError from requests.exceptions import RequestException, Timeout from .tool import Tool MAX_RETRY_TIMES = 3 class ParametersSchema(BaseModel): name: str description: str required: Optional[bool] = True class ToolSchema(BaseModel): name: str description: str parameters: List[ParametersSchema] class OpenAPIPluginTool(Tool): """ openapi schema tool """ name: str = 'api tool' description: str = 'This is a api tool that ...' parameters: list = [] def __init__(self, cfg, name): self.name = name self.cfg = cfg.get(self.name, {}) self.is_remote_tool = self.cfg.get('is_remote_tool', False) # remote call self.url = self.cfg.get('url', '') self.token = self.cfg.get('token', '') self.header = self.cfg.get('header', '') self.method = self.cfg.get('method', '') self.parameters = self.cfg.get('parameters', []) self.description = self.cfg.get('description', 'This is a api tool that ...') self.responses_param = self.cfg.get('responses_param', []) try: all_para = { 'name': self.name, 'description': self.description, 'parameters': self.parameters } self.tool_schema = ToolSchema(**all_para) except ValidationError: raise ValueError(f'Error when parsing parameters of {self.name}') self._str = self.tool_schema.model_dump_json() self._function = self.parse_pydantic_model_to_openai_function(all_para) def _remote_call(self, *args, **kwargs): if self.url == '': raise ValueError( f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint" ) remote_parsed_input = json.dumps( self._remote_parse_input(*args, **kwargs)) origin_result = None if self.method == 'POST': retry_times = MAX_RETRY_TIMES while retry_times: retry_times -= 1 try: print(f'data: {kwargs}') print(f'header: {self.header}') response = requests.request( 'POST', url=self.url, headers=self.header, data=remote_parsed_input) if response.status_code != requests.codes.ok: response.raise_for_status() origin_result = json.loads( response.content.decode('utf-8')) final_result = self._parse_output( origin_result, remote=True) return final_result except Timeout: continue except RequestException as e: raise ValueError( f'Remote call failed with error code: {e.response.status_code},\ error message: {e.response.content.decode("utf-8")}') raise ValueError( 'Remote call max retry times exceeded! Please try to use local call.' ) elif self.method == 'GET': retry_times = MAX_RETRY_TIMES new_url = self.url matches = re.findall(r'\{(.*?)\}', self.url) for match in matches: if match in kwargs: new_url = new_url.replace('{' + match + '}', kwargs[match]) else: print( f'The parameter {match} was not generated by the model.' ) while retry_times: retry_times -= 1 try: print('GET:', new_url) print('GET:', self.url) response = requests.request( 'GET', url=new_url, headers=self.header, params=remote_parsed_input) if response.status_code != requests.codes.ok: response.raise_for_status() origin_result = json.loads( response.content.decode('utf-8')) final_result = self._parse_output( origin_result, remote=True) return final_result except Timeout: continue except RequestException as e: raise ValueError( f'Remote call failed with error code: {e.response.status_code},\ error message: {e.response.content.decode("utf-8")}') raise ValueError( 'Remote call max retry times exceeded! Please try to use local call.' ) else: raise ValueError( 'Remote call method is invalid!We have POST and GET method.') def _remote_parse_input(self, *args, **kwargs): restored_dict = {} for key, value in kwargs.items(): if '.' in key: # Split keys by "." and create nested dictionary structures keys = key.split('.') temp_dict = restored_dict for k in keys[:-1]: temp_dict = temp_dict.setdefault(k, {}) temp_dict[keys[-1]] = value else: # f the key does not contain ".", directly store the key-value pair into restored_dict restored_dict[key] = value kwargs = restored_dict print('传给tool的参数:', kwargs) return kwargs # openapi_schema_convert,register to tool_config.json def extract_references(schema_content): references = [] if isinstance(schema_content, dict): if '$ref' in schema_content: references.append(schema_content['$ref']) for key, value in schema_content.items(): references.extend(extract_references(value)) elif isinstance(schema_content, list): for item in schema_content: references.extend(extract_references(item)) return references def parse_nested_parameters(param_name, param_info, parameters_list, content): param_type = param_info['type'] param_description = param_info.get('description', f'用户输入的{param_name}') # 按需更改描述 param_required = param_name in content['required'] try: if param_type == 'object': properties = param_info.get('properties') if properties: # If the argument type is an object and has a non-empty "properties" field, # its internal properties are parsed recursively for inner_param_name, inner_param_info in properties.items(): inner_param_type = inner_param_info['type'] inner_param_description = inner_param_info.get( 'description', f'用户输入的{param_name}.{inner_param_name}') inner_param_required = param_name.split( '.')[0] in content['required'] # Recursively call the function to handle nested objects if inner_param_type == 'object': parse_nested_parameters( f'{param_name}.{inner_param_name}', inner_param_info, parameters_list, content) else: parameters_list.append({ 'name': f'{param_name}.{inner_param_name}', 'description': inner_param_description, 'required': inner_param_required, 'type': inner_param_type, 'value': inner_param_info.get('enum', '') }) else: # Non-nested parameters are added directly to the parameter list parameters_list.append({ 'name': param_name, 'description': param_description, 'required': param_required, 'type': param_type, 'value': param_info.get('enum', '') }) except Exception as e: raise ValueError(f'{e}:schema结构出错') def parse_responses_parameters(param_name, param_info, parameters_list): param_type = param_info['type'] param_description = param_info.get('description', f'调用api返回的{param_name}') # 按需更改描述 try: if param_type == 'object': properties = param_info.get('properties') if properties: # If the argument type is an object and has a non-empty "properties" # field, its internal properties are parsed recursively for inner_param_name, inner_param_info in properties.items(): param_type = inner_param_info['type'] param_description = inner_param_info.get( 'description', f'调用api返回的{param_name}.{inner_param_name}') parameters_list.append({ 'name': f'{param_name}.{inner_param_name}', 'description': param_description, 'type': param_type, }) else: # Non-nested parameters are added directly to the parameter list parameters_list.append({ 'name': param_name, 'description': param_description, 'type': param_type, }) except Exception as e: raise ValueError(f'{e}:schema结构出错') def openapi_schema_convert(schema, auth): resolver = RefResolver.from_schema(schema) servers = schema.get('servers', []) if servers: servers_url = servers[0].get('url') else: print('No URL found in the schema.') # Extract endpoints endpoints = schema.get('paths', {}) description = schema.get('info', {}).get('description', 'This is a api tool that ...') config_data = {} # Iterate over each endpoint and its contents for endpoint_path, methods in endpoints.items(): for method, details in methods.items(): summary = details.get('summary', 'No summary').replace(' ', '_') name = details.get('operationId', 'No operationId') url = f'{servers_url}{endpoint_path}' security = details.get('security', [{}]) # Security (Bearer Token) authorization = '' if security: for sec in security: if 'BearerAuth' in sec: api_token = auth.get('apikey', os.environ['apikey']) api_token_type = auth.get('apikey_type', os.environ['apikey_type']) authorization = f'{api_token_type} {api_token}' if method.upper() == 'POST': requestBody = details.get('requestBody', {}) if requestBody: for content_type, content_details in requestBody.get( 'content', {}).items(): schema_content = content_details.get('schema', {}) references = extract_references(schema_content) for reference in references: resolved_schema = resolver.resolve(reference) content = resolved_schema[1] parameters_list = [] for param_name, param_info in content[ 'properties'].items(): parse_nested_parameters( param_name, param_info, parameters_list, content) X_DashScope_Async = requestBody.get( 'X-DashScope-Async', '') if X_DashScope_Async == '': config_entry = { 'name': name, 'description': description, 'is_active': True, 'is_remote_tool': True, 'url': url, 'method': method.upper(), 'parameters': parameters_list, 'header': { 'Content-Type': content_type, 'Authorization': authorization } } else: config_entry = { 'name': name, 'description': description, 'is_active': True, 'is_remote_tool': True, 'url': url, 'method': method.upper(), 'parameters': parameters_list, 'header': { 'Content-Type': content_type, 'Authorization': authorization, 'X-DashScope-Async': 'enable' } } else: config_entry = { 'name': name, 'description': description, 'is_active': True, 'is_remote_tool': True, 'url': url, 'method': method.upper(), 'parameters': [], 'header': { 'Content-Type': 'application/json', 'Authorization': authorization } } elif method.upper() == 'GET': parameters_list = [] parameters_list = details.get('parameters', []) config_entry = { 'name': name, 'description': description, 'is_active': True, 'is_remote_tool': True, 'url': url, 'method': method.upper(), 'parameters': parameters_list, 'header': { 'Authorization': authorization } } else: raise 'method is not POST or GET' config_data[summary] = config_entry return config_data