|
import torch |
|
import torch.nn as nn |
|
from torch.utils.data import DataLoader |
|
|
|
import numpy as np |
|
import os |
|
import argparse |
|
import re |
|
|
|
from fastspeech2 import FastSpeech2 |
|
from loss import FastSpeech2Loss |
|
|
|
from dataset import Dataset |
|
from text import text_to_sequence, sequence_to_text |
|
|
|
import hparams as hp |
|
import utils |
|
import audio as Audio |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
def get_FastSpeech2(num): |
|
checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) |
|
model = nn.DataParallel(FastSpeech2()) |
|
model.load_state_dict(torch.load(checkpoint_path)['model']) |
|
model.requires_grad = False |
|
model.eval() |
|
return model |
|
|
|
def evaluate(model, step, vocoder=None): |
|
model.eval() |
|
torch.manual_seed(0) |
|
|
|
mean_mel, std_mel = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) |
|
mean_f0, std_f0 = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) |
|
mean_energy, std_energy = torch.tensor(np.load(os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) |
|
|
|
eval_path = hp.eval_path |
|
if not os.path.exists(eval_path): |
|
os.makedirs(eval_path) |
|
|
|
|
|
dataset = Dataset("val.txt", sort=False) |
|
loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) |
|
|
|
|
|
Loss = FastSpeech2Loss().to(device) |
|
|
|
|
|
d_l = [] |
|
f_l = [] |
|
e_l = [] |
|
mel_l = [] |
|
mel_p_l = [] |
|
current_step = 0 |
|
idx = 0 |
|
for i, batchs in enumerate(loader): |
|
for j, data_of_batch in enumerate(batchs): |
|
|
|
id_ = data_of_batch["id"] |
|
text = torch.from_numpy(data_of_batch["text"]).long().to(device) |
|
mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device) |
|
D = torch.from_numpy(data_of_batch["D"]).int().to(device) |
|
log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) |
|
f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) |
|
energy = torch.from_numpy(data_of_batch["energy"]).float().to(device) |
|
src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device) |
|
mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device) |
|
max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) |
|
max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) |
|
|
|
with torch.no_grad(): |
|
|
|
mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model( |
|
text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) |
|
|
|
|
|
mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( |
|
log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) |
|
|
|
d_l.append(d_loss.item()) |
|
f_l.append(f_loss.item()) |
|
e_l.append(e_loss.item()) |
|
mel_l.append(mel_loss.item()) |
|
mel_p_l.append(mel_postnet_loss.item()) |
|
|
|
if idx == 0 and vocoder is not None: |
|
|
|
for k in range(1): |
|
basename = id_[k] |
|
gt_length = mel_len[k] |
|
out_length = out_mel_len[k] |
|
|
|
mel_target_torch = mel_target[k:k+1, :gt_length] |
|
mel_target_ = mel_target[k, :gt_length] |
|
mel_postnet_torch = mel_postnet_output[k:k+1, :out_length] |
|
mel_postnet = mel_postnet_output[k, :out_length] |
|
|
|
mel_target_torch = utils.de_norm(mel_target_torch, mean_mel, std_mel).transpose(1, 2).detach() |
|
mel_target_ = utils.de_norm(mel_target_, mean_mel, std_mel).cpu().transpose(0, 1).detach() |
|
mel_postnet_torch = utils.de_norm(mel_postnet_torch, mean_mel, std_mel).transpose(1, 2).detach() |
|
mel_postnet = utils.de_norm(mel_postnet, mean_mel, std_mel).cpu().transpose(0, 1).detach() |
|
|
|
if hp.vocoder == "vocgan": |
|
utils.vocgan_infer(mel_target_torch, vocoder, path=os.path.join(hp.eval_path, 'eval_groundtruth_{}_{}.wav'.format(basename, hp.vocoder))) |
|
utils.vocgan_infer(mel_postnet_torch, vocoder, path=os.path.join(hp.eval_path, 'eval_step_{}_{}_{}.wav'.format(step, basename, hp.vocoder))) |
|
np.save(os.path.join(hp.eval_path, 'eval_step_{}_{}_mel.npy'.format(step, basename)), mel_postnet.numpy()) |
|
|
|
f0_ = f0[k, :gt_length] |
|
energy_ = energy[k, :gt_length] |
|
f0_output_ = f0_output[k, :out_length] |
|
energy_output_ = energy_output[k, :out_length] |
|
|
|
f0_ = utils.de_norm(f0_, mean_f0, std_f0).detach().cpu().numpy() |
|
f0_output_ = utils.de_norm(f0_output, mean_f0, std_f0).detach().cpu().numpy() |
|
energy_ = utils.de_norm(energy_, mean_energy, std_energy).detach().cpu().numpy() |
|
energy_output_ = utils.de_norm(energy_output_, mean_energy, std_energy).detach().cpu().numpy() |
|
|
|
utils.plot_data([(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], |
|
['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_step_{}_{}.png'.format(step, basename))) |
|
idx += 1 |
|
print("done") |
|
current_step += 1 |
|
|
|
d_l = sum(d_l) / len(d_l) |
|
f_l = sum(f_l) / len(f_l) |
|
e_l = sum(e_l) / len(e_l) |
|
mel_l = sum(mel_l) / len(mel_l) |
|
mel_p_l = sum(mel_p_l) / len(mel_p_l) |
|
|
|
str1 = "FastSpeech2 Step {},".format(step) |
|
str2 = "Duration Loss: {}".format(d_l) |
|
str3 = "F0 Loss: {}".format(f_l) |
|
str4 = "Energy Loss: {}".format(e_l) |
|
str5 = "Mel Loss: {}".format(mel_l) |
|
str6 = "Mel Postnet Loss: {}".format(mel_p_l) |
|
|
|
print("\n" + str1) |
|
print(str2) |
|
print(str3) |
|
print(str4) |
|
print(str5) |
|
print(str6) |
|
|
|
with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: |
|
f_log.write(str1 + "\n") |
|
f_log.write(str2 + "\n") |
|
f_log.write(str3 + "\n") |
|
f_log.write(str4 + "\n") |
|
f_log.write(str5 + "\n") |
|
f_log.write(str6 + "\n") |
|
f_log.write("\n") |
|
model.train() |
|
|
|
return d_l, f_l, e_l, mel_l, mel_p_l |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--step', type=int, default=30000) |
|
args = parser.parse_args() |
|
|
|
|
|
model = get_FastSpeech2(args.step).to(device) |
|
print("Model Has Been Defined") |
|
num_param = utils.get_param_num(model) |
|
print('Number of FastSpeech2 Parameters:', num_param) |
|
|
|
|
|
if hp.vocoder == 'vocgan': |
|
vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path) |
|
vocoder.to(device) |
|
|
|
|
|
if not os.path.exists(hp.log_path): |
|
os.makedirs(hp.log_path) |
|
if not os.path.exists(hp.eval_path): |
|
os.makedirs(hp.eval_path) |
|
evaluate(model, args.step, vocoder) |
|
|