Spaces:
Build error
Build error
import json | |
import logging | |
from typing import Optional | |
from httpx import get | |
from core.model_runtime.utils.encoders import jsonable_encoder | |
from core.tools.entities.api_entities import UserTool, UserToolProvider | |
from core.tools.entities.common_entities import I18nObject | |
from core.tools.entities.tool_bundle import ApiToolBundle | |
from core.tools.entities.tool_entities import ( | |
ApiProviderAuthType, | |
ApiProviderSchemaType, | |
ToolCredentialsOption, | |
ToolProviderCredentials, | |
) | |
from core.tools.provider.api_tool_provider import ApiToolProviderController | |
from core.tools.tool_label_manager import ToolLabelManager | |
from core.tools.tool_manager import ToolManager | |
from core.tools.utils.configuration import ToolConfigurationManager | |
from core.tools.utils.parser import ApiBasedToolSchemaParser | |
from extensions.ext_database import db | |
from models.tools import ApiToolProvider | |
from services.tools.tools_transform_service import ToolTransformService | |
logger = logging.getLogger(__name__) | |
class ApiToolManageService: | |
def parser_api_schema(schema: str) -> list[ApiToolBundle]: | |
""" | |
parse api schema to tool bundle | |
""" | |
try: | |
warnings = {} | |
try: | |
tool_bundles, schema_type = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, warning=warnings) | |
except Exception as e: | |
raise ValueError(f"invalid schema: {str(e)}") | |
credentials_schema = [ | |
ToolProviderCredentials( | |
name="auth_type", | |
type=ToolProviderCredentials.CredentialsType.SELECT, | |
required=True, | |
default="none", | |
options=[ | |
ToolCredentialsOption(value="none", label=I18nObject(en_US="None", zh_Hans="无")), | |
ToolCredentialsOption(value="api_key", label=I18nObject(en_US="Api Key", zh_Hans="Api Key")), | |
], | |
placeholder=I18nObject(en_US="Select auth type", zh_Hans="选择认证方式"), | |
), | |
ToolProviderCredentials( | |
name="api_key_header", | |
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |
required=False, | |
placeholder=I18nObject(en_US="Enter api key header", zh_Hans="输入 api key header,如:X-API-KEY"), | |
default="api_key", | |
help=I18nObject(en_US="HTTP header name for api key", zh_Hans="HTTP 头部字段名,用于传递 api key"), | |
), | |
ToolProviderCredentials( | |
name="api_key_value", | |
type=ToolProviderCredentials.CredentialsType.TEXT_INPUT, | |
required=False, | |
placeholder=I18nObject(en_US="Enter api key", zh_Hans="输入 api key"), | |
default="", | |
), | |
] | |
return jsonable_encoder( | |
{ | |
"schema_type": schema_type, | |
"parameters_schema": tool_bundles, | |
"credentials_schema": credentials_schema, | |
"warning": warnings, | |
} | |
) | |
except Exception as e: | |
raise ValueError(f"invalid schema: {str(e)}") | |
def convert_schema_to_tool_bundles( | |
schema: str, extra_info: Optional[dict] = None | |
) -> tuple[list[ApiToolBundle], str]: | |
""" | |
convert schema to tool bundles | |
:return: the list of tool bundles, description | |
""" | |
try: | |
tool_bundles = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema, extra_info=extra_info) | |
return tool_bundles | |
except Exception as e: | |
raise ValueError(f"invalid schema: {str(e)}") | |
def create_api_tool_provider( | |
user_id: str, | |
tenant_id: str, | |
provider_name: str, | |
icon: dict, | |
credentials: dict, | |
schema_type: str, | |
schema: str, | |
privacy_policy: str, | |
custom_disclaimer: str, | |
labels: list[str], | |
): | |
""" | |
create api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f"invalid schema type {schema}") | |
# check if the provider exists | |
provider: ApiToolProvider = ( | |
db.session.query(ApiToolProvider) | |
.filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
) | |
.first() | |
) | |
if provider is not None: | |
raise ValueError(f"provider {provider_name} already exists") | |
# parse openapi to tool bundle | |
extra_info = {} | |
# extra info like description will be set here | |
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |
if len(tool_bundles) > 100: | |
raise ValueError("the number of apis should be less than 100") | |
# create db provider | |
db_provider = ApiToolProvider( | |
tenant_id=tenant_id, | |
user_id=user_id, | |
name=provider_name, | |
icon=json.dumps(icon), | |
schema=schema, | |
description=extra_info.get("description", ""), | |
schema_type_str=schema_type, | |
tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |
credentials_str={}, | |
privacy_policy=privacy_policy, | |
custom_disclaimer=custom_disclaimer, | |
) | |
if "auth_type" not in credentials: | |
raise ValueError("auth_type is required") | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |
# create provider entity | |
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# encrypt credentials | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
encrypted_credentials = tool_configuration.encrypt_tool_credentials(credentials) | |
db_provider.credentials_str = json.dumps(encrypted_credentials) | |
db.session.add(db_provider) | |
db.session.commit() | |
# update labels | |
ToolLabelManager.update_tool_labels(provider_controller, labels) | |
return {"result": "success"} | |
def get_api_tool_provider_remote_schema(user_id: str, tenant_id: str, url: str): | |
""" | |
get api tool provider remote schema | |
""" | |
headers = { | |
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko)" | |
" Chrome/120.0.0.0 Safari/537.36 Edg/120.0.0.0", | |
"Accept": "*/*", | |
} | |
try: | |
response = get(url, headers=headers, timeout=10) | |
if response.status_code != 200: | |
raise ValueError(f"Got status code {response.status_code}") | |
schema = response.text | |
# try to parse schema, avoid SSRF attack | |
ApiToolManageService.parser_api_schema(schema) | |
except Exception as e: | |
logger.error(f"parse api schema error: {str(e)}") | |
raise ValueError("invalid schema, please check the url you provided") | |
return {"schema": schema} | |
def list_api_tool_provider_tools(user_id: str, tenant_id: str, provider: str) -> list[UserTool]: | |
""" | |
list api tool provider tools | |
""" | |
provider: ApiToolProvider = ( | |
db.session.query(ApiToolProvider) | |
.filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider, | |
) | |
.first() | |
) | |
if provider is None: | |
raise ValueError(f"you have not added provider {provider}") | |
controller = ToolTransformService.api_provider_to_controller(db_provider=provider) | |
labels = ToolLabelManager.get_tool_labels(controller) | |
return [ | |
ToolTransformService.tool_to_user_tool( | |
tool_bundle, | |
labels=labels, | |
) | |
for tool_bundle in provider.tools | |
] | |
def update_api_tool_provider( | |
user_id: str, | |
tenant_id: str, | |
provider_name: str, | |
original_provider: str, | |
icon: dict, | |
credentials: dict, | |
schema_type: str, | |
schema: str, | |
privacy_policy: str, | |
custom_disclaimer: str, | |
labels: list[str], | |
): | |
""" | |
update api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f"invalid schema type {schema}") | |
# check if the provider exists | |
provider: ApiToolProvider = ( | |
db.session.query(ApiToolProvider) | |
.filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == original_provider, | |
) | |
.first() | |
) | |
if provider is None: | |
raise ValueError(f"api provider {provider_name} does not exists") | |
# parse openapi to tool bundle | |
extra_info = {} | |
# extra info like description will be set here | |
tool_bundles, schema_type = ApiToolManageService.convert_schema_to_tool_bundles(schema, extra_info) | |
# update db provider | |
provider.name = provider_name | |
provider.icon = json.dumps(icon) | |
provider.schema = schema | |
provider.description = extra_info.get("description", "") | |
provider.schema_type_str = ApiProviderSchemaType.OPENAPI.value | |
provider.tools_str = json.dumps(jsonable_encoder(tool_bundles)) | |
provider.privacy_policy = privacy_policy | |
provider.custom_disclaimer = custom_disclaimer | |
if "auth_type" not in credentials: | |
raise ValueError("auth_type is required") | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |
# create provider entity | |
provider_controller = ApiToolProviderController.from_db(provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# get original credentials if exists | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
original_credentials = tool_configuration.decrypt_tool_credentials(provider.credentials) | |
masked_credentials = tool_configuration.mask_tool_credentials(original_credentials) | |
# check if the credential has changed, save the original credential | |
for name, value in credentials.items(): | |
if name in masked_credentials and value == masked_credentials[name]: | |
credentials[name] = original_credentials[name] | |
credentials = tool_configuration.encrypt_tool_credentials(credentials) | |
provider.credentials_str = json.dumps(credentials) | |
db.session.add(provider) | |
db.session.commit() | |
# delete cache | |
tool_configuration.delete_tool_credentials_cache() | |
# update labels | |
ToolLabelManager.update_tool_labels(provider_controller, labels) | |
return {"result": "success"} | |
def delete_api_tool_provider(user_id: str, tenant_id: str, provider_name: str): | |
""" | |
delete tool provider | |
""" | |
provider: ApiToolProvider = ( | |
db.session.query(ApiToolProvider) | |
.filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
) | |
.first() | |
) | |
if provider is None: | |
raise ValueError(f"you have not added provider {provider_name}") | |
db.session.delete(provider) | |
db.session.commit() | |
return {"result": "success"} | |
def get_api_tool_provider(user_id: str, tenant_id: str, provider: str): | |
""" | |
get api tool provider | |
""" | |
return ToolManager.user_get_api_provider(provider=provider, tenant_id=tenant_id) | |
def test_api_tool_preview( | |
tenant_id: str, | |
provider_name: str, | |
tool_name: str, | |
credentials: dict, | |
parameters: dict, | |
schema_type: str, | |
schema: str, | |
): | |
""" | |
test api tool before adding api tool provider | |
""" | |
if schema_type not in [member.value for member in ApiProviderSchemaType]: | |
raise ValueError(f"invalid schema type {schema_type}") | |
try: | |
tool_bundles, _ = ApiBasedToolSchemaParser.auto_parse_to_tool_bundle(schema) | |
except Exception as e: | |
raise ValueError("invalid schema") | |
# get tool bundle | |
tool_bundle = next(filter(lambda tb: tb.operation_id == tool_name, tool_bundles), None) | |
if tool_bundle is None: | |
raise ValueError(f"invalid tool name {tool_name}") | |
db_provider: ApiToolProvider = ( | |
db.session.query(ApiToolProvider) | |
.filter( | |
ApiToolProvider.tenant_id == tenant_id, | |
ApiToolProvider.name == provider_name, | |
) | |
.first() | |
) | |
if not db_provider: | |
# create a fake db provider | |
db_provider = ApiToolProvider( | |
tenant_id="", | |
user_id="", | |
name="", | |
icon="", | |
schema=schema, | |
description="", | |
schema_type_str=ApiProviderSchemaType.OPENAPI.value, | |
tools_str=json.dumps(jsonable_encoder(tool_bundles)), | |
credentials_str=json.dumps(credentials), | |
) | |
if "auth_type" not in credentials: | |
raise ValueError("auth_type is required") | |
# get auth type, none or api key | |
auth_type = ApiProviderAuthType.value_of(credentials["auth_type"]) | |
# create provider entity | |
provider_controller = ApiToolProviderController.from_db(db_provider, auth_type) | |
# load tools into provider entity | |
provider_controller.load_bundled_tools(tool_bundles) | |
# decrypt credentials | |
if db_provider.id: | |
tool_configuration = ToolConfigurationManager(tenant_id=tenant_id, provider_controller=provider_controller) | |
decrypted_credentials = tool_configuration.decrypt_tool_credentials(credentials) | |
# check if the credential has changed, save the original credential | |
masked_credentials = tool_configuration.mask_tool_credentials(decrypted_credentials) | |
for name, value in credentials.items(): | |
if name in masked_credentials and value == masked_credentials[name]: | |
credentials[name] = decrypted_credentials[name] | |
try: | |
provider_controller.validate_credentials_format(credentials) | |
# get tool | |
tool = provider_controller.get_tool(tool_name) | |
tool = tool.fork_tool_runtime( | |
runtime={ | |
"credentials": credentials, | |
"tenant_id": tenant_id, | |
} | |
) | |
result = tool.validate_credentials(credentials, parameters) | |
except Exception as e: | |
return {"error": str(e)} | |
return {"result": result or "empty response"} | |
def list_api_tools(user_id: str, tenant_id: str) -> list[UserToolProvider]: | |
""" | |
list api tools | |
""" | |
# get all api providers | |
db_providers: list[ApiToolProvider] = ( | |
db.session.query(ApiToolProvider).filter(ApiToolProvider.tenant_id == tenant_id).all() or [] | |
) | |
result: list[UserToolProvider] = [] | |
for provider in db_providers: | |
# convert provider controller to user provider | |
provider_controller = ToolTransformService.api_provider_to_controller(db_provider=provider) | |
labels = ToolLabelManager.get_tool_labels(provider_controller) | |
user_provider = ToolTransformService.api_provider_to_user_provider( | |
provider_controller, db_provider=provider, decrypt_credentials=True | |
) | |
user_provider.labels = labels | |
# add icon | |
ToolTransformService.repack_provider(user_provider) | |
tools = provider_controller.get_tools(user_id=user_id, tenant_id=tenant_id) | |
for tool in tools: | |
user_provider.tools.append( | |
ToolTransformService.tool_to_user_tool( | |
tenant_id=tenant_id, tool=tool, credentials=user_provider.original_credentials, labels=labels | |
) | |
) | |
result.append(user_provider) | |
return result | |