chenjgtea
commited on
Commit
•
1898711
1
Parent(s):
602e4a2
gpu模型下代码更新,已定型
Browse files- Chat2TTS/core.py +21 -21
- web/app_gpu.py +33 -33
Chat2TTS/core.py
CHANGED
@@ -181,27 +181,27 @@ class Chat:
|
|
181 |
return wav
|
182 |
|
183 |
|
184 |
-
def sample_random_speaker(self) -> str:
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
@staticmethod
|
189 |
-
def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
def
|
205 |
|
206 |
with torch.no_grad():
|
207 |
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
|
|
181 |
return wav
|
182 |
|
183 |
|
184 |
+
# def sample_random_speaker(self) -> str:
|
185 |
+
# return self._encode_spk_emb(self.sample_random_speaker_tensor())
|
186 |
+
#
|
187 |
+
#
|
188 |
+
# @staticmethod
|
189 |
+
# def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
|
190 |
+
# with torch.no_grad():
|
191 |
+
# arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
|
192 |
+
# s = b14.encode_to_string(
|
193 |
+
# lzma.compress(
|
194 |
+
# arr.tobytes(),
|
195 |
+
# format=lzma.FORMAT_RAW,
|
196 |
+
# filters=[
|
197 |
+
# {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
|
198 |
+
# ],
|
199 |
+
# ),
|
200 |
+
# )
|
201 |
+
# del arr
|
202 |
+
# return s
|
203 |
+
|
204 |
+
def sample_random_speaker_tensor(self) -> torch.Tensor:
|
205 |
|
206 |
with torch.no_grad():
|
207 |
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
|
web/app_gpu.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import os, sys
|
2 |
import spaces
|
3 |
|
4 |
-
from tool import TorchSeedContext
|
5 |
|
6 |
if sys.platform == "darwin":
|
7 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
@@ -12,6 +11,7 @@ from tool.logger import get_logger
|
|
12 |
from tool.func import *
|
13 |
from tool.np import *
|
14 |
from tool.gpu import select_device
|
|
|
15 |
import Chat2TTS
|
16 |
import argparse
|
17 |
import torch._dynamo
|
@@ -116,18 +116,18 @@ def main(args):
|
|
116 |
)
|
117 |
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
|
118 |
|
119 |
-
with gr.Row():
|
120 |
-
spk_emb_text = gr.Textbox(
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
)
|
128 |
-
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
129 |
|
130 |
with gr.Row():
|
|
|
131 |
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
132 |
|
133 |
with gr.Row():
|
@@ -152,7 +152,7 @@ def main(args):
|
|
152 |
# 针对页面元素新增 监听事件
|
153 |
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
|
154 |
|
155 |
-
audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
|
156 |
|
157 |
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
|
158 |
|
@@ -160,19 +160,18 @@ def main(args):
|
|
160 |
|
161 |
# reload_chat_button.click()
|
162 |
|
163 |
-
generate_button.click(fn=
|
164 |
-
|
165 |
text_seed_input,
|
166 |
refine_text_checkBox,
|
167 |
temperature_slider,
|
168 |
top_p_slider,
|
169 |
top_k_slider,
|
170 |
-
audio_seed_input
|
171 |
-
spk_emb_text
|
172 |
],
|
173 |
-
|
174 |
# 初始化 spk_emb_text 数值
|
175 |
-
spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
|
176 |
logger.info("元素初始化完成,启动gradio服务=======")
|
177 |
|
178 |
# 运行gradio服务
|
@@ -193,14 +192,13 @@ def main(args):
|
|
193 |
简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
|
194 |
'''
|
195 |
@spaces.GPU
|
196 |
-
def
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
spk_emb_text):
|
204 |
|
205 |
logger.info("========开始处理TTS模型=====")
|
206 |
#音频参数设置
|
@@ -229,7 +227,9 @@ def get_chat_infer_audio(text,
|
|
229 |
#torch.manual_seed(audio_seed_input)
|
230 |
|
231 |
with TorchSeedContext(audio_seed_input):
|
232 |
-
rand_spk = torch.randn(768)
|
|
|
|
|
233 |
params_infer_code = {
|
234 |
'spk_emb': rand_spk,
|
235 |
'temperature': temperature_slider,
|
@@ -274,12 +274,12 @@ def get_chat_infer_audio(text,
|
|
274 |
#
|
275 |
# return chat_text[0] if isinstance(chat_text, list) else chat_text
|
276 |
|
277 |
-
|
278 |
-
def on_audio_seed_change(audio_seed_input):
|
279 |
-
|
280 |
-
|
281 |
-
|
282 |
-
|
283 |
# rand_spk = torch.randn(audio_seed_input)
|
284 |
# return encode_spk_emb(rand_spk)
|
285 |
|
|
|
1 |
import os, sys
|
2 |
import spaces
|
3 |
|
|
|
4 |
|
5 |
if sys.platform == "darwin":
|
6 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
|
11 |
from tool.func import *
|
12 |
from tool.np import *
|
13 |
from tool.gpu import select_device
|
14 |
+
from tool.ctx import TorchSeedContext
|
15 |
import Chat2TTS
|
16 |
import argparse
|
17 |
import torch._dynamo
|
|
|
116 |
)
|
117 |
generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
|
118 |
|
119 |
+
# with gr.Row():
|
120 |
+
# spk_emb_text = gr.Textbox(
|
121 |
+
# label="Speaker Embedding",
|
122 |
+
# max_lines=3,
|
123 |
+
# show_copy_button=True,
|
124 |
+
# interactive=False,
|
125 |
+
# scale=2,
|
126 |
+
#
|
127 |
+
# )
|
|
|
128 |
|
129 |
with gr.Row():
|
130 |
+
reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
|
131 |
generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
|
132 |
|
133 |
with gr.Row():
|
|
|
152 |
# 针对页面元素新增 监听事件
|
153 |
voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
|
154 |
|
155 |
+
#audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
|
156 |
|
157 |
generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
|
158 |
|
|
|
160 |
|
161 |
# reload_chat_button.click()
|
162 |
|
163 |
+
generate_button.click(fn=general_chat_infer_audio,
|
164 |
+
inputs=[text_input,
|
165 |
text_seed_input,
|
166 |
refine_text_checkBox,
|
167 |
temperature_slider,
|
168 |
top_p_slider,
|
169 |
top_k_slider,
|
170 |
+
audio_seed_input
|
|
|
171 |
],
|
172 |
+
outputs=[text_output,audio_output])
|
173 |
# 初始化 spk_emb_text 数值
|
174 |
+
#spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
|
175 |
logger.info("元素初始化完成,启动gradio服务=======")
|
176 |
|
177 |
# 运行gradio服务
|
|
|
192 |
简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
|
193 |
'''
|
194 |
@spaces.GPU
|
195 |
+
def general_chat_infer_audio(text,
|
196 |
+
text_seed_input,
|
197 |
+
refine_text_checkBox,
|
198 |
+
temperature_slider,
|
199 |
+
top_p_slider,
|
200 |
+
top_k_slider,
|
201 |
+
audio_seed_input):
|
|
|
202 |
|
203 |
logger.info("========开始处理TTS模型=====")
|
204 |
#音频参数设置
|
|
|
227 |
#torch.manual_seed(audio_seed_input)
|
228 |
|
229 |
with TorchSeedContext(audio_seed_input):
|
230 |
+
#rand_spk = torch.randn(768)
|
231 |
+
rand_spk = chat.sample_random_speaker_tensor()
|
232 |
+
logger.info("========生成音频spk_emb参数完成=====")
|
233 |
params_infer_code = {
|
234 |
'spk_emb': rand_spk,
|
235 |
'temperature': temperature_slider,
|
|
|
274 |
#
|
275 |
# return chat_text[0] if isinstance(chat_text, list) else chat_text
|
276 |
|
277 |
+
#@spaces.GPU
|
278 |
+
# def on_audio_seed_change(audio_seed_input):
|
279 |
+
# global chat
|
280 |
+
# torch.manual_seed(audio_seed_input)
|
281 |
+
# rand_spk = chat.sample_random_speaker()
|
282 |
+
# return rand_spk
|
283 |
# rand_spk = torch.randn(audio_seed_input)
|
284 |
# return encode_spk_emb(rand_spk)
|
285 |
|