Spaces:
Runtime error
Runtime error
candlend
commited on
Commit
•
7eaf36d
1
Parent(s):
3d5f773
vits
Browse files- vits/tts_inferencer.py +24 -22
vits/tts_inferencer.py
CHANGED
@@ -33,21 +33,23 @@ def get_text(text, hps):
|
|
33 |
|
34 |
class TTSInferencer:
|
35 |
def __init__(self, hps_path, device="cpu"):
|
36 |
-
print("init")
|
37 |
self.device = torch.device(device)
|
38 |
self.hps = utils.get_hparams_from_file(hps_path)
|
39 |
-
self.
|
40 |
-
self.
|
41 |
-
for key, value in mode_dict.items():
|
42 |
-
self.model_paths[key] = self.get_latest_model_path_by_mode(key)
|
43 |
-
self.load_models()
|
44 |
|
45 |
-
def
|
46 |
-
|
47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
-
def infer(self, text,
|
50 |
-
print(self.pth_path)
|
51 |
stn_tst = get_text(text, self.hps)
|
52 |
with torch.no_grad():
|
53 |
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
@@ -65,6 +67,7 @@ class TTSInferencer:
|
|
65 |
|
66 |
def render(self):
|
67 |
choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
|
|
|
68 |
# with gr.Row():
|
69 |
# advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
|
70 |
# default = gr.Button("恢复默认设置").style(full_width=False)
|
@@ -77,6 +80,7 @@ class TTSInferencer:
|
|
77 |
value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
|
78 |
tts_submit = gr.Button("合成", variant="primary")
|
79 |
tts_output = gr.Audio(label="Output")
|
|
|
80 |
gr.HTML('''
|
81 |
<div style="text-align:right;font-size:12px;color:#4D4D4D">
|
82 |
<div class="font-medium">版权声明</div>
|
@@ -89,19 +93,17 @@ class TTSInferencer:
|
|
89 |
# default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
|
90 |
# default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
|
91 |
# default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
|
92 |
-
|
|
|
|
|
|
|
93 |
|
94 |
-
def
|
95 |
-
|
96 |
-
|
97 |
len(symbols),
|
98 |
self.hps.data.filter_length // 2 + 1,
|
99 |
self.hps.train.segment_size // self.hps.data.hop_length,
|
100 |
**self.hps.model).to(self.device)
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
def __del__(self):
|
105 |
-
print("del")
|
106 |
-
del self.net_g
|
107 |
-
self.net_g = None
|
|
|
33 |
|
34 |
class TTSInferencer:
|
35 |
def __init__(self, hps_path, device="cpu"):
|
|
|
36 |
self.device = torch.device(device)
|
37 |
self.hps = utils.get_hparams_from_file(hps_path)
|
38 |
+
self.select_mode(default_mode)
|
39 |
+
self.load_model(self.latest_model_path)
|
|
|
|
|
|
|
40 |
|
41 |
+
def select_mode(self, mode):
|
42 |
+
self.mode = mode
|
43 |
+
self.model_dir_path = os.path.join(ROOT_PATH, "models", mode_dict[mode])
|
44 |
+
self.models = []
|
45 |
+
for f in os.listdir(self.model_dir_path):
|
46 |
+
if (f.startswith("D_")):
|
47 |
+
continue
|
48 |
+
if (f.endswith(".pth")):
|
49 |
+
self.models.append(f)
|
50 |
+
self.latest_model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
|
51 |
|
52 |
+
def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
|
|
|
53 |
stn_tst = get_text(text, self.hps)
|
54 |
with torch.no_grad():
|
55 |
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
|
|
67 |
|
68 |
def render(self):
|
69 |
choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
|
70 |
+
choice_model = gr.Dropdown(choices=self.models, label=f"模型迭代版本选择", value=os.path.basename(self.pth_path))
|
71 |
# with gr.Row():
|
72 |
# advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
|
73 |
# default = gr.Button("恢复默认设置").style(full_width=False)
|
|
|
80 |
value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
|
81 |
tts_submit = gr.Button("合成", variant="primary")
|
82 |
tts_output = gr.Audio(label="Output")
|
83 |
+
tts_model = gr.Markdown(f"载入模型:{os.path.basename(self.latest_model_path)}({self.mode})")
|
84 |
gr.HTML('''
|
85 |
<div style="text-align:right;font-size:12px;color:#4D4D4D">
|
86 |
<div class="font-medium">版权声明</div>
|
|
|
93 |
# default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
|
94 |
# default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
|
95 |
# default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
|
96 |
+
choice_mode.change(self.change_mode, inputs=choice_mode, outputs=choice_model)
|
97 |
+
choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
|
98 |
+
tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
|
99 |
+
|
100 |
|
101 |
+
def load_model(self, model_path):
|
102 |
+
self.pth_path = model_path
|
103 |
+
self.net_g = SynthesizerTrn(
|
104 |
len(symbols),
|
105 |
self.hps.data.filter_length // 2 + 1,
|
106 |
self.hps.train.segment_size // self.hps.data.hop_length,
|
107 |
**self.hps.model).to(self.device)
|
108 |
+
_ = self.net_g.eval()
|
109 |
+
_ = utils.load_checkpoint(self.pth_path, self.net_g, None)
|
|
|
|
|
|
|
|
|
|