|
from fastapi import File, Form, HTTPException, Body, UploadFile |
|
|
|
from numpy import clip |
|
from pydantic import BaseModel, Field |
|
from fastapi.responses import StreamingResponse |
|
|
|
|
|
from modules.api.impl.handler.TTSHandler import TTSHandler |
|
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat |
|
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig |
|
from modules.api.impl.model.enhancer_model import EnhancerConfig |
|
|
|
|
|
from typing import List, Optional |
|
|
|
from modules.api import utils as api_utils |
|
from modules.api.Api import APIManager |
|
|
|
from modules.speaker import Speaker, speaker_mgr |
|
from modules.data import styles_mgr |
|
|
|
|
|
class AudioSpeechRequest(BaseModel): |
|
input: str |
|
model: str = "chattts-4w" |
|
voice: str = "female2" |
|
response_format: AudioFormat = "mp3" |
|
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio") |
|
seed: int = 42 |
|
|
|
temperature: float = 0.3 |
|
top_k: int = 20 |
|
top_p: float = 0.7 |
|
|
|
style: str = "" |
|
batch_size: int = Field(1, ge=1, le=20, description="Batch size") |
|
spliter_threshold: float = Field( |
|
100, ge=10, le=1024, description="Threshold for sentence spliter" |
|
) |
|
|
|
eos: str = "[uv_break]" |
|
|
|
enhance: bool = False |
|
denoise: bool = False |
|
|
|
|
|
async def openai_speech_api( |
|
request: AudioSpeechRequest = Body( |
|
..., description="JSON body with model, input text, and voice" |
|
) |
|
): |
|
model = request.model |
|
input_text = request.input |
|
voice = request.voice |
|
style = request.style |
|
eos = request.eos |
|
seed = request.seed |
|
|
|
response_format = request.response_format |
|
if not isinstance(response_format, AudioFormat) and isinstance( |
|
response_format, str |
|
): |
|
response_format = AudioFormat(response_format) |
|
|
|
batch_size = request.batch_size |
|
spliter_threshold = request.spliter_threshold |
|
speed = request.speed |
|
speed = clip(speed, 0.1, 10) |
|
|
|
if not input_text: |
|
raise HTTPException(status_code=400, detail="Input text is required.") |
|
if speaker_mgr.get_speaker(voice) is None: |
|
raise HTTPException(status_code=400, detail="Invalid voice.") |
|
try: |
|
if style: |
|
styles_mgr.find_item_by_name(style) |
|
except: |
|
raise HTTPException(status_code=400, detail="Invalid style.") |
|
|
|
ctx_params = api_utils.calc_spk_style(spk=voice, style=style) |
|
|
|
speaker = ctx_params.get("spk") |
|
if not isinstance(speaker, Speaker): |
|
raise HTTPException(status_code=400, detail="Invalid voice.") |
|
|
|
tts_config = ChatTTSConfig( |
|
style=style, |
|
temperature=request.temperature, |
|
top_k=request.top_k, |
|
top_p=request.top_p, |
|
) |
|
infer_config = InferConfig( |
|
batch_size=batch_size, |
|
spliter_threshold=spliter_threshold, |
|
eos=eos, |
|
seed=seed, |
|
) |
|
adjust_config = AdjustConfig(speaking_rate=speed) |
|
enhancer_config = EnhancerConfig( |
|
enabled=request.enhance or request.denoise or False, |
|
lambd=0.9 if request.denoise else 0.1, |
|
) |
|
try: |
|
handler = TTSHandler( |
|
text_content=input_text, |
|
spk=speaker, |
|
tts_config=tts_config, |
|
infer_config=infer_config, |
|
adjust_config=adjust_config, |
|
enhancer_config=enhancer_config, |
|
) |
|
|
|
buffer = handler.enqueue_to_buffer(response_format) |
|
|
|
mime_type = f"audio/{response_format.value}" |
|
if response_format == AudioFormat.mp3: |
|
mime_type = "audio/mpeg" |
|
return StreamingResponse(buffer, media_type=mime_type) |
|
|
|
except Exception as e: |
|
import logging |
|
|
|
logging.exception(e) |
|
|
|
if isinstance(e, HTTPException): |
|
raise e |
|
else: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
class TranscribeSegment(BaseModel): |
|
id: int |
|
seek: float |
|
start: float |
|
end: float |
|
text: str |
|
tokens: list[int] |
|
temperature: float |
|
avg_logprob: float |
|
compression_ratio: float |
|
no_speech_prob: float |
|
|
|
|
|
class TranscriptionsVerboseResponse(BaseModel): |
|
task: str |
|
language: str |
|
duration: float |
|
text: str |
|
segments: list[TranscribeSegment] |
|
|
|
|
|
def setup(app: APIManager): |
|
app.post( |
|
"/v1/audio/speech", |
|
description=""" |
|
openai api document: |
|
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech) |
|
|
|
以下属性为本系统自定义属性,不在openai文档中: |
|
- batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐) |
|
- spliter_threshold: 开启batch合成时,句子分割的阈值 |
|
- style: 风格 |
|
|
|
> model 可填任意值 |
|
""", |
|
)(openai_speech_api) |
|
|
|
@app.post( |
|
"/v1/audio/transcriptions", |
|
response_model=TranscriptionsVerboseResponse, |
|
description="Transcribes audio into the input language.", |
|
) |
|
async def transcribe( |
|
file: UploadFile = File(...), |
|
model: str = Form(...), |
|
language: Optional[str] = Form(None), |
|
prompt: Optional[str] = Form(None), |
|
response_format: str = Form("json"), |
|
temperature: float = Form(0), |
|
timestamp_granularities: List[str] = Form(["segment"]), |
|
): |
|
|
|
return api_utils.success_response("not implemented yet") |
|
|