tts-fastspeech-mydata / synthesize.py
94insane's picture
230817,V0.1
320e69e
import torch
import torch.nn as nn
import numpy as np
import hparams as hp
import os
import time
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]=hp.synth_visible_devices
import argparse
import re
from string import punctuation
from fastspeech2 import FastSpeech2
from vocoder import vocgan_generator
from text import text_to_sequence, sequence_to_text
import utils
import audio as Audio
import codecs
from g2pk import G2p
from jamo import h2j
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def kor_preprocess(text):
text = text.rstrip(punctuation)
g2p=G2p()
phone = g2p(text)
phone = h2j(phone)
phone = list(filter(lambda p: p != ' ', phone))
phone = '{' + '}{'.join(phone) + '}'
phone = re.sub(r'\{[^\w\s]?\}', '{sp}', phone)
phone = phone.replace('}{', ' ')
sequence = np.array(text_to_sequence(phone,hp.text_cleaners))
sequence = np.stack([sequence])
return torch.from_numpy(sequence).long().to(device)
def get_FastSpeech2(num):
checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num))
model = nn.DataParallel(FastSpeech2())
model.load_state_dict(torch.load(checkpoint_path)['model'])
model.requires_grad = False
model.eval()
return model
def synthesize(model, vocoder, text, sentence, prefix=''):
sentence = sentence[:10] # long filename will result in OS Error
mean_mel, std_mel = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device)
mean_f0, std_f0 = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device)
mean_energy, std_energy = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device)
mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1)
mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1)
mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape(1, -1)
src_len = torch.from_numpy(np.array([text.shape[1]])).to(device)
mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len)
mel_torch = mel.transpose(1, 2).detach()
mel_postnet_torch = mel_postnet.transpose(1, 2).detach()
f0_output = f0_output[0]
energy_output = energy_output[0]
mel_torch = utils.de_norm(mel_torch.transpose(1, 2), mean_mel, std_mel)
mel_postnet_torch = utils.de_norm(mel_postnet_torch.transpose(1, 2), mean_mel, std_mel).transpose(1, 2)
f0_output = utils.de_norm(f0_output, mean_f0, std_f0).squeeze().detach().cpu().numpy()
energy_output = utils.de_norm(energy_output, mean_energy, std_energy).squeeze().detach().cpu().numpy()
if not os.path.exists(hp.test_path):
os.makedirs(hp.test_path)
Audio.tools.inv_mel_spec(mel_postnet_torch[0], os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence)))
if vocoder is not None:
if hp.vocoder.lower() == "vocgan":
utils.vocgan_infer(mel_postnet_torch, vocoder, path=os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)))
utils.plot_data([(mel_postnet_torch[0].detach().cpu().numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
if __name__ == "__main__":
# Test
parser = argparse.ArgumentParser()
parser.add_argument('--step', type=int, default=700000)
args = parser.parse_args()
model = get_FastSpeech2(args.step).to(device)
if hp.vocoder == 'vocgan':
vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path)
else:
vocoder = None
g2p=G2p()
print('input sentence : ')
sentence=input()
print('sentence that will be synthesized: ')
start = time.time()
text = kor_preprocess(sentence)
synthesize(model, vocoder, text, sentence, prefix='step_{}'.format(args.step))
end = time.time()
runTime = end - start
print('run time: {} sec,'.format(runTime),'Done~')