candlend commited on
Commit
b48b8c0
1 Parent(s): 7a669c0

tts_inferencer

Browse files
Files changed (2) hide show
  1. app.py +5 -19
  2. inferencer.py → tts_inferencer.py +30 -18
app.py CHANGED
@@ -1,27 +1,13 @@
1
- import os
2
- import commons
3
- import utils
4
- from models import SynthesizerTrn
5
- from text.symbols import symbols
6
- from text import text_to_sequence
7
  import gradio as gr
8
- from inferencer import Inferencer
 
 
9
 
10
  app = gr.Blocks()
11
  with app:
12
  with open("header.html", "r") as f:
13
  gr.HTML(f.read())
14
  with gr.Tabs():
15
- with gr.TabItem("普通声线"):
16
- normal_description = """
17
- 使用星弥Hoshimi录播音频作为数据集训练而成
18
- """
19
- normal_inferencer = Inferencer("normal", "./configs/hoshimi_base.json", description=normal_description)
20
- normal_inferencer.render()
21
- with gr.TabItem("营业声线"):
22
- formal_description = """
23
- 使用星弥Hoshimi音声作为数据集训练而成
24
- """
25
- formal_inferencer = Inferencer("formal", "./configs/hoshimi_base.json", description=formal_description)
26
- formal_inferencer.render()
27
  app.launch()
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from tts_inferencer import TTSInferencer
3
+
4
+ tts_inferencer = TTSInferencer("./configs/hoshimi_base.json")
5
 
6
  app = gr.Blocks()
7
  with app:
8
  with open("header.html", "r") as f:
9
  gr.HTML(f.read())
10
  with gr.Tabs():
11
+ with gr.TabItem("语音合成"):
12
+ tts_inferencer.render()
 
 
 
 
 
 
 
 
 
 
13
  app.launch()
inferencer.py → tts_inferencer.py RENAMED
@@ -13,6 +13,12 @@ from text.symbols import symbols
13
  from text import text_to_sequence
14
  import gradio as gr
15
 
 
 
 
 
 
 
16
  default_noise_scale = 0.667
17
  default_noise_scale_w = 0.8
18
  default_length_scale = 1
@@ -24,21 +30,23 @@ def get_text(text, hps):
24
  text_norm = torch.LongTensor(text_norm)
25
  return text_norm
26
 
27
- class Inferencer:
28
- def __init__(self, mode, hps_path, description):
 
 
 
 
 
 
29
  self.mode = mode
30
- self.description = description
31
  self.models = []
32
- self.model_dir_path = os.path.join("models", mode)
33
  for f in os.listdir(self.model_dir_path):
34
  if (f.startswith("D_")):
35
  continue
36
  if (f.endswith(".pth")):
37
  self.models.append(f)
38
- self.device = torch.device("cpu")
39
- self.hps = utils.get_hparams_from_file(hps_path)
40
- model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
41
- self.load_model(model_path)
42
 
43
  def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
44
  stn_tst = get_text(text, self.hps)
@@ -48,12 +56,17 @@ class Inferencer:
48
  audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
49
  return (self.hps.data.sampling_rate, audio)
50
 
 
 
 
 
51
  def change_model(self, model_file_name):
52
  self.load_model(os.path.join(self.model_dir_path, model_file_name))
53
- return "载入模型:" + model_file_name
54
 
55
  def render(self):
56
- choice_model = gr.Dropdown(choices=self.models, label=f"模型迭代版本选择({self.description})", value=os.path.basename(self.pth_path))
 
57
  with gr.Row():
58
  advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
59
  default = gr.Button("恢复默认设置").style(full_width=False)
@@ -74,15 +87,14 @@ class Inferencer:
74
  <div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
75
  </div>
76
  ''')
77
- for component in [noise_scale, noise_scale_w]:
78
- advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=component)
79
- for component, default_value in [
80
- (noise_scale, default_noise_scale),
81
- (noise_scale_w, default_noise_scale_w),
82
- (length_scale, default_length_scale)]:
83
- default.click(fn=lambda visible: gr.update(value=default_value), inputs=advanced, outputs=component)
84
  choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
85
- tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"{self.mode}_infer")
86
 
87
 
88
  def load_model(self, model_path):
 
13
  from text import text_to_sequence
14
  import gradio as gr
15
 
16
+ mode_dict = {
17
+ "普通声线": "normal",
18
+ "营业声线": "formal"
19
+ }
20
+
21
+ default_mode = "普通声线"
22
  default_noise_scale = 0.667
23
  default_noise_scale_w = 0.8
24
  default_length_scale = 1
 
30
  text_norm = torch.LongTensor(text_norm)
31
  return text_norm
32
 
33
+ class TTSInferencer:
34
+ def __init__(self, hps_path, device="cpu"):
35
+ self.device = torch.device(device)
36
+ self.hps = utils.get_hparams_from_file(hps_path)
37
+ self.select_mode(default_mode)
38
+ self.load_model(self.latest_model_path)
39
+
40
+ def select_mode(self, mode):
41
  self.mode = mode
42
+ self.model_dir_path = os.path.join("models", mode_dict[mode])
43
  self.models = []
 
44
  for f in os.listdir(self.model_dir_path):
45
  if (f.startswith("D_")):
46
  continue
47
  if (f.endswith(".pth")):
48
  self.models.append(f)
49
+ self.latest_model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
 
 
 
50
 
51
  def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
52
  stn_tst = get_text(text, self.hps)
 
56
  audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
57
  return (self.hps.data.sampling_rate, audio)
58
 
59
+ def change_mode(self, mode):
60
+ self.select_mode(mode)
61
+ return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
62
+
63
  def change_model(self, model_file_name):
64
  self.load_model(os.path.join(self.model_dir_path, model_file_name))
65
+ return f"载入模型:{model_file_name}({self.mode})"
66
 
67
  def render(self):
68
+ choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
69
+ choice_model = gr.Dropdown(choices=self.models, label=f"模型迭代版本选择", value=os.path.basename(self.pth_path))
70
  with gr.Row():
71
  advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
72
  default = gr.Button("恢复默认设置").style(full_width=False)
 
87
  <div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
88
  </div>
89
  ''')
90
+ advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale)
91
+ advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale_w)
92
+ default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
93
+ default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
94
+ default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
95
+ choice_mode.change(self.change_mode, inputs=choice_mode, outputs=choice_model)
 
96
  choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
97
+ tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
98
 
99
 
100
  def load_model(self, model_path):