ikechan8370 commited on
Commit
b772f7c
·
1 Parent(s): 7e90749

feat: add support for gpu

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -1,4 +1,3 @@
1
- # coding=utf-8
2
  import time
3
  import gradio as gr
4
  import utils
@@ -6,14 +5,16 @@ import commons
6
  from models import SynthesizerTrn
7
  from text import text_to_sequence
8
  from torch import no_grad, LongTensor
 
9
 
10
  hps_ms = utils.get_hparams_from_file(r'./model/config.json')
 
11
  net_g_ms = SynthesizerTrn(
12
  len(hps_ms.symbols),
13
  hps_ms.data.filter_length // 2 + 1,
14
  hps_ms.train.segment_size // hps_ms.data.hop_length,
15
  n_speakers=hps_ms.data.n_speakers,
16
- **hps_ms.model)
17
  _ = net_g_ms.eval()
18
  speakers = hps_ms.speakers
19
  model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
@@ -30,7 +31,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
30
  if not len(text):
31
  return "输入文本不能为空!", None, None
32
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
33
- if len(text) > 300:
34
  return f"输入文字过长!{len(text)}>100", None, None
35
  if language == 0:
36
  text = f"[ZH]{text}[ZH]"
@@ -44,7 +45,7 @@ def vits(text, language, speaker_id, noise_scale, noise_scale_w, length_scale):
44
  x_tst_lengths = LongTensor([stn_tst.size(0)])
45
  speaker_id = LongTensor([speaker_id])
46
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
47
- length_scale=length_scale)[0][0, 0].data.float().numpy()
48
 
49
  return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
50
 
@@ -116,8 +117,8 @@ if __name__ == '__main__':
116
  download = gr.Button("Download Audio")
117
  btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
118
  download.click(None, [], [], _js=download_audio_js.format())
119
- btn2.click(search_speaker, inputs=[search], outputs=[sid], api_name="search_speaker")
120
- lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls], api_name="fuck")
121
  with gr.TabItem("可用人物一览"):
122
  gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
123
- app.queue(concurrency_count=1).launch()
 
 
1
  import time
2
  import gradio as gr
3
  import utils
 
5
  from models import SynthesizerTrn
6
  from text import text_to_sequence
7
  from torch import no_grad, LongTensor
8
+ import torch
9
 
10
  hps_ms = utils.get_hparams_from_file(r'./model/config.json')
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
  net_g_ms = SynthesizerTrn(
13
  len(hps_ms.symbols),
14
  hps_ms.data.filter_length // 2 + 1,
15
  hps_ms.train.segment_size // hps_ms.data.hop_length,
16
  n_speakers=hps_ms.data.n_speakers,
17
+ **hps_ms.model).to(device)
18
  _ = net_g_ms.eval()
19
  speakers = hps_ms.speakers
20
  model, optimizer, learning_rate, epochs = utils.load_checkpoint(r'./model/G_953000.pth', net_g_ms, None)
 
31
  if not len(text):
32
  return "输入文本不能为空!", None, None
33
  text = text.replace('\n', ' ').replace('\r', '').replace(" ", "")
34
+ if len(text) > 500:
35
  return f"输入文字过长!{len(text)}>100", None, None
36
  if language == 0:
37
  text = f"[ZH]{text}[ZH]"
 
45
  x_tst_lengths = LongTensor([stn_tst.size(0)])
46
  speaker_id = LongTensor([speaker_id])
47
  audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w,
48
+ length_scale=length_scale)[0][0, 0].data.cpu().float().numpy()
49
 
50
  return "生成成功!", (22050, audio), f"生成耗时 {round(time.perf_counter()-start, 2)} s"
51
 
 
117
  download = gr.Button("Download Audio")
118
  btn.click(vits, inputs=[input_text, lang, sid, ns, nsw, ls], outputs=[o1, o2, o3], api_name="generate")
119
  download.click(None, [], [], _js=download_audio_js.format())
120
+ btn2.click(search_speaker, inputs=[search], outputs=[sid])
121
+ lang.change(change_lang, inputs=[lang], outputs=[ns, nsw, ls])
122
  with gr.TabItem("可用人物一览"):
123
  gr.Radio(label="Speaker", choices=speakers, interactive=False, type="index")
124
+ app.queue(concurrency_count=1).launch()