candlend commited on
Commit
1f55a13
1 Parent(s): 7eaf36d

remove choice model

Browse files
Files changed (1) hide show
  1. vits/tts_inferencer.py +17 -37
vits/tts_inferencer.py CHANGED
@@ -33,44 +33,33 @@ def get_text(text, hps):
33
 
34
  class TTSInferencer:
35
  def __init__(self, hps_path, device="cpu"):
 
36
  self.device = torch.device(device)
37
  self.hps = utils.get_hparams_from_file(hps_path)
38
- self.select_mode(default_mode)
39
- self.load_model(self.latest_model_path)
 
 
 
40
 
41
- def select_mode(self, mode):
42
- self.mode = mode
43
- self.model_dir_path = os.path.join(ROOT_PATH, "models", mode_dict[mode])
44
- self.models = []
45
- for f in os.listdir(self.model_dir_path):
46
- if (f.startswith("D_")):
47
- continue
48
- if (f.endswith(".pth")):
49
- self.models.append(f)
50
- self.latest_model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
51
 
52
- def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
53
  stn_tst = get_text(text, self.hps)
54
  with torch.no_grad():
55
  x_tst = stn_tst.unsqueeze(0).to(self.device)
56
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.device)
57
- audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
58
  return (self.hps.data.sampling_rate, audio)
59
 
60
  def change_mode(self, mode):
61
  self.select_mode(mode)
62
  return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
63
-
64
- def change_model(self, model_file_name):
65
- self.load_model(os.path.join(self.model_dir_path, model_file_name))
66
- return f"载入模型:{model_file_name}({self.mode})"
67
 
68
  def render(self):
69
  choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
70
- choice_model = gr.Dropdown(choices=self.models, label=f"模型迭代版本选择", value=os.path.basename(self.pth_path))
71
- # with gr.Row():
72
- # advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
73
- # default = gr.Button("恢复默认设置").style(full_width=False)
74
  noise_scale = gr.Slider(minimum=0, maximum=3, value=default_noise_scale, step=0.001, label="noise_scale(效果不可控,谨慎修改)")
75
  noise_scale_w = gr.Slider(minimum=0, maximum=3, value=default_noise_scale_w, step=0.001, label="noise_scale_w(效果不可控,谨慎修改)")
76
  length_scale = gr.Slider(minimum=0, maximum=3, value=default_length_scale, step=0.001, label="length_scale(数值越大输出音频越长)")
@@ -80,7 +69,6 @@ class TTSInferencer:
80
  value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
81
  tts_submit = gr.Button("合成", variant="primary")
82
  tts_output = gr.Audio(label="Output")
83
- tts_model = gr.Markdown(f"载入模型:{os.path.basename(self.latest_model_path)}({self.mode})")
84
  gr.HTML('''
85
  <div style="text-align:right;font-size:12px;color:#4D4D4D">
86
  <div class="font-medium">版权声明</div>
@@ -88,22 +76,14 @@ class TTSInferencer:
88
  <div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
89
  </div>
90
  ''')
91
- # advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale)
92
- # advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale_w)
93
- # default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
94
- # default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
95
- # default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
96
- choice_mode.change(self.change_mode, inputs=choice_mode, outputs=choice_model)
97
- choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
98
- tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
99
-
100
 
101
- def load_model(self, model_path):
102
- self.pth_path = model_path
103
- self.net_g = SynthesizerTrn(
104
  len(symbols),
105
  self.hps.data.filter_length // 2 + 1,
106
  self.hps.train.segment_size // self.hps.data.hop_length,
107
  **self.hps.model).to(self.device)
108
- _ = self.net_g.eval()
109
- _ = utils.load_checkpoint(self.pth_path, self.net_g, None)
 
33
 
34
  class TTSInferencer:
35
  def __init__(self, hps_path, device="cpu"):
36
+ print("init")
37
  self.device = torch.device(device)
38
  self.hps = utils.get_hparams_from_file(hps_path)
39
+ self.model_paths = {}
40
+ self.models = {}
41
+ for mode in mode_dict:
42
+ self.model_paths[mode] = self.get_latest_model_path_by_mode(mode)
43
+ self.load_models()
44
 
45
+ def get_latest_model_path_by_mode(self, mode):
46
+ model_dir_path = os.path.join(ROOT_PATH, "models", mode_dict[mode])
47
+ return utils.latest_checkpoint_path(model_dir_path, "G_*.pth")
 
 
 
 
 
 
 
48
 
49
+ def infer(self, text, mode, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
50
  stn_tst = get_text(text, self.hps)
51
  with torch.no_grad():
52
  x_tst = stn_tst.unsqueeze(0).to(self.device)
53
  x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.device)
54
+ audio = self.models[mode].infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
55
  return (self.hps.data.sampling_rate, audio)
56
 
57
  def change_mode(self, mode):
58
  self.select_mode(mode)
59
  return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
 
 
 
 
60
 
61
  def render(self):
62
  choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
 
 
 
 
63
  noise_scale = gr.Slider(minimum=0, maximum=3, value=default_noise_scale, step=0.001, label="noise_scale(效果不可控,谨慎修改)")
64
  noise_scale_w = gr.Slider(minimum=0, maximum=3, value=default_noise_scale_w, step=0.001, label="noise_scale_w(效果不可控,谨慎修改)")
65
  length_scale = gr.Slider(minimum=0, maximum=3, value=default_length_scale, step=0.001, label="length_scale(数值越大输出音频越长)")
 
69
  value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
70
  tts_submit = gr.Button("合成", variant="primary")
71
  tts_output = gr.Audio(label="Output")
 
72
  gr.HTML('''
73
  <div style="text-align:right;font-size:12px;color:#4D4D4D">
74
  <div class="font-medium">版权声明</div>
 
76
  <div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
77
  </div>
78
  ''')
79
+ tts_submit.click(self.infer, [tts_input, choice_mode, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
 
 
 
 
 
 
 
 
80
 
81
+ def load_models(self):
82
+ for key, model_path in self.model_paths.items():
83
+ self.models[key] = SynthesizerTrn(
84
  len(symbols),
85
  self.hps.data.filter_length // 2 + 1,
86
  self.hps.train.segment_size // self.hps.data.hop_length,
87
  **self.hps.model).to(self.device)
88
+ _ = self.models[key].eval()
89
+ _ = utils.load_checkpoint(model_path, self.models[key], None)