hahunavth commited on
Commit
614861a
1 Parent(s): 11cbbad

add source code

Browse files
Files changed (9) hide show
  1. .gitignore +4 -0
  2. Dockerfile +20 -0
  3. credentials.json +13 -0
  4. google_sheet.py +53 -0
  5. login.py +47 -0
  6. main.py +103 -0
  7. model.py +36 -0
  8. requirements.txt +10 -0
  9. tts.py +78 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+
2
+ output.mp3
3
+ .idea
4
+ __pycache__
Dockerfile ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9
2
+
3
+ WORKDIR /code
4
+
5
+ COPY ./requirements.txt /code/requirements.txt
6
+
7
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
8
+
9
+ RUN useradd -m -u 1000 user
10
+
11
+ USER user
12
+
13
+ ENV HOME=/home/user \
14
+ PATH=/home/user/.local/bin:$PATH
15
+
16
+ WORKDIR $HOME/app
17
+
18
+ COPY --chown=user . $HOME/app
19
+
20
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
credentials.json ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "type": "service_account",
3
+ "project_id": "vlsp2023-tts-api",
4
+ "private_key_id": "411335f1452da99ac34dbaeaaea212dab3ff2400",
5
+ "private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQClPtZ/kYuBkYrP\nzbfJ/ye2DmHm4xt2o0Fvdqnuddri63662YA0GmQ+3RpH/yMg4b9Kgcyl0GXJ6Q6r\nZZzdueg3Edx5B9DkBWgYkPaTLFFe5XXOj1w2vgzc8UBN5+oomzS/WaTCpKmrs/D1\nVi05jt7Uo9mrALUkPD8wr2D9vkXQdM1G4NNxuFWCDoAF0MoVKcS5Le3f6bH4OgpL\nl5BniW4sh/UCLcWrWfXS6tuzK+8z/L6AI1ezNUJHlt58AzLjLi5Q8fOdeXviJT60\nf4w22W4gtY5YHdVCG7b7frM7ig6V9bWyVvHjGWqVkmXjYepOCIWnwQHVqja6NpVg\nR45HForBAgMBAAECggEATeqhTamdNE0iPPXtcWvEl82UUEBKFNjJ4/r6CZy8xz7v\nlL81+ltvZUzwNX6SW9DWWBV4H79yH5CrABp7qvkcC8t6P/91ee8qtFq2SZMeEzbz\nI6DphE581jlTbuiputfkOU3VqInoDzRbq/Mkg/1gCLfxzPYac6mMyjIH892iIbYP\nhlexZiXhZWRLqGFu9J3JHQAjlAtCxjhb/vBXkxzxO2j1MXv7BNrp6t7MAgFvX57k\ngRfIg/rmWohHni/NxQgWWQqTydy2JvxEVUUdPr2hgWF7hHgMk4WFbb7hzy6wH/T3\nMcjd5OHMHZcBF/NgrnxMD9NRuhD1cSbSNETJVTt9WQKBgQDS9I3MGOz5/j+jkWY7\n4CdjbbLCgat9RvvyBAkSonogLhVap1/jfKK4XQz1+aDFFY9YWW4MX/hCxKWOM5Fw\nnvDka/m2yjxha66ismflDdxvpfcjAvAt6WNX6cbiPNxFBkpJXQRH55ug9a9PKjXY\nUtc7ZOM5R/NyxfsqPjNm01faVwKBgQDIh6X0ypwCKr0CF8GP7gYPfVeK1Zh22s21\neIOEyGLB40HWOsi+PQxQHkva8teVaGMPKqppdxjs+d26d3h4Lo1rSYQrF9x+BoGm\nUU+d+3vfp1aLOj3VHMIKHKrQVOF0JSy/VdYFkz6vu6nFy3n4WdLg0XXFIQTS4Z+d\nFAD81nhEpwKBgQCqI9WNa/kNM7MuACH9Xq9F8P7BA4ZFVw/yxLBwmBx5gdF1ORMM\nTcSLf3jpljjFW7suHYq1bl2ztBh2lT7TH03YXQGdHIUQaaIC1HMY+VH1tlyZn1AJ\nJ3gZOpJOe5mIDiex/dRrDfCmJCENb1TYMRAodhkRZOeDhQwqqNoaL5BmpwKBgQCM\nWMQB67v8mETorg++2GxNcwBOHugyZzkKBWqnCEh2QsPVWBcfbkKr4Ehe2Q+hdgm+\nl7HlVoGPeeGBnBQoqQw5Rp7GOlELsyoSaV47x8MO6WNc1kpoWVRFF4NFg+K3Ez2a\nPE0qYb/B5qoP0TVwaA17Y531dgKWRWsc2N9IFiLeiQKBgQDNc2ygLHTkAKmqltcF\nzcccEnUTabPHvFkt6qycmx6oOrF6IXr1/x9ODZnwo/2JD5H/Mbub47M2k3I+BDdW\nwb0ETSEW3rdrtk85yPx3eVDIVtHoEr/Zhr9a9EYK+V8uAku3UM0L0TrnHtgqbM0W\nyGvFQOUt1hYxgoydw/Lx7h+1dA==\n-----END PRIVATE KEY-----\n",
6
+ "client_email": "hahunavth@vlsp2023-tts-api.iam.gserviceaccount.com",
7
+ "client_id": "102309854252079057442",
8
+ "auth_uri": "https://accounts.google.com/o/oauth2/auth",
9
+ "token_uri": "https://oauth2.googleapis.com/token",
10
+ "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
11
+ "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/hahunavth%40vlsp2023-tts-api.iam.gserviceaccount.com",
12
+ "universe_domain": "googleapis.com"
13
+ }
google_sheet.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gspread
2
+ from oauth2client.service_account import ServiceAccountCredentials
3
+ from typing import Dict
4
+
5
+ class SheetCRUDRepository:
6
+ def __init__(self, worksheet):
7
+ self.worksheet = worksheet
8
+ self.titles = self.worksheet.row_values(1) # Assuming titles are in the first row
9
+ assert len(set(self.titles)) == len(self.titles), f"Failed to init {SheetCRUDRepository.__class__}, titles: {self.titles} contain duplicated values!"
10
+
11
+ def create(self, data: Dict):
12
+ values = [data.get(title, '') for title in self.titles]
13
+ self.worksheet.append_row(values)
14
+
15
+ def read(self, row_index: int) -> Dict:
16
+ values = self.worksheet.row_values(row_index)
17
+ return {title: value for title, value in zip(self.titles, values)}
18
+
19
+ def update(self, row_index: int, data: Dict):
20
+ values = [data.get(title, '') for title in self.titles]
21
+ self.worksheet.update(f"A{row_index}:Z{row_index}", [values])
22
+
23
+ def delete(self, row_index: int):
24
+ self.worksheet.delete_row(row_index)
25
+
26
+ def find(self, search_dict):
27
+ for col_title, value in search_dict.items():
28
+ if col_title in self.titles:
29
+ col_index = self.titles.index(col_title) + 1 # Adding 1 to match gspread indexing
30
+ cell = self.worksheet.find(value, in_column=col_index)
31
+ if cell is None:
32
+ break
33
+ row_number = cell.row
34
+ return row_number, self.read(row_number)
35
+ return None
36
+
37
+ def create_repositories():
38
+ scope = [
39
+ 'https://www.googleapis.com/auth/spreadsheets',
40
+ 'https://www.googleapis.com/auth/drive'
41
+ ]
42
+ creds = ServiceAccountCredentials.from_json_keyfile_name('credentials.json', scope)
43
+ client = gspread.authorize(creds)
44
+ sheet_url = "https://docs.google.com/spreadsheets/d/17OxKF0iP_aJJ0HCgJkwFsH762EUrtcEIYcPmyiiKnaM"
45
+ sheet = client.open_by_url(sheet_url)
46
+ worksheet = sheet.get_worksheet(0)
47
+ account_repository = SheetCRUDRepository(worksheet)
48
+ return account_repository
49
+
50
+
51
+ if __name__ == "__main__":
52
+ a = create_repositories()
53
+ print(a)
login.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta
2
+ from jose import JWTError, jwt, ExpiredSignatureError
3
+ # from passlib.context import CryptContext
4
+
5
+
6
+ class AuthService:
7
+ def __init__(self, account_repo, secret_key="123"):
8
+ assert account_repo is not None
9
+ assert secret_key is not None
10
+
11
+ self.account_repo = account_repo
12
+ self.secret_key = secret_key
13
+ self.encode_alg = "HS256"
14
+ # self.pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
15
+
16
+ async def authenticate_user(self, email, password):
17
+ key, user = self.account_repo.find({"email": email})
18
+ if not user:
19
+ return False
20
+
21
+ assert 'password' in user.keys()
22
+ if user['password'] == password:
23
+ return True
24
+ # if not self.pwd_context.verify(password, user["hashed_password"]):
25
+ # return False
26
+ return user
27
+
28
+ async def create_token(self, email):
29
+ expire = datetime.utcnow() + timedelta(minutes=30)
30
+ encoded_jwt = jwt.encode(
31
+ {"sub": email, "exp": expire},
32
+ self.secret_key,
33
+ algorithm=self.encode_alg
34
+ )
35
+ return encoded_jwt
36
+
37
+ async def validate_token(self, encoded_token):
38
+ try:
39
+ decoded_token = jwt.decode(encoded_token, self.secret_key, algorithms=[self.encode_alg])
40
+ key, user = self.account_repo.find({"email": decoded_token['sub']})
41
+ return user['email']
42
+ except ExpiredSignatureError:
43
+ return False # Expired
44
+ except JWTError:
45
+ return False
46
+ except Exception as e:
47
+ raise e
main.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated
2
+
3
+ from fastapi.encoders import jsonable_encoder
4
+ from fastapi.exceptions import RequestValidationError
5
+ from starlette.middleware.cors import CORSMiddleware
6
+ from fastapi import FastAPI, Header, UploadFile, Depends, HTTPException, status
7
+ import base64
8
+ from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
9
+ from starlette.responses import JSONResponse
10
+ import soundfile as sf
11
+ from collections import defaultdict
12
+
13
+ from model import SynthesisRequest, SynthesisResponse, TransferRequest, TransferResponse, LoginRequest, LoginResponse, BaseResponse
14
+ from google_sheet import create_repositories
15
+ from login import AuthService
16
+ from tts import TTSService
17
+
18
+ account_repo = create_repositories()
19
+ auth_service = AuthService(account_repo=account_repo)
20
+ tts_service = TTSService()
21
+
22
+
23
+ app = FastAPI()
24
+
25
+ auth = HTTPBearer()
26
+
27
+
28
+ @app.exception_handler(HTTPException)
29
+ async def http_exception_handler(request, exc: HTTPException):
30
+ return JSONResponse(
31
+ status_code=status.HTTP_400_BAD_REQUEST,
32
+ content=jsonable_encoder(BaseResponse(status=0, message=exc.detail))
33
+ )
34
+
35
+ @app.exception_handler(RequestValidationError)
36
+ def validation_exception_handler(request, exc: RequestValidationError) -> JSONResponse:
37
+ reformatted_message = defaultdict(list)
38
+ for pydantic_error in exc.errors():
39
+ loc, msg = pydantic_error["loc"], pydantic_error["msg"]
40
+ filtered_loc = loc[1:] if loc[0] in ("body", "query", "path") else loc
41
+ field_string = ".".join(filtered_loc)
42
+ reformatted_message[field_string].append(msg)
43
+
44
+ return JSONResponse(
45
+ status_code=status.HTTP_400_BAD_REQUEST,
46
+ content=jsonable_encoder(BaseResponse(status=0, message="Invalid request", result=reformatted_message))
47
+ )
48
+ # return JSONResponse(content=jsonable_encoder(BaseResponse(status=0, message="RequestValidationError", result=str(exc))))
49
+
50
+
51
+ async def get_current_user(access_token: Annotated[str, Header(convert_underscores=False)] = None):
52
+ if access_token is None:
53
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Token missing")
54
+
55
+ username = await auth_service.validate_token(access_token)
56
+ if not username:
57
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Token")
58
+ return username
59
+
60
+
61
+ @app.post("/login", response_model=BaseResponse)
62
+ async def login(request: LoginRequest):
63
+ email = request.email
64
+ password = request.password
65
+ user = await auth_service.authenticate_user(email, password)
66
+ if not user:
67
+ raise HTTPException(status_code=400, detail="Incorrect username or password")
68
+ else:
69
+ encoded_jwt = await auth_service.create_token(email)
70
+ return BaseResponse(result={"access_token": encoded_jwt})
71
+
72
+
73
+ @app.post("/test-auth", response_model=BaseResponse)
74
+ def test_auth(username: str = Depends(get_current_user)):
75
+ return BaseResponse(result={"email": username})
76
+
77
+
78
+ @app.post("/tts/sub-task-1", response_model=BaseResponse)
79
+ def synthesis(request: SynthesisRequest): # todo: , username: str = Depends(get_current_user)):
80
+ audio_data = tts_service.synthesis(request.input_text)
81
+ return BaseResponse(result=SynthesisResponse(data=audio_data))
82
+
83
+ @app.post("/tts/sub-task-2", response_model=BaseResponse)
84
+ async def transfer(input_text: str, ref_audio: UploadFile): # request: TransferRequest # todo: , username: str = Depends(get_current_user)):
85
+ if ref_audio.content_type != "audio/mpeg":
86
+ raise HTTPException(status_code=400, detail="Only audio files allowed")
87
+
88
+ # ref_audio_contents = await request.ref_audio.read()
89
+ # Convert the audio file to a NumPy array
90
+ with sf.SoundFile(ref_audio.file, 'rb') as f:
91
+ audio_np_array = f.read(dtype='float32')
92
+
93
+ audio_out = tts_service.transfer(input_text, audio_np_array)
94
+ audio_data = base64.b64encode(audio_out)
95
+ return BaseResponse(result=TransferResponse(data=audio_data))
96
+
97
+ app.add_middleware(
98
+ CORSMiddleware,
99
+ allow_origins=["*"],
100
+ allow_credentials=True,
101
+ allow_methods=["*"],
102
+ allow_headers=["*"],
103
+ )
model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import UploadFile
2
+ from pydantic import BaseModel
3
+
4
+
5
+ class SynthesisRequest(BaseModel):
6
+ input_text: str
7
+ emotion: str
8
+
9
+
10
+ class SynthesisResponse(BaseModel):
11
+ data: str
12
+
13
+
14
+ class TransferRequest(BaseModel):
15
+ input_text: str
16
+ ref_audio: UploadFile
17
+
18
+
19
+
20
+ class TransferResponse(BaseModel):
21
+ data: str
22
+
23
+
24
+ class LoginRequest(BaseModel):
25
+ email: str
26
+ password: str
27
+
28
+
29
+ class LoginResponse(BaseModel):
30
+ access_token: str
31
+
32
+
33
+ class BaseResponse(BaseModel):
34
+ status: int = 1
35
+ message: str = ""
36
+ result: object = None
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn
3
+ python-multipart
4
+ passlib
5
+ python-jose
6
+ motor
7
+ pyjwt
8
+ gspread
9
+ google-api-python-client
10
+ oauth2client
tts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from abc import ABC, abstractmethod
3
+
4
+ from gtts import gTTS
5
+ from io import BytesIO
6
+ import numpy as np
7
+
8
+
9
+ class ExpressiveModel(ABC):
10
+ @abstractmethod
11
+ def load(self):
12
+ pass
13
+
14
+ @abstractmethod
15
+ def synthesize(self, text: str, emotion: str):
16
+ """
17
+ Synthesis audio with emotion
18
+ :param text: (str)
19
+ :param emotion: (str) neutral | happy | ...
20
+ :return: np.array
21
+ """
22
+ pass
23
+
24
+
25
+ class StyleTransferModel(ABC):
26
+ @abstractmethod
27
+ def load(self):
28
+ pass
29
+
30
+ @abstractmethod
31
+ def synthesize(self, text: str, ref_audio):
32
+ """
33
+ Synthesis audio with reference audio
34
+ :param text: (str)
35
+ :param ref_audio: (np.array)
36
+ :return: np.array
37
+ """
38
+ pass
39
+
40
+
41
+ class TTSService:
42
+ """
43
+ Get input text (str), emotion label (str) or reference audio (np.array)
44
+ Synthesis audio (np.array)
45
+ Convert audio to base64
46
+ """
47
+ @staticmethod
48
+ def synthesis(text: str) -> str:
49
+ tts = gTTS(text)
50
+ # Using in-memory handling
51
+ audio_data = BytesIO()
52
+ tts.write_to_fp(audio_data)
53
+ encoded_audio = base64.b64encode(audio_data.getvalue()).decode('utf-8')
54
+ return encoded_audio
55
+ #
56
+ # tts.save("output.mp3")
57
+ # with open("output.mp3", "rb") as audio_file:
58
+ # audio_data = audio_file.read()
59
+ # encoded_audio = base64.b64encode(audio_data).decode('utf-8')
60
+ # return encoded_audio
61
+
62
+ @staticmethod
63
+ def transfer(input_text: str, ref_audio: np.array) -> str:
64
+ # Process reference audio
65
+ # ..
66
+ # np.array to audio
67
+ # tts_output_np_array = np.array([0, 1, 0, 1])
68
+ # tts_output_bytes = tts_output_np_array.tobytes()
69
+ # audio_data = base64.b64encode(tts_output_bytes).decode('utf-8')
70
+ # return audio_data
71
+ #
72
+ # example
73
+ tts_text = input_text
74
+ tts = gTTS(tts_text)
75
+ audio_data = BytesIO()
76
+ tts.write_to_fp(audio_data)
77
+ encoded_audio = base64.b64encode(audio_data.getvalue()).decode('utf-8')
78
+ return encoded_audio