import numpy as np import os import torch import commons import models import utils from argparse import ArgumentParser from tqdm import tqdm from text import text_to_sequence if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-m", "--model_dir", required=True, type=str) parser.add_argument("-s", "--mels_dir", required=True, type=str) args = parser.parse_args() MODEL_DIR = args.model_dir # path to model dir SAVE_MELS_DIR = args.mels_dir # path to save generated mels if not os.path.exists(SAVE_MELS_DIR): os.makedirs(SAVE_MELS_DIR) hps = utils.get_hparams_from_dir(MODEL_DIR) symbols = list(hps.data.punc) + list(hps.data.chars) checkpoint_path = utils.latest_checkpoint_path(MODEL_DIR) cleaner = hps.data.text_cleaners model = models.FlowGenerator( len(symbols) + getattr(hps.data, "add_blank", False), out_channels=hps.data.n_mel_channels, **hps.model ).to("cuda") utils.load_checkpoint(checkpoint_path, model) model.decoder.store_inverse() # do not calcuate jacobians for fast decoding _ = model.eval() def get_mel(text, fpath): if getattr(hps.data, "add_blank", False): text_norm = text_to_sequence(text, symbols, cleaner) text_norm = commons.intersperse(text_norm, len(symbols)) else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality text = " " + text.strip() + " " text_norm = text_to_sequence(text, symbols, cleaner) sequence = np.array(text_norm)[None, :] x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long() x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda() with torch.no_grad(): noise_scale = 0.667 length_scale = 1.0 (y_gen_tst, *_), *_, (attn_gen, *_) = model( x_tst, x_tst_lengths, gen=True, noise_scale=noise_scale, length_scale=length_scale, ) np.save(os.path.join(SAVE_MELS_DIR, fpath), y_gen_tst.cpu().detach().numpy()) for f in [hps.data.training_files, hps.data.validation_files]: file_lines = open(f).read().splitlines() for line in tqdm(file_lines): fname, text = line.split("|") fname = os.path.basename(fname).replace(".wav", ".npy") get_mel(text, fname)