MyGO_VIts-bert / app.py
Mahiruoshi's picture
Upload 120 files
9169788
raw
history blame
No virus
10.6 kB
import argparse
import os
from pathlib import Path
import logging
import re_matching
logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.basicConfig(
level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
)
logger = logging.getLogger(__name__)
import librosa
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import Wav2Vec2Processor
from transformers.models.wav2vec2.modeling_wav2vec2 import (
Wav2Vec2Model,
Wav2Vec2PreTrainedModel,
)
import gradio as gr
import utils
from config import config
import torch
import commons
from text import cleaned_text_to_sequence, get_bert
from emo_gen import process_func, EmotionModel, Wav2Vec2Processor, Wav2Vec2Model, Wav2Vec2PreTrainedModel, RegressionHead
from text.cleaner import clean_text
import utils
from models import SynthesizerTrn
from text.symbols import symbols
import sys
net_g = None
device = 'cpu'
device = (
"cuda:0"
if torch.cuda.is_available()
else (
"mps"
if sys.platform == "darwin" and torch.backends.mps.is_available()
else "cpu"
)
)
BandList = {
"MyGo&AveMujica(Part)":["燈","愛音","そよ","立希","楽奈"],
"AveMujica":["祥子","睦","海鈴","にゃむ","初華"]
}
def get_net_g(model_path: str, version: str, device: str, hps):
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
**hps.model,
).to(device)
_ = net_g.eval()
_ = utils.load_checkpoint(model_path, net_g, None, skip_optimizer=True)
return net_g
def get_text(text, language_str, hps, device):
norm_text, phone, tone, word2ph = clean_text(text, language_str)
phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
print(text)
if hps.data.add_blank:
phone = commons.intersperse(phone, 0)
tone = commons.intersperse(tone, 0)
language = commons.intersperse(language, 0)
for i in range(len(word2ph)):
word2ph[i] = word2ph[i] * 2
word2ph[0] += 1
bert_ori = get_bert(norm_text, word2ph, language_str, device)
del word2ph
assert bert_ori.shape[-1] == len(phone), phone
if language_str == "ZH":
bert = bert_ori
ja_bert = torch.zeros(1024, len(phone))
en_bert = torch.zeros(1024, len(phone))
elif language_str == "JP":
bert = torch.zeros(1024, len(phone))
ja_bert = bert_ori
en_bert = torch.zeros(1024, len(phone))
elif language_str == "EN":
bert = torch.zeros(1024, len(phone))
ja_bert = torch.zeros(1024, len(phone))
en_bert = bert_ori
else:
raise ValueError("language_str should be ZH, JP or EN")
assert bert.shape[-1] == len(
phone
), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
phone = torch.LongTensor(phone)
tone = torch.LongTensor(tone)
language = torch.LongTensor(language)
return bert, ja_bert, en_bert, phone, tone, language
def get_emo_(reference_audio, emotion):
emo = (
torch.from_numpy(get_emo(reference_audio))
if reference_audio
else torch.Tensor([emotion])
)
return emo
def get_emo(path):
wav, sr = librosa.load(path, 16000)
device = config.bert_gen_config.device
return process_func(
np.expand_dims(wav, 0).astype(np.float),
sr,
emotional_model,
emotional_processor,
device,
embeddings=True,
).squeeze(0)
def infer(
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
sid,
reference_audio=None,
emotion=None,
):
language= 'JP' if is_japanese(text) else 'ZH'
bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
text, language, hps, device
)
emo = get_emo_(reference_audio, emotion)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
en_bert = en_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
emo = emo.to(device).unsqueeze(0)
del phones
speakers = torch.LongTensor([hps.data.spk2id[sid]]).to(device)
audio = (
net_g.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
en_bert,
emo,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=length_scale,
)[0][0, 0]
.data.cpu()
.float()
.numpy()
)
del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert, emo
if torch.cuda.is_available():
torch.cuda.empty_cache()
return (hps.data.sampling_rate,audio)
def is_japanese(string):
for ch in string:
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
return True
return False
def loadmodel(model):
_ = net_g.eval()
_ = utils.load_checkpoint(model, net_g, None, skip_optimizer=True)
return "success"
if __name__ == "__main__":
emotional_model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
emotional_processor = Wav2Vec2Processor.from_pretrained(emotional_model_name)
emotional_model = EmotionModel.from_pretrained(emotional_model_name).to(device)
hps = utils.get_hparams_from_file('Data/BanGDream/configs/config.json')
net_g = get_net_g(
model_path='Data/BangDream/models/G_49000.pth', version="2.1", device=device, hps=hps
)
speaker_ids = hps.data.spk2id
speakers = list(speaker_ids.keys())
languages = [ "Auto", "ZH", "JP"]
modelPaths = []
for dirpath, dirnames, filenames in os.walk("Data/BanGDream/models/"):
for filename in filenames:
modelPaths.append(os.path.join(dirpath, filename))
with gr.Blocks() as app:
for band in BandList:
with gr.TabItem(band):
for name in BandList[band]:
with gr.TabItem(name):
with gr.Row():
with gr.Column():
with gr.Row():
gr.Markdown(
'<div align="center">'
f'<img style="width:auto;height:400px;" src="file/image/{name}.png">'
'</div>'
)
length_scale = gr.Slider(
minimum=0.1, maximum=2, value=1, step=0.01, label="语速调节"
)
emotion = gr.Slider(
minimum=0, maximum=9, value=0, step=1, label="Emotion"
)
with gr.Accordion(label="参数设定", open=False):
sdp_ratio = gr.Slider(
minimum=0, maximum=1, value=0.2, step=0.01, label="SDP/DP混合比"
)
noise_scale = gr.Slider(
minimum=0.1, maximum=2, value=0.6, step=0.01, label="感情调节"
)
noise_scale_w = gr.Slider(
minimum=0.1, maximum=2, value=0.8, step=0.01, label="音素长度"
)
speaker = gr.Dropdown(
choices=speakers, value=name, label="说话人"
)
with gr.Accordion(label="切换模型", open=False):
modelstrs = gr.Dropdown(label = "模型", choices = modelPaths, value = modelPaths[0], type = "value")
btnMod = gr.Button("载入模型")
statusa = gr.TextArea()
btnMod.click(loadmodel, inputs=[modelstrs], outputs = [statusa])
with gr.Column():
text = gr.TextArea(
label="输入纯日语或者中文",
placeholder="输入纯日语或者中文",
value="为什么要演奏春日影!",
)
reference_audio = gr.Audio(label="情感参考音频(WAV 格式):用于生成语音的情感参考。(WAV 格式)", type="filepath")
btn = gr.Button("点击生成", variant="primary")
audio_output = gr.Audio(label="Output Audio")
'''
btntran = gr.Button("快速中翻日")
translateResult = gr.TextArea("从这复制翻译后的文本")
btntran.click(translate, inputs=[text], outputs = [translateResult])
'''
btn.click(
infer,
inputs=[
text,
sdp_ratio,
noise_scale,
noise_scale_w,
length_scale,
speaker,
reference_audio,
emotion,
],
outputs=[audio_output],
)
print("推理页面已开启!")
app.launch()