|
import datetime |
|
|
|
import pytz |
|
from flask import request |
|
from flask_login import current_user |
|
from flask_restful import Resource, fields, marshal_with, reqparse |
|
|
|
from configs import dify_config |
|
from constants.languages import supported_language |
|
from controllers.console import api |
|
from controllers.console.workspace.error import ( |
|
AccountAlreadyInitedError, |
|
CurrentPasswordIncorrectError, |
|
InvalidInvitationCodeError, |
|
RepeatPasswordNotMatchError, |
|
) |
|
from controllers.console.wraps import account_initialization_required, setup_required |
|
from extensions.ext_database import db |
|
from fields.member_fields import account_fields |
|
from libs.helper import TimestampField, timezone |
|
from libs.login import login_required |
|
from models import AccountIntegrate, InvitationCode |
|
from services.account_service import AccountService |
|
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError |
|
|
|
|
|
class AccountInitApi(Resource): |
|
@setup_required |
|
@login_required |
|
def post(self): |
|
account = current_user |
|
|
|
if account.status == "active": |
|
raise AccountAlreadyInitedError() |
|
|
|
parser = reqparse.RequestParser() |
|
|
|
if dify_config.EDITION == "CLOUD": |
|
parser.add_argument("invitation_code", type=str, location="json") |
|
|
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json") |
|
parser.add_argument("timezone", type=timezone, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
if dify_config.EDITION == "CLOUD": |
|
if not args["invitation_code"]: |
|
raise ValueError("invitation_code is required") |
|
|
|
|
|
invitation_code = ( |
|
db.session.query(InvitationCode) |
|
.filter( |
|
InvitationCode.code == args["invitation_code"], |
|
InvitationCode.status == "unused", |
|
) |
|
.first() |
|
) |
|
|
|
if not invitation_code: |
|
raise InvalidInvitationCodeError() |
|
|
|
invitation_code.status = "used" |
|
invitation_code.used_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
invitation_code.used_by_tenant_id = account.current_tenant_id |
|
invitation_code.used_by_account_id = account.id |
|
|
|
account.interface_language = args["interface_language"] |
|
account.timezone = args["timezone"] |
|
account.interface_theme = "light" |
|
account.status = "active" |
|
account.initialized_at = datetime.datetime.now(datetime.timezone.utc).replace(tzinfo=None) |
|
db.session.commit() |
|
|
|
return {"result": "success"} |
|
|
|
|
|
class AccountProfileApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def get(self): |
|
return current_user |
|
|
|
|
|
class AccountNameApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("name", type=str, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
|
|
if len(args["name"]) < 3 or len(args["name"]) > 30: |
|
raise ValueError("Account name must be between 3 and 30 characters.") |
|
|
|
updated_account = AccountService.update_account(current_user, name=args["name"]) |
|
|
|
return updated_account |
|
|
|
|
|
class AccountAvatarApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("avatar", type=str, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
updated_account = AccountService.update_account(current_user, avatar=args["avatar"]) |
|
|
|
return updated_account |
|
|
|
|
|
class AccountInterfaceLanguageApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("interface_language", type=supported_language, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
updated_account = AccountService.update_account(current_user, interface_language=args["interface_language"]) |
|
|
|
return updated_account |
|
|
|
|
|
class AccountInterfaceThemeApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("interface_theme", type=str, choices=["light", "dark"], required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
updated_account = AccountService.update_account(current_user, interface_theme=args["interface_theme"]) |
|
|
|
return updated_account |
|
|
|
|
|
class AccountTimezoneApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("timezone", type=str, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
|
|
if args["timezone"] not in pytz.all_timezones: |
|
raise ValueError("Invalid timezone string.") |
|
|
|
updated_account = AccountService.update_account(current_user, timezone=args["timezone"]) |
|
|
|
return updated_account |
|
|
|
|
|
class AccountPasswordApi(Resource): |
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(account_fields) |
|
def post(self): |
|
parser = reqparse.RequestParser() |
|
parser.add_argument("password", type=str, required=False, location="json") |
|
parser.add_argument("new_password", type=str, required=True, location="json") |
|
parser.add_argument("repeat_new_password", type=str, required=True, location="json") |
|
args = parser.parse_args() |
|
|
|
if args["new_password"] != args["repeat_new_password"]: |
|
raise RepeatPasswordNotMatchError() |
|
|
|
try: |
|
AccountService.update_account_password(current_user, args["password"], args["new_password"]) |
|
except ServiceCurrentPasswordIncorrectError: |
|
raise CurrentPasswordIncorrectError() |
|
|
|
return {"result": "success"} |
|
|
|
|
|
class AccountIntegrateApi(Resource): |
|
integrate_fields = { |
|
"provider": fields.String, |
|
"created_at": TimestampField, |
|
"is_bound": fields.Boolean, |
|
"link": fields.String, |
|
} |
|
|
|
integrate_list_fields = { |
|
"data": fields.List(fields.Nested(integrate_fields)), |
|
} |
|
|
|
@setup_required |
|
@login_required |
|
@account_initialization_required |
|
@marshal_with(integrate_list_fields) |
|
def get(self): |
|
account = current_user |
|
|
|
account_integrates = db.session.query(AccountIntegrate).filter(AccountIntegrate.account_id == account.id).all() |
|
|
|
base_url = request.url_root.rstrip("/") |
|
oauth_base_path = "/console/api/oauth/login" |
|
providers = ["github", "google"] |
|
|
|
integrate_data = [] |
|
for provider in providers: |
|
existing_integrate = next((ai for ai in account_integrates if ai.provider == provider), None) |
|
if existing_integrate: |
|
integrate_data.append( |
|
{ |
|
"id": existing_integrate.id, |
|
"provider": provider, |
|
"created_at": existing_integrate.created_at, |
|
"is_bound": True, |
|
"link": None, |
|
} |
|
) |
|
else: |
|
integrate_data.append( |
|
{ |
|
"id": None, |
|
"provider": provider, |
|
"created_at": None, |
|
"is_bound": False, |
|
"link": f"{base_url}{oauth_base_path}/{provider}", |
|
} |
|
) |
|
|
|
return {"data": integrate_data} |
|
|
|
|
|
|
|
api.add_resource(AccountInitApi, "/account/init") |
|
api.add_resource(AccountProfileApi, "/account/profile") |
|
api.add_resource(AccountNameApi, "/account/name") |
|
api.add_resource(AccountAvatarApi, "/account/avatar") |
|
api.add_resource(AccountInterfaceLanguageApi, "/account/interface-language") |
|
api.add_resource(AccountInterfaceThemeApi, "/account/interface-theme") |
|
api.add_resource(AccountTimezoneApi, "/account/timezone") |
|
api.add_resource(AccountPasswordApi, "/account/password") |
|
api.add_resource(AccountIntegrateApi, "/account/integrates") |
|
|
|
|
|
|