harveen
Add Tamil
3b92d66
raw
history blame
2.51 kB
import numpy as np
import os
import torch
import commons
import models
import utils
from argparse import ArgumentParser
from tqdm import tqdm
from text import text_to_sequence
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("-m", "--model_dir", required=True, type=str)
parser.add_argument("-s", "--mels_dir", required=True, type=str)
args = parser.parse_args()
MODEL_DIR = args.model_dir # path to model dir
SAVE_MELS_DIR = args.mels_dir # path to save generated mels
if not os.path.exists(SAVE_MELS_DIR):
os.makedirs(SAVE_MELS_DIR)
hps = utils.get_hparams_from_dir(MODEL_DIR)
symbols = list(hps.data.punc) + list(hps.data.chars)
checkpoint_path = utils.latest_checkpoint_path(MODEL_DIR)
cleaner = hps.data.text_cleaners
model = models.FlowGenerator(
len(symbols) + getattr(hps.data, "add_blank", False),
out_channels=hps.data.n_mel_channels,
**hps.model
).to("cuda")
utils.load_checkpoint(checkpoint_path, model)
model.decoder.store_inverse() # do not calcuate jacobians for fast decoding
_ = model.eval()
def get_mel(text, fpath):
if getattr(hps.data, "add_blank", False):
text_norm = text_to_sequence(text, symbols, cleaner)
text_norm = commons.intersperse(text_norm, len(symbols))
else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality
text = " " + text.strip() + " "
text_norm = text_to_sequence(text, symbols, cleaner)
sequence = np.array(text_norm)[None, :]
x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda()
with torch.no_grad():
noise_scale = 0.667
length_scale = 1.0
(y_gen_tst, *_), *_, (attn_gen, *_) = model(
x_tst,
x_tst_lengths,
gen=True,
noise_scale=noise_scale,
length_scale=length_scale,
)
np.save(os.path.join(SAVE_MELS_DIR, fpath), y_gen_tst.cpu().detach().numpy())
for f in [hps.data.training_files, hps.data.validation_files]:
file_lines = open(f).read().splitlines()
for line in tqdm(file_lines):
fname, text = line.split("|")
fname = os.path.basename(fname).replace(".wav", ".npy")
get_mel(text, fname)