Spaces:
Runtime error
Runtime error
candlend
commited on
Commit
•
a299114
1
Parent(s):
6c1e802
vits
Browse files- vits/tts_inferencer.py +107 -0
vits/tts_inferencer.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import torch
|
5 |
+
from torch import nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
from vits import ROOT_PATH
|
9 |
+
|
10 |
+
import commons
|
11 |
+
import utils
|
12 |
+
from models import SynthesizerTrn
|
13 |
+
from text.symbols import symbols
|
14 |
+
from text import text_to_sequence
|
15 |
+
import gradio as gr
|
16 |
+
|
17 |
+
mode_dict = {
|
18 |
+
"普通声线": "normal",
|
19 |
+
"营业声线": "formal"
|
20 |
+
}
|
21 |
+
|
22 |
+
default_mode = "普通声线"
|
23 |
+
default_noise_scale = 0.667
|
24 |
+
default_noise_scale_w = 0.8
|
25 |
+
default_length_scale = 1
|
26 |
+
|
27 |
+
def get_text(text, hps):
|
28 |
+
text_norm = text_to_sequence(text, hps.data.text_cleaners)
|
29 |
+
if hps.data.add_blank:
|
30 |
+
text_norm = commons.intersperse(text_norm, 0)
|
31 |
+
text_norm = torch.LongTensor(text_norm)
|
32 |
+
return text_norm
|
33 |
+
|
34 |
+
class TTSInferencer:
|
35 |
+
def __init__(self, hps_path, device="cpu"):
|
36 |
+
print("init")
|
37 |
+
self.device = torch.device(device)
|
38 |
+
self.hps = utils.get_hparams_from_file(hps_path)
|
39 |
+
self.model_paths = {}
|
40 |
+
self.models = {}
|
41 |
+
for key, value in mode_dict.items():
|
42 |
+
self.model_paths[key] = self.get_latest_model_path_by_mode(key)
|
43 |
+
self.load_models()
|
44 |
+
|
45 |
+
def get_latest_model_path_by_mode(self, mode):
|
46 |
+
model_dir_path = os.path.join(ROOT_PATH, "models", mode_dict[mode])
|
47 |
+
return utils.latest_checkpoint_path(model_dir_path, "G_*.pth")
|
48 |
+
|
49 |
+
def infer(self, text, mode, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
|
50 |
+
print(self.pth_path)
|
51 |
+
stn_tst = get_text(text, self.hps)
|
52 |
+
with torch.no_grad():
|
53 |
+
x_tst = stn_tst.unsqueeze(0).to(self.device)
|
54 |
+
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.device)
|
55 |
+
audio = self.net_g.infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
|
56 |
+
return (self.hps.data.sampling_rate, audio)
|
57 |
+
|
58 |
+
def change_mode(self, mode):
|
59 |
+
self.select_mode(mode)
|
60 |
+
return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
|
61 |
+
|
62 |
+
def change_model(self, model_file_name):
|
63 |
+
self.load_model(os.path.join(self.model_dir_path, model_file_name))
|
64 |
+
return f"载入模型:{model_file_name}({self.mode})"
|
65 |
+
|
66 |
+
def render(self):
|
67 |
+
choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
|
68 |
+
# with gr.Row():
|
69 |
+
# advanced = gr.Checkbox(label="显示高级设置(效果不可控)")
|
70 |
+
# default = gr.Button("恢复默认设置").style(full_width=False)
|
71 |
+
noise_scale = gr.Slider(minimum=0, maximum=3, value=default_noise_scale, step=0.001, label="noise_scale(效果不可控,谨慎修改)")
|
72 |
+
noise_scale_w = gr.Slider(minimum=0, maximum=3, value=default_noise_scale_w, step=0.001, label="noise_scale_w(效果不可控,谨慎修改)")
|
73 |
+
length_scale = gr.Slider(minimum=0, maximum=3, value=default_length_scale, step=0.001, label="length_scale(数值越大输出音频越长)")
|
74 |
+
|
75 |
+
tts_input = gr.TextArea(
|
76 |
+
label="请输入文本(目前只支持汉字和单个英文字母,可以使用常用符号和空格来改变语调和停顿,请勿一次性输入过长文本)",
|
77 |
+
value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
|
78 |
+
tts_submit = gr.Button("合成", variant="primary")
|
79 |
+
tts_output = gr.Audio(label="Output")
|
80 |
+
gr.HTML('''
|
81 |
+
<div style="text-align:right;font-size:12px;color:#4D4D4D">
|
82 |
+
<div class="font-medium">版权声明</div>
|
83 |
+
<div>本项目数据集和模型版权属于星弥Hoshimi</div>
|
84 |
+
<div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
|
85 |
+
</div>
|
86 |
+
''')
|
87 |
+
# advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale)
|
88 |
+
# advanced.change(fn=lambda visible: gr.update(visible=visible), inputs=advanced, outputs=noise_scale_w)
|
89 |
+
# default.click(fn=lambda visible: gr.update(value=default_noise_scale), inputs=advanced, outputs=noise_scale)
|
90 |
+
# default.click(fn=lambda visible: gr.update(value=default_noise_scale_w), inputs=advanced, outputs=noise_scale_w)
|
91 |
+
# default.click(fn=lambda visible: gr.update(value=default_length_scale), inputs=advanced, outputs=length_scale)
|
92 |
+
tts_submit.click(self.infer, [choice_mode, tts_input, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
|
93 |
+
|
94 |
+
def load_models(self):
|
95 |
+
for key, model_path in self.latest_model_paths.items():
|
96 |
+
self.models[key] = SynthesizerTrn(
|
97 |
+
len(symbols),
|
98 |
+
self.hps.data.filter_length // 2 + 1,
|
99 |
+
self.hps.train.segment_size // self.hps.data.hop_length,
|
100 |
+
**self.hps.model).to(self.device)
|
101 |
+
_ = self.models[key].eval()
|
102 |
+
_ = utils.load_checkpoint(model_path, self.models[key], None)
|
103 |
+
|
104 |
+
def __del__(self):
|
105 |
+
print("del")
|
106 |
+
del self.net_g
|
107 |
+
self.net_g = None
|