Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import torch | |
| import yaml | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader | |
| from utils.model import get_model, get_vocoder | |
| from utils.tools import to_device, log, synth_one_sample | |
| from model import FastSpeech2Loss | |
| from dataset import Dataset | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def evaluate(model, step, configs, logger=None, vocoder=None): | |
| preprocess_config, model_config, train_config = configs | |
| # Get dataset | |
| dataset = Dataset( | |
| "val.txt", preprocess_config, train_config, sort=False, drop_last=False | |
| ) | |
| batch_size = train_config["optimizer"]["batch_size"] | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| collate_fn=dataset.collate_fn, | |
| ) | |
| # Get loss function | |
| Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) | |
| # Evaluation | |
| loss_sums = [0 for _ in range(6)] | |
| for batchs in loader: | |
| for batch in batchs: | |
| batch = to_device(batch, device) | |
| with torch.no_grad(): | |
| # Forward | |
| output = model(*(batch[2:])) | |
| # Cal Loss | |
| losses = Loss(batch, output) | |
| for i in range(len(losses)): | |
| loss_sums[i] += losses[i].item() * len(batch[0]) | |
| loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums] | |
| message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Pitch Loss: {:.4f}, Energy Loss: {:.4f}, Duration Loss: {:.4f}".format( | |
| *([step] + [l for l in loss_means]) | |
| ) | |
| if logger is not None: | |
| fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( | |
| batch, | |
| output, | |
| vocoder, | |
| model_config, | |
| preprocess_config, | |
| ) | |
| log(logger, step, losses=loss_means) | |
| log( | |
| logger, | |
| fig=fig, | |
| tag="Validation/step_{}_{}".format(step, tag), | |
| ) | |
| sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] | |
| log( | |
| logger, | |
| audio=wav_reconstruction, | |
| sampling_rate=sampling_rate, | |
| tag="Validation/step_{}_{}_reconstructed".format(step, tag), | |
| ) | |
| log( | |
| logger, | |
| audio=wav_prediction, | |
| sampling_rate=sampling_rate, | |
| tag="Validation/step_{}_{}_synthesized".format(step, tag), | |
| ) | |
| return message | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--restore_step", type=int, default=30000) | |
| parser.add_argument( | |
| "-p", | |
| "--preprocess_config", | |
| type=str, | |
| required=True, | |
| help="path to preprocess.yaml", | |
| ) | |
| parser.add_argument( | |
| "-m", "--model_config", type=str, required=True, help="path to model.yaml" | |
| ) | |
| parser.add_argument( | |
| "-t", "--train_config", type=str, required=True, help="path to train.yaml" | |
| ) | |
| args = parser.parse_args() | |
| # Read Config | |
| preprocess_config = yaml.load( | |
| open(args.preprocess_config, "r"), Loader=yaml.FullLoader | |
| ) | |
| model_config = yaml.load(open(args.model_config, "r"), Loader=yaml.FullLoader) | |
| train_config = yaml.load(open(args.train_config, "r"), Loader=yaml.FullLoader) | |
| configs = (preprocess_config, model_config, train_config) | |
| # Get model | |
| model = get_model(args, configs, device, train=False).to(device) | |
| message = evaluate(model, args.restore_step, configs) | |
| print(message) |