import os import json import math import torch from torch import nn from torch.nn import functional as F from torch.utils.data import DataLoader from vits import VITS_ROOT_PATH from vits import commons from vits import utils from vits.models import SynthesizerTrn from vits.text.symbols import symbols from vits.text import text_to_sequence import gradio as gr mode_dict = { "普通声线": "normal", "营业声线": "formal" } default_mode = "普通声线" default_noise_scale = 0.667 default_noise_scale_w = 0.8 default_length_scale = 1 replace_list = [ ("candle", "刊豆"), ("end", "按的"), ("hoshimi", "吼西咪"), ("mua", "木啊"), ("hsm", "吼西咪"), ("ho", "齁"), ("na", "呐"), ("shi", "西"), ("mi", "咪"), ] def get_text(text, hps): text = preprocess_text(text) text_norm = text_to_sequence(text, hps.data.text_cleaners) if hps.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def preprocess_text(text): text = text.lower() for src, dst in replace_list: text = text.replace(src, dst) return text class VitsInferencer: def __init__(self, hps_path, device="cpu"): print("init") self.device = torch.device(device) self.hps = utils.get_hparams_from_file(hps_path) self.model_paths = {} self.models = {} for mode in mode_dict: self.model_paths[mode] = self.get_latest_model_path_by_mode(mode) self.load_models() def get_latest_model_path_by_mode(self, mode): model_dir_path = os.path.join(VITS_ROOT_PATH, "models", mode_dict[mode]) return utils.latest_checkpoint_path(model_dir_path, "G_*.pth") def infer(self, text, mode, noise_scale=.667, noise_scale_w=0.8, length_scale=1): stn_tst = get_text(text, self.hps) with torch.no_grad(): x_tst = stn_tst.unsqueeze(0).to(self.device) x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(self.device) audio = self.models[mode].infer(x_tst, x_tst_lengths, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.float().cpu().numpy() return (self.hps.data.sampling_rate, audio) def change_mode(self, mode): self.select_mode(mode) return gr.update(choices=self.models, value=os.path.basename(self.latest_model_path)) def render(self): choice_mode = gr.Radio(choices=["普通声线", "营业声线"], label="声线选择", value=default_mode) noise_scale = gr.Slider(minimum=0, maximum=3, value=default_noise_scale, step=0.001, label="noise_scale(效果不可控,谨慎修改)") noise_scale_w = gr.Slider(minimum=0, maximum=3, value=default_noise_scale_w, step=0.001, label="noise_scale_w(效果不可控,谨慎修改)") length_scale = gr.Slider(minimum=0, maximum=3, value=default_length_scale, step=0.001, label="length_scale(数值越大输出音频越长)") tts_input = gr.TextArea( label="请输入文本(目前只支持汉字、单个英文字母和极个别专有名词,可以使用常用符号和空格来改变语调和停顿,请勿一次性输入过长文本)", value="这里是爱喝奶茶,穿得也像奶茶魅力点是普通话二乙的星弥Hoshimi,晚上Ho") tts_submit = gr.Button("合成", variant="primary") tts_output = gr.Audio(label="Output") gr.HTML('''
版权声明
本项目数据集和模型版权属于星弥Hoshimi
仅供学习交流,不可用于任何商业和非法用途,否则后果自负
''') tts_submit.click(self.infer, [tts_input, choice_mode, noise_scale, noise_scale_w, length_scale], [tts_output], api_name=f"infer") def load_models(self): for key, model_path in self.model_paths.items(): self.models[key] = SynthesizerTrn( len(symbols), self.hps.data.filter_length // 2 + 1, self.hps.train.segment_size // self.hps.data.hop_length, **self.hps.model).to(self.device) _ = self.models[key].eval() _ = utils.load_checkpoint(model_path, self.models[key], None)