File size: 5,214 Bytes
f34bda5
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f34bda5
01e655b
 
 
 
 
c5458aa
 
 
01e655b
 
 
 
 
 
 
 
02e90e4
c5458aa
 
01e655b
 
 
02e90e4
01e655b
 
 
 
 
 
 
 
 
 
c5458aa
 
 
 
 
 
 
 
 
 
 
 
 
 
01e655b
c5458aa
 
 
 
01e655b
c5458aa
01e655b
 
 
 
 
 
 
02e90e4
 
01e655b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f34bda5
 
 
 
 
 
c5458aa
f34bda5
 
 
 
 
 
 
 
 
 
 
c5458aa
f34bda5
 
 
 
02e90e4
 
 
 
 
 
 
 
 
 
 
 
 
 
f34bda5
 
 
c5458aa
 
f34bda5
 
 
 
 
 
 
 
 
 
 
c5458aa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from fastapi import File, Form, HTTPException, Body, UploadFile
from fastapi.responses import StreamingResponse

import io
from numpy import clip
import soundfile as sf
from pydantic import BaseModel, Field
from fastapi.responses import FileResponse


from modules.synthesize_audio import synthesize_audio
from modules.normalization import text_normalize

from modules import generate_audio as generate


from typing import List, Literal, Optional, Union
import pyrubberband as pyrb

from modules.api import utils as api_utils
from modules.api.Api import APIManager

from modules.speaker import speaker_mgr
from modules.data import styles_mgr

import numpy as np


class AudioSpeechRequest(BaseModel):
    input: str  # 需要合成的文本
    model: str = "chattts-4w"
    voice: str = "female2"
    response_format: Literal["mp3", "wav"] = "mp3"
    speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
    seed: int = 42
    temperature: float = 0.3
    style: str = ""
    # 是否开启batch合成,小于等于1表示不适用batch
    # 开启batch合成会自动分割句子
    batch_size: int = Field(1, ge=1, le=20, description="Batch size")
    spliter_threshold: float = Field(
        100, ge=10, le=1024, description="Threshold for sentence spliter"
    )


async def openai_speech_api(
    request: AudioSpeechRequest = Body(
        ..., description="JSON body with model, input text, and voice"
    )
):
    model = request.model
    input_text = request.input
    voice = request.voice
    style = request.style
    response_format = request.response_format
    batch_size = request.batch_size
    spliter_threshold = request.spliter_threshold
    speed = request.speed
    speed = clip(speed, 0.1, 10)

    if not input_text:
        raise HTTPException(status_code=400, detail="Input text is required.")
    if speaker_mgr.get_speaker(voice) is None:
        raise HTTPException(status_code=400, detail="Invalid voice.")
    try:
        if style:
            styles_mgr.find_item_by_name(style)
    except:
        raise HTTPException(status_code=400, detail="Invalid style.")

    try:
        # Normalize the text
        text = text_normalize(input_text, is_end=True)

        # Calculate speaker and style based on input voice
        params = api_utils.calc_spk_style(spk=voice, style=style)

        spk = params.get("spk", -1)
        seed = params.get("seed", request.seed or 42)
        temperature = params.get("temperature", request.temperature or 0.3)
        prompt1 = params.get("prompt1", "")
        prompt2 = params.get("prompt2", "")
        prefix = params.get("prefix", "")

        # Generate audio
        sample_rate, audio_data = synthesize_audio(
            text,
            temperature=temperature,
            top_P=0.7,
            top_K=20,
            spk=spk,
            infer_seed=seed,
            batch_size=batch_size,
            spliter_threshold=spliter_threshold,
            prompt1=prompt1,
            prompt2=prompt2,
            prefix=prefix,
        )

        if speed != 1:
            audio_data = pyrb.time_stretch(audio_data, sample_rate, speed)

        # Convert audio data to wav format
        buffer = io.BytesIO()
        sf.write(buffer, audio_data, sample_rate, format="wav")
        buffer.seek(0)

        if response_format == "mp3":
            # Convert wav to mp3
            buffer = api_utils.wav_to_mp3(buffer)

        return StreamingResponse(buffer, media_type="audio/mp3")

    except Exception as e:
        import logging

        logging.exception(e)
        raise HTTPException(status_code=500, detail=str(e))


class TranscribeSegment(BaseModel):
    id: int
    seek: float
    start: float
    end: float
    text: str
    tokens: list[int]
    temperature: float
    avg_logprob: float
    compression_ratio: float
    no_speech_prob: float


class TranscriptionsVerboseResponse(BaseModel):
    task: str
    language: str
    duration: float
    text: str
    segments: list[TranscribeSegment]


def setup(app: APIManager):
    app.post(
        "/v1/audio/speech",
        response_class=FileResponse,
        description="""
openai api document: 
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)

以下属性为本系统自定义属性,不在openai文档中:
- batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
- spliter_threshold: 开启batch合成时,句子分割的阈值
- style: 风格

> model 可填任意值
        """,
    )(openai_speech_api)

    @app.post(
        "/v1/audio/transcriptions",
        response_model=TranscriptionsVerboseResponse,
        description="Transcribes audio into the input language.",
    )
    async def transcribe(
        file: UploadFile = File(...),
        model: str = Form(...),
        language: Optional[str] = Form(None),
        prompt: Optional[str] = Form(None),
        response_format: str = Form("json"),
        temperature: float = Form(0),
        timestamp_granularities: List[str] = Form(["segment"]),
    ):
        # TODO: Implement transcribe
        return api_utils.success_response("not implemented yet")