hzrr's picture
Update app.py
e3eef83
import gradio as gr
# import matplotlib.pyplot as plt
import logging
# logger = logging.getLogger(__name__)
import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import time
def get_text(text, hps):
# text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
text_norm = text_to_sequence(text, hps.data.text_cleaners)
# print(hps.data.text_cleaners)
# print(text_norm)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = torch.LongTensor(text_norm)
return text_norm
def load_model(config_path, pth_path):
global dev, hps, net_g
dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
hps = utils.get_hparams_from_file(config_path)
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model).to(dev)
_ = net_g.eval()
_ = utils.load_checkpoint(pth_path, net_g)
print(f"{pth_path}加载成功!")
def infer(c_name, text):
c_id = character_dict[c_name]
stn_tst = get_text(text, hps)
with torch.no_grad():
x_tst = stn_tst.to(dev).unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
sid = torch.LongTensor([c_id]).to(dev)
audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
return (hps.data.sampling_rate, audio)
pth_path = "model/G_215000.pth"
config_path = "configs/config.json"
character_dict = {
"夜刀神十香": 1,
"鸢一折纸": 2,
"时崎狂三": 3,
"冰芽川四糸乃": 4,
"五河琴里": 5,
"八舞夕弦": 6,
"八舞耶俱矢": 7,
"诱宵美九": 8,
"园神凛祢": 9,
"园神凛绪": 11,
"或守鞠亚": 12,
"或守鞠奈": 13,
"崇宫真那": 14,
}
load_model(config_path, pth_path)
app = gr.Blocks()
with app:
with gr.Tabs():
with gr.Row():
text = gr.TextArea(
label="请输入文本(仅支持日语)", value="こんにちは,世界!")
with gr.Row():
radio = gr.Radio(list(character_dict.keys()),
label="请选择角色")
with gr.Row():
tts_submit = gr.Button("合成", variant="primary")
with gr.Row():
tts_output = gr.Audio(label="Output")
# model_submit.click(load_model, [config_path, pth_path])
tts_submit.click(infer, [radio, text], [tts_output])
radio.change(infer, [radio, text], [tts_output])
gr.HTML("""
<div style="text-align:center">
<h4 class="h-sign" style="font-size: 12px;">
这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo(215k step)。
</h4>
</div>
</div>
<div style="text-align:center">
仅供学习交流,不可用于商业或非法用途
<br/>
使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成
</div>
""")
app.launch()