Spaces:
Runtime error
Runtime error
candlend
commited on
Commit
•
b48b8c0
1
Parent(s):
7a669c0
tts_inferencer
Browse files- app.py +5 -19
- inferencer.py → tts_inferencer.py +30 -18
app.py
CHANGED
@@ -1,27 +1,13 @@
|
|
1 |
-
import os
|
2 |
-
import commons
|
3 |
-
import utils
|
4 |
-
from models import SynthesizerTrn
|
5 |
-
from text.symbols import symbols
|
6 |
-
from text import text_to_sequence
|
7 |
import gradio as gr
|
8 |
-
from
|
|
|
|
|
9 |
|
10 |
app = gr.Blocks()
|
11 |
with app:
|
12 |
with open("header.html", "r") as f:
|
13 |
gr.HTML(f.read())
|
14 |
with gr.Tabs():
|
15 |
-
with gr.TabItem("
|
16 |
-
|
17 |
-
使用星弥Hoshimi录播音频作为数据集训练而成
|
18 |
-
"""
|
19 |
-
normal_inferencer = Inferencer("normal", "./configs/hoshimi_base.json", description=normal_description)
|
20 |
-
normal_inferencer.render()
|
21 |
-
with gr.TabItem("营业声线"):
|
22 |
-
formal_description = """
|
23 |
-
使用星弥Hoshimi音声作为数据集训练而成
|
24 |
-
"""
|
25 |
-
formal_inferencer = Inferencer("formal", "./configs/hoshimi_base.json", description=formal_description)
|
26 |
-
formal_inferencer.render()
|
27 |
app.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from tts_inferencer import TTSInferencer
|
3 |
+
|
4 |
+
tts_inferencer = TTSInferencer("./configs/hoshimi_base.json")
|
5 |
|
6 |
app = gr.Blocks()
|
7 |
with app:
|
8 |
with open("header.html", "r") as f:
|
9 |
gr.HTML(f.read())
|
10 |
with gr.Tabs():
|
11 |
+
with gr.TabItem("语音合成"):
|
12 |
+
tts_inferencer.render()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
app.launch()
|
inferencer.py → tts_inferencer.py
RENAMED
@@ -13,6 +13,12 @@ from text.symbols import symbols
|
|
13 |
from text import text_to_sequence
|
14 |
import gradio as gr
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
default_noise_scale = 0.667
|
17 |
default_noise_scale_w = 0.8
|
18 |
default_length_scale = 1
|
@@ -24,21 +30,23 @@ def get_text(text, hps):
|
|
24 |
text_norm = torch.LongTensor(text_norm)
|
25 |
return text_norm
|
26 |
|
27 |
-
class
|
28 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
self.mode = mode
|
30 |
-
self.
|
31 |
self.models = []
|
32 |
-
self.model_dir_path = os.path.join("models", mode)
|
33 |
for f in os.listdir(self.model_dir_path):
|
34 |
if (f.startswith("D_")):
|
35 |
continue
|
36 |
if (f.endswith(".pth")):
|
37 |
self.models.append(f)
|
38 |
-
self.
|
39 |
-
self.hps = utils.get_hparams_from_file(hps_path)
|
40 |
-
model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
|
41 |
-
self.load_model(model_path)
|
42 |
|
43 |
def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
|
44 |
stn_tst = get_text(text, self.hps)
|
@@ -48,12 +56,17 @@ class Inferencer:
|
|
48 |
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
|
49 |
return (self.hps.data.sampling_rate, audio)
|
50 |
|
|
|
|
|
|
|
|
|
51 |
def change_model(self, model_file_name):
|
52 |
self.load_model(os.path.join(self.model_dir_path, model_file_name))
|
53 |
-
return "载入模型:"
|
54 |
|
55 |
def render(self):
|
56 |
-
|
|
|
57 |
with gr.Row():
|
58 |
advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
|
59 |
default = gr.Button("恢复默认设置").style(full_width=False)
|
@@ -74,15 +87,14 @@ class Inferencer:
|
|
74 |
<div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
|
75 |
</div>
|
76 |
''')
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
default.click(fn=lambda visible: gr.update(value=default_value), inputs=advanced, outputs=component)
|
84 |
choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
|
85 |
-
tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"
|
86 |
|
87 |
|
88 |
def load_model(self, model_path):
|
|
|
13 |
from text import text_to_sequence
|
14 |
import gradio as gr
|
15 |
|
16 |
+
mode_dict = {
|
17 |
+
"普通声线": "normal",
|
18 |
+
"营业声线": "formal"
|
19 |
+
}
|
20 |
+
|
21 |
+
default_mode = "普通声线"
|
22 |
default_noise_scale = 0.667
|
23 |
default_noise_scale_w = 0.8
|
24 |
default_length_scale = 1
|
|
|
30 |
text_norm = torch.LongTensor(text_norm)
|
31 |
return text_norm
|
32 |
|
33 |
+
class TTSInferencer:
|
34 |
+
def __init__(self, hps_path, device="cpu"):
|
35 |
+
self.device = torch.device(device)
|
36 |
+
self.hps = utils.get_hparams_from_file(hps_path)
|
37 |
+
self.select_mode(default_mode)
|
38 |
+
self.load_model(self.latest_model_path)
|
39 |
+
|
40 |
+
def select_mode(self, mode):
|
41 |
self.mode = mode
|
42 |
+
self.model_dir_path = os.path.join("models", mode_dict[mode])
|
43 |
self.models = []
|
|
|
44 |
for f in os.listdir(self.model_dir_path):
|
45 |
if (f.startswith("D_")):
|
46 |
continue
|
47 |
if (f.endswith(".pth")):
|
48 |
self.models.append(f)
|
49 |
+
self.latest_model_path = utils.latest_checkpoint_path(self.model_dir_path, "G_*.pth")
|
|
|
|
|
|
|
50 |
|
51 |
def infer(self, text, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
|
52 |
stn_tst = get_text(text, self.hps)
|
|
|
56 |
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
|
57 |
return (self.hps.data.sampling_rate, audio)
|
58 |
|
59 |
+
def change_mode(self, mode):
|
60 |
+
self.select_mode(mode)
|
61 |
+
return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
|
62 |
+
|
63 |
def change_model(self, model_file_name):
|
64 |
self.load_model(os.path.join(self.model_dir_path, model_file_name))
|
65 |
+
return f"载入模型:{model_file_name}({self.mode})"
|
66 |
|
67 |
def render(self):
|
68 |
+
choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
|
69 |
+
choice_model = gr.Dropdown(choices=self.models, label=f"模型迭代版本选择", value=os.path.basename(self.pth_path))
|
70 |
with gr.Row():
|
71 |
advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
|
72 |
default = gr.Button("恢复默认设置").style(full_width=False)
|
|
|
87 |
<div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
|
88 |
</div>
|
89 |
''')
|
90 |
+
advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale)
|
91 |
+
advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale_w)
|
92 |
+
default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
|
93 |
+
default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
|
94 |
+
default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
|
95 |
+
choice_mode.change(self.change_mode, inputs=choice_mode, outputs=choice_model)
|
|
|
96 |
choice_model.change(self.change_model, inputs=[choice_model], outputs=[tts_model])
|
97 |
+
tts_submit.click(self.infer, [tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
|
98 |
|
99 |
|
100 |
def load_model(self, model_path):
|