Spaces:
Paused
Paused
| import logging | |
| from flask_login import current_user | |
| from flask_restful import Resource, reqparse | |
| from werkzeug.exceptions import Forbidden | |
| from controllers.console import api | |
| from controllers.console.wraps import account_initialization_required, setup_required | |
| from core.model_runtime.entities.model_entities import ModelType | |
| from core.model_runtime.errors.validate import CredentialsValidateFailedError | |
| from core.model_runtime.utils.encoders import jsonable_encoder | |
| from libs.login import login_required | |
| from services.model_load_balancing_service import ModelLoadBalancingService | |
| from services.model_provider_service import ModelProviderService | |
| class DefaultModelApi(Resource): | |
| def get(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="args", | |
| ) | |
| args = parser.parse_args() | |
| tenant_id = current_user.current_tenant_id | |
| model_provider_service = ModelProviderService() | |
| default_model_entity = model_provider_service.get_default_model_of_model_type( | |
| tenant_id=tenant_id, model_type=args["model_type"] | |
| ) | |
| return jsonable_encoder({"data": default_model_entity}) | |
| def post(self): | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model_settings", type=list, required=True, nullable=False, location="json") | |
| args = parser.parse_args() | |
| tenant_id = current_user.current_tenant_id | |
| model_provider_service = ModelProviderService() | |
| model_settings = args["model_settings"] | |
| for model_setting in model_settings: | |
| if "model_type" not in model_setting or model_setting["model_type"] not in [mt.value for mt in ModelType]: | |
| raise ValueError("invalid model type") | |
| if "provider" not in model_setting: | |
| continue | |
| if "model" not in model_setting: | |
| raise ValueError("invalid model") | |
| try: | |
| model_provider_service.update_default_model_of_model_type( | |
| tenant_id=tenant_id, | |
| model_type=model_setting["model_type"], | |
| provider=model_setting["provider"], | |
| model=model_setting["model"], | |
| ) | |
| except Exception as ex: | |
| logging.exception(f"{model_setting['model_type']} save error: {ex}") | |
| raise ex | |
| return {"result": "success"} | |
| class ModelProviderModelApi(Resource): | |
| def get(self, provider): | |
| tenant_id = current_user.current_tenant_id | |
| model_provider_service = ModelProviderService() | |
| models = model_provider_service.get_models_by_provider(tenant_id=tenant_id, provider=provider) | |
| return jsonable_encoder({"data": models}) | |
| def post(self, provider: str): | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="json", | |
| ) | |
| parser.add_argument("credentials", type=dict, required=False, nullable=True, location="json") | |
| parser.add_argument("load_balancing", type=dict, required=False, nullable=True, location="json") | |
| parser.add_argument("config_from", type=str, required=False, nullable=True, location="json") | |
| args = parser.parse_args() | |
| model_load_balancing_service = ModelLoadBalancingService() | |
| if ( | |
| "load_balancing" in args | |
| and args["load_balancing"] | |
| and "enabled" in args["load_balancing"] | |
| and args["load_balancing"]["enabled"] | |
| ): | |
| if "configs" not in args["load_balancing"]: | |
| raise ValueError("invalid load balancing configs") | |
| # save load balancing configs | |
| model_load_balancing_service.update_load_balancing_configs( | |
| tenant_id=tenant_id, | |
| provider=provider, | |
| model=args["model"], | |
| model_type=args["model_type"], | |
| configs=args["load_balancing"]["configs"], | |
| ) | |
| # enable load balancing | |
| model_load_balancing_service.enable_model_load_balancing( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| else: | |
| # disable load balancing | |
| model_load_balancing_service.disable_model_load_balancing( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| if args.get("config_from", "") != "predefined-model": | |
| model_provider_service = ModelProviderService() | |
| try: | |
| model_provider_service.save_model_credentials( | |
| tenant_id=tenant_id, | |
| provider=provider, | |
| model=args["model"], | |
| model_type=args["model_type"], | |
| credentials=args["credentials"], | |
| ) | |
| except CredentialsValidateFailedError as ex: | |
| logging.exception(f"save model credentials error: {ex}") | |
| raise ValueError(str(ex)) | |
| return {"result": "success"}, 200 | |
| def delete(self, provider: str): | |
| if not current_user.is_admin_or_owner: | |
| raise Forbidden() | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="json", | |
| ) | |
| args = parser.parse_args() | |
| model_provider_service = ModelProviderService() | |
| model_provider_service.remove_model_credentials( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| return {"result": "success"}, 204 | |
| class ModelProviderModelCredentialApi(Resource): | |
| def get(self, provider: str): | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="args") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="args", | |
| ) | |
| args = parser.parse_args() | |
| model_provider_service = ModelProviderService() | |
| credentials = model_provider_service.get_model_credentials( | |
| tenant_id=tenant_id, provider=provider, model_type=args["model_type"], model=args["model"] | |
| ) | |
| model_load_balancing_service = ModelLoadBalancingService() | |
| is_load_balancing_enabled, load_balancing_configs = model_load_balancing_service.get_load_balancing_configs( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| return { | |
| "credentials": credentials, | |
| "load_balancing": {"enabled": is_load_balancing_enabled, "configs": load_balancing_configs}, | |
| } | |
| class ModelProviderModelEnableApi(Resource): | |
| def patch(self, provider: str): | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="json", | |
| ) | |
| args = parser.parse_args() | |
| model_provider_service = ModelProviderService() | |
| model_provider_service.enable_model( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| return {"result": "success"} | |
| class ModelProviderModelDisableApi(Resource): | |
| def patch(self, provider: str): | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="json", | |
| ) | |
| args = parser.parse_args() | |
| model_provider_service = ModelProviderService() | |
| model_provider_service.disable_model( | |
| tenant_id=tenant_id, provider=provider, model=args["model"], model_type=args["model_type"] | |
| ) | |
| return {"result": "success"} | |
| class ModelProviderModelValidateApi(Resource): | |
| def post(self, provider: str): | |
| tenant_id = current_user.current_tenant_id | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="json") | |
| parser.add_argument( | |
| "model_type", | |
| type=str, | |
| required=True, | |
| nullable=False, | |
| choices=[mt.value for mt in ModelType], | |
| location="json", | |
| ) | |
| parser.add_argument("credentials", type=dict, required=True, nullable=False, location="json") | |
| args = parser.parse_args() | |
| model_provider_service = ModelProviderService() | |
| result = True | |
| error = None | |
| try: | |
| model_provider_service.model_credentials_validate( | |
| tenant_id=tenant_id, | |
| provider=provider, | |
| model=args["model"], | |
| model_type=args["model_type"], | |
| credentials=args["credentials"], | |
| ) | |
| except CredentialsValidateFailedError as ex: | |
| result = False | |
| error = str(ex) | |
| response = {"result": "success" if result else "error"} | |
| if not result: | |
| response["error"] = error | |
| return response | |
| class ModelProviderModelParameterRuleApi(Resource): | |
| def get(self, provider: str): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("model", type=str, required=True, nullable=False, location="args") | |
| args = parser.parse_args() | |
| tenant_id = current_user.current_tenant_id | |
| model_provider_service = ModelProviderService() | |
| parameter_rules = model_provider_service.get_model_parameter_rules( | |
| tenant_id=tenant_id, provider=provider, model=args["model"] | |
| ) | |
| return jsonable_encoder({"data": parameter_rules}) | |
| class ModelProviderAvailableModelApi(Resource): | |
| def get(self, model_type): | |
| tenant_id = current_user.current_tenant_id | |
| model_provider_service = ModelProviderService() | |
| models = model_provider_service.get_models_by_model_type(tenant_id=tenant_id, model_type=model_type) | |
| return jsonable_encoder({"data": models}) | |
| api.add_resource(ModelProviderModelApi, "/workspaces/current/model-providers/<string:provider>/models") | |
| api.add_resource( | |
| ModelProviderModelEnableApi, | |
| "/workspaces/current/model-providers/<string:provider>/models/enable", | |
| endpoint="model-provider-model-enable", | |
| ) | |
| api.add_resource( | |
| ModelProviderModelDisableApi, | |
| "/workspaces/current/model-providers/<string:provider>/models/disable", | |
| endpoint="model-provider-model-disable", | |
| ) | |
| api.add_resource( | |
| ModelProviderModelCredentialApi, "/workspaces/current/model-providers/<string:provider>/models/credentials" | |
| ) | |
| api.add_resource( | |
| ModelProviderModelValidateApi, "/workspaces/current/model-providers/<string:provider>/models/credentials/validate" | |
| ) | |
| api.add_resource( | |
| ModelProviderModelParameterRuleApi, "/workspaces/current/model-providers/<string:provider>/models/parameter-rules" | |
| ) | |
| api.add_resource(ModelProviderAvailableModelApi, "/workspaces/current/models/model-types/<string:model_type>") | |
| api.add_resource(DefaultModelApi, "/workspaces/current/default-model") | |