Spaces:
Running
Running
import os | |
import re | |
from typing import List, Optional | |
import json | |
import requests | |
from jsonschema import RefResolver | |
from pydantic import BaseModel, ValidationError | |
from requests.exceptions import RequestException, Timeout | |
from .tool import Tool | |
MAX_RETRY_TIMES = 3 | |
class ParametersSchema(BaseModel): | |
name: str | |
description: str | |
required: Optional[bool] = True | |
class ToolSchema(BaseModel): | |
name: str | |
description: str | |
parameters: List[ParametersSchema] | |
class OpenAPIPluginTool(Tool): | |
""" | |
openapi schema tool | |
""" | |
name: str = 'api tool' | |
description: str = 'This is a api tool that ...' | |
parameters: list = [] | |
def __init__(self, cfg, name): | |
self.name = name | |
self.cfg = cfg.get(self.name, {}) | |
self.is_remote_tool = self.cfg.get('is_remote_tool', False) | |
# remote call | |
self.url = self.cfg.get('url', '') | |
self.token = self.cfg.get('token', '') | |
self.header = self.cfg.get('header', '') | |
self.method = self.cfg.get('method', '') | |
self.parameters = self.cfg.get('parameters', []) | |
self.description = self.cfg.get('description', | |
'This is a api tool that ...') | |
self.responses_param = self.cfg.get('responses_param', []) | |
try: | |
all_para = { | |
'name': self.name, | |
'description': self.description, | |
'parameters': self.parameters | |
} | |
self.tool_schema = ToolSchema(**all_para) | |
except ValidationError: | |
raise ValueError(f'Error when parsing parameters of {self.name}') | |
self._str = self.tool_schema.model_dump_json() | |
self._function = self.parse_pydantic_model_to_openai_function(all_para) | |
def _remote_call(self, *args, **kwargs): | |
if self.url == '': | |
raise ValueError( | |
f"Could not use remote call for {self.name} since this tool doesn't have a remote endpoint" | |
) | |
remote_parsed_input = json.dumps( | |
self._remote_parse_input(*args, **kwargs)) | |
origin_result = None | |
if self.method == 'POST': | |
retry_times = MAX_RETRY_TIMES | |
while retry_times: | |
retry_times -= 1 | |
try: | |
print(f'data: {kwargs}') | |
print(f'header: {self.header}') | |
response = requests.request( | |
'POST', | |
url=self.url, | |
headers=self.header, | |
data=remote_parsed_input) | |
if response.status_code != requests.codes.ok: | |
response.raise_for_status() | |
origin_result = json.loads( | |
response.content.decode('utf-8')) | |
final_result = self._parse_output( | |
origin_result, remote=True) | |
return final_result | |
except Timeout: | |
continue | |
except RequestException as e: | |
raise ValueError( | |
f'Remote call failed with error code: {e.response.status_code},\ | |
error message: {e.response.content.decode("utf-8")}') | |
raise ValueError( | |
'Remote call max retry times exceeded! Please try to use local call.' | |
) | |
elif self.method == 'GET': | |
retry_times = MAX_RETRY_TIMES | |
new_url = self.url | |
matches = re.findall(r'\{(.*?)\}', self.url) | |
for match in matches: | |
if match in kwargs: | |
new_url = new_url.replace('{' + match + '}', kwargs[match]) | |
else: | |
print( | |
f'The parameter {match} was not generated by the model.' | |
) | |
while retry_times: | |
retry_times -= 1 | |
try: | |
print('GET:', new_url) | |
print('GET:', self.url) | |
response = requests.request( | |
'GET', | |
url=new_url, | |
headers=self.header, | |
params=remote_parsed_input) | |
if response.status_code != requests.codes.ok: | |
response.raise_for_status() | |
origin_result = json.loads( | |
response.content.decode('utf-8')) | |
final_result = self._parse_output( | |
origin_result, remote=True) | |
return final_result | |
except Timeout: | |
continue | |
except RequestException as e: | |
raise ValueError( | |
f'Remote call failed with error code: {e.response.status_code},\ | |
error message: {e.response.content.decode("utf-8")}') | |
raise ValueError( | |
'Remote call max retry times exceeded! Please try to use local call.' | |
) | |
else: | |
raise ValueError( | |
'Remote call method is invalid!We have POST and GET method.') | |
def _remote_parse_input(self, *args, **kwargs): | |
restored_dict = {} | |
for key, value in kwargs.items(): | |
if '.' in key: | |
# Split keys by "." and create nested dictionary structures | |
keys = key.split('.') | |
temp_dict = restored_dict | |
for k in keys[:-1]: | |
temp_dict = temp_dict.setdefault(k, {}) | |
temp_dict[keys[-1]] = value | |
else: | |
# f the key does not contain ".", directly store the key-value pair into restored_dict | |
restored_dict[key] = value | |
kwargs = restored_dict | |
print('传给tool的参数:', kwargs) | |
return kwargs | |
# openapi_schema_convert,register to tool_config.json | |
def extract_references(schema_content): | |
references = [] | |
if isinstance(schema_content, dict): | |
if '$ref' in schema_content: | |
references.append(schema_content['$ref']) | |
for key, value in schema_content.items(): | |
references.extend(extract_references(value)) | |
elif isinstance(schema_content, list): | |
for item in schema_content: | |
references.extend(extract_references(item)) | |
return references | |
def parse_nested_parameters(param_name, param_info, parameters_list, content): | |
param_type = param_info['type'] | |
param_description = param_info.get('description', | |
f'用户输入的{param_name}') # 按需更改描述 | |
param_required = param_name in content['required'] | |
try: | |
if param_type == 'object': | |
properties = param_info.get('properties') | |
if properties: | |
# If the argument type is an object and has a non-empty "properties" field, | |
# its internal properties are parsed recursively | |
for inner_param_name, inner_param_info in properties.items(): | |
inner_param_type = inner_param_info['type'] | |
inner_param_description = inner_param_info.get( | |
'description', f'用户输入的{param_name}.{inner_param_name}') | |
inner_param_required = param_name.split( | |
'.')[0] in content['required'] | |
# Recursively call the function to handle nested objects | |
if inner_param_type == 'object': | |
parse_nested_parameters( | |
f'{param_name}.{inner_param_name}', | |
inner_param_info, parameters_list, content) | |
else: | |
parameters_list.append({ | |
'name': | |
f'{param_name}.{inner_param_name}', | |
'description': | |
inner_param_description, | |
'required': | |
inner_param_required, | |
'type': | |
inner_param_type, | |
'value': | |
inner_param_info.get('enum', '') | |
}) | |
else: | |
# Non-nested parameters are added directly to the parameter list | |
parameters_list.append({ | |
'name': param_name, | |
'description': param_description, | |
'required': param_required, | |
'type': param_type, | |
'value': param_info.get('enum', '') | |
}) | |
except Exception as e: | |
raise ValueError(f'{e}:schema结构出错') | |
def parse_responses_parameters(param_name, param_info, parameters_list): | |
param_type = param_info['type'] | |
param_description = param_info.get('description', | |
f'调用api返回的{param_name}') # 按需更改描述 | |
try: | |
if param_type == 'object': | |
properties = param_info.get('properties') | |
if properties: | |
# If the argument type is an object and has a non-empty "properties" | |
# field, its internal properties are parsed recursively | |
for inner_param_name, inner_param_info in properties.items(): | |
param_type = inner_param_info['type'] | |
param_description = inner_param_info.get( | |
'description', | |
f'调用api返回的{param_name}.{inner_param_name}') | |
parameters_list.append({ | |
'name': f'{param_name}.{inner_param_name}', | |
'description': param_description, | |
'type': param_type, | |
}) | |
else: | |
# Non-nested parameters are added directly to the parameter list | |
parameters_list.append({ | |
'name': param_name, | |
'description': param_description, | |
'type': param_type, | |
}) | |
except Exception as e: | |
raise ValueError(f'{e}:schema结构出错') | |
def openapi_schema_convert(schema, auth): | |
resolver = RefResolver.from_schema(schema) | |
servers = schema.get('servers', []) | |
if servers: | |
servers_url = servers[0].get('url') | |
else: | |
print('No URL found in the schema.') | |
# Extract endpoints | |
endpoints = schema.get('paths', {}) | |
description = schema.get('info', {}).get('description', | |
'This is a api tool that ...') | |
config_data = {} | |
# Iterate over each endpoint and its contents | |
for endpoint_path, methods in endpoints.items(): | |
for method, details in methods.items(): | |
summary = details.get('summary', 'No summary').replace(' ', '_') | |
name = details.get('operationId', 'No operationId') | |
url = f'{servers_url}{endpoint_path}' | |
security = details.get('security', [{}]) | |
# Security (Bearer Token) | |
authorization = '' | |
if security: | |
for sec in security: | |
if 'BearerAuth' in sec: | |
api_token = auth.get('apikey', os.environ['apikey']) | |
api_token_type = auth.get('apikey_type', | |
os.environ['apikey_type']) | |
authorization = f'{api_token_type} {api_token}' | |
if method.upper() == 'POST': | |
requestBody = details.get('requestBody', {}) | |
if requestBody: | |
for content_type, content_details in requestBody.get( | |
'content', {}).items(): | |
schema_content = content_details.get('schema', {}) | |
references = extract_references(schema_content) | |
for reference in references: | |
resolved_schema = resolver.resolve(reference) | |
content = resolved_schema[1] | |
parameters_list = [] | |
for param_name, param_info in content[ | |
'properties'].items(): | |
parse_nested_parameters( | |
param_name, param_info, parameters_list, | |
content) | |
X_DashScope_Async = requestBody.get( | |
'X-DashScope-Async', '') | |
if X_DashScope_Async == '': | |
config_entry = { | |
'name': name, | |
'description': description, | |
'is_active': True, | |
'is_remote_tool': True, | |
'url': url, | |
'method': method.upper(), | |
'parameters': parameters_list, | |
'header': { | |
'Content-Type': content_type, | |
'Authorization': authorization | |
} | |
} | |
else: | |
config_entry = { | |
'name': name, | |
'description': description, | |
'is_active': True, | |
'is_remote_tool': True, | |
'url': url, | |
'method': method.upper(), | |
'parameters': parameters_list, | |
'header': { | |
'Content-Type': content_type, | |
'Authorization': authorization, | |
'X-DashScope-Async': 'enable' | |
} | |
} | |
else: | |
config_entry = { | |
'name': name, | |
'description': description, | |
'is_active': True, | |
'is_remote_tool': True, | |
'url': url, | |
'method': method.upper(), | |
'parameters': [], | |
'header': { | |
'Content-Type': 'application/json', | |
'Authorization': authorization | |
} | |
} | |
elif method.upper() == 'GET': | |
parameters_list = [] | |
parameters_list = details.get('parameters', []) | |
config_entry = { | |
'name': name, | |
'description': description, | |
'is_active': True, | |
'is_remote_tool': True, | |
'url': url, | |
'method': method.upper(), | |
'parameters': parameters_list, | |
'header': { | |
'Authorization': authorization | |
} | |
} | |
else: | |
raise 'method is not POST or GET' | |
config_data[summary] = config_entry | |
return config_data | |