ChatTTS-Forge / modules /api /impl /openai_api.py
zhzluke96
update
c5458aa
raw
history blame
No virus
5.21 kB
from fastapi import File, Form, HTTPException, Body, UploadFile
from fastapi.responses import StreamingResponse
import io
from numpy import clip
import soundfile as sf
from pydantic import BaseModel, Field
from fastapi.responses import FileResponse
from modules.synthesize_audio import synthesize_audio
from modules.normalization import text_normalize
from modules import generate_audio as generate
from typing import List, Literal, Optional, Union
import pyrubberband as pyrb
from modules.api import utils as api_utils
from modules.api.Api import APIManager
from modules.speaker import speaker_mgr
from modules.data import styles_mgr
import numpy as np
class AudioSpeechRequest(BaseModel):
input: str # 需要合成的文本
model: str = "chattts-4w"
voice: str = "female2"
response_format: Literal["mp3", "wav"] = "mp3"
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
seed: int = 42
temperature: float = 0.3
style: str = ""
# 是否开启batch合成,小于等于1表示不适用batch
# 开启batch合成会自动分割句子
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
spliter_threshold: float = Field(
100, ge=10, le=1024, description="Threshold for sentence spliter"
)
async def openai_speech_api(
request: AudioSpeechRequest = Body(
..., description="JSON body with model, input text, and voice"
)
):
model = request.model
input_text = request.input
voice = request.voice
style = request.style
response_format = request.response_format
batch_size = request.batch_size
spliter_threshold = request.spliter_threshold
speed = request.speed
speed = clip(speed, 0.1, 10)
if not input_text:
raise HTTPException(status_code=400, detail="Input text is required.")
if speaker_mgr.get_speaker(voice) is None:
raise HTTPException(status_code=400, detail="Invalid voice.")
try:
if style:
styles_mgr.find_item_by_name(style)
except:
raise HTTPException(status_code=400, detail="Invalid style.")
try:
# Normalize the text
text = text_normalize(input_text, is_end=True)
# Calculate speaker and style based on input voice
params = api_utils.calc_spk_style(spk=voice, style=style)
spk = params.get("spk", -1)
seed = params.get("seed", request.seed or 42)
temperature = params.get("temperature", request.temperature or 0.3)
prompt1 = params.get("prompt1", "")
prompt2 = params.get("prompt2", "")
prefix = params.get("prefix", "")
# Generate audio
sample_rate, audio_data = synthesize_audio(
text,
temperature=temperature,
top_P=0.7,
top_K=20,
spk=spk,
infer_seed=seed,
batch_size=batch_size,
spliter_threshold=spliter_threshold,
prompt1=prompt1,
prompt2=prompt2,
prefix=prefix,
)
if speed != 1:
audio_data = pyrb.time_stretch(audio_data, sample_rate, speed)
# Convert audio data to wav format
buffer = io.BytesIO()
sf.write(buffer, audio_data, sample_rate, format="wav")
buffer.seek(0)
if response_format == "mp3":
# Convert wav to mp3
buffer = api_utils.wav_to_mp3(buffer)
return StreamingResponse(buffer, media_type="audio/mp3")
except Exception as e:
import logging
logging.exception(e)
raise HTTPException(status_code=500, detail=str(e))
class TranscribeSegment(BaseModel):
id: int
seek: float
start: float
end: float
text: str
tokens: list[int]
temperature: float
avg_logprob: float
compression_ratio: float
no_speech_prob: float
class TranscriptionsVerboseResponse(BaseModel):
task: str
language: str
duration: float
text: str
segments: list[TranscribeSegment]
def setup(app: APIManager):
app.post(
"/v1/audio/speech",
response_class=FileResponse,
description="""
openai api document:
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
以下属性为本系统自定义属性,不在openai文档中:
- batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
- spliter_threshold: 开启batch合成时,句子分割的阈值
- style: 风格
> model 可填任意值
""",
)(openai_speech_api)
@app.post(
"/v1/audio/transcriptions",
response_model=TranscriptionsVerboseResponse,
description="Transcribes audio into the input language.",
)
async def transcribe(
file: UploadFile = File(...),
model: str = Form(...),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: str = Form("json"),
temperature: float = Form(0),
timestamp_granularities: List[str] = Form(["segment"]),
):
# TODO: Implement transcribe
return api_utils.success_response("not implemented yet")