Spaces:
				
			
			
	
			
			
		Paused
		
	
	
	
			
			
	
	
	
	
		
		
		Paused
		
	| # | |
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import json | |
| import re | |
| from datetime import datetime | |
| from flask import request, session, redirect | |
| from werkzeug.security import generate_password_hash, check_password_hash | |
| from flask_login import login_required, current_user, login_user, logout_user | |
| from api.db.db_models import TenantLLM | |
| from api.db.services.llm_service import TenantLLMService, LLMService | |
| from api.utils.api_utils import server_error_response, validate_request | |
| from api.utils import get_uuid, get_format_time, decrypt, download_img, current_timestamp, datetime_format | |
| from api.db import UserTenantRole, LLMType, FileType | |
| from api.settings import RetCode, GITHUB_OAUTH, FEISHU_OAUTH, CHAT_MDL, EMBEDDING_MDL, ASR_MDL, IMAGE2TEXT_MDL, PARSERS, \ | |
| API_KEY, \ | |
| LLM_FACTORY, LLM_BASE_URL, RERANK_MDL | |
| from api.db.services.user_service import UserService, TenantService, UserTenantService | |
| from api.db.services.file_service import FileService | |
| from api.settings import stat_logger | |
| from api.utils.api_utils import get_json_result, cors_reponse | |
| def login(): | |
| login_channel = "password" | |
| if not request.json: | |
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |
| retmsg='Unautherized!') | |
| email = request.json.get('email', "") | |
| users = UserService.query(email=email) | |
| if not users: | |
| return get_json_result( | |
| data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg=f'This Email is not registered!') | |
| password = request.json.get('password') | |
| try: | |
| password = decrypt(password) | |
| except BaseException: | |
| return get_json_result( | |
| data=False, retcode=RetCode.SERVER_ERROR, retmsg='Fail to crypt password') | |
| user = UserService.query_user(email, password) | |
| if user: | |
| response_data = user.to_json() | |
| user.access_token = get_uuid() | |
| login_user(user) | |
| user.update_time = current_timestamp(), | |
| user.update_date = datetime_format(datetime.now()), | |
| user.save() | |
| msg = "Welcome back!" | |
| return cors_reponse(data=response_data, auth=user.get_id(), retmsg=msg) | |
| else: | |
| return get_json_result(data=False, retcode=RetCode.AUTHENTICATION_ERROR, | |
| retmsg='Email and Password do not match!') | |
| def github_callback(): | |
| import requests | |
| res = requests.post(GITHUB_OAUTH.get("url"), data={ | |
| "client_id": GITHUB_OAUTH.get("client_id"), | |
| "client_secret": GITHUB_OAUTH.get("secret_key"), | |
| "code": request.args.get('code') | |
| }, headers={"Accept": "application/json"}) | |
| res = res.json() | |
| if "error" in res: | |
| return redirect("/?error=%s" % res["error_description"]) | |
| if "user:email" not in res["scope"].split(","): | |
| return redirect("/?error=user:email not in scope") | |
| session["access_token"] = res["access_token"] | |
| session["access_token_from"] = "github" | |
| userinfo = user_info_from_github(session["access_token"]) | |
| users = UserService.query(email=userinfo["email"]) | |
| user_id = get_uuid() | |
| if not users: | |
| try: | |
| try: | |
| avatar = download_img(userinfo["avatar_url"]) | |
| except Exception as e: | |
| stat_logger.exception(e) | |
| avatar = "" | |
| users = user_register(user_id, { | |
| "access_token": session["access_token"], | |
| "email": userinfo["email"], | |
| "avatar": avatar, | |
| "nickname": userinfo["login"], | |
| "login_channel": "github", | |
| "last_login_time": get_format_time(), | |
| "is_superuser": False, | |
| }) | |
| if not users: | |
| raise Exception('Register user failure.') | |
| if len(users) > 1: | |
| raise Exception('Same E-mail exist!') | |
| user = users[0] | |
| login_user(user) | |
| return redirect("/?auth=%s" % user.get_id()) | |
| except Exception as e: | |
| rollback_user_registration(user_id) | |
| stat_logger.exception(e) | |
| return redirect("/?error=%s" % str(e)) | |
| user = users[0] | |
| user.access_token = get_uuid() | |
| login_user(user) | |
| user.save() | |
| return redirect("/?auth=%s" % user.get_id()) | |
| def feishu_callback(): | |
| import requests | |
| app_access_token_res = requests.post(FEISHU_OAUTH.get("app_access_token_url"), data=json.dumps({ | |
| "app_id": FEISHU_OAUTH.get("app_id"), | |
| "app_secret": FEISHU_OAUTH.get("app_secret") | |
| }), headers={"Content-Type": "application/json; charset=utf-8"}) | |
| app_access_token_res = app_access_token_res.json() | |
| if app_access_token_res['code'] != 0: | |
| return redirect("/?error=%s" % app_access_token_res) | |
| res = requests.post(FEISHU_OAUTH.get("user_access_token_url"), data=json.dumps({ | |
| "grant_type": FEISHU_OAUTH.get("grant_type"), | |
| "code": request.args.get('code') | |
| }), headers={"Content-Type": "application/json; charset=utf-8", | |
| 'Authorization': f"Bearer {app_access_token_res['app_access_token']}"}) | |
| res = res.json() | |
| if res['code'] != 0: | |
| return redirect("/?error=%s" % res["message"]) | |
| if "contact:user.email:readonly" not in res["data"]["scope"].split(" "): | |
| return redirect("/?error=contact:user.email:readonly not in scope") | |
| session["access_token"] = res["data"]["access_token"] | |
| session["access_token_from"] = "feishu" | |
| userinfo = user_info_from_feishu(session["access_token"]) | |
| users = UserService.query(email=userinfo["email"]) | |
| user_id = get_uuid() | |
| if not users: | |
| try: | |
| try: | |
| avatar = download_img(userinfo["avatar_url"]) | |
| except Exception as e: | |
| stat_logger.exception(e) | |
| avatar = "" | |
| users = user_register(user_id, { | |
| "access_token": session["access_token"], | |
| "email": userinfo["email"], | |
| "avatar": avatar, | |
| "nickname": userinfo["en_name"], | |
| "login_channel": "feishu", | |
| "last_login_time": get_format_time(), | |
| "is_superuser": False, | |
| }) | |
| if not users: | |
| raise Exception('Register user failure.') | |
| if len(users) > 1: | |
| raise Exception('Same E-mail exist!') | |
| user = users[0] | |
| login_user(user) | |
| return redirect("/?auth=%s" % user.get_id()) | |
| except Exception as e: | |
| rollback_user_registration(user_id) | |
| stat_logger.exception(e) | |
| return redirect("/?error=%s" % str(e)) | |
| user = users[0] | |
| user.access_token = get_uuid() | |
| login_user(user) | |
| user.save() | |
| return redirect("/?auth=%s" % user.get_id()) | |
| def user_info_from_feishu(access_token): | |
| import requests | |
| headers = {"Content-Type": "application/json; charset=utf-8", | |
| 'Authorization': f"Bearer {access_token}"} | |
| res = requests.get( | |
| f"https://open.feishu.cn/open-apis/authen/v1/user_info", | |
| headers=headers) | |
| user_info = res.json()["data"] | |
| user_info["email"] = None if user_info.get("email") == "" else user_info["email"] | |
| return user_info | |
| def user_info_from_github(access_token): | |
| import requests | |
| headers = {"Accept": "application/json", | |
| 'Authorization': f"token {access_token}"} | |
| res = requests.get( | |
| f"https://api.github.com/user?access_token={access_token}", | |
| headers=headers) | |
| user_info = res.json() | |
| email_info = requests.get( | |
| f"https://api.github.com/user/emails?access_token={access_token}", | |
| headers=headers).json() | |
| user_info["email"] = next( | |
| (email for email in email_info if email['primary'] == True), | |
| None)["email"] | |
| return user_info | |
| def log_out(): | |
| current_user.access_token = "" | |
| current_user.save() | |
| logout_user() | |
| return get_json_result(data=True) | |
| def setting_user(): | |
| update_dict = {} | |
| request_data = request.json | |
| if request_data.get("password"): | |
| new_password = request_data.get("new_password") | |
| if not check_password_hash( | |
| current_user.password, decrypt(request_data["password"])): | |
| return get_json_result( | |
| data=False, retcode=RetCode.AUTHENTICATION_ERROR, retmsg='Password error!') | |
| if new_password: | |
| update_dict["password"] = generate_password_hash( | |
| decrypt(new_password)) | |
| for k in request_data.keys(): | |
| if k in ["password", "new_password"]: | |
| continue | |
| update_dict[k] = request_data[k] | |
| try: | |
| UserService.update_by_id(current_user.id, update_dict) | |
| return get_json_result(data=True) | |
| except Exception as e: | |
| stat_logger.exception(e) | |
| return get_json_result( | |
| data=False, retmsg='Update failure!', retcode=RetCode.EXCEPTION_ERROR) | |
| def user_info(): | |
| return get_json_result(data=current_user.to_dict()) | |
| def rollback_user_registration(user_id): | |
| try: | |
| UserService.delete_by_id(user_id) | |
| except Exception as e: | |
| pass | |
| try: | |
| TenantService.delete_by_id(user_id) | |
| except Exception as e: | |
| pass | |
| try: | |
| u = UserTenantService.query(tenant_id=user_id) | |
| if u: | |
| UserTenantService.delete_by_id(u[0].id) | |
| except Exception as e: | |
| pass | |
| try: | |
| TenantLLM.delete().where(TenantLLM.tenant_id == user_id).execute() | |
| except Exception as e: | |
| pass | |
| def user_register(user_id, user): | |
| user["id"] = user_id | |
| tenant = { | |
| "id": user_id, | |
| "name": user["nickname"] + "‘s Kingdom", | |
| "llm_id": CHAT_MDL, | |
| "embd_id": EMBEDDING_MDL, | |
| "asr_id": ASR_MDL, | |
| "parser_ids": PARSERS, | |
| "img2txt_id": IMAGE2TEXT_MDL, | |
| "rerank_id": RERANK_MDL | |
| } | |
| usr_tenant = { | |
| "tenant_id": user_id, | |
| "user_id": user_id, | |
| "invited_by": user_id, | |
| "role": UserTenantRole.OWNER | |
| } | |
| file_id = get_uuid() | |
| file = { | |
| "id": file_id, | |
| "parent_id": file_id, | |
| "tenant_id": user_id, | |
| "created_by": user_id, | |
| "name": "/", | |
| "type": FileType.FOLDER.value, | |
| "size": 0, | |
| "location": "", | |
| } | |
| tenant_llm = [] | |
| for llm in LLMService.query(fid=LLM_FACTORY): | |
| tenant_llm.append({"tenant_id": user_id, | |
| "llm_factory": LLM_FACTORY, | |
| "llm_name": llm.llm_name, | |
| "model_type": llm.model_type, | |
| "api_key": API_KEY, | |
| "api_base": LLM_BASE_URL | |
| }) | |
| if not UserService.save(**user): | |
| return | |
| TenantService.insert(**tenant) | |
| UserTenantService.insert(**usr_tenant) | |
| TenantLLMService.insert_many(tenant_llm) | |
| FileService.insert(file) | |
| return UserService.query(email=user["email"]) | |
| def user_add(): | |
| req = request.json | |
| if UserService.query(email=req["email"]): | |
| return get_json_result( | |
| data=False, retmsg=f'Email: {req["email"]} has already registered!', retcode=RetCode.OPERATING_ERROR) | |
| if not re.match(r"^[\w\._-]+@([\w_-]+\.)+[\w-]{2,4}$", req["email"]): | |
| return get_json_result(data=False, retmsg=f'Invaliad e-mail: {req["email"]}!', | |
| retcode=RetCode.OPERATING_ERROR) | |
| user_dict = { | |
| "access_token": get_uuid(), | |
| "email": req["email"], | |
| "nickname": req["nickname"], | |
| "password": decrypt(req["password"]), | |
| "login_channel": "password", | |
| "last_login_time": get_format_time(), | |
| "is_superuser": False, | |
| } | |
| user_id = get_uuid() | |
| try: | |
| users = user_register(user_id, user_dict) | |
| if not users: | |
| raise Exception('Register user failure.') | |
| if len(users) > 1: | |
| raise Exception('Same E-mail exist!') | |
| user = users[0] | |
| login_user(user) | |
| return cors_reponse(data=user.to_json(), | |
| auth=user.get_id(), retmsg="Welcome aboard!") | |
| except Exception as e: | |
| rollback_user_registration(user_id) | |
| stat_logger.exception(e) | |
| return get_json_result( | |
| data=False, retmsg='User registration failure!', retcode=RetCode.EXCEPTION_ERROR) | |
| def tenant_info(): | |
| try: | |
| tenants = TenantService.get_by_user_id(current_user.id)[0] | |
| return get_json_result(data=tenants) | |
| except Exception as e: | |
| return server_error_response(e) | |
| def set_tenant_info(): | |
| req = request.json | |
| try: | |
| tid = req["tenant_id"] | |
| del req["tenant_id"] | |
| TenantService.update_by_id(tid, req) | |
| return get_json_result(data=True) | |
| except Exception as e: | |
| return server_error_response(e) | |