File size: 4,017 Bytes
a299114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3817de1
a299114
1f55a13
a299114
 
1f55a13
 
 
 
 
a299114
1f55a13
 
 
a299114
1f55a13
a299114
 
 
 
1f55a13
a299114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f55a13
a299114
1f55a13
 
 
a299114
 
 
 
1f55a13
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from vits import ROOT_PATH

import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
import gradio as gr

mode_dict = {
    "普通声线": "normal",
    "营业声线": "formal"
}

default_mode = "普通声线"
default_noise_scale = 0.667
default_noise_scale_w = 0.8
default_length_scale = 1

def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

class VitsInferencer:
    def __init__(self, hps_path, device="cpu"):
        print("init")
        self.device = torch.device(device)
        self.hps = utils.get_hparams_from_file(hps_path)
        self.model_paths = {}
        self.models = {}
        for mode in mode_dict:
            self.model_paths[mode] = self.get_latest_model_path_by_mode(mode)
        self.load_models()

    def get_latest_model_path_by_mode(self, mode):
        model_dir_path = os.path.join(ROOT_PATH, "models", mode_dict[mode])
        return utils.latest_checkpoint_path(model_dir_path, "G_*.pth")

    def infer(self, text, mode, noise_scale=.667, noise_scale_w=0.8, length_scale=1):
        stn_tst = get_text(text, self.hps)
        with torch.no_grad():
            x_tst = stn_tst.unsqueeze(0).to(self.device)
            x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.device)
            audio = self.models[mode].infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().numpy()
        return (self.hps.data.sampling_rate, audio)

    def change_mode(self, mode):
        self.select_mode(mode)
        return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path))
    
    def render(self):
        choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode)
        noise_scale = gr.Slider(minimum=0, maximum=3, value=default_noise_scale, step=0.001, label="noise_scale(效果不可控,谨慎修改)")
        noise_scale_w = gr.Slider(minimum=0, maximum=3, value=default_noise_scale_w, step=0.001, label="noise_scale_w(效果不可控,谨慎修改)")
        length_scale = gr.Slider(minimum=0, maximum=3, value=default_length_scale, step=0.001, label="length_scale(数值越大输出音频越长)")
        
        tts_input = gr.TextArea(
            label="请输入文本(目前只支持汉字和单个英文字母,可以使用常用符号和空格来改变语调和停顿,请勿一次性输入过长文本)",
            value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥吼西咪,晚上齁。")
        tts_submit = gr.Button("合成", variant="primary")
        tts_output = gr.Audio(label="Output")
        gr.HTML('''
            <div style="text-align:right;font-size:12px;color:#4D4D4D">
                <div class="font-medium">版权声明</div>
                <div>本项目数据集和模型版权属于星弥Hoshimi</div>
                <div>仅供学习交流,不可用于任何商业和非法用途,否则后果自负</div>
            </div>
        ''')
        tts_submit.click(self.infer, [tts_input, choice_mode, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer")
    
    def load_models(self):
        for key, model_path in self.model_paths.items():
            self.models[key] = SynthesizerTrn(
            len(symbols),
            self.hps.data.filter_length // 2 + 1,
            self.hps.train.segment_size // self.hps.data.hop_length,
            **self.hps.model).to(self.device)
            _ = self.models[key].eval()
            _ = utils.load_checkpoint(model_path, self.models[key], None)