|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import hparams as hp |
|
import os |
|
import time |
|
|
|
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" |
|
os.environ["CUDA_VISIBLE_DEVICES"]=hp.synth_visible_devices |
|
|
|
import argparse |
|
import re |
|
from string import punctuation |
|
|
|
from fastspeech2 import FastSpeech2 |
|
from vocoder import vocgan_generator |
|
|
|
from text import text_to_sequence, sequence_to_text |
|
import utils |
|
import audio as Audio |
|
|
|
import codecs |
|
from g2pk import G2p |
|
from jamo import h2j |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def kor_preprocess(text): |
|
text = text.rstrip(punctuation) |
|
g2p=G2p() |
|
phone = g2p(text) |
|
phone = h2j(phone) |
|
phone = list(filter(lambda p: p != ' ', phone)) |
|
phone = '{' + '}{'.join(phone) + '}' |
|
phone = re.sub(r'\{[^\w\s]?\}', '{sp}', phone) |
|
phone = phone.replace('}{', ' ') |
|
|
|
sequence = np.array(text_to_sequence(phone,hp.text_cleaners)) |
|
sequence = np.stack([sequence]) |
|
return torch.from_numpy(sequence).long().to(device) |
|
|
|
def get_FastSpeech2(num): |
|
checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) |
|
model = nn.DataParallel(FastSpeech2()) |
|
model.load_state_dict(torch.load(checkpoint_path)['model']) |
|
model.requires_grad = False |
|
model.eval() |
|
return model |
|
|
|
def synthesize(model, vocoder, text, sentence, prefix=''): |
|
sentence = sentence[:10] |
|
|
|
mean_mel, std_mel = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) |
|
mean_f0, std_f0 = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) |
|
mean_energy, std_energy = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) |
|
|
|
mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1) |
|
mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1) |
|
mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(1, -1) |
|
|
|
src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) |
|
|
|
mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len) |
|
|
|
mel_torch = mel.transpose(1, 2).detach() |
|
mel_postnet_torch = mel_postnet.transpose(1, 2).detach() |
|
f0_output = f0_output[0] |
|
energy_output = energy_output[0] |
|
|
|
mel_torch = utils.de_norm(mel_torch.transpose(1, 2), mean_mel, std_mel) |
|
mel_postnet_torch = utils.de_norm(mel_postnet_torch.transpose(1, 2), mean_mel, std_mel).transpose(1, 2) |
|
f0_output = utils.de_norm(f0_output, mean_f0, std_f0).squeeze().detach().cpu().numpy() |
|
energy_output = utils.de_norm(energy_output, mean_energy, std_energy).squeeze().detach().cpu().numpy() |
|
|
|
if not os.path.exists(hp.test_path): |
|
os.makedirs(hp.test_path) |
|
|
|
Audio.tools.inv_mel_spec(mel_postnet_torch[0], os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) |
|
|
|
if vocoder is not None: |
|
if hp.vocoder.lower() == "vocgan": |
|
utils.vocgan_infer(mel_postnet_torch, vocoder, path=os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) |
|
|
|
utils.plot_data([(mel_postnet_torch[0].detach().cpu().numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence))) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--step', type=int, default=700000) |
|
args = parser.parse_args() |
|
|
|
|
|
model = get_FastSpeech2(args.step).to(device) |
|
if hp.vocoder == 'vocgan': |
|
vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path) |
|
else: |
|
vocoder = None |
|
|
|
g2p=G2p() |
|
|
|
print('input sentence : ') |
|
|
|
sentence=input() |
|
print('sentence that will be synthesized: ') |
|
start = time.time() |
|
text = kor_preprocess(sentence) |
|
synthesize(model, vocoder, text, sentence, prefix='step_{}'.format(args.step)) |
|
end = time.time() |
|
runTime = end - start |
|
print('run time: {} sec,'.format(runTime),'Done~') |
|
|