|
from fastapi import Depends, HTTPException, Query |
|
from fastapi.responses import StreamingResponse |
|
|
|
import io |
|
from pydantic import BaseModel |
|
import soundfile as sf |
|
from fastapi.responses import FileResponse |
|
|
|
|
|
from modules.normalization import text_normalize |
|
|
|
from modules import generate_audio as generate |
|
|
|
from modules.api import utils as api_utils |
|
from modules.api.Api import APIManager |
|
from modules.synthesize_audio import synthesize_audio |
|
|
|
|
|
class TTSParams(BaseModel): |
|
text: str = Query(..., description="Text to synthesize") |
|
spk: str = Query( |
|
"female2", description="Specific speaker by speaker name or speaker seed" |
|
) |
|
style: str = Query("chat", description="Specific style by style name") |
|
temperature: float = Query( |
|
0.3, description="Temperature for sampling (may be overridden by style or spk)" |
|
) |
|
top_P: float = Query( |
|
0.5, description="Top P for sampling (may be overridden by style or spk)" |
|
) |
|
top_K: int = Query( |
|
20, description="Top K for sampling (may be overridden by style or spk)" |
|
) |
|
seed: int = Query( |
|
42, description="Seed for generate (may be overridden by style or spk)" |
|
) |
|
format: str = Query("mp3", description="Response audio format: [mp3,wav]") |
|
prompt1: str = Query("", description="Text prompt for inference") |
|
prompt2: str = Query("", description="Text prompt for inference") |
|
prefix: str = Query("", description="Text prefix for inference") |
|
bs: str = Query("8", description="Batch size for inference") |
|
thr: str = Query("100", description="Threshold for sentence spliter") |
|
|
|
|
|
async def synthesize_tts(params: TTSParams = Depends()): |
|
try: |
|
text = text_normalize(params.text, is_end=False) |
|
|
|
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style) |
|
|
|
spk = calc_params.get("spk", params.spk) |
|
seed = params.seed or calc_params.get("seed", params.seed) |
|
temperature = params.temperature or calc_params.get( |
|
"temperature", params.temperature |
|
) |
|
prefix = params.prefix or calc_params.get("prefix", params.prefix) |
|
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1) |
|
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2) |
|
|
|
batch_size = int(params.bs) |
|
threshold = int(params.thr) |
|
|
|
sample_rate, audio_data = synthesize_audio( |
|
text, |
|
temperature=temperature, |
|
top_P=params.top_P, |
|
top_K=params.top_K, |
|
spk=spk, |
|
infer_seed=seed, |
|
prompt1=prompt1, |
|
prompt2=prompt2, |
|
prefix=prefix, |
|
batch_size=batch_size, |
|
spliter_threshold=threshold, |
|
) |
|
|
|
buffer = io.BytesIO() |
|
sf.write(buffer, audio_data, sample_rate, format="wav") |
|
buffer.seek(0) |
|
|
|
if format == "mp3": |
|
buffer = api_utils.wav_to_mp3(buffer) |
|
|
|
return StreamingResponse(buffer, media_type="audio/wav") |
|
|
|
except Exception as e: |
|
import logging |
|
|
|
logging.exception(e) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
def setup(api_manager: APIManager): |
|
api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts) |
|
|