File size: 3,754 Bytes
69c7b60
06441c0
 
 
 
 
 
 
 
 
 
69c7b60
06441c0
 
 
 
 
 
 
69c7b60
151afe1
06441c0
 
 
 
 
 
 
 
 
 
 
7c27bc6
06441c0
7c27bc6
06441c0
 
 
7c27bc6
 
 
 
06441c0
7c27bc6
06441c0
 
3d7a6b8
06441c0
0281d5f
 
7c27bc6
06441c0
 
 
 
 
 
7c27bc6
f7c1f1e
358130b
06441c0
f7c1f1e
8bfdbcf
 
 
 
 
 
 
 
 
4924b5c
 
 
 
f7c1f1e
 
e3eef83
4c7f733
 
8d569d2
f7c1f1e
 
7c27bc6
9161d4c
4d0831c
d34114c
3bf0ca5
8bfdbcf
d34114c
8bfdbcf
d34114c
9161d4c
4c7f733
9161d4c
 
83440e2
f7c1f1e
4d0831c
 
 
75caba3
4d0831c
 
 
f7c1f1e
 
 
 
 
 
8d569d2
b780be3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import gradio as gr
# import matplotlib.pyplot as plt
import logging
# logger = logging.getLogger(__name__)
import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import time


def get_text(text, hps):
    # text_norm = requests.post("http://121.5.171.42:39001/texttosequence?text="+text).json()["text_norm"]
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    # print(hps.data.text_cleaners)
    # print(text_norm)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def load_model(config_path, pth_path):
    global dev, hps, net_g
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    hps = utils.get_hparams_from_file(config_path)

    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        n_speakers=hps.data.n_speakers,
        **hps.model).to(dev)
    _ = net_g.eval()
    
    _ = utils.load_checkpoint(pth_path, net_g)

    print(f"{pth_path}加载成功!")
    
def infer(c_name, text):
    c_id = character_dict[c_name]
    stn_tst = get_text(text, hps)
    with torch.no_grad():
        x_tst = stn_tst.to(dev).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(dev)
        sid = torch.LongTensor([c_id]).to(dev)
        audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()

        return (hps.data.sampling_rate, audio)

pth_path = "model/G_215000.pth"
config_path = "configs/config.json"
character_dict = {
    "夜刀神十香": 1,
    "鸢一折纸": 2,
    "时崎狂三": 3,
    "冰芽川四糸乃": 4,
    "五河琴里": 5,
    "八舞夕弦": 6,
    "八舞耶俱矢": 7,
    "诱宵美九": 8,
    "园神凛祢": 9,
    "园神凛绪": 11,
    "或守鞠亚": 12,
    "或守鞠奈": 13,
    "崇宫真那": 14,
}


load_model(config_path, pth_path)

app = gr.Blocks()
with app:
    with gr.Tabs():
        with gr.Row():
            text = gr.TextArea(
                label="请输入文本(仅支持日语)", value="こんにちは,世界!")
        with gr.Row():
            radio = gr.Radio(list(character_dict.keys()), 
                label="请选择角色")
        with gr.Row():
            tts_submit = gr.Button("合成", variant="primary")
        with gr.Row():
            tts_output = gr.Audio(label="Output")
        # model_submit.click(load_model, [config_path, pth_path])
        tts_submit.click(infer, [radio, text], [tts_output])
        radio.change(infer, [radio, text], [tts_output])

    gr.HTML("""
    <div style="text-align:center">
                    <h4 class="h-sign" style="font-size: 12px;">
                        这是一个使用<a href="https://github.com/thesupersonic16/DALTools" target="_blank">thesupersonic16/DALTools</a>提供的解包音频作为数据集,
                        使用<a href="https://github.com/jaywalnut310/vits" target="_blank">VITS</a>技术训练的语音合成demo(215k step)。
                    </h4>
                </div>
            </div>
<div style="text-align:center"> 
    仅供学习交流,不可用于商业或非法用途
    <br/>
    使用本项目模型直接或间接生成的音频,必须声明由AI技术或VITS技术合成
</div>
    """)
    app.launch()