chenjgtea commited on
Commit
1898711
1 Parent(s): 602e4a2

gpu模型下代码更新,已定型

Browse files
Files changed (2) hide show
  1. Chat2TTS/core.py +21 -21
  2. web/app_gpu.py +33 -33
Chat2TTS/core.py CHANGED
@@ -181,27 +181,27 @@ class Chat:
181
  return wav
182
 
183
 
184
- def sample_random_speaker(self) -> str:
185
- return self._encode_spk_emb(self._sample_random_speaker())
186
-
187
-
188
- @staticmethod
189
- def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
190
- with torch.no_grad():
191
- arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
192
- s = b14.encode_to_string(
193
- lzma.compress(
194
- arr.tobytes(),
195
- format=lzma.FORMAT_RAW,
196
- filters=[
197
- {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
198
- ],
199
- ),
200
- )
201
- del arr
202
- return s
203
-
204
- def _sample_random_speaker(self) -> torch.Tensor:
205
 
206
  with torch.no_grad():
207
  dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
 
181
  return wav
182
 
183
 
184
+ # def sample_random_speaker(self) -> str:
185
+ # return self._encode_spk_emb(self.sample_random_speaker_tensor())
186
+ #
187
+ #
188
+ # @staticmethod
189
+ # def _encode_spk_emb(spk_emb: torch.Tensor) -> str:
190
+ # with torch.no_grad():
191
+ # arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy()
192
+ # s = b14.encode_to_string(
193
+ # lzma.compress(
194
+ # arr.tobytes(),
195
+ # format=lzma.FORMAT_RAW,
196
+ # filters=[
197
+ # {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}
198
+ # ],
199
+ # ),
200
+ # )
201
+ # del arr
202
+ # return s
203
+
204
+ def sample_random_speaker_tensor(self) -> torch.Tensor:
205
 
206
  with torch.no_grad():
207
  dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
web/app_gpu.py CHANGED
@@ -1,7 +1,6 @@
1
  import os, sys
2
  import spaces
3
 
4
- from tool import TorchSeedContext
5
 
6
  if sys.platform == "darwin":
7
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
@@ -12,6 +11,7 @@ from tool.logger import get_logger
12
  from tool.func import *
13
  from tool.np import *
14
  from tool.gpu import select_device
 
15
  import Chat2TTS
16
  import argparse
17
  import torch._dynamo
@@ -116,18 +116,18 @@ def main(args):
116
  )
117
  generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
118
 
119
- with gr.Row():
120
- spk_emb_text = gr.Textbox(
121
- label="Speaker Embedding",
122
- max_lines=3,
123
- show_copy_button=True,
124
- interactive=False,
125
- scale=2,
126
-
127
- )
128
- reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
129
 
130
  with gr.Row():
 
131
  generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
132
 
133
  with gr.Row():
@@ -152,7 +152,7 @@ def main(args):
152
  # 针对页面元素新增 监听事件
153
  voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
154
 
155
- audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
156
 
157
  generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
158
 
@@ -160,19 +160,18 @@ def main(args):
160
 
161
  # reload_chat_button.click()
162
 
163
- generate_button.click(fn=get_chat_infer_audio,
164
- inputs=[text_input,
165
  text_seed_input,
166
  refine_text_checkBox,
167
  temperature_slider,
168
  top_p_slider,
169
  top_k_slider,
170
- audio_seed_input,
171
- spk_emb_text
172
  ],
173
- outputs=[text_output,audio_output])
174
  # 初始化 spk_emb_text 数值
175
- spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
176
  logger.info("元素初始化完成,启动gradio服务=======")
177
 
178
  # 运行gradio服务
@@ -193,14 +192,13 @@ def main(args):
193
  简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
194
  '''
195
  @spaces.GPU
196
- def get_chat_infer_audio(text,
197
- text_seed_input,
198
- refine_text_checkBox,
199
- temperature_slider,
200
- top_p_slider,
201
- top_k_slider,
202
- audio_seed_input,
203
- spk_emb_text):
204
 
205
  logger.info("========开始处理TTS模型=====")
206
  #音频参数设置
@@ -229,7 +227,9 @@ def get_chat_infer_audio(text,
229
  #torch.manual_seed(audio_seed_input)
230
 
231
  with TorchSeedContext(audio_seed_input):
232
- rand_spk = torch.randn(768)
 
 
233
  params_infer_code = {
234
  'spk_emb': rand_spk,
235
  'temperature': temperature_slider,
@@ -274,12 +274,12 @@ def get_chat_infer_audio(text,
274
  #
275
  # return chat_text[0] if isinstance(chat_text, list) else chat_text
276
 
277
- @spaces.GPU
278
- def on_audio_seed_change(audio_seed_input):
279
- global chat
280
- torch.manual_seed(audio_seed_input)
281
- rand_spk = chat.sample_random_speaker()
282
- return rand_spk
283
  # rand_spk = torch.randn(audio_seed_input)
284
  # return encode_spk_emb(rand_spk)
285
 
 
1
  import os, sys
2
  import spaces
3
 
 
4
 
5
  if sys.platform == "darwin":
6
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
 
11
  from tool.func import *
12
  from tool.np import *
13
  from tool.gpu import select_device
14
+ from tool.ctx import TorchSeedContext
15
  import Chat2TTS
16
  import argparse
17
  import torch._dynamo
 
116
  )
117
  generate_text_seed = gr.Button("随机生成文本种子", interactive=True)
118
 
119
+ # with gr.Row():
120
+ # spk_emb_text = gr.Textbox(
121
+ # label="Speaker Embedding",
122
+ # max_lines=3,
123
+ # show_copy_button=True,
124
+ # interactive=False,
125
+ # scale=2,
126
+ #
127
+ # )
 
128
 
129
  with gr.Row():
130
+ reload_chat_button = gr.Button("Reload", scale=1, interactive=True)
131
  generate_button = gr.Button("生成音频文件", scale=1, interactive=True)
132
 
133
  with gr.Row():
 
152
  # 针对页面元素新增 监听事件
153
  voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input)
154
 
155
+ #audio_seed_input.change(fn=on_audio_seed_change, inputs=audio_seed_input, outputs=spk_emb_text)
156
 
157
  generate_audio_seed.click(fn=generate_seed, outputs=audio_seed_input)
158
 
 
160
 
161
  # reload_chat_button.click()
162
 
163
+ generate_button.click(fn=general_chat_infer_audio,
164
+ inputs=[text_input,
165
  text_seed_input,
166
  refine_text_checkBox,
167
  temperature_slider,
168
  top_p_slider,
169
  top_k_slider,
170
+ audio_seed_input
 
171
  ],
172
+ outputs=[text_output,audio_output])
173
  # 初始化 spk_emb_text 数值
174
+ #spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
175
  logger.info("元素初始化完成,启动gradio服务=======")
176
 
177
  # 运行gradio服务
 
192
  简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
193
  '''
194
  @spaces.GPU
195
+ def general_chat_infer_audio(text,
196
+ text_seed_input,
197
+ refine_text_checkBox,
198
+ temperature_slider,
199
+ top_p_slider,
200
+ top_k_slider,
201
+ audio_seed_input):
 
202
 
203
  logger.info("========开始处理TTS模型=====")
204
  #音频参数设置
 
227
  #torch.manual_seed(audio_seed_input)
228
 
229
  with TorchSeedContext(audio_seed_input):
230
+ #rand_spk = torch.randn(768)
231
+ rand_spk = chat.sample_random_speaker_tensor()
232
+ logger.info("========生成音频spk_emb参数完成=====")
233
  params_infer_code = {
234
  'spk_emb': rand_spk,
235
  'temperature': temperature_slider,
 
274
  #
275
  # return chat_text[0] if isinstance(chat_text, list) else chat_text
276
 
277
+ #@spaces.GPU
278
+ # def on_audio_seed_change(audio_seed_input):
279
+ # global chat
280
+ # torch.manual_seed(audio_seed_input)
281
+ # rand_spk = chat.sample_random_speaker()
282
+ # return rand_spk
283
  # rand_spk = torch.randn(audio_seed_input)
284
  # return encode_spk_emb(rand_spk)
285