Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Depends, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.middleware.trustedhost import TrustedHostMiddleware | |
from controller import LoginController, FileController, MySQLController, DefaultController, OTPController | |
import firebase_admin | |
from controller import ChatController | |
from firebase_admin import credentials | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from service import MySQLService,LoginService,ChatService | |
from request import RequestMySQL,RequestLogin,RequestDefault | |
from auth.authentication import decodeJWT | |
from repository import UserRepository | |
from auth import authentication | |
from datetime import datetime, timedelta | |
from fastapi import Depends, HTTPException, Form, File, UploadFile | |
from typing import List | |
from service import FileService,DefaultService,LoginService | |
from request import RequestFile,RequestChat,RequestDefault | |
from fastapi import FastAPI, Request, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic.error_wrappers import ErrorWrapper | |
from fastapi import Query | |
from typing import Optional | |
import json | |
from function import support_function | |
from response import ResponseDefault as res | |
app = FastAPI( | |
title="ChatBot HCMUTE", | |
description="Python ChatBot is intended for use in the topic Customizing chatbots. With the construction of 2 students Vo Nhu Y - 20133118 and Nguyen Quang Phuc 20133080", | |
swagger_ui_parameters={"syntaxHighlight.theme": "obsidian"}, | |
version="1.0.0", | |
contact={ | |
"name": "Vo Nhu Y", | |
"url": "https://pychatbot1.streamlit.app", | |
"email": "vonhuy5112002@gmail.com", | |
}, | |
license_info={ | |
"name": "Apache 2.0", | |
"url": "https://www.apache.org/licenses/LICENSE-2.0.html", | |
} | |
) | |
origins = [ | |
"https://kltn20133118-pychatbot.hf.space", | |
"http://localhost:8501" | |
] | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=origins, | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
ALLOWED_EXTENSIONS = {'csv', 'txt', 'doc', 'docx', 'pdf', 'xlsx', 'pptx', 'json', 'md', 'xlsx'} | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
if not firebase_admin._apps: | |
cred = credentials.Certificate("firebase_certificate.json") | |
fred = firebase_admin.initialize_app(cred) | |
class JWTBearer(HTTPBearer): | |
def __init__(self, auto_error: bool = True): | |
super(JWTBearer, self).__init__(auto_error=auto_error) | |
async def __call__(self, request: Request): | |
credentials: HTTPAuthorizationCredentials = await super(JWTBearer, self).__call__(request) | |
if credentials: | |
if credentials.scheme != "Bearer": | |
raise HTTPException(status_code=403, detail="Invalid authentication scheme.") | |
if not self.verify_accesstoken(credentials.credentials): | |
raise HTTPException(status_code=403, detail="Token does not exist") | |
if not self.verify_jwt(credentials.credentials): | |
raise HTTPException(status_code=403, detail="Invalid token or expired token.") | |
return credentials.credentials | |
else: | |
raise HTTPException(status_code=401, detail="Invalid authorization code.") | |
def verify_accesstoken(self, jwtoken: str) -> bool: | |
check = LoginService.check_token_is_valid(jwtoken) | |
return check | |
def verify_jwt(self, jwtoken: str) -> bool: | |
try: | |
payload = decodeJWT(jwtoken) | |
email_encode = payload.get('sub') | |
self.email = authentication.str_decode(email_encode) | |
return True | |
except Exception as e: | |
print(e) | |
return False | |
def get_current_user_email(credentials: str = Depends(JWTBearer())): | |
try: | |
payload = decodeJWT(credentials) | |
email_encode = payload.get('sub') | |
email = authentication.str_decode(email_encode) | |
return email | |
except Exception as e: | |
print(e) | |
raise HTTPException(status_code=403, detail="Invalid token or expired token.") | |
async def override_render_chat(user_id: Optional[str] = Query(None), current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
request = RequestMySQL.RequestRenderChatHistory(user_id=user_id) | |
return MySQLService.render_chat_history(request) | |
async def override_edit_chat(request: RequestMySQL.RequestEditNameChat, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
name_new = request.name_new | |
if name_new is None or name_new.strip() == "": | |
raise HTTPException(status_code=400, detail="name_new field is required.") | |
name_old = request.name_old | |
if name_old is None or name_old.strip() == "": | |
raise HTTPException(status_code=400, detail="name_old field is required.") | |
return MySQLService.edit_chat(request) | |
async def override_delete_chat(request: RequestMySQL.RequestDeleteChat, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
chat_name= request.chat_name | |
if chat_name is None or chat_name.strip() == "": | |
raise HTTPException(status_code=400, detail="chat_name field is required.") | |
return MySQLService.delete_chat(request) | |
async def override_load_chat(chat_id: Optional[str] = Query(None), user_id: Optional[str] = Query(None), current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
if chat_id is None or chat_id.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="chat_id field is required.")) | |
chat_id = chat_id.strip("'").strip('"') | |
try: | |
chat_id_int = int(chat_id) | |
except ValueError: | |
return res.ReponseError(status=400, | |
data=res.Message(message="chat_id must be an integer")) | |
if not support_function.is_positive_integer(chat_id_int): | |
return res.ReponseError(status=400, | |
data=res.Message(message="chat_id must be greater than 0")) | |
request = RequestMySQL.RequestLoadChatHistory(chat_id=chat_id,user_id = user_id) | |
return MySQLService.load_chat_history(request) | |
async def override_get_user(user_id: str = Query(None),current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
request = RequestDefault.RequestInfoUser(user_id=user_id) | |
return DefaultService.info_user(request) | |
async def override_update_user_info(request: RequestLogin.RequestUpdateUserInfo, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check != True: | |
return check | |
uid = request.uid | |
email = request.email | |
display_name = request.display_name | |
photo_url = request.photo_url | |
if uid is None or uid.strip() == "": | |
raise HTTPException(status_code=400, detail="uid field is required.") | |
if email is None or email.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="email field is required.")) | |
if display_name is None or display_name.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="display_name field is required.")) | |
if photo_url is None or photo_url.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="photo_url field is required.")) | |
return LoginService.update_user_info(request) | |
async def override_reset_password_firebase(request: RequestLogin.RequestChangePassword, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check != True: | |
return check | |
new_password = request.new_password | |
current_password = request.current_password | |
if new_password is None or new_password.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="new_password field is required.")) | |
if current_password is None or current_password.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="current_password field is required.")) | |
return LoginService.change_password(request) | |
async def override_delete_folder(request: RequestFile.RequestDeleteAllFile, current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(request.user_id, current_user_email) | |
if check != True: | |
return check | |
return FileService.deleteAllFile(request) | |
async def override_delete_one_file(request: RequestFile.RequestDeleteFile, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check != True: | |
return check | |
name_file = request.name_file | |
if name_file is None or name_file.strip() == "": | |
return res.ReponseError(status=400, | |
data=res.Message(message="name_file is required.")) | |
return FileService.deleteFile(request) | |
async def override_download_folder_from_dropbox(request: RequestFile.RequestDownLoadFolder, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
return FileService.download_folder(request) | |
async def override_download_file_by_id(request: RequestFile.RequestDownLoadFile, current_user_email: str = Depends(get_current_user_email)): | |
user_id = request.user_id | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
return FileService.download_file(request) | |
async def override_upload_files_dropbox( | |
user_id: str = Form(None), | |
files: List[UploadFile] = File(None), | |
current_user_email: str = Depends(get_current_user_email) | |
): | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
request = RequestFile.RequestUploadFile(files=files, user_id=user_id) | |
return FileService.upload_files(request) | |
async def override_handle_query2_upgrade_old(request: Request, user_id: str = Form(None), text_all: str = Form(...), question: str = Form(None), chat_name: str = Form(None), current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id, current_user_email) | |
if check is not True: | |
return check | |
request = RequestChat.RequestQuery2UpgradeOld(user_id=user_id, text_all=text_all, question=question, chat_name=chat_name) | |
return ChatService.query2_upgrade_old(request) | |
async def override_extract_file(user_id: str, current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id,current_user_email) | |
if check is not True: | |
return check | |
request = RequestChat.RequestExtractFile(user_id=user_id) | |
return ChatService.extract_file(request) | |
async def override_generate_question(user_id: str , current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id,current_user_email) | |
if check is not True: | |
return check | |
request = RequestChat.RequestGenerateQuestion(user_id=user_id) | |
return ChatService.generate_question(request) | |
async def override_upload_image(user_id: str = Form(None), file: UploadFile = File(...),current_user_email: str = Depends(get_current_user_email)): | |
check = support_function.check_value_user_id(user_id,current_user_email) | |
if check is not True: | |
return check | |
request = RequestDefault.RequestUpLoadImage(user_id=user_id, files=file) | |
return DefaultService.upload_image_service(request) | |
app.include_router(MySQLController.router, prefix="/api/mysql") | |
app.include_router(LoginController.router, prefix="/api/users") | |
app.include_router(FileController.router, prefix="/api/file") | |
app.include_router(ChatController.router, prefix="/api/chat") | |
app.include_router(DefaultController.router, prefix="/api/default") | |
routes_to_override = { | |
"/api/mysql/render_chat_history": {"GET"}, | |
"/api/mysql/load_chat_history": {"GET"}, | |
"/api/mysql/edit_chat/": {"PUT"}, | |
"/api/mysql/delete_chat/": {"DELETE"}, | |
"/api/users/update_user_info": {"POST"}, | |
"/api/users/change_password": {"PUT"}, | |
"/api/file/delete_all_file/": {"DELETE"}, | |
"/api/file/delete_one_file/": {"DELETE"}, | |
"/api/file/chatbot/download_folder/": {"POST"}, | |
"/api/file/chatbot/download_files/": {"POST"}, | |
"/api/file/upload_files/": {"POST"}, | |
"/api/chat/query2_upgrade/": {"POST"}, | |
"/api/chat/chatbot/query/": {"POST"}, | |
"/api/chat/chatbot/extract_file/": {"GET"}, | |
"/api/chat/chatbot/generate_question/": {"GET"}, | |
"/api/default/upload_image/": {"POST"}, | |
"/api/default/info_user": {"GET"} | |
} | |
app.router.routes = [ | |
route for route in app.router.routes | |
if not ( | |
route.path in routes_to_override and | |
route.methods.intersection(routes_to_override[route.path]) | |
) | |
] | |
app.add_api_route("/api/mysql/render_chat_history", override_render_chat, methods=["GET"], dependencies=[Depends(JWTBearer())], tags=["MySQL"]) | |
app.add_api_route("/api/mysql/load_chat_history", override_load_chat, methods=["GET"], dependencies=[Depends(JWTBearer())], tags=["MySQL"]) | |
app.add_api_route("/api/mysql/edit_chat/", override_edit_chat, methods=["PUT"], dependencies=[Depends(JWTBearer())], tags=["MySQL"]) | |
app.add_api_route("/api/mysql/delete_chat/", override_delete_chat, methods=["DELETE"], dependencies=[Depends(JWTBearer())], tags=["MySQL"]) | |
app.add_api_route("/api/users/update_user_info", override_update_user_info, methods=["POST"],dependencies=[Depends(JWTBearer())], tags=["Login"]) | |
app.add_api_route("/api/users/change_password", override_reset_password_firebase, methods=["PUT"],dependencies=[Depends(JWTBearer())], tags=["Login"]) | |
app.add_api_route("/api/file/delete_all_file/", override_delete_folder, methods=["DELETE"],dependencies=[Depends(JWTBearer())], tags=["File"]) | |
app.add_api_route("/api/file/delete_one_file/", override_delete_one_file, methods=["DELETE"],dependencies=[Depends(JWTBearer())], tags=["File"]) | |
app.add_api_route("/api/file/chatbot/download_folder/", override_download_folder_from_dropbox, methods=["POST"],dependencies=[Depends(JWTBearer())], tags=["File"]) | |
app.add_api_route("/api/file/chatbot/download_files/", override_download_file_by_id, methods=["POST"],dependencies=[Depends(JWTBearer())], tags=["File"]) | |
app.add_api_route("/api/file/upload_files/", override_upload_files_dropbox, methods=["POST"],dependencies=[Depends(JWTBearer())], tags=["File"]) | |
app.add_api_route("/api/chat/chatbot/query/", override_handle_query2_upgrade_old, methods=["POST"],dependencies=[Depends(JWTBearer())], tags=["Chat"]) | |
app.add_api_route("/api/chat/chatbot/extract_file/", override_extract_file, methods=["GET"], dependencies=[Depends(JWTBearer())], tags=["Chat"]) | |
app.add_api_route("/api/chat/chatbot/generate_question/", override_generate_question,methods=["GET"], dependencies=[Depends(JWTBearer())], tags=["Chat"]) | |
app.add_api_route("/api/default/upload_image/", override_upload_image, methods=["POST"], dependencies=[Depends(JWTBearer())], tags=["Default"]) | |
app.add_api_route("/api/default/info_user",override_get_user, methods=["GET"], dependencies=[Depends(JWTBearer())], tags=["Default"]) | |
app.include_router(OTPController.router, tags = ["OTP"], prefix="/api/otp") |