sudip1310's picture
Upload synthesis.py
0a1979f
# coding: utf-8
"""
Synthesis waveform from trained model.
usage: tts.py [options] <checkpoint> <text_list_file> <dst_dir>
options:
--file-name-suffix=<s> File name suffix [default: ].
--max-decoder-steps=<N> Max decoder steps [default: 500].
-h, --help Show help message.
"""
from docopt import docopt
# Use text & audio modules from existing Tacotron implementation.
import sys
import os
from os.path import dirname, join
tacotron_lib_dir = join(dirname(__file__), "lib", "tacotron")
sys.path.append(tacotron_lib_dir)
from text import text_to_sequence, symbols
from util import audio
from util.plot import plot_alignment
import torch
from torch.autograd import Variable
import numpy as np
import nltk
from tacotron_pytorch import Tacotron
from hparams import hparams
from tqdm import tqdm
use_cuda = torch.cuda.is_available()
def tts(model, text):
"""Convert text to speech waveform given a Tacotron model.
"""
if use_cuda:
model = model.cuda()
# TODO: Turning off dropout of decoder's prenet causes serious performance
# regression, not sure why.
# model.decoder.eval()
model.encoder.eval()
model.postnet.eval()
sequence = np.array(text_to_sequence(text, [hparams.cleaners]))
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
if use_cuda:
sequence = sequence.cuda()
# Greedy decoding
mel_outputs, linear_outputs, alignments = model(sequence)
linear_output = linear_outputs[0].cpu().data.numpy()
spectrogram = audio._denormalize(linear_output)
alignment = alignments[0].cpu().data.numpy()
# Predicted audio signal
waveform = audio.inv_spectrogram(linear_output.T)
return waveform, alignment, spectrogram
if __name__ == "__main__":
args = docopt(__doc__)
print("Command line args:\n", args)
checkpoint_path = args["<checkpoint>"]
text_list_file_path = args["<text_list_file>"]
dst_dir = args["<dst_dir>"]
max_decoder_steps = int(args["--max-decoder-steps"])
file_name_suffix = args["--file-name-suffix"]
model = Tacotron(n_vocab=len(symbols),
embedding_dim=256,
mel_dim=hparams.num_mels,
linear_dim=hparams.num_freq,
r=hparams.outputs_per_step,
padding_idx=hparams.padding_idx,
use_memory_mask=hparams.use_memory_mask,
)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["state_dict"])
model.decoder.max_decoder_steps = max_decoder_steps
os.makedirs(dst_dir, exist_ok=True)
with open(text_list_file_path, "rb") as f:
lines = f.readlines()
for idx, line in enumerate(lines):
text = line.decode("utf-8")[:-1]
words = nltk.word_tokenize(text)
print("{}: {} ({} chars, {} words)".format(idx, text, len(text), len(words)))
waveform, alignment, _ = tts(model, text)
dst_wav_path = join(dst_dir, "{}{}.wav".format(idx, file_name_suffix))
dst_alignment_path = join(dst_dir, "{}_alignment.png".format(idx))
plot_alignment(alignment.T, dst_alignment_path,
info="tacotron, {}".format(checkpoint_path))
audio.save_wav(waveform, dst_wav_path)
print("Finished! Check out {} for generated audio samples.".format(dst_dir))
sys.exit(0)