Spaces:
Sleeping
Sleeping
File size: 6,335 Bytes
bed01bd 01e655b d2b7e94 01e655b d2b7e94 d5d0921 01e655b bed01bd 01e655b d5d0921 01e655b d5d0921 01e655b d5d0921 01e655b bed01bd 01e655b ebc4336 d5d0921 ebc4336 d5d0921 ebc4336 d5d0921 ebc4336 d5d0921 ebc4336 d5d0921 ebc4336 01e655b d5d0921 01e655b 1df74c6 01e655b d5d0921 01e655b d5d0921 01e655b d5d0921 01e655b d5d0921 01e655b d5d0921 01e655b bed01bd 01e655b ebc4336 01e655b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import logging
from fastapi import Depends, HTTPException, Query
from fastapi.responses import FileResponse, StreamingResponse
from pydantic import BaseModel
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.api.impl.handler.TTSHandler import TTSHandler
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
from modules.api.impl.model.enhancer_model import EnhancerConfig
from modules.speaker import Speaker
logger = logging.getLogger(__name__)
class TTSParams(BaseModel):
text: str = Query(..., description="Text to synthesize")
spk: str = Query(
"female2", description="Specific speaker by speaker name or speaker seed"
)
style: str = Query("chat", description="Specific style by style name")
temperature: float = Query(
0.3, description="Temperature for sampling (may be overridden by style or spk)"
)
top_p: float = Query(
0.5, description="Top P for sampling (may be overridden by style or spk)"
)
top_k: int = Query(
20, description="Top K for sampling (may be overridden by style or spk)"
)
seed: int = Query(
42, description="Seed for generate (may be overridden by style or spk)"
)
format: str = Query("mp3", description="Response audio format: [mp3,wav]")
prompt1: str = Query("", description="Text prompt for inference")
prompt2: str = Query("", description="Text prompt for inference")
prefix: str = Query("", description="Text prefix for inference")
bs: str = Query("8", description="Batch size for inference")
thr: str = Query("100", description="Threshold for sentence spliter")
eos: str = Query("[uv_break]", description="End of sentence str")
enhance: bool = Query(False, description="Enable enhancer")
denoise: bool = Query(False, description="Enable denoiser")
speed: float = Query(1.0, description="Speed of the audio")
pitch: float = Query(0, description="Pitch of the audio")
volume_gain: float = Query(0, description="Volume gain of the audio")
stream: bool = Query(False, description="Stream the audio")
async def synthesize_tts(params: TTSParams = Depends()):
try:
# Validate text
if not params.text.strip():
raise HTTPException(
status_code=422, detail="Text parameter cannot be empty"
)
# Validate temperature
if not (0 <= params.temperature <= 1):
raise HTTPException(
status_code=422, detail="Temperature must be between 0 and 1"
)
# Validate top_p
if not (0 <= params.top_p <= 1):
raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
# Validate top_k
if params.top_k <= 0:
raise HTTPException(
status_code=422, detail="top_k must be a positive integer"
)
if params.top_k > 100:
raise HTTPException(
status_code=422, detail="top_k must be less than or equal to 100"
)
# Validate format
if params.format not in ["mp3", "wav"]:
raise HTTPException(
status_code=422,
detail="Invalid format. Supported formats are mp3 and wav",
)
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
spk = calc_params.get("spk", params.spk)
if not isinstance(spk, Speaker):
raise HTTPException(status_code=422, detail="Invalid speaker")
style = calc_params.get("style", params.style)
seed = params.seed or calc_params.get("seed", params.seed)
temperature = params.temperature or calc_params.get(
"temperature", params.temperature
)
prefix = params.prefix or calc_params.get("prefix", params.prefix)
prompt1 = params.prompt1 or calc_params.get("prompt1", params.prompt1)
prompt2 = params.prompt2 or calc_params.get("prompt2", params.prompt2)
eos = params.eos or ""
batch_size = int(params.bs)
threshold = int(params.thr)
tts_config = ChatTTSConfig(
style=style,
temperature=temperature,
top_k=params.top_k,
top_p=params.top_p,
prefix=prefix,
prompt1=prompt1,
prompt2=prompt2,
)
infer_config = InferConfig(
batch_size=batch_size,
spliter_threshold=threshold,
eos=eos,
seed=seed,
)
adjust_config = AdjustConfig(
pitch=params.pitch,
speed_rate=params.speed,
volume_gain_db=params.volume_gain,
)
enhancer_config = EnhancerConfig(
enabled=params.enhance or params.denoise or False,
lambd=0.9 if params.denoise else 0.1,
)
handler = TTSHandler(
text_content=params.text,
spk=spk,
tts_config=tts_config,
infer_config=infer_config,
adjust_config=adjust_config,
enhancer_config=enhancer_config,
)
media_type = f"audio/{params.format}"
if params.format == "mp3":
media_type = "audio/mpeg"
if params.stream:
if infer_config.batch_size != 1:
# 流式生成下仅支持 batch size 为 1,当前请求参数将被忽略
logger.warning(
f"Batch size {infer_config.batch_size} is not supported in streaming mode, will set to 1"
)
buffer_gen = handler.enqueue_to_stream(format=AudioFormat(params.format))
return StreamingResponse(buffer_gen, media_type=media_type)
else:
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
return StreamingResponse(buffer, media_type=media_type)
except Exception as e:
import logging
logging.exception(e)
if isinstance(e, HTTPException):
raise e
else:
raise HTTPException(status_code=500, detail=str(e))
def setup(api_manager: APIManager):
api_manager.get("/v1/tts", response_class=FileResponse)(synthesize_tts)
|