File size: 4,078 Bytes
614861a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from typing import Annotated

from fastapi.encoders import jsonable_encoder
from fastapi.exceptions import RequestValidationError
from starlette.middleware.cors import CORSMiddleware
from fastapi import FastAPI, Header, UploadFile, Depends, HTTPException, status
import base64
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from starlette.responses import JSONResponse
import soundfile as sf
from collections import defaultdict

from model import SynthesisRequest, SynthesisResponse, TransferRequest, TransferResponse, LoginRequest, LoginResponse, BaseResponse
from google_sheet import create_repositories
from login import AuthService
from tts import TTSService

account_repo = create_repositories()
auth_service = AuthService(account_repo=account_repo)
tts_service = TTSService()


app = FastAPI()

auth = HTTPBearer()


@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc: HTTPException):
    return JSONResponse(
        status_code=status.HTTP_400_BAD_REQUEST,
        content=jsonable_encoder(BaseResponse(status=0, message=exc.detail))
    )

@app.exception_handler(RequestValidationError)
def validation_exception_handler(request, exc: RequestValidationError) -> JSONResponse:
    reformatted_message = defaultdict(list)
    for pydantic_error in exc.errors():
        loc, msg = pydantic_error["loc"], pydantic_error["msg"]
        filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc
        field_string = ".".join(filtered_loc)
        reformatted_message[field_string].append(msg)

    return JSONResponse(
        status_code=status.HTTP_400_BAD_REQUEST,
        content=jsonable_encoder(BaseResponse(status=0, message="Invalid request", result=reformatted_message))
    )
    # return JSONResponse(content=jsonable_encoder(BaseResponse(status=0, message="RequestValidationError", result=str(exc))))


async def get_current_user(access_token: Annotated[str, Header(convert_underscores=False)] = None):
    if access_token is None:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing")

    username = await auth_service.validate_token(access_token)
    if not username:
        raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
    return username


@app.post("/login", response_model=BaseResponse)
async def login(request: LoginRequest):
    email = request.email
    password = request.password
    user = await auth_service.authenticate_user(email, password)
    if not user:
        raise HTTPException(status_code=400, detail="Incorrect username or password")
    else:
        encoded_jwt = await auth_service.create_token(email)
        return BaseResponse(result={"access_token": encoded_jwt})


@app.post("/test-auth", response_model=BaseResponse)
def test_auth(username: str = Depends(get_current_user)):
    return BaseResponse(result={"email": username})


@app.post("/tts/sub-task-1", response_model=BaseResponse)
def synthesis(request: SynthesisRequest): # todo: , username: str = Depends(get_current_user)):
    audio_data = tts_service.synthesis(request.input_text)
    return BaseResponse(result=SynthesisResponse(data=audio_data))

@app.post("/tts/sub-task-2", response_model=BaseResponse)
async def transfer(input_text: str, ref_audio: UploadFile): # request: TransferRequest # todo: , username: str = Depends(get_current_user)):
    if ref_audio.content_type != "audio/mpeg":
        raise HTTPException(status_code=400, detail="Only audio files allowed")

    # ref_audio_contents = await request.ref_audio.read()
    # Convert the audio file to a NumPy array
    with sf.SoundFile(ref_audio.file, 'rb') as f:
        audio_np_array = f.read(dtype='float32')

        audio_out = tts_service.transfer(input_text, audio_np_array)
        audio_data = base64.b64encode(audio_out)
        return BaseResponse(result=TransferResponse(data=audio_data))

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)