CarlDennis's picture
Upload app.py
a0c929e
raw history blame
No virus
4.06 kB
import re
import gradio as gr
import torch
import unicodedata
import commons
import utils
from models import SynthesizerTrn
from text import text_to_sequence
config_json = "muse_tricolor_b.json"
pth_path = "G=496.pth"
def get_text(text, hps, cleaned=False):
if cleaned:
text_norm = text_to_sequence(text, hps.symbols, [])
else:
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def get_label(text, label):
if f'[{label}]' in text:
return True, text.replace(f'[{label}]', '')
else:
return False, text
def clean_text(text):
print(text)
jap = re.compile(r'[\u3040-\u309F\u30A0-\u30FF]') # 匹配日文
text = unicodedata.normalize('NFKC', text)
text = f"[JA]{text}[JA]" if jap.search(text) else f"[ZH]{text}[ZH]"
return text
def load_model(config_json, pth_path):
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hps_ms = utils.get_hparams_from_file(f"{config_json}")
n_speakers = hps_ms.data.n_speakers if 'n_speakers' in hps_ms.data.keys() else 0
n_symbols = len(hps_ms.symbols) if 'symbols' in hps_ms.keys() else 0
net_g_ms = SynthesizerTrn(
n_symbols,
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=n_speakers,
**hps_ms.model).to(dev)
_ = net_g_ms.eval()
_ = utils.load_checkpoint(pth_path, net_g_ms)
return net_g_ms
net_g_ms = load_model(config_json, pth_path)
def selection(speaker):
if speaker == "南小鸟":
spk = 0
return spk
elif speaker == "园田海未":
spk = 1
return spk
elif speaker == "小泉花阳":
spk = 2
return spk
elif speaker == "星空凛":
spk = 3
return spk
elif speaker == "东条希":
spk = 4
return spk
elif speaker == "矢泽妮可":
spk = 5
return spk
elif speaker == "绚濑绘里":
spk = 6
return spk
elif speaker == "西木野真姬":
spk = 7
return spk
elif speaker == "高坂穗乃果":
spk = 8
return spk
def infer(text,speaker_id, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ):
text = clean_text(text)
speaker_id = int(selection(speaker_id))
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hps_ms = utils.get_hparams_from_file(f"{config_json}")
with torch.no_grad():
stn_tst = get_text(text, hps_ms, cleaned=False)
x_tst = stn_tst.unsqueeze(0).to(dev)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
sid = torch.LongTensor([speaker_id]).to(dev)
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=n_scale, noise_scale_w=n_scale_w, length_scale=l_scale)[0][
0, 0].data.cpu().float().numpy()
return (hps_ms.data.sampling_rate, audio)
idols = ["南小鸟","园田海未","小泉花阳","星空凛","东条希","矢泽妮可","绚濑绘里","西木野真姬","高坂穗乃果"]
app = gr.Blocks()
with app:
with gr.Tabs():
with gr.TabItem("Basic"):
tts_input1 = gr.TextArea(label="请输入纯中文或纯日文", value="大家好")
para_input1 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声比例", value=0.667)
para_input2 = gr.Slider(minimum= 0.01,maximum=1.0,label="更改噪声偏差", value=0.8)
para_input3 = gr.Slider(minimum= 0.1,maximum=10,label="更改时间比例", value=1)
tts_submit = gr.Button("Generate", variant="primary")
speaker1 = gr.Dropdown(label="选择说话人",choices=idols, value="高坂穗乃果", interactive=True)
tts_output2 = gr.Audio(label="Output")
tts_submit.click(infer, [tts_input1,speaker1,para_input1,para_input2,para_input3], [tts_output2])
app.launch()