Spaces:
Paused
Paused
| import datetime | |
| import pytz | |
| from flask import request | |
| from flask_login import current_user | |
| from flask_restful import Resource, fields, marshal_with, reqparse | |
| from configs import dify_config | |
| from constants.languages import supported_language | |
| from controllers.console import api | |
| from controllers.console.workspace.error import ( | |
| AccountAlreadyInitedError, | |
| CurrentPasswordIncorrectError, | |
| InvalidInvitationCodeError, | |
| RepeatPasswordNotMatchError, | |
| ) | |
| from controllers.console.wraps import account_initialization_required, setup_required | |
| from extensions.ext_database import db | |
| from fields.member_fields import account_fields | |
| from libs.helper import TimestampField, timezone | |
| from libs.login import login_required | |
| from models import AccountIntegrate, InvitationCode | |
| from services.account_service import AccountService | |
| from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError | |
| class AccountInitApi(Resource): | |
| def post(self): | |
| account = current_user | |
| if account.status == "active": | |
| raise AccountAlreadyInitedError() | |
| parser = reqparse.RequestParser() | |
| if dify_config.EDITION == "CLOUD": | |
| parser.add_argument("invitation_code", type=str, location="json") | |
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |
| parser.add_argument("timezone", type=timezone, required=True, location="json") | |
| args = parser.parse_args() | |
| if dify_config.EDITION == "CLOUD": | |
| if not args["invitation_code"]: | |
| raise ValueError("invitation_code is required") | |
| # check invitation code | |
| invitation_code = ( | |
| db.session.query(InvitationCode) | |
| .filter( | |
| InvitationCode.code == args["invitation_code"], | |
| InvitationCode.status == "unused", | |
| ) | |
| .first() | |
| ) | |
| if not invitation_code: | |
| raise InvalidInvitationCodeError() | |
| invitation_code.status = "used" | |
| invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
| invitation_code.used_by_tenant_id = account.current_tenant_id | |
| invitation_code.used_by_account_id = account.id | |
| account.interface_language = args["interface_language"] | |
| account.timezone = args["timezone"] | |
| account.interface_theme = "light" | |
| account.status = "active" | |
| account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) | |
| db.session.commit() | |
| return {"result": "success"} | |
| class AccountProfileApi(Resource): | |
| def get(self): | |
| return current_user | |
| class AccountNameApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("name", type=str, required=True, location="json") | |
| args = parser.parse_args() | |
| # Validate account name length | |
| if len(args["name"]) < 3 or len(args["name"]) > 30: | |
| raise ValueError("Account name must be between 3 and 30 characters.") | |
| updated_account = AccountService.update_account(current_user, name=args["name"]) | |
| return updated_account | |
| class AccountAvatarApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("avatar", type=str, required=True, location="json") | |
| args = parser.parse_args() | |
| updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) | |
| return updated_account | |
| class AccountInterfaceLanguageApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("interface_language", type=supported_language, required=True, location="json") | |
| args = parser.parse_args() | |
| updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) | |
| return updated_account | |
| class AccountInterfaceThemeApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") | |
| args = parser.parse_args() | |
| updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) | |
| return updated_account | |
| class AccountTimezoneApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("timezone", type=str, required=True, location="json") | |
| args = parser.parse_args() | |
| # Validate timezone string, e.g. America/New_York, Asia/Shanghai | |
| if args["timezone"] not in pytz.all_timezones: | |
| raise ValueError("Invalid timezone string.") | |
| updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) | |
| return updated_account | |
| class AccountPasswordApi(Resource): | |
| def post(self): | |
| parser = reqparse.RequestParser() | |
| parser.add_argument("password", type=str, required=False, location="json") | |
| parser.add_argument("new_password", type=str, required=True, location="json") | |
| parser.add_argument("repeat_new_password", type=str, required=True, location="json") | |
| args = parser.parse_args() | |
| if args["new_password"] != args["repeat_new_password"]: | |
| raise RepeatPasswordNotMatchError() | |
| try: | |
| AccountService.update_account_password(current_user, args["password"], args["new_password"]) | |
| except ServiceCurrentPasswordIncorrectError: | |
| raise CurrentPasswordIncorrectError() | |
| return {"result": "success"} | |
| class AccountIntegrateApi(Resource): | |
| integrate_fields = { | |
| "provider": fields.String, | |
| "created_at": TimestampField, | |
| "is_bound": fields.Boolean, | |
| "link": fields.String, | |
| } | |
| integrate_list_fields = { | |
| "data": fields.List(fields.Nested(integrate_fields)), | |
| } | |
| def get(self): | |
| account = current_user | |
| account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() | |
| base_url = request.url_root.rstrip("/") | |
| oauth_base_path = "/console/api/oauth/login" | |
| providers = ["github", "google"] | |
| integrate_data = [] | |
| for provider in providers: | |
| existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) | |
| if existing_integrate: | |
| integrate_data.append( | |
| { | |
| "id": existing_integrate.id, | |
| "provider": provider, | |
| "created_at": existing_integrate.created_at, | |
| "is_bound": True, | |
| "link": None, | |
| } | |
| ) | |
| else: | |
| integrate_data.append( | |
| { | |
| "id": None, | |
| "provider": provider, | |
| "created_at": None, | |
| "is_bound": False, | |
| "link": f"{base_url}{oauth_base_path}/{provider}", | |
| } | |
| ) | |
| return {"data": integrate_data} | |
| # Register API resources | |
| api.add_resource(AccountInitApi, "/account/init") | |
| api.add_resource(AccountProfileApi, "/account/profile") | |
| api.add_resource(AccountNameApi, "/account/name") | |
| api.add_resource(AccountAvatarApi, "/account/avatar") | |
| api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") | |
| api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") | |
| api.add_resource(AccountTimezoneApi, "/account/timezone") | |
| api.add_resource(AccountPasswordApi, "/account/password") | |
| api.add_resource(AccountIntegrateApi, "/account/integrates") | |
| # api.add_resource(AccountEmailApi, '/account/email') | |
| # api.add_resource(AccountEmailVerifyApi, '/account/email-verify') | |