Spaces:
Paused
Paused
# | |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
import json | |
import re | |
from datetime import datetime | |
from flask import request, session, redirect | |
from werkzeug.security import generate_password_hash, check_password_hash | |
from flask_login import login_required, current_user, login_user, logout_user | |
from api.db.db_models import TenantLLM | |
from api.db.services.llm_service import TenantLLMService, LLMService | |
from api.utils.api_utils import server_error_response, validate_request | |
from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | |
from api.db import UserTenantRole, LLMType, FileType | |
from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ | |
API_KEY, \ | |
LLM_FACTORY, LLM_BASE_URL, RERANK_MDL | |
from api.db.services.user_service import UserService, TenantService, UserTenantService | |
from api.db.services.file_service import FileService | |
from api.settings import stat_logger | |
from api.utils.api_utils import get_json_result, cors_reponse | |
def login(): | |
login_channel = "password" | |
if not request.json: | |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |
retmsg='Unautherized!') | |
email = request.json.get('email', "") | |
users = UserService.query(email=email) | |
if not users: | |
return get_json_result( | |
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |
password = request.json.get('password') | |
try: | |
password = decrypt(password) | |
except BaseException: | |
return get_json_result( | |
data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') | |
user = UserService.query_user(email, password) | |
if user: | |
response_data = user.to_json() | |
user.access_token = get_uuid() | |
login_user(user) | |
user.update_time = current_timestamp(), | |
user.update_date = datetime_format(datetime.now()), | |
user.save() | |
msg = "Welcome back!" | |
return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) | |
else: | |
return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |
retmsg='Email and Password do not match!') | |
def github_callback(): | |
import requests | |
res = requests.post(GITHUB_OAUTH.get("url"), data={ | |
"client_id": GITHUB_OAUTH.get("client_id"), | |
"client_secret": GITHUB_OAUTH.get("secret_key"), | |
"code": request.args.get('code') | |
}, headers={"Accept": "application/json"}) | |
res = res.json() | |
if "error" in res: | |
return redirect("/?error=%s" % res["error_description"]) | |
if "user:email" not in res["scope"].split(","): | |
return redirect("/?error=user:email not in scope") | |
session["access_token"] = res["access_token"] | |
session["access_token_from"] = "github" | |
userinfo = user_info_from_github(session["access_token"]) | |
users = UserService.query(email=userinfo["email"]) | |
user_id = get_uuid() | |
if not users: | |
try: | |
try: | |
avatar = download_img(userinfo["avatar_url"]) | |
except Exception as e: | |
stat_logger.exception(e) | |
avatar = "" | |
users = user_register(user_id, { | |
"access_token": session["access_token"], | |
"email": userinfo["email"], | |
"avatar": avatar, | |
"nickname": userinfo["login"], | |
"login_channel": "github", | |
"last_login_time": get_format_time(), | |
"is_superuser": False, | |
}) | |
if not users: | |
raise Exception('Register user failure.') | |
if len(users) > 1: | |
raise Exception('Same E-mail exist!') | |
user = users[0] | |
login_user(user) | |
return redirect("/?auth=%s" % user.get_id()) | |
except Exception as e: | |
rollback_user_registration(user_id) | |
stat_logger.exception(e) | |
return redirect("/?error=%s" % str(e)) | |
user = users[0] | |
user.access_token = get_uuid() | |
login_user(user) | |
user.save() | |
return redirect("/?auth=%s" % user.get_id()) | |
def feishu_callback(): | |
import requests | |
app_access_token_res = requests.post(FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps({ | |
"app_id": FEISHU_OAUTH.get("app_id"), | |
"app_secret": FEISHU_OAUTH.get("app_secret") | |
}), headers={"Content-Type": "application/json; charset=utf-8"}) | |
app_access_token_res = app_access_token_res.json() | |
if app_access_token_res['code'] != 0: | |
return redirect("/?error=%s" % app_access_token_res) | |
res = requests.post(FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps({ | |
"grant_type": FEISHU_OAUTH.get("grant_type"), | |
"code": request.args.get('code') | |
}), headers={"Content-Type": "application/json; charset=utf-8", | |
'Authorization': f"Bearer {app_access_token_res['app_access_token']}"}) | |
res = res.json() | |
if res['code'] != 0: | |
return redirect("/?error=%s" % res["message"]) | |
if "contact:user.email:readonly" not in res["data"]["scope"].split(" "): | |
return redirect("/?error=contact:user.email:readonly not in scope") | |
session["access_token"] = res["data"]["access_token"] | |
session["access_token_from"] = "feishu" | |
userinfo = user_info_from_feishu(session["access_token"]) | |
users = UserService.query(email=userinfo["email"]) | |
user_id = get_uuid() | |
if not users: | |
try: | |
try: | |
avatar = download_img(userinfo["avatar_url"]) | |
except Exception as e: | |
stat_logger.exception(e) | |
avatar = "" | |
users = user_register(user_id, { | |
"access_token": session["access_token"], | |
"email": userinfo["email"], | |
"avatar": avatar, | |
"nickname": userinfo["en_name"], | |
"login_channel": "feishu", | |
"last_login_time": get_format_time(), | |
"is_superuser": False, | |
}) | |
if not users: | |
raise Exception('Register user failure.') | |
if len(users) > 1: | |
raise Exception('Same E-mail exist!') | |
user = users[0] | |
login_user(user) | |
return redirect("/?auth=%s" % user.get_id()) | |
except Exception as e: | |
rollback_user_registration(user_id) | |
stat_logger.exception(e) | |
return redirect("/?error=%s" % str(e)) | |
user = users[0] | |
user.access_token = get_uuid() | |
login_user(user) | |
user.save() | |
return redirect("/?auth=%s" % user.get_id()) | |
def user_info_from_feishu(access_token): | |
import requests | |
headers = {"Content-Type": "application/json; charset=utf-8", | |
'Authorization': f"Bearer {access_token}"} | |
res = requests.get( | |
f"https://open.feishu.cn/open-apis/authen/v1/user_info", | |
headers=headers) | |
user_info = res.json()["data"] | |
user_info["email"] = None if user_info.get("email") == "" else user_info["email"] | |
return user_info | |
def user_info_from_github(access_token): | |
import requests | |
headers = {"Accept": "application/json", | |
'Authorization': f"token {access_token}"} | |
res = requests.get( | |
f"https://api.github.com/user?access_token={access_token}", | |
headers=headers) | |
user_info = res.json() | |
email_info = requests.get( | |
f"https://api.github.com/user/emails?access_token={access_token}", | |
headers=headers).json() | |
user_info["email"] = next( | |
(email for email in email_info if email['primary'] == True), | |
None)["email"] | |
return user_info | |
def log_out(): | |
current_user.access_token = "" | |
current_user.save() | |
logout_user() | |
return get_json_result(data=True) | |
def setting_user(): | |
update_dict = {} | |
request_data = request.json | |
if request_data.get("password"): | |
new_password = request_data.get("new_password") | |
if not check_password_hash( | |
current_user.password, decrypt(request_data["password"])): | |
return get_json_result( | |
data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') | |
if new_password: | |
update_dict["password"] = generate_password_hash( | |
decrypt(new_password)) | |
for k in request_data.keys(): | |
if k in ["password", "new_password"]: | |
continue | |
update_dict[k] = request_data[k] | |
try: | |
UserService.update_by_id(current_user.id, update_dict) | |
return get_json_result(data=True) | |
except Exception as e: | |
stat_logger.exception(e) | |
return get_json_result( | |
data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) | |
def user_info(): | |
return get_json_result(data=current_user.to_dict()) | |
def rollback_user_registration(user_id): | |
try: | |
UserService.delete_by_id(user_id) | |
except Exception as e: | |
pass | |
try: | |
TenantService.delete_by_id(user_id) | |
except Exception as e: | |
pass | |
try: | |
u = UserTenantService.query(tenant_id=user_id) | |
if u: | |
UserTenantService.delete_by_id(u[0].id) | |
except Exception as e: | |
pass | |
try: | |
TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute() | |
except Exception as e: | |
pass | |
def user_register(user_id, user): | |
user["id"] = user_id | |
tenant = { | |
"id": user_id, | |
"name": user["nickname"] + "‘s Kingdom", | |
"llm_id": CHAT_MDL, | |
"embd_id": EMBEDDING_MDL, | |
"asr_id": ASR_MDL, | |
"parser_ids": PARSERS, | |
"img2txt_id": IMAGE2TEXT_MDL, | |
"rerank_id": RERANK_MDL | |
} | |
usr_tenant = { | |
"tenant_id": user_id, | |
"user_id": user_id, | |
"invited_by": user_id, | |
"role": UserTenantRole.OWNER | |
} | |
file_id = get_uuid() | |
file = { | |
"id": file_id, | |
"parent_id": file_id, | |
"tenant_id": user_id, | |
"created_by": user_id, | |
"name": "/", | |
"type": FileType.FOLDER.value, | |
"size": 0, | |
"location": "", | |
} | |
tenant_llm = [] | |
for llm in LLMService.query(fid=LLM_FACTORY): | |
tenant_llm.append({"tenant_id": user_id, | |
"llm_factory": LLM_FACTORY, | |
"llm_name": llm.llm_name, | |
"model_type": llm.model_type, | |
"api_key": API_KEY, | |
"api_base": LLM_BASE_URL | |
}) | |
if not UserService.save(**user): | |
return | |
TenantService.insert(**tenant) | |
UserTenantService.insert(**usr_tenant) | |
TenantLLMService.insert_many(tenant_llm) | |
FileService.insert(file) | |
return UserService.query(email=user["email"]) | |
def user_add(): | |
req = request.json | |
if UserService.query(email=req["email"]): | |
return get_json_result( | |
data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) | |
if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]): | |
return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!', | |
retcode=RetCode.OPERATING_ERROR) | |
user_dict = { | |
"access_token": get_uuid(), | |
"email": req["email"], | |
"nickname": req["nickname"], | |
"password": decrypt(req["password"]), | |
"login_channel": "password", | |
"last_login_time": get_format_time(), | |
"is_superuser": False, | |
} | |
user_id = get_uuid() | |
try: | |
users = user_register(user_id, user_dict) | |
if not users: | |
raise Exception('Register user failure.') | |
if len(users) > 1: | |
raise Exception('Same E-mail exist!') | |
user = users[0] | |
login_user(user) | |
return cors_reponse(data=user.to_json(), | |
auth=user.get_id(), retmsg="Welcome aboard!") | |
except Exception as e: | |
rollback_user_registration(user_id) | |
stat_logger.exception(e) | |
return get_json_result( | |
data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | |
def tenant_info(): | |
try: | |
tenants = TenantService.get_by_user_id(current_user.id)[0] | |
return get_json_result(data=tenants) | |
except Exception as e: | |
return server_error_response(e) | |
def set_tenant_info(): | |
req = request.json | |
try: | |
tid = req["tenant_id"] | |
del req["tenant_id"] | |
TenantService.update_by_id(tid, req) | |
return get_json_result(data=True) | |
except Exception as e: | |
return server_error_response(e) | |