chenjgtea commited on
Commit
5c0140c
1 Parent(s): 8dce793

新增gpu模式下chattts代码

Browse files
Files changed (1) hide show
  1. web/app_gpu.py +76 -51
web/app_gpu.py CHANGED
@@ -1,6 +1,8 @@
1
  import os, sys
2
  import spaces
3
 
 
 
4
  if sys.platform == "darwin":
5
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
  now_dir = os.getcwd()
@@ -158,21 +160,17 @@ def main(args):
158
 
159
  # reload_chat_button.click()
160
 
161
- generate_button.click(fn=get_chat_infer_text,
162
- inputs=[text_input,
163
  text_seed_input,
164
- refine_text_checkBox
165
- ],
166
- outputs=[text_output]
167
- ).then(fn=get_chat_infer_audio,
168
- inputs=[text_output,
169
  temperature_slider,
170
  top_p_slider,
171
  top_k_slider,
172
  audio_seed_input,
173
  spk_emb_text
174
  ],
175
- outputs=[audio_output])
176
  # 初始化 spk_emb_text 数值
177
  spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
178
  logger.info("元素初始化完成,启动gradio服务=======")
@@ -195,13 +193,16 @@ def main(args):
195
  简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
196
  '''
197
  @spaces.GPU
198
- def get_chat_infer_audio(chat_txt,
199
- temperature_slider,
200
- top_p_slider,
201
- top_k_slider,
202
- audio_seed_input,
203
- spk_emb_text):
204
- logger.info("========开始生成音频文件=====")
 
 
 
205
  #音频参数设置
206
  # params_infer_code = Chat2TTS.Chat.InferCodeParams(
207
  # spk_emb=spk_emb_text, # add sampled speaker
@@ -209,45 +210,69 @@ def get_chat_infer_audio(chat_txt,
209
  # top_P=top_p_slider, # top P decode
210
  # top_K=top_k_slider, # top K decode
211
  # )
212
- # torch.manual_seed(audio_seed_input)
213
- # rand_spk = torch.randn(768)
214
- params_infer_code = {
215
- 'spk_emb': None,
216
- 'temperature': temperature_slider,
217
- 'top_P': top_p_slider,
218
- 'top_K': top_k_slider,
219
- }
220
- torch.manual_seed(audio_seed_input)
221
- wav = chat.infer(
222
- text=chat_txt,
223
- skip_refine_text=True, #跳过文本优化
224
- params_infer_code=params_infer_code,
225
- )
226
- yield 24000, float_to_int16(wav[0]).T
227
-
228
- @spaces.GPU
229
- def get_chat_infer_text(text,seed,refine_text_checkBox):
230
-
231
- logger.info("========开始优化文本内容2=====")
232
- global chat
233
  if not refine_text_checkBox:
234
  logger.info("========文本内容无需优化=====")
235
- return text
236
-
237
- # params_refine_text = Chat2TTS.Chat.RefineTextParams(
238
- # prompt='[oral_2][laugh_0][break_6]',
239
- # )
240
-
241
- params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
242
- torch.manual_seed(seed)
243
- chat_text = chat.infer(
244
- text=text,
245
- skip_refine_text=False,
246
- refine_text_only=True, #仅返回优化后文本内容
247
- params_refine_text=params_refine_text,
248
- )
249
 
250
- return chat_text[0] if isinstance(chat_text, list) else chat_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  @spaces.GPU
253
  def on_audio_seed_change(audio_seed_input):
 
1
  import os, sys
2
  import spaces
3
 
4
+ from tool import TorchSeedContext
5
+
6
  if sys.platform == "darwin":
7
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
8
  now_dir = os.getcwd()
 
160
 
161
  # reload_chat_button.click()
162
 
163
+ generate_button.click(fn=get_chat_infer_audio,
164
+ inputs=[text_input,
165
  text_seed_input,
166
+ refine_text_checkBox,
 
 
 
 
167
  temperature_slider,
168
  top_p_slider,
169
  top_k_slider,
170
  audio_seed_input,
171
  spk_emb_text
172
  ],
173
+ outputs=[text_output,audio_output])
174
  # 初始化 spk_emb_text 数值
175
  spk_emb_text.value = on_audio_seed_change(audio_seed_input.value)
176
  logger.info("元素初始化完成,启动gradio服务=======")
 
193
  简而言之,"spk_embedding"关注的是对话参与者的身份特征,而"temperature"是用于调整生成文本不确定性的一个超参数。
194
  '''
195
  @spaces.GPU
196
+ def get_chat_infer_audio(text,
197
+ text_seed_input,
198
+ refine_text_checkBox,
199
+ temperature_slider,
200
+ top_p_slider,
201
+ top_k_slider,
202
+ audio_seed_input,
203
+ spk_emb_text):
204
+
205
+ logger.info("========开始处理TTS模型=====")
206
  #音频参数设置
207
  # params_infer_code = Chat2TTS.Chat.InferCodeParams(
208
  # spk_emb=spk_emb_text, # add sampled speaker
 
210
  # top_P=top_p_slider, # top P decode
211
  # top_K=top_k_slider, # top K decode
212
  # )
213
+ params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
  if not refine_text_checkBox:
215
  logger.info("========文本内容无需优化=====")
216
+ chat_txt=text
217
+ else:
218
+ logger.info("========开始优化文本内容=====")
219
+ #torch.manual_seed(text_seed_input)
220
+ with TorchSeedContext(text_seed_input):
221
+ chat_txt = chat.infer(
222
+ text=text,
223
+ skip_refine_text=False,
224
+ refine_text_only=True, #仅返回优化后文本内容
225
+ params_refine_text=params_refine_text,
226
+ )
 
 
 
227
 
228
+ logger.info("========开始生成音频文件=====")
229
+ #torch.manual_seed(audio_seed_input)
230
+
231
+ with TorchSeedContext(audio_seed_input):
232
+ rand_spk = torch.randn(768)
233
+ params_infer_code = {
234
+ 'spk_emb': rand_spk,
235
+ 'temperature': temperature_slider,
236
+ 'top_P': top_p_slider,
237
+ 'top_K': top_k_slider,
238
+ }
239
+ wav = chat.infer(
240
+ text=chat_txt,
241
+ skip_refine_text=True, #跳过文本优化
242
+ params_refine_text=params_refine_text,
243
+ params_infer_code=params_infer_code,
244
+ )
245
+ #yield 24000, float_to_int16(wav[0]).T
246
+ audio_data = np.array(wav[0]).flatten()
247
+ sample_rate = 24000
248
+ text_data = text[0] if isinstance(text, list) else text
249
+
250
+ return [text_data,(sample_rate, audio_data)]
251
+
252
+
253
+ # @spaces.GPU
254
+ # def get_chat_infer_text(text,seed,refine_text_checkBox):
255
+ #
256
+ # logger.info("========开始优化文本内容2=====")
257
+ # global chat
258
+ # if not refine_text_checkBox:
259
+ # logger.info("========文本内容无需优化=====")
260
+ # return text
261
+ #
262
+ # # params_refine_text = Chat2TTS.Chat.RefineTextParams(
263
+ # # prompt='[oral_2][laugh_0][break_6]',
264
+ # # )
265
+ #
266
+ # params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}
267
+ # torch.manual_seed(seed)
268
+ # chat_text = chat.infer(
269
+ # text=text,
270
+ # skip_refine_text=False,
271
+ # refine_text_only=True, #仅返回优化后文本内容
272
+ # params_refine_text=params_refine_text,
273
+ # )
274
+ #
275
+ # return chat_text[0] if isinstance(chat_text, list) else chat_text
276
 
277
  @spaces.GPU
278
  def on_audio_seed_change(audio_seed_input):