chattts / modules /api /impl /ssml_api.py
zhzluke96
update
d5b3cd8
raw
history blame
2.48 kB
from fastapi import HTTPException, Body
from fastapi.responses import StreamingResponse
import io
from pydantic import BaseModel
from fastapi.responses import FileResponse
from modules.normalization import text_normalize
from modules.ssml import parse_ssml
from modules.SynthesizeSegments import (
SynthesizeSegments,
synthesize_segment,
combine_audio_segments,
)
from modules.api import utils as api_utils
from modules.api.Api import APIManager
class SSMLRequest(BaseModel):
ssml: str
format: str = "mp3"
batch_size: int = 4
async def synthesize_ssml(
request: SSMLRequest = Body(
..., description="JSON body with SSML string and format"
)
):
try:
ssml = request.ssml
format = request.format
batch_size = request.batch_size
if batch_size < 1:
raise HTTPException(
status_code=400, detail="Batch size must be greater than 0."
)
if not ssml:
raise HTTPException(status_code=400, detail="SSML content is required.")
segments = parse_ssml(ssml)
for seg in segments:
seg["text"] = text_normalize(seg["text"], is_end=True)
if batch_size != 1:
synthesize = SynthesizeSegments(batch_size)
audio_segments = synthesize.synthesize_segments(segments)
combined_audio = combine_audio_segments(audio_segments)
buffer = io.BytesIO()
combined_audio.export(buffer, format="wav")
buffer.seek(0)
if format == "mp3":
buffer = api_utils.wav_to_mp3(buffer)
return StreamingResponse(buffer, media_type=f"audio/{format}")
else:
def audio_streamer():
for segment in segments:
audio_segment = synthesize_segment(segment=segment)
buffer = io.BytesIO()
audio_segment.export(buffer, format="wav")
buffer.seek(0)
if format == "mp3":
buffer = api_utils.wav_to_mp3(buffer)
yield buffer.read()
return StreamingResponse(audio_streamer(), media_type=f"audio/{format}")
except Exception as e:
import logging
logging.exception(e)
raise HTTPException(status_code=500, detail=str(e))
def setup(api_manager: APIManager):
api_manager.post("/v1/ssml", response_class=FileResponse)(synthesize_ssml)