from typing import Annotated from fastapi.encoders import jsonable_encoder from fastapi.exceptions import RequestValidationError from starlette.middleware.cors import CORSMiddleware from fastapi import FastAPI, Header, UploadFile, Depends, HTTPException, status import base64 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from starlette.responses import JSONResponse import soundfile as sf from collections import defaultdict from model import SynthesisRequest, SynthesisResponse, TransferRequest, TransferResponse, LoginRequest, LoginResponse, BaseResponse from google_sheet import create_repositories from login import AuthService from tts import TTSService account_repo = create_repositories() auth_service = AuthService(account_repo=account_repo) tts_service = TTSService() app = FastAPI() auth = HTTPBearer() @app.exception_handler(HTTPException) async def http_exception_handler(request, exc: HTTPException): return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=jsonable_encoder(BaseResponse(status=0, message=exc.detail)) ) @app.exception_handler(RequestValidationError) def validation_exception_handler(request, exc: RequestValidationError) -> JSONResponse: reformatted_message = defaultdict(list) for pydantic_error in exc.errors(): loc, msg = pydantic_error["loc"], pydantic_error["msg"] filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc field_string = ".".join(filtered_loc) reformatted_message[field_string].append(msg) return JSONResponse( status_code=status.HTTP_400_BAD_REQUEST, content=jsonable_encoder(BaseResponse(status=0, message="Invalid request", result=reformatted_message)) ) # return JSONResponse(content=jsonable_encoder(BaseResponse(status=0, message="RequestValidationError", result=str(exc)))) async def get_current_user(access_token: Annotated[str, Header(convert_underscores=False)] = None): if access_token is None: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing") username = await auth_service.validate_token(access_token) if not username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token") return username @app.post("/login", response_model=BaseResponse) async def login(request: LoginRequest): email = request.email password = request.password user = await auth_service.authenticate_user(email, password) if not user: raise HTTPException(status_code=400, detail="Incorrect username or password") else: encoded_jwt = await auth_service.create_token(email) return BaseResponse(result={"access_token": encoded_jwt}) @app.post("/test-auth", response_model=BaseResponse) def test_auth(username: str = Depends(get_current_user)): return BaseResponse(result={"email": username}) @app.post("/tts/sub-task-1", response_model=BaseResponse) def synthesis(request: SynthesisRequest): # todo: , username: str = Depends(get_current_user)): audio_data = tts_service.synthesis(request.input_text) return BaseResponse(result=SynthesisResponse(data=audio_data)) @app.post("/tts/sub-task-2", response_model=BaseResponse) async def transfer(input_text: str, ref_audio: UploadFile): # request: TransferRequest # todo: , username: str = Depends(get_current_user)): if ref_audio.content_type != "audio/mpeg": raise HTTPException(status_code=400, detail="Only audio files allowed") # ref_audio_contents = await request.ref_audio.read() # Convert the audio file to a NumPy array with sf.SoundFile(ref_audio.file, 'rb') as f: audio_np_array = f.read(dtype='float32') audio_out = tts_service.transfer(input_text, audio_np_array) audio_data = base64.b64encode(audio_out) return BaseResponse(result=TransferResponse(data=audio_data)) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], )