zhzluke96
commited on
Commit
·
d5d0921
1
Parent(s):
2be0618
update
Browse files- data/speakers/Bob_ft10.pt +3 -0
- modules/ChatTTS/ChatTTS/core.py +1 -1
- modules/SynthesizeSegments.py +40 -7
- modules/api/app_config.py +2 -2
- modules/api/impl/google_api.py +66 -107
- modules/api/impl/handler/AudioHandler.py +37 -0
- modules/api/impl/handler/SSMLHandler.py +94 -0
- modules/api/impl/handler/TTSHandler.py +97 -0
- modules/api/impl/model/audio_model.py +14 -0
- modules/api/impl/model/chattts_model.py +19 -0
- modules/api/impl/model/enhancer_model.py +11 -0
- modules/api/impl/openai_api.py +57 -56
- modules/api/impl/refiner_api.py +1 -0
- modules/api/impl/ssml_api.py +30 -25
- modules/api/impl/tts_api.py +58 -31
- modules/api/impl/xtts_v2_api.py +52 -6
- modules/api/utils.py +2 -11
- modules/devices/devices.py +7 -1
- modules/finetune/train_speaker.py +18 -11
- modules/prompts/news_oral_prompt.txt +14 -0
- modules/prompts/podcast_prompt.txt +1 -0
- modules/ssml_parser/SSMLParser.py +1 -4
- modules/webui/speaker/speaker_editor.py +1 -1
- modules/webui/speaker/speaker_merger.py +2 -6
data/speakers/Bob_ft10.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91015b82a99c40034048090228b6d647ab99fd7b86e8babd6a7c3a9236e8d800
|
3 |
+
size 4508
|
modules/ChatTTS/ChatTTS/core.py
CHANGED
@@ -17,7 +17,7 @@ from .infer.api import refine_text, infer_code
|
|
17 |
|
18 |
from huggingface_hub import snapshot_download
|
19 |
|
20 |
-
logging.basicConfig(level=logging.
|
21 |
|
22 |
|
23 |
class Chat:
|
|
|
17 |
|
18 |
from huggingface_hub import snapshot_download
|
19 |
|
20 |
+
logging.basicConfig(level=logging.INFO)
|
21 |
|
22 |
|
23 |
class Chat:
|
modules/SynthesizeSegments.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
import copy
|
|
|
2 |
from box import Box
|
3 |
from pydub import AudioSegment
|
4 |
from typing import List, Union
|
@@ -160,7 +161,21 @@ class SynthesizeSegments:
|
|
160 |
for i in range(0, len(bucket), self.batch_size):
|
161 |
batch = bucket[i : i + self.batch_size]
|
162 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
params = param_arr[0]
|
166 |
audio_datas = generate_audio.generate_audio_batch(
|
@@ -182,6 +197,7 @@ class SynthesizeSegments:
|
|
182 |
|
183 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
184 |
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
|
|
185 |
original_index = src_segments.index(segment)
|
186 |
audio_segments[original_index] = audio_segment
|
187 |
|
@@ -226,13 +242,30 @@ class SynthesizeSegments:
|
|
226 |
|
227 |
sentences = spliter.parse(text)
|
228 |
for sentence in sentences:
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
params=copy.copy(segment.params),
|
234 |
-
)
|
235 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
return ret_segments
|
238 |
|
|
|
1 |
import copy
|
2 |
+
import re
|
3 |
from box import Box
|
4 |
from pydub import AudioSegment
|
5 |
from typing import List, Union
|
|
|
161 |
for i in range(0, len(bucket), self.batch_size):
|
162 |
batch = bucket[i : i + self.batch_size]
|
163 |
param_arr = [self.segment_to_generate_params(segment) for segment in batch]
|
164 |
+
|
165 |
+
def append_eos(text: str):
|
166 |
+
text = text.strip()
|
167 |
+
eos_arr = ["[uv_break]", "[v_break]", "[lbreak]", "[llbreak]"]
|
168 |
+
has_eos = False
|
169 |
+
for eos in eos_arr:
|
170 |
+
if eos in text:
|
171 |
+
has_eos = True
|
172 |
+
break
|
173 |
+
if not has_eos:
|
174 |
+
text += self.eos
|
175 |
+
return text
|
176 |
+
|
177 |
+
# 这里会添加 end_of_text 到 text 之后
|
178 |
+
texts = [append_eos(params.text) for params in param_arr]
|
179 |
|
180 |
params = param_arr[0]
|
181 |
audio_datas = generate_audio.generate_audio_batch(
|
|
|
197 |
|
198 |
audio_segment = audio_data_to_segment(audio_data, sr)
|
199 |
audio_segment = apply_prosody(audio_segment, rate, volume, pitch)
|
200 |
+
# compare by Box object
|
201 |
original_index = src_segments.index(segment)
|
202 |
audio_segments[original_index] = audio_segment
|
203 |
|
|
|
242 |
|
243 |
sentences = spliter.parse(text)
|
244 |
for sentence in sentences:
|
245 |
+
seg = SSMLSegment(
|
246 |
+
text=sentence,
|
247 |
+
attrs=segment.attrs.copy(),
|
248 |
+
params=copy.copy(segment.params),
|
|
|
|
|
249 |
)
|
250 |
+
ret_segments.append(seg)
|
251 |
+
setattr(seg, "_idx", len(ret_segments) - 1)
|
252 |
+
|
253 |
+
def is_none_speak_segment(segment: SSMLSegment):
|
254 |
+
text = segment.text.strip()
|
255 |
+
regexp = r"\[[^\]]+?\]"
|
256 |
+
text = re.sub(regexp, "", text)
|
257 |
+
text = text.strip()
|
258 |
+
if not text:
|
259 |
+
return True
|
260 |
+
return False
|
261 |
+
|
262 |
+
# 将 none_speak 合并到前一个 speak segment
|
263 |
+
for i in range(1, len(ret_segments)):
|
264 |
+
if is_none_speak_segment(ret_segments[i]):
|
265 |
+
ret_segments[i - 1].text += ret_segments[i].text
|
266 |
+
ret_segments[i].text = ""
|
267 |
+
# 移除空的 segment
|
268 |
+
ret_segments = [seg for seg in ret_segments if seg.text.strip()]
|
269 |
|
270 |
return ret_segments
|
271 |
|
modules/api/app_config.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
app_description = """
|
2 |
-
ChatTTS-Forge
|
3 |
-
ChatTTS-Forge is a
|
4 |
|
5 |
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
6 |
|
|
|
1 |
app_description = """
|
2 |
+
🍦 ChatTTS-Forge 是一个围绕 TTS 生成模型 ChatTTS 开发的项目,实现了 API Server 和 基于 Gradio 的 WebUI。<br/>
|
3 |
+
🍦 ChatTTS-Forge is a project developed around the TTS generation model ChatTTS, implementing an API Server and a Gradio-based WebUI.
|
4 |
|
5 |
项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
|
6 |
|
modules/api/impl/google_api.py
CHANGED
@@ -1,38 +1,25 @@
|
|
1 |
-
import
|
2 |
-
from typing import Literal
|
3 |
from fastapi import HTTPException
|
4 |
|
5 |
-
import io
|
6 |
-
import soundfile as sf
|
7 |
from pydantic import BaseModel
|
8 |
|
9 |
|
10 |
-
from modules.Enhancer.ResembleEnhance import (
|
11 |
-
apply_audio_enhance,
|
12 |
-
apply_audio_enhance_full,
|
13 |
-
)
|
14 |
from modules.api.Api import APIManager
|
15 |
-
from modules.
|
16 |
-
from modules.
|
17 |
-
from modules.
|
18 |
-
from modules.
|
|
|
19 |
|
20 |
-
from modules import
|
21 |
-
from modules.speaker import speaker_mgr
|
22 |
|
23 |
|
24 |
-
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
25 |
-
from modules.SynthesizeSegments import (
|
26 |
-
SynthesizeSegments,
|
27 |
-
combine_audio_segments,
|
28 |
-
)
|
29 |
-
|
30 |
from modules.api import utils as api_utils
|
31 |
|
32 |
|
33 |
class SynthesisInput(BaseModel):
|
34 |
-
text: str =
|
35 |
-
ssml: str =
|
36 |
|
37 |
|
38 |
class VoiceSelectionParams(BaseModel):
|
@@ -50,24 +37,15 @@ class VoiceSelectionParams(BaseModel):
|
|
50 |
|
51 |
|
52 |
class AudioConfig(BaseModel):
|
53 |
-
audioEncoding:
|
54 |
speakingRate: float = 1
|
55 |
pitch: float = 0
|
56 |
volumeGainDb: float = 0
|
57 |
sampleRateHertz: int = 24000
|
58 |
-
batchSize: int =
|
59 |
spliterThreshold: int = 100
|
60 |
|
61 |
|
62 |
-
class EnhancerConfig(BaseModel):
|
63 |
-
enabled: bool = False
|
64 |
-
model: str = "resemble-enhance"
|
65 |
-
nfe: int = 32
|
66 |
-
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
|
67 |
-
lambd: float = 0.5
|
68 |
-
tau: float = 0.5
|
69 |
-
|
70 |
-
|
71 |
class GoogleTextSynthesizeRequest(BaseModel):
|
72 |
input: SynthesisInput
|
73 |
voice: VoiceSelectionParams
|
@@ -92,7 +70,11 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
92 |
voice_name = voice.name
|
93 |
infer_seed = voice.seed or 42
|
94 |
eos = voice.eos or "[uv_break]"
|
95 |
-
audio_format = audioConfig.audioEncoding
|
|
|
|
|
|
|
|
|
96 |
speaking_rate = audioConfig.speakingRate or 1
|
97 |
pitch = audioConfig.pitch or 0
|
98 |
volume_gain_db = audioConfig.volumeGainDb or 0
|
@@ -101,6 +83,7 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
101 |
|
102 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
103 |
|
|
|
104 |
sample_rate = audioConfig.sampleRateHertz or 24000
|
105 |
|
106 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
@@ -111,92 +94,68 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
|
|
111 |
status_code=422, detail="The specified voice name is not supported."
|
112 |
)
|
113 |
|
114 |
-
if
|
115 |
raise HTTPException(
|
116 |
-
status_code=422, detail="
|
117 |
)
|
118 |
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
123 |
try:
|
124 |
if input.text:
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
top_P=voice.topP if voice.topP else params.get("top_p", 0.7),
|
135 |
-
top_K=voice.topK if voice.topK else params.get("top_k", 20),
|
136 |
-
spk=params.get("spk", -1),
|
137 |
-
infer_seed=infer_seed,
|
138 |
-
prompt1=params.get("prompt1", ""),
|
139 |
-
prompt2=params.get("prompt2", ""),
|
140 |
-
prefix=params.get("prefix", ""),
|
141 |
-
batch_size=batch_size,
|
142 |
-
spliter_threshold=spliter_threshold,
|
143 |
-
end_of_sentence=eos,
|
144 |
)
|
145 |
|
146 |
-
|
147 |
-
|
148 |
-
segments = parser.parse(input.ssml)
|
149 |
-
for seg in segments:
|
150 |
-
seg["text"] = text_normalize(seg["text"], is_end=True)
|
151 |
-
|
152 |
-
if len(segments) == 0:
|
153 |
-
raise HTTPException(
|
154 |
-
status_code=422, detail="The SSML text is empty or parsing failed."
|
155 |
-
)
|
156 |
-
|
157 |
-
synthesize = SynthesizeSegments(
|
158 |
-
batch_size=batch_size, eos=eos, spliter_thr=spliter_threshold
|
159 |
-
)
|
160 |
-
audio_segments = synthesize.synthesize_segments(segments)
|
161 |
-
combined_audio = combine_audio_segments(audio_segments)
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
raise HTTPException(
|
166 |
-
status_code=422, detail="Either text or SSML input must be provided."
|
167 |
-
)
|
168 |
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
solver=enhancerConfig.solver,
|
175 |
-
lambd=enhancerConfig.lambd,
|
176 |
-
tau=enhancerConfig.tau,
|
177 |
)
|
178 |
|
179 |
-
|
180 |
-
audio_data,
|
181 |
-
rate=speaking_rate,
|
182 |
-
pitch=pitch,
|
183 |
-
volume=volume_gain_db,
|
184 |
-
sr=sample_rate,
|
185 |
-
)
|
186 |
-
|
187 |
-
buffer = io.BytesIO()
|
188 |
-
sf.write(buffer, audio_data, sample_rate, format="wav")
|
189 |
-
buffer.seek(0)
|
190 |
|
191 |
-
|
192 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
193 |
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
"audioContent": f"data:audio/{audio_format.lower()};base64,{base64_string}"
|
199 |
-
}
|
200 |
|
201 |
except Exception as e:
|
202 |
import logging
|
|
|
1 |
+
from typing import Union
|
|
|
2 |
from fastapi import HTTPException
|
3 |
|
|
|
|
|
4 |
from pydantic import BaseModel
|
5 |
|
6 |
|
|
|
|
|
|
|
|
|
7 |
from modules.api.Api import APIManager
|
8 |
+
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
9 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
10 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
11 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
12 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
13 |
|
14 |
+
from modules.speaker import Speaker, speaker_mgr
|
|
|
15 |
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
from modules.api import utils as api_utils
|
18 |
|
19 |
|
20 |
class SynthesisInput(BaseModel):
|
21 |
+
text: Union[str, None] = None
|
22 |
+
ssml: Union[str, None] = None
|
23 |
|
24 |
|
25 |
class VoiceSelectionParams(BaseModel):
|
|
|
37 |
|
38 |
|
39 |
class AudioConfig(BaseModel):
|
40 |
+
audioEncoding: AudioFormat = AudioFormat.mp3
|
41 |
speakingRate: float = 1
|
42 |
pitch: float = 0
|
43 |
volumeGainDb: float = 0
|
44 |
sampleRateHertz: int = 24000
|
45 |
+
batchSize: int = 4
|
46 |
spliterThreshold: int = 100
|
47 |
|
48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
class GoogleTextSynthesizeRequest(BaseModel):
|
50 |
input: SynthesisInput
|
51 |
voice: VoiceSelectionParams
|
|
|
70 |
voice_name = voice.name
|
71 |
infer_seed = voice.seed or 42
|
72 |
eos = voice.eos or "[uv_break]"
|
73 |
+
audio_format = audioConfig.audioEncoding
|
74 |
+
|
75 |
+
if not isinstance(audio_format, AudioFormat) and isinstance(audio_format, str):
|
76 |
+
audio_format = AudioFormat(audio_format)
|
77 |
+
|
78 |
speaking_rate = audioConfig.speakingRate or 1
|
79 |
pitch = audioConfig.pitch or 0
|
80 |
volume_gain_db = audioConfig.volumeGainDb or 0
|
|
|
83 |
|
84 |
spliter_threshold = audioConfig.spliterThreshold or 100
|
85 |
|
86 |
+
# TODO
|
87 |
sample_rate = audioConfig.sampleRateHertz or 24000
|
88 |
|
89 |
params = api_utils.calc_spk_style(spk=voice.name, style=voice.style)
|
|
|
94 |
status_code=422, detail="The specified voice name is not supported."
|
95 |
)
|
96 |
|
97 |
+
if not isinstance(params.get("spk"), Speaker):
|
98 |
raise HTTPException(
|
99 |
+
status_code=422, detail="The specified voice name is not supported."
|
100 |
)
|
101 |
|
102 |
+
speaker = params.get("spk")
|
103 |
+
tts_config = ChatTTSConfig(
|
104 |
+
style=params.get("style", ""),
|
105 |
+
temperature=voice.temperature,
|
106 |
+
top_k=voice.topK,
|
107 |
+
top_p=voice.topP,
|
108 |
+
)
|
109 |
+
infer_config = InferConfig(
|
110 |
+
batch_size=batch_size,
|
111 |
+
spliter_threshold=spliter_threshold,
|
112 |
+
eos=eos,
|
113 |
+
seed=infer_seed,
|
114 |
+
)
|
115 |
+
adjust_config = AdjustConfig(
|
116 |
+
speaking_rate=speaking_rate,
|
117 |
+
pitch=pitch,
|
118 |
+
volume_gain_db=volume_gain_db,
|
119 |
+
)
|
120 |
+
enhancer_config = enhancerConfig
|
121 |
+
|
122 |
+
mime_type = f"audio/{audio_format.value}"
|
123 |
+
if audio_format == AudioFormat.mp3:
|
124 |
+
mime_type = "audio/mpeg"
|
125 |
try:
|
126 |
if input.text:
|
127 |
+
text_content = input.text
|
128 |
+
|
129 |
+
handler = TTSHandler(
|
130 |
+
text_content=text_content,
|
131 |
+
spk=speaker,
|
132 |
+
tts_config=tts_config,
|
133 |
+
infer_config=infer_config,
|
134 |
+
adjust_config=adjust_config,
|
135 |
+
enhancer_config=enhancer_config,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
)
|
137 |
|
138 |
+
base64_string = handler.enqueue_to_base64(format=audio_format)
|
139 |
+
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
|
141 |
+
elif input.ssml:
|
142 |
+
ssml_content = input.ssml
|
|
|
|
|
|
|
143 |
|
144 |
+
handler = SSMLHandler(
|
145 |
+
ssml_content=ssml_content,
|
146 |
+
infer_config=infer_config,
|
147 |
+
adjust_config=adjust_config,
|
148 |
+
enhancer_config=enhancer_config,
|
|
|
|
|
|
|
149 |
)
|
150 |
|
151 |
+
base64_string = handler.enqueue_to_base64(format=audio_format)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
return {"audioContent": f"data:{mime_type};base64,{base64_string}"}
|
|
|
154 |
|
155 |
+
else:
|
156 |
+
raise HTTPException(
|
157 |
+
status_code=422, detail="Invalid input text or ssml specified."
|
158 |
+
)
|
|
|
|
|
159 |
|
160 |
except Exception as e:
|
161 |
import logging
|
modules/api/impl/handler/AudioHandler.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import base64
|
2 |
+
import io
|
3 |
+
import numpy as np
|
4 |
+
import soundfile as sf
|
5 |
+
|
6 |
+
from modules.api.impl.model.audio_model import AudioFormat
|
7 |
+
from modules.api import utils as api_utils
|
8 |
+
|
9 |
+
|
10 |
+
class AudioHandler:
|
11 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
def enqueue_to_buffer(self, format: AudioFormat) -> io.BytesIO:
|
15 |
+
audio_data, sample_rate = self.enqueue()
|
16 |
+
|
17 |
+
buffer = io.BytesIO()
|
18 |
+
sf.write(buffer, audio_data, sample_rate, format="wav")
|
19 |
+
buffer.seek(0)
|
20 |
+
|
21 |
+
if format == AudioFormat.mp3:
|
22 |
+
buffer = api_utils.wav_to_mp3(buffer)
|
23 |
+
|
24 |
+
return buffer
|
25 |
+
|
26 |
+
def enqueue_to_bytes(self, format: AudioFormat) -> bytes:
|
27 |
+
buffer = self.enqueue_to_buffer(format=format)
|
28 |
+
binary = buffer.read()
|
29 |
+
return binary
|
30 |
+
|
31 |
+
def enqueue_to_base64(self, format: AudioFormat) -> str:
|
32 |
+
binary = self.enqueue_to_bytes(format=format)
|
33 |
+
|
34 |
+
base64_encoded = base64.b64encode(binary)
|
35 |
+
base64_string = base64_encoded.decode("utf-8")
|
36 |
+
|
37 |
+
return base64_string
|
modules/api/impl/handler/SSMLHandler.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import HTTPException
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
5 |
+
from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
|
6 |
+
from modules.api.impl.handler.AudioHandler import AudioHandler
|
7 |
+
from modules.api.impl.model.audio_model import AdjustConfig
|
8 |
+
from modules.api.impl.model.chattts_model import InferConfig
|
9 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
10 |
+
from modules.normalization import text_normalize
|
11 |
+
from modules.ssml_parser.SSMLParser import create_ssml_parser
|
12 |
+
from modules.utils import audio
|
13 |
+
|
14 |
+
|
15 |
+
class SSMLHandler(AudioHandler):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
ssml_content: str,
|
19 |
+
infer_config: InferConfig,
|
20 |
+
adjust_config: AdjustConfig,
|
21 |
+
enhancer_config: EnhancerConfig,
|
22 |
+
) -> None:
|
23 |
+
assert isinstance(ssml_content, str), "ssml_content must be a string."
|
24 |
+
assert isinstance(
|
25 |
+
infer_config, InferConfig
|
26 |
+
), "infer_config must be an InferConfig object."
|
27 |
+
assert isinstance(
|
28 |
+
adjust_config, AdjustConfig
|
29 |
+
), "adjest_config should be AdjustConfig"
|
30 |
+
assert isinstance(
|
31 |
+
enhancer_config, EnhancerConfig
|
32 |
+
), "enhancer_config must be an EnhancerConfig object."
|
33 |
+
|
34 |
+
self.ssml_content = ssml_content
|
35 |
+
self.infer_config = infer_config
|
36 |
+
self.adjest_config = adjust_config
|
37 |
+
self.enhancer_config = enhancer_config
|
38 |
+
|
39 |
+
self.validate()
|
40 |
+
|
41 |
+
def validate(self):
|
42 |
+
# TODO params checker
|
43 |
+
pass
|
44 |
+
|
45 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
46 |
+
ssml_content = self.ssml_content
|
47 |
+
infer_config = self.infer_config
|
48 |
+
adjust_config = self.adjest_config
|
49 |
+
enhancer_config = self.enhancer_config
|
50 |
+
|
51 |
+
parser = create_ssml_parser()
|
52 |
+
segments = parser.parse(ssml_content)
|
53 |
+
for seg in segments:
|
54 |
+
seg["text"] = text_normalize(seg["text"], is_end=True)
|
55 |
+
|
56 |
+
if len(segments) == 0:
|
57 |
+
raise HTTPException(
|
58 |
+
status_code=422, detail="The SSML text is empty or parsing failed."
|
59 |
+
)
|
60 |
+
|
61 |
+
synthesize = SynthesizeSegments(
|
62 |
+
batch_size=infer_config.batch_size,
|
63 |
+
eos=infer_config.eos,
|
64 |
+
spliter_thr=infer_config.spliter_threshold,
|
65 |
+
)
|
66 |
+
audio_segments = synthesize.synthesize_segments(segments)
|
67 |
+
combined_audio = combine_audio_segments(audio_segments)
|
68 |
+
|
69 |
+
sample_rate, audio_data = audio.pydub_to_np(combined_audio)
|
70 |
+
|
71 |
+
if enhancer_config.enabled:
|
72 |
+
nfe = enhancer_config.nfe
|
73 |
+
solver = enhancer_config.solver
|
74 |
+
lambd = enhancer_config.lambd
|
75 |
+
tau = enhancer_config.tau
|
76 |
+
|
77 |
+
audio_data, sample_rate = apply_audio_enhance_full(
|
78 |
+
audio_data=audio_data,
|
79 |
+
sr=sample_rate,
|
80 |
+
nfe=nfe,
|
81 |
+
solver=solver,
|
82 |
+
lambd=lambd,
|
83 |
+
tau=tau,
|
84 |
+
)
|
85 |
+
|
86 |
+
audio_data = audio.apply_prosody_to_audio_data(
|
87 |
+
audio_data=audio_data,
|
88 |
+
rate=adjust_config.speed_rate,
|
89 |
+
pitch=adjust_config.pitch,
|
90 |
+
volume=adjust_config.volume_gain_db,
|
91 |
+
sr=sample_rate,
|
92 |
+
)
|
93 |
+
|
94 |
+
return audio_data, sample_rate
|
modules/api/impl/handler/TTSHandler.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from modules.Enhancer.ResembleEnhance import apply_audio_enhance_full
|
3 |
+
from modules.api.impl.handler.AudioHandler import AudioHandler
|
4 |
+
from modules.api.impl.model.audio_model import AdjustConfig
|
5 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
6 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
7 |
+
from modules.normalization import text_normalize
|
8 |
+
from modules.speaker import Speaker
|
9 |
+
from modules.synthesize_audio import synthesize_audio
|
10 |
+
|
11 |
+
from modules.utils.audio import apply_prosody_to_audio_data
|
12 |
+
|
13 |
+
|
14 |
+
class TTSHandler(AudioHandler):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
text_content: str,
|
18 |
+
spk: Speaker,
|
19 |
+
tts_config: ChatTTSConfig,
|
20 |
+
infer_config: InferConfig,
|
21 |
+
adjust_config: AdjustConfig,
|
22 |
+
enhancer_config: EnhancerConfig,
|
23 |
+
):
|
24 |
+
assert isinstance(text_content, str), "text_content should be str"
|
25 |
+
assert isinstance(spk, Speaker), "spk should be Speaker"
|
26 |
+
assert isinstance(
|
27 |
+
tts_config, ChatTTSConfig
|
28 |
+
), "tts_config should be ChatTTSConfig"
|
29 |
+
assert isinstance(
|
30 |
+
infer_config, InferConfig
|
31 |
+
), "infer_config should be InferConfig"
|
32 |
+
assert isinstance(
|
33 |
+
adjust_config, AdjustConfig
|
34 |
+
), "adjest_config should be AdjustConfig"
|
35 |
+
assert isinstance(
|
36 |
+
enhancer_config, EnhancerConfig
|
37 |
+
), "enhancer_config should be EnhancerConfig"
|
38 |
+
|
39 |
+
self.text_content = text_content
|
40 |
+
self.spk = spk
|
41 |
+
self.tts_config = tts_config
|
42 |
+
self.infer_config = infer_config
|
43 |
+
self.adjest_config = adjust_config
|
44 |
+
self.enhancer_config = enhancer_config
|
45 |
+
|
46 |
+
self.validate()
|
47 |
+
|
48 |
+
def validate(self):
|
49 |
+
# TODO params checker
|
50 |
+
pass
|
51 |
+
|
52 |
+
def enqueue(self) -> tuple[np.ndarray, int]:
|
53 |
+
text = text_normalize(self.text_content)
|
54 |
+
tts_config = self.tts_config
|
55 |
+
infer_config = self.infer_config
|
56 |
+
adjust_config = self.adjest_config
|
57 |
+
enhancer_config = self.enhancer_config
|
58 |
+
|
59 |
+
sample_rate, audio_data = synthesize_audio(
|
60 |
+
text,
|
61 |
+
spk=self.spk,
|
62 |
+
temperature=tts_config.temperature,
|
63 |
+
top_P=tts_config.top_p,
|
64 |
+
top_K=tts_config.top_k,
|
65 |
+
prompt1=tts_config.prompt1,
|
66 |
+
prompt2=tts_config.prompt2,
|
67 |
+
prefix=tts_config.prefix,
|
68 |
+
infer_seed=infer_config.seed,
|
69 |
+
batch_size=infer_config.batch_size,
|
70 |
+
spliter_threshold=infer_config.spliter_threshold,
|
71 |
+
end_of_sentence=infer_config.eos,
|
72 |
+
)
|
73 |
+
|
74 |
+
if enhancer_config.enabled:
|
75 |
+
nfe = enhancer_config.nfe
|
76 |
+
solver = enhancer_config.solver
|
77 |
+
lambd = enhancer_config.lambd
|
78 |
+
tau = enhancer_config.tau
|
79 |
+
|
80 |
+
audio_data, sample_rate = apply_audio_enhance_full(
|
81 |
+
audio_data=audio_data,
|
82 |
+
sr=sample_rate,
|
83 |
+
nfe=nfe,
|
84 |
+
solver=solver,
|
85 |
+
lambd=lambd,
|
86 |
+
tau=tau,
|
87 |
+
)
|
88 |
+
|
89 |
+
audio_data = apply_prosody_to_audio_data(
|
90 |
+
audio_data=audio_data,
|
91 |
+
rate=adjust_config.speed_rate,
|
92 |
+
pitch=adjust_config.pitch,
|
93 |
+
volume=adjust_config.volume_gain_db,
|
94 |
+
sr=sample_rate,
|
95 |
+
)
|
96 |
+
|
97 |
+
return audio_data, sample_rate
|
modules/api/impl/model/audio_model.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
from pydantic import BaseModel
|
4 |
+
|
5 |
+
|
6 |
+
class AudioFormat(str, Enum):
|
7 |
+
mp3 = "mp3"
|
8 |
+
wav = "wav"
|
9 |
+
|
10 |
+
|
11 |
+
class AdjustConfig(BaseModel):
|
12 |
+
pitch: float = 0
|
13 |
+
speed_rate: float = 1
|
14 |
+
volume_gain_db: float = 0
|
modules/api/impl/model/chattts_model.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic import BaseModel
|
2 |
+
|
3 |
+
|
4 |
+
class ChatTTSConfig(BaseModel):
|
5 |
+
style: str = ""
|
6 |
+
temperature: float = 0.3
|
7 |
+
top_p: float = 0.7
|
8 |
+
top_k: int = 20
|
9 |
+
prompt1: str = ""
|
10 |
+
prompt2: str = ""
|
11 |
+
prefix: str = ""
|
12 |
+
|
13 |
+
|
14 |
+
class InferConfig(BaseModel):
|
15 |
+
batch_size: int = 4
|
16 |
+
spliter_threshold: int = 100
|
17 |
+
# end_of_sentence
|
18 |
+
eos: str = "[uv_break]"
|
19 |
+
seed: int = 42
|
modules/api/impl/model/enhancer_model.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
from pydantic import BaseModel
|
3 |
+
|
4 |
+
|
5 |
+
class EnhancerConfig(BaseModel):
|
6 |
+
enabled: bool = False
|
7 |
+
model: str = "resemble-enhance"
|
8 |
+
nfe: int = 32
|
9 |
+
solver: Literal["midpoint", "rk4", "euler"] = "midpoint"
|
10 |
+
lambd: float = 0.5
|
11 |
+
tau: float = 0.5
|
modules/api/impl/openai_api.py
CHANGED
@@ -1,42 +1,38 @@
|
|
1 |
from fastapi import File, Form, HTTPException, Body, UploadFile
|
2 |
-
from fastapi.responses import StreamingResponse
|
3 |
|
4 |
-
import io
|
5 |
from numpy import clip
|
6 |
-
import soundfile as sf
|
7 |
from pydantic import BaseModel, Field
|
8 |
-
from fastapi.responses import
|
9 |
-
|
10 |
|
11 |
-
from modules.synthesize_audio import synthesize_audio
|
12 |
-
from modules.normalization import text_normalize
|
13 |
|
14 |
-
from modules import
|
|
|
|
|
|
|
15 |
|
16 |
|
17 |
-
from typing import List,
|
18 |
-
import pyrubberband as pyrb
|
19 |
|
20 |
from modules.api import utils as api_utils
|
21 |
from modules.api.Api import APIManager
|
22 |
|
23 |
-
from modules.speaker import speaker_mgr
|
24 |
from modules.data import styles_mgr
|
25 |
|
26 |
-
import numpy as np
|
27 |
-
|
28 |
|
29 |
class AudioSpeechRequest(BaseModel):
|
30 |
input: str # 需要合成的文本
|
31 |
model: str = "chattts-4w"
|
32 |
voice: str = "female2"
|
33 |
-
response_format:
|
34 |
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
|
35 |
seed: int = 42
|
|
|
36 |
temperature: float = 0.3
|
|
|
|
|
|
|
37 |
style: str = ""
|
38 |
-
# 是否开启batch合成,小于等于1表示不适用batch
|
39 |
-
# 开启batch合成会自动分割句子
|
40 |
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
|
41 |
spliter_threshold: float = Field(
|
42 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
@@ -44,6 +40,9 @@ class AudioSpeechRequest(BaseModel):
|
|
44 |
# end of sentence
|
45 |
eos: str = "[uv_break]"
|
46 |
|
|
|
|
|
|
|
47 |
|
48 |
async def openai_speech_api(
|
49 |
request: AudioSpeechRequest = Body(
|
@@ -55,7 +54,14 @@ async def openai_speech_api(
|
|
55 |
voice = request.voice
|
56 |
style = request.style
|
57 |
eos = request.eos
|
|
|
|
|
58 |
response_format = request.response_format
|
|
|
|
|
|
|
|
|
|
|
59 |
batch_size = request.batch_size
|
60 |
spliter_threshold = request.spliter_threshold
|
61 |
speed = request.speed
|
@@ -71,49 +77,45 @@ async def openai_speech_api(
|
|
71 |
except:
|
72 |
raise HTTPException(status_code=400, detail="Invalid style.")
|
73 |
|
74 |
-
|
75 |
-
# Normalize the text
|
76 |
-
text = text_normalize(input_text, is_end=True)
|
77 |
-
|
78 |
-
# Calculate speaker and style based on input voice
|
79 |
-
params = api_utils.calc_spk_style(spk=voice, style=style)
|
80 |
-
|
81 |
-
spk = params.get("spk", -1)
|
82 |
-
seed = params.get("seed", request.seed or 42)
|
83 |
-
temperature = params.get("temperature", request.temperature or 0.3)
|
84 |
-
prompt1 = params.get("prompt1", "")
|
85 |
-
prompt2 = params.get("prompt2", "")
|
86 |
-
prefix = params.get("prefix", "")
|
87 |
-
|
88 |
-
# Generate audio
|
89 |
-
sample_rate, audio_data = synthesize_audio(
|
90 |
-
text,
|
91 |
-
temperature=temperature,
|
92 |
-
top_P=0.7,
|
93 |
-
top_K=20,
|
94 |
-
spk=spk,
|
95 |
-
infer_seed=seed,
|
96 |
-
batch_size=batch_size,
|
97 |
-
spliter_threshold=spliter_threshold,
|
98 |
-
prompt1=prompt1,
|
99 |
-
prompt2=prompt2,
|
100 |
-
prefix=prefix,
|
101 |
-
end_of_sentence=eos,
|
102 |
-
)
|
103 |
|
104 |
-
|
105 |
-
|
|
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
|
112 |
-
|
113 |
-
# Convert wav to mp3
|
114 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
115 |
|
116 |
-
|
|
|
|
|
|
|
117 |
|
118 |
except Exception as e:
|
119 |
import logging
|
@@ -150,7 +152,6 @@ class TranscriptionsVerboseResponse(BaseModel):
|
|
150 |
def setup(app: APIManager):
|
151 |
app.post(
|
152 |
"/v1/audio/speech",
|
153 |
-
response_class=FileResponse,
|
154 |
description="""
|
155 |
openai api document:
|
156 |
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
|
|
|
1 |
from fastapi import File, Form, HTTPException, Body, UploadFile
|
|
|
2 |
|
|
|
3 |
from numpy import clip
|
|
|
4 |
from pydantic import BaseModel, Field
|
5 |
+
from fastapi.responses import StreamingResponse
|
|
|
6 |
|
|
|
|
|
7 |
|
8 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
|
13 |
|
14 |
+
from typing import List, Optional
|
|
|
15 |
|
16 |
from modules.api import utils as api_utils
|
17 |
from modules.api.Api import APIManager
|
18 |
|
19 |
+
from modules.speaker import Speaker, speaker_mgr
|
20 |
from modules.data import styles_mgr
|
21 |
|
|
|
|
|
22 |
|
23 |
class AudioSpeechRequest(BaseModel):
|
24 |
input: str # 需要合成的文本
|
25 |
model: str = "chattts-4w"
|
26 |
voice: str = "female2"
|
27 |
+
response_format: AudioFormat = "mp3"
|
28 |
speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
|
29 |
seed: int = 42
|
30 |
+
|
31 |
temperature: float = 0.3
|
32 |
+
top_k: int = 20
|
33 |
+
top_p: float = 0.7
|
34 |
+
|
35 |
style: str = ""
|
|
|
|
|
36 |
batch_size: int = Field(1, ge=1, le=20, description="Batch size")
|
37 |
spliter_threshold: float = Field(
|
38 |
100, ge=10, le=1024, description="Threshold for sentence spliter"
|
|
|
40 |
# end of sentence
|
41 |
eos: str = "[uv_break]"
|
42 |
|
43 |
+
enhance: bool = False
|
44 |
+
denoise: bool = False
|
45 |
+
|
46 |
|
47 |
async def openai_speech_api(
|
48 |
request: AudioSpeechRequest = Body(
|
|
|
54 |
voice = request.voice
|
55 |
style = request.style
|
56 |
eos = request.eos
|
57 |
+
seed = request.seed
|
58 |
+
|
59 |
response_format = request.response_format
|
60 |
+
if not isinstance(response_format, AudioFormat) and isinstance(
|
61 |
+
response_format, str
|
62 |
+
):
|
63 |
+
response_format = AudioFormat(response_format)
|
64 |
+
|
65 |
batch_size = request.batch_size
|
66 |
spliter_threshold = request.spliter_threshold
|
67 |
speed = request.speed
|
|
|
77 |
except:
|
78 |
raise HTTPException(status_code=400, detail="Invalid style.")
|
79 |
|
80 |
+
ctx_params = api_utils.calc_spk_style(spk=voice, style=style)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
+
speaker = ctx_params.get("spk")
|
83 |
+
if not isinstance(speaker, Speaker):
|
84 |
+
raise HTTPException(status_code=400, detail="Invalid voice.")
|
85 |
|
86 |
+
tts_config = ChatTTSConfig(
|
87 |
+
style=style,
|
88 |
+
temperature=request.temperature,
|
89 |
+
top_k=request.top_k,
|
90 |
+
top_p=request.top_p,
|
91 |
+
)
|
92 |
+
infer_config = InferConfig(
|
93 |
+
batch_size=batch_size,
|
94 |
+
spliter_threshold=spliter_threshold,
|
95 |
+
eos=eos,
|
96 |
+
seed=seed,
|
97 |
+
)
|
98 |
+
adjust_config = AdjustConfig(speaking_rate=speed)
|
99 |
+
enhancer_config = EnhancerConfig(
|
100 |
+
enabled=request.enhance or request.denoise or False,
|
101 |
+
lambd=0.9 if request.denoise else 0.1,
|
102 |
+
)
|
103 |
+
try:
|
104 |
+
handler = TTSHandler(
|
105 |
+
text_content=input_text,
|
106 |
+
spk=speaker,
|
107 |
+
tts_config=tts_config,
|
108 |
+
infer_config=infer_config,
|
109 |
+
adjust_config=adjust_config,
|
110 |
+
enhancer_config=enhancer_config,
|
111 |
+
)
|
112 |
|
113 |
+
buffer = handler.enqueue_to_buffer(response_format)
|
|
|
|
|
114 |
|
115 |
+
mime_type = f"audio/{response_format.value}"
|
116 |
+
if response_format == AudioFormat.mp3:
|
117 |
+
mime_type = "audio/mpeg"
|
118 |
+
return StreamingResponse(buffer, media_type=mime_type)
|
119 |
|
120 |
except Exception as e:
|
121 |
import logging
|
|
|
152 |
def setup(app: APIManager):
|
153 |
app.post(
|
154 |
"/v1/audio/speech",
|
|
|
155 |
description="""
|
156 |
openai api document:
|
157 |
[https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
|
modules/api/impl/refiner_api.py
CHANGED
@@ -31,6 +31,7 @@ async def refiner_prompt_post(request: RefineTextRequest):
|
|
31 |
text = request.text
|
32 |
if request.normalize:
|
33 |
text = text_normalize(request.text)
|
|
|
34 |
refined_text = refiner.refine_text(
|
35 |
text=text,
|
36 |
prompt=request.prompt,
|
|
|
31 |
text = request.text
|
32 |
if request.normalize:
|
33 |
text = text_normalize(request.text)
|
34 |
+
# TODO 其实这里可以做 spliter 和 batch 处理
|
35 |
refined_text = refiner.refine_text(
|
36 |
text=text,
|
37 |
prompt=request.prompt,
|
modules/api/impl/ssml_api.py
CHANGED
@@ -1,27 +1,22 @@
|
|
1 |
from fastapi import HTTPException, Body
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
4 |
-
import io
|
5 |
from pydantic import BaseModel
|
6 |
from fastapi.responses import FileResponse
|
7 |
|
8 |
|
9 |
-
from modules.
|
10 |
-
from modules.
|
11 |
-
from modules.
|
12 |
-
|
13 |
-
combine_audio_segments,
|
14 |
-
)
|
15 |
|
16 |
|
17 |
-
from modules.api import utils as api_utils
|
18 |
-
|
19 |
from modules.api.Api import APIManager
|
20 |
|
21 |
|
22 |
class SSMLRequest(BaseModel):
|
23 |
ssml: str
|
24 |
-
format:
|
25 |
|
26 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
27 |
batch_size: int = 4
|
@@ -31,6 +26,9 @@ class SSMLRequest(BaseModel):
|
|
31 |
|
32 |
spliter_thr: int = 100
|
33 |
|
|
|
|
|
|
|
34 |
|
35 |
async def synthesize_ssml_api(
|
36 |
request: SSMLRequest = Body(
|
@@ -43,6 +41,8 @@ async def synthesize_ssml_api(
|
|
43 |
batch_size = request.batch_size
|
44 |
eos = request.eos
|
45 |
spliter_thr = request.spliter_thr
|
|
|
|
|
46 |
|
47 |
if batch_size < 1:
|
48 |
raise HTTPException(
|
@@ -62,22 +62,27 @@ async def synthesize_ssml_api(
|
|
62 |
status_code=400, detail="Format must be 'mp3' or 'wav'."
|
63 |
)
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
synthesize = SynthesizeSegments(
|
71 |
-
batch_size=batch_size, eos=eos, spliter_thr=spliter_thr
|
72 |
)
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
except Exception as e:
|
83 |
import logging
|
|
|
1 |
from fastapi import HTTPException, Body
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
|
|
4 |
from pydantic import BaseModel
|
5 |
from fastapi.responses import FileResponse
|
6 |
|
7 |
|
8 |
+
from modules.api.impl.handler.SSMLHandler import SSMLHandler
|
9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
+
from modules.api.impl.model.chattts_model import InferConfig
|
11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
|
|
|
|
12 |
|
13 |
|
|
|
|
|
14 |
from modules.api.Api import APIManager
|
15 |
|
16 |
|
17 |
class SSMLRequest(BaseModel):
|
18 |
ssml: str
|
19 |
+
format: AudioFormat = "mp3"
|
20 |
|
21 |
# NOTE: 🤔 也许这个值应该配置成系统变量? 传进来有点奇怪
|
22 |
batch_size: int = 4
|
|
|
26 |
|
27 |
spliter_thr: int = 100
|
28 |
|
29 |
+
enhancer: EnhancerConfig = EnhancerConfig()
|
30 |
+
adjuster: AdjustConfig = AdjustConfig()
|
31 |
+
|
32 |
|
33 |
async def synthesize_ssml_api(
|
34 |
request: SSMLRequest = Body(
|
|
|
41 |
batch_size = request.batch_size
|
42 |
eos = request.eos
|
43 |
spliter_thr = request.spliter_thr
|
44 |
+
enhancer = request.enhancer
|
45 |
+
adjuster = request.adjuster
|
46 |
|
47 |
if batch_size < 1:
|
48 |
raise HTTPException(
|
|
|
62 |
status_code=400, detail="Format must be 'mp3' or 'wav'."
|
63 |
)
|
64 |
|
65 |
+
infer_config = InferConfig(
|
66 |
+
batch_size=batch_size,
|
67 |
+
spliter_threshold=spliter_thr,
|
68 |
+
eos=eos,
|
|
|
|
|
|
|
69 |
)
|
70 |
+
adjust_config = adjuster
|
71 |
+
enhancer_config = enhancer
|
72 |
+
|
73 |
+
handler = SSMLHandler(
|
74 |
+
ssml_content=ssml,
|
75 |
+
infer_config=infer_config,
|
76 |
+
adjust_config=adjust_config,
|
77 |
+
enhancer_config=enhancer_config,
|
78 |
+
)
|
79 |
+
|
80 |
+
buffer = handler.enqueue_to_buffer(format=request.format)
|
81 |
+
|
82 |
+
mime_type = f"audio/{format}"
|
83 |
+
if format == AudioFormat.mp3:
|
84 |
+
mime_type = "audio/mpeg"
|
85 |
+
return StreamingResponse(buffer, media_type=mime_type)
|
86 |
|
87 |
except Exception as e:
|
88 |
import logging
|
modules/api/impl/tts_api.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
from fastapi import Depends, HTTPException, Query
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
4 |
-
import io
|
5 |
from pydantic import BaseModel
|
6 |
-
import soundfile as sf
|
7 |
from fastapi.responses import FileResponse
|
8 |
|
9 |
|
10 |
-
from modules.
|
|
|
|
|
|
|
11 |
|
12 |
from modules.api import utils as api_utils
|
13 |
from modules.api.Api import APIManager
|
14 |
-
from modules.
|
15 |
|
16 |
|
17 |
class TTSParams(BaseModel):
|
@@ -23,10 +24,10 @@ class TTSParams(BaseModel):
|
|
23 |
temperature: float = Query(
|
24 |
0.3, description="Temperature for sampling (may be overridden by style or spk)"
|
25 |
)
|
26 |
-
|
27 |
0.5, description="Top P for sampling (may be overridden by style or spk)"
|
28 |
)
|
29 |
-
|
30 |
20, description="Top K for sampling (may be overridden by style or spk)"
|
31 |
)
|
32 |
seed: int = Query(
|
@@ -38,7 +39,14 @@ class TTSParams(BaseModel):
|
|
38 |
prefix: str = Query("", description="Text prefix for inference")
|
39 |
bs: str = Query("8", description="Batch size for inference")
|
40 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
41 |
-
eos: str = Query("", description="End of sentence str")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
|
44 |
async def synthesize_tts(params: TTSParams = Depends()):
|
@@ -55,18 +63,18 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
55 |
status_code=422, detail="Temperature must be between 0 and 1"
|
56 |
)
|
57 |
|
58 |
-
# Validate
|
59 |
-
if not (0 <= params.
|
60 |
-
raise HTTPException(status_code=422, detail="
|
61 |
|
62 |
-
# Validate
|
63 |
-
if params.
|
64 |
raise HTTPException(
|
65 |
-
status_code=422, detail="
|
66 |
)
|
67 |
-
if params.
|
68 |
raise HTTPException(
|
69 |
-
status_code=422, detail="
|
70 |
)
|
71 |
|
72 |
# Validate format
|
@@ -76,11 +84,13 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
76 |
detail="Invalid format. Supported formats are mp3 and wav",
|
77 |
)
|
78 |
|
79 |
-
text = text_normalize(params.text, is_end=False)
|
80 |
-
|
81 |
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
|
82 |
|
83 |
spk = calc_params.get("spk", params.spk)
|
|
|
|
|
|
|
|
|
84 |
seed = params.seed or calc_params.get("seed", params.seed)
|
85 |
temperature = params.temperature or calc_params.get(
|
86 |
"temperature", params.temperature
|
@@ -93,29 +103,46 @@ async def synthesize_tts(params: TTSParams = Depends()):
|
|
93 |
batch_size = int(params.bs)
|
94 |
threshold = int(params.thr)
|
95 |
|
96 |
-
|
97 |
-
|
98 |
temperature=temperature,
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
infer_seed=seed,
|
103 |
prompt1=prompt1,
|
104 |
prompt2=prompt2,
|
105 |
-
|
|
|
106 |
batch_size=batch_size,
|
107 |
spliter_threshold=threshold,
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
109 |
)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
-
|
116 |
-
buffer = api_utils.wav_to_mp3(buffer)
|
117 |
|
118 |
-
|
|
|
|
|
|
|
119 |
|
120 |
except Exception as e:
|
121 |
import logging
|
|
|
1 |
from fastapi import Depends, HTTPException, Query
|
2 |
from fastapi.responses import StreamingResponse
|
3 |
|
|
|
4 |
from pydantic import BaseModel
|
|
|
5 |
from fastapi.responses import FileResponse
|
6 |
|
7 |
|
8 |
+
from modules.api.impl.handler.TTSHandler import TTSHandler
|
9 |
+
from modules.api.impl.model.audio_model import AdjustConfig, AudioFormat
|
10 |
+
from modules.api.impl.model.chattts_model import ChatTTSConfig, InferConfig
|
11 |
+
from modules.api.impl.model.enhancer_model import EnhancerConfig
|
12 |
|
13 |
from modules.api import utils as api_utils
|
14 |
from modules.api.Api import APIManager
|
15 |
+
from modules.speaker import Speaker
|
16 |
|
17 |
|
18 |
class TTSParams(BaseModel):
|
|
|
24 |
temperature: float = Query(
|
25 |
0.3, description="Temperature for sampling (may be overridden by style or spk)"
|
26 |
)
|
27 |
+
top_p: float = Query(
|
28 |
0.5, description="Top P for sampling (may be overridden by style or spk)"
|
29 |
)
|
30 |
+
top_k: int = Query(
|
31 |
20, description="Top K for sampling (may be overridden by style or spk)"
|
32 |
)
|
33 |
seed: int = Query(
|
|
|
39 |
prefix: str = Query("", description="Text prefix for inference")
|
40 |
bs: str = Query("8", description="Batch size for inference")
|
41 |
thr: str = Query("100", description="Threshold for sentence spliter")
|
42 |
+
eos: str = Query("[uv_break]", description="End of sentence str")
|
43 |
+
|
44 |
+
enhance: bool = Query(False, description="Enable enhancer")
|
45 |
+
denoise: bool = Query(False, description="Enable denoiser")
|
46 |
+
|
47 |
+
speed: float = Query(1.0, description="Speed of the audio")
|
48 |
+
pitch: float = Query(0, description="Pitch of the audio")
|
49 |
+
volume_gain: float = Query(0, description="Volume gain of the audio")
|
50 |
|
51 |
|
52 |
async def synthesize_tts(params: TTSParams = Depends()):
|
|
|
63 |
status_code=422, detail="Temperature must be between 0 and 1"
|
64 |
)
|
65 |
|
66 |
+
# Validate top_p
|
67 |
+
if not (0 <= params.top_p <= 1):
|
68 |
+
raise HTTPException(status_code=422, detail="top_p must be between 0 and 1")
|
69 |
|
70 |
+
# Validate top_k
|
71 |
+
if params.top_k <= 0:
|
72 |
raise HTTPException(
|
73 |
+
status_code=422, detail="top_k must be a positive integer"
|
74 |
)
|
75 |
+
if params.top_k > 100:
|
76 |
raise HTTPException(
|
77 |
+
status_code=422, detail="top_k must be less than or equal to 100"
|
78 |
)
|
79 |
|
80 |
# Validate format
|
|
|
84 |
detail="Invalid format. Supported formats are mp3 and wav",
|
85 |
)
|
86 |
|
|
|
|
|
87 |
calc_params = api_utils.calc_spk_style(spk=params.spk, style=params.style)
|
88 |
|
89 |
spk = calc_params.get("spk", params.spk)
|
90 |
+
if not isinstance(spk, Speaker):
|
91 |
+
raise HTTPException(status_code=422, detail="Invalid speaker")
|
92 |
+
|
93 |
+
style = calc_params.get("style", params.style)
|
94 |
seed = params.seed or calc_params.get("seed", params.seed)
|
95 |
temperature = params.temperature or calc_params.get(
|
96 |
"temperature", params.temperature
|
|
|
103 |
batch_size = int(params.bs)
|
104 |
threshold = int(params.thr)
|
105 |
|
106 |
+
tts_config = ChatTTSConfig(
|
107 |
+
style=style,
|
108 |
temperature=temperature,
|
109 |
+
top_k=params.top_k,
|
110 |
+
top_p=params.top_p,
|
111 |
+
prefix=prefix,
|
|
|
112 |
prompt1=prompt1,
|
113 |
prompt2=prompt2,
|
114 |
+
)
|
115 |
+
infer_config = InferConfig(
|
116 |
batch_size=batch_size,
|
117 |
spliter_threshold=threshold,
|
118 |
+
eos=eos,
|
119 |
+
seed=seed,
|
120 |
+
)
|
121 |
+
adjust_config = AdjustConfig(
|
122 |
+
pitch=params.pitch,
|
123 |
+
speed_rate=params.speed,
|
124 |
+
volume_gain_db=params.volume_gain,
|
125 |
+
)
|
126 |
+
enhancer_config = EnhancerConfig(
|
127 |
+
enabled=params.enhance or params.denoise or False,
|
128 |
+
lambd=0.9 if params.denoise else 0.1,
|
129 |
)
|
130 |
|
131 |
+
handler = TTSHandler(
|
132 |
+
text_content=params.text,
|
133 |
+
spk=spk,
|
134 |
+
tts_config=tts_config,
|
135 |
+
infer_config=infer_config,
|
136 |
+
adjust_config=adjust_config,
|
137 |
+
enhancer_config=enhancer_config,
|
138 |
+
)
|
139 |
|
140 |
+
buffer = handler.enqueue_to_buffer(format=AudioFormat(params.format))
|
|
|
141 |
|
142 |
+
media_type = f"audio/{params.format}"
|
143 |
+
if params.format == "mp3":
|
144 |
+
media_type = "audio/mpeg"
|
145 |
+
return StreamingResponse(buffer, media_type=media_type)
|
146 |
|
147 |
except Exception as e:
|
148 |
import logging
|
modules/api/impl/xtts_v2_api.py
CHANGED
@@ -30,8 +30,19 @@ class XTTS_V2_Settings:
|
|
30 |
self.top_k = 20
|
31 |
self.enable_text_splitting = True
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
|
34 |
class TTSSettingsRequest(BaseModel):
|
|
|
35 |
stream_chunk_size: int
|
36 |
temperature: float
|
37 |
speed: float
|
@@ -41,6 +52,15 @@ class TTSSettingsRequest(BaseModel):
|
|
41 |
top_k: int
|
42 |
enable_text_splitting: bool
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
class SynthesisRequest(BaseModel):
|
46 |
text: str
|
@@ -79,17 +99,22 @@ def setup(app: APIManager):
|
|
79 |
|
80 |
text = text_normalize(text, is_end=True)
|
81 |
sample_rate, audio_data = synthesize_audio(
|
82 |
-
|
83 |
-
temperature=XTTSV2.temperature,
|
84 |
# length_penalty=XTTSV2.length_penalty,
|
85 |
# repetition_penalty=XTTSV2.repetition_penalty,
|
|
|
|
|
86 |
top_P=XTTSV2.top_p,
|
87 |
top_K=XTTSV2.top_k,
|
88 |
spk=spk,
|
89 |
-
spliter_threshold=XTTSV2.
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
)
|
94 |
|
95 |
if XTTSV2.speed:
|
@@ -145,6 +170,8 @@ def setup(app: APIManager):
|
|
145 |
)
|
146 |
|
147 |
XTTSV2.stream_chunk_size = request.stream_chunk_size
|
|
|
|
|
148 |
XTTSV2.temperature = request.temperature
|
149 |
XTTSV2.speed = request.speed
|
150 |
XTTSV2.length_penalty = request.length_penalty
|
@@ -152,6 +179,25 @@ def setup(app: APIManager):
|
|
152 |
XTTSV2.top_p = request.top_p
|
153 |
XTTSV2.top_k = request.top_k
|
154 |
XTTSV2.enable_text_splitting = request.enable_text_splitting
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
return {"message": "Settings successfully applied"}
|
156 |
except Exception as e:
|
157 |
if isinstance(e, HTTPException):
|
|
|
30 |
self.top_k = 20
|
31 |
self.enable_text_splitting = True
|
32 |
|
33 |
+
# 下面是额外配置 xtts_v2 中不包含的,但是本系统需要的
|
34 |
+
self.batch_size = 4
|
35 |
+
self.eos = "[uv_break]"
|
36 |
+
self.infer_seed = 42
|
37 |
+
self.use_decoder = True
|
38 |
+
self.prompt1 = ""
|
39 |
+
self.prompt2 = ""
|
40 |
+
self.prefix = ""
|
41 |
+
self.spliter_threshold = 100
|
42 |
+
|
43 |
|
44 |
class TTSSettingsRequest(BaseModel):
|
45 |
+
# 这个 stream_chunk 现在当作 spliter_threshold 用
|
46 |
stream_chunk_size: int
|
47 |
temperature: float
|
48 |
speed: float
|
|
|
52 |
top_k: int
|
53 |
enable_text_splitting: bool
|
54 |
|
55 |
+
batch_size: int = None
|
56 |
+
eos: str = None
|
57 |
+
infer_seed: int = None
|
58 |
+
use_decoder: bool = None
|
59 |
+
prompt1: str = None
|
60 |
+
prompt2: str = None
|
61 |
+
prefix: str = None
|
62 |
+
spliter_threshold: int = None
|
63 |
+
|
64 |
|
65 |
class SynthesisRequest(BaseModel):
|
66 |
text: str
|
|
|
99 |
|
100 |
text = text_normalize(text, is_end=True)
|
101 |
sample_rate, audio_data = synthesize_audio(
|
102 |
+
# TODO: 这两个参数现在用不着...但是其实gpt是可以用的
|
|
|
103 |
# length_penalty=XTTSV2.length_penalty,
|
104 |
# repetition_penalty=XTTSV2.repetition_penalty,
|
105 |
+
text=text,
|
106 |
+
temperature=XTTSV2.temperature,
|
107 |
top_P=XTTSV2.top_p,
|
108 |
top_K=XTTSV2.top_k,
|
109 |
spk=spk,
|
110 |
+
spliter_threshold=XTTSV2.spliter_threshold,
|
111 |
+
batch_size=XTTSV2.batch_size,
|
112 |
+
end_of_sentence=XTTSV2.eos,
|
113 |
+
infer_seed=XTTSV2.infer_seed,
|
114 |
+
use_decoder=XTTSV2.use_decoder,
|
115 |
+
prompt1=XTTSV2.prompt1,
|
116 |
+
prompt2=XTTSV2.prompt2,
|
117 |
+
prefix=XTTSV2.prefix,
|
118 |
)
|
119 |
|
120 |
if XTTSV2.speed:
|
|
|
170 |
)
|
171 |
|
172 |
XTTSV2.stream_chunk_size = request.stream_chunk_size
|
173 |
+
XTTSV2.spliter_threshold = request.stream_chunk_size
|
174 |
+
|
175 |
XTTSV2.temperature = request.temperature
|
176 |
XTTSV2.speed = request.speed
|
177 |
XTTSV2.length_penalty = request.length_penalty
|
|
|
179 |
XTTSV2.top_p = request.top_p
|
180 |
XTTSV2.top_k = request.top_k
|
181 |
XTTSV2.enable_text_splitting = request.enable_text_splitting
|
182 |
+
|
183 |
+
# TODO: checker
|
184 |
+
if request.batch_size:
|
185 |
+
XTTSV2.batch_size = request.batch_size
|
186 |
+
if request.eos:
|
187 |
+
XTTSV2.eos = request.eos
|
188 |
+
if request.infer_seed:
|
189 |
+
XTTSV2.infer_seed = request.infer_seed
|
190 |
+
if request.use_decoder:
|
191 |
+
XTTSV2.use_decoder = request.use_decoder
|
192 |
+
if request.prompt1:
|
193 |
+
XTTSV2.prompt1 = request.prompt1
|
194 |
+
if request.prompt2:
|
195 |
+
XTTSV2.prompt2 = request.prompt2
|
196 |
+
if request.prefix:
|
197 |
+
XTTSV2.prefix = request.prefix
|
198 |
+
if request.spliter_threshold:
|
199 |
+
XTTSV2.spliter_threshold = request.spliter_threshold
|
200 |
+
|
201 |
return {"message": "Settings successfully applied"}
|
202 |
except Exception as e:
|
203 |
if isinstance(e, HTTPException):
|
modules/api/utils.py
CHANGED
@@ -1,9 +1,8 @@
|
|
1 |
from pydantic import BaseModel
|
2 |
from typing import Any, Union
|
3 |
|
4 |
-
import torch
|
5 |
|
6 |
-
from modules.speaker import
|
7 |
|
8 |
|
9 |
from modules.data import styles_mgr
|
@@ -13,18 +12,10 @@ from pydub import AudioSegment
|
|
13 |
from modules.ssml import merge_prompt
|
14 |
|
15 |
|
16 |
-
from enum import Enum
|
17 |
-
|
18 |
-
|
19 |
class ParamsTypeError(Exception):
|
20 |
pass
|
21 |
|
22 |
|
23 |
-
class AudioFormat(str, Enum):
|
24 |
-
mp3 = "mp3"
|
25 |
-
wav = "wav"
|
26 |
-
|
27 |
-
|
28 |
class BaseResponse(BaseModel):
|
29 |
message: str
|
30 |
data: Any
|
@@ -35,7 +26,7 @@ def success_response(data: Any, message: str = "ok") -> BaseResponse:
|
|
35 |
|
36 |
|
37 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
38 |
-
audio = AudioSegment.from_wav(
|
39 |
wav_data,
|
40 |
)
|
41 |
return audio.export(format="mp3", bitrate=bitrate)
|
|
|
1 |
from pydantic import BaseModel
|
2 |
from typing import Any, Union
|
3 |
|
|
|
4 |
|
5 |
+
from modules.speaker import speaker_mgr
|
6 |
|
7 |
|
8 |
from modules.data import styles_mgr
|
|
|
12 |
from modules.ssml import merge_prompt
|
13 |
|
14 |
|
|
|
|
|
|
|
15 |
class ParamsTypeError(Exception):
|
16 |
pass
|
17 |
|
18 |
|
|
|
|
|
|
|
|
|
|
|
19 |
class BaseResponse(BaseModel):
|
20 |
message: str
|
21 |
data: Any
|
|
|
26 |
|
27 |
|
28 |
def wav_to_mp3(wav_data, bitrate="48k"):
|
29 |
+
audio: AudioSegment = AudioSegment.from_wav(
|
30 |
wav_data,
|
31 |
)
|
32 |
return audio.export(format="mp3", bitrate=bitrate)
|
modules/devices/devices.py
CHANGED
@@ -127,6 +127,12 @@ def reset_device():
|
|
127 |
global dtype_gpt
|
128 |
global dtype_decoder
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
if not config.runtime_env_vars.no_half:
|
131 |
dtype = torch.float16
|
132 |
dtype_dvae = torch.float16
|
@@ -144,7 +150,7 @@ def reset_device():
|
|
144 |
|
145 |
logger.info("Using full precision: torch.float32")
|
146 |
|
147 |
-
if config.runtime_env_vars.use_cpu
|
148 |
device = cpu
|
149 |
else:
|
150 |
device = get_optimal_device()
|
|
|
127 |
global dtype_gpt
|
128 |
global dtype_decoder
|
129 |
|
130 |
+
if "all" in config.runtime_env_vars.use_cpu and not config.runtime_env_vars.no_half:
|
131 |
+
logger.warning(
|
132 |
+
"Cannot use half precision with CPU, using full precision instead"
|
133 |
+
)
|
134 |
+
config.runtime_env_vars.no_half = True
|
135 |
+
|
136 |
if not config.runtime_env_vars.no_half:
|
137 |
dtype = torch.float16
|
138 |
dtype_dvae = torch.float16
|
|
|
150 |
|
151 |
logger.info("Using full precision: torch.float32")
|
152 |
|
153 |
+
if "all" in config.runtime_env_vars.use_cpu:
|
154 |
device = cpu
|
155 |
else:
|
156 |
device = get_optimal_device()
|
modules/finetune/train_speaker.py
CHANGED
@@ -45,9 +45,10 @@ def train_speaker_embeddings(
|
|
45 |
)
|
46 |
for speaker in dataset.speakers
|
47 |
}
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
51 |
|
52 |
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
53 |
AUDIO_EOS_TOKEN_ID = 0
|
@@ -166,13 +167,13 @@ def train_speaker_embeddings(
|
|
166 |
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
167 |
)
|
168 |
loss = audio_loss
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
|
177 |
gpt_gen_mel_specs = decoder_decoder(
|
178 |
audio_hidden_states[:, :-1].transpose(1, 2)
|
@@ -181,7 +182,12 @@ def train_speaker_embeddings(
|
|
181 |
loss += 0.01 * mse_loss
|
182 |
|
183 |
optimizer.zero_grad()
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
185 |
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
|
186 |
optimizer.step()
|
187 |
logger.meters["loss"].update(loss.item(), n=batch_size)
|
@@ -203,6 +209,7 @@ if __name__ == "__main__":
|
|
203 |
from modules.speaker import Speaker
|
204 |
|
205 |
config.runtime_env_vars.no_half = True
|
|
|
206 |
devices.reset_device()
|
207 |
|
208 |
parser = argparse.ArgumentParser()
|
|
|
45 |
)
|
46 |
for speaker in dataset.speakers
|
47 |
}
|
48 |
+
|
49 |
+
for speaker_embed in speaker_embeds.values():
|
50 |
+
std, mean = chat.pretrain_models["spk_stat"].chunk(2)
|
51 |
+
speaker_embed.data = speaker_embed.data * std + mean
|
52 |
|
53 |
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]")
|
54 |
AUDIO_EOS_TOKEN_ID = 0
|
|
|
167 |
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2)
|
168 |
)
|
169 |
loss = audio_loss
|
170 |
+
|
171 |
+
text_logits = gpt.head_text(text_hidden_states)
|
172 |
+
text_loss = loss_fn(
|
173 |
+
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1)
|
174 |
+
)
|
175 |
+
loss += text_loss
|
176 |
+
logger.meters["text_loss"].update(text_loss.item(), n=batch_size)
|
177 |
|
178 |
gpt_gen_mel_specs = decoder_decoder(
|
179 |
audio_hidden_states[:, :-1].transpose(1, 2)
|
|
|
182 |
loss += 0.01 * mse_loss
|
183 |
|
184 |
optimizer.zero_grad()
|
185 |
+
|
186 |
+
if train_text:
|
187 |
+
# just for test
|
188 |
+
text_loss.backward()
|
189 |
+
else:
|
190 |
+
loss.backward()
|
191 |
torch.nn.utils.clip_grad_norm_(speaker_embeds.values(), 1.0)
|
192 |
optimizer.step()
|
193 |
logger.meters["loss"].update(loss.item(), n=batch_size)
|
|
|
209 |
from modules.speaker import Speaker
|
210 |
|
211 |
config.runtime_env_vars.no_half = True
|
212 |
+
config.runtime_env_vars.use_cpu = []
|
213 |
devices.reset_device()
|
214 |
|
215 |
parser = argparse.ArgumentParser()
|
modules/prompts/news_oral_prompt.txt
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 任务要求
|
2 |
+
任务: 新闻稿口播化
|
3 |
+
|
4 |
+
你需要将一个新闻稿改写为口语化的口播文本
|
5 |
+
同时,适当的添加一些 附语言 标签为文本增加多样性
|
6 |
+
|
7 |
+
目前可以使用的附语言标签如下:
|
8 |
+
- `[laugh]`: 表示笑声
|
9 |
+
- `[uv_break]`: 表示无声停顿
|
10 |
+
- `[v_break]`: 表示有声停顿,如“嗯”、“啊”等
|
11 |
+
- `[lbreak]`: 表示一个长停顿一般表示段落结束
|
12 |
+
|
13 |
+
# 输入
|
14 |
+
{{USER_INPUT}}
|
modules/prompts/podcast_prompt.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
TODO
|
modules/ssml_parser/SSMLParser.py
CHANGED
@@ -1,13 +1,10 @@
|
|
1 |
from lxml import etree
|
2 |
|
3 |
|
4 |
-
from typing import
|
5 |
import logging
|
6 |
|
7 |
-
from modules.data import styles_mgr
|
8 |
-
from modules.speaker import speaker_mgr
|
9 |
from box import Box
|
10 |
-
import copy
|
11 |
|
12 |
|
13 |
class SSMLContext(Box):
|
|
|
1 |
from lxml import etree
|
2 |
|
3 |
|
4 |
+
from typing import List, Union
|
5 |
import logging
|
6 |
|
|
|
|
|
7 |
from box import Box
|
|
|
8 |
|
9 |
|
10 |
class SSMLContext(Box):
|
modules/webui/speaker/speaker_editor.py
CHANGED
@@ -25,7 +25,7 @@ def speaker_editor_ui():
|
|
25 |
spk: Speaker = Speaker.from_file(spk_file)
|
26 |
spk.name = name
|
27 |
spk.gender = gender
|
28 |
-
spk.
|
29 |
|
30 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
31 |
torch.save(spk, tmp_file)
|
|
|
25 |
spk: Speaker = Speaker.from_file(spk_file)
|
26 |
spk.name = name
|
27 |
spk.gender = gender
|
28 |
+
spk.describe = desc
|
29 |
|
30 |
with tempfile.NamedTemporaryFile(delete=False, suffix=".pt") as tmp_file:
|
31 |
torch.save(spk, tmp_file)
|
modules/webui/speaker/speaker_merger.py
CHANGED
@@ -38,12 +38,8 @@ def merge_spk(
|
|
38 |
tensor_c = spk_to_tensor(spk_c)
|
39 |
tensor_d = spk_to_tensor(spk_d)
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
or tensor_b is not None
|
44 |
-
or tensor_c is not None
|
45 |
-
or tensor_d is not None
|
46 |
-
), "At least one speaker should be selected"
|
47 |
|
48 |
merge_tensor = torch.zeros_like(
|
49 |
tensor_a
|
|
|
38 |
tensor_c = spk_to_tensor(spk_c)
|
39 |
tensor_d = spk_to_tensor(spk_d)
|
40 |
|
41 |
+
if tensor_a is None and tensor_b is None and tensor_c is None and tensor_d is None:
|
42 |
+
raise gr.Error("At least one speaker should be selected")
|
|
|
|
|
|
|
|
|
43 |
|
44 |
merge_tensor = torch.zeros_like(
|
45 |
tensor_a
|