Spaces:
Runtime error
Runtime error
Upload synthesis.py
Browse files- synthesis.py +103 -0
synthesis.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
Synthesis waveform from trained model.
|
4 |
+
|
5 |
+
usage: tts.py [options] <checkpoint> <text_list_file> <dst_dir>
|
6 |
+
|
7 |
+
options:
|
8 |
+
--file-name-suffix=<s> File name suffix [default: ].
|
9 |
+
--max-decoder-steps=<N> Max decoder steps [default: 500].
|
10 |
+
-h, --help Show help message.
|
11 |
+
"""
|
12 |
+
from docopt import docopt
|
13 |
+
|
14 |
+
# Use text & audio modules from existing Tacotron implementation.
|
15 |
+
import sys
|
16 |
+
import os
|
17 |
+
from os.path import dirname, join
|
18 |
+
tacotron_lib_dir = join(dirname(__file__), "lib", "tacotron")
|
19 |
+
sys.path.append(tacotron_lib_dir)
|
20 |
+
from text import text_to_sequence, symbols
|
21 |
+
from util import audio
|
22 |
+
from util.plot import plot_alignment
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch.autograd import Variable
|
26 |
+
import numpy as np
|
27 |
+
import nltk
|
28 |
+
|
29 |
+
from tacotron_pytorch import Tacotron
|
30 |
+
from hparams import hparams
|
31 |
+
|
32 |
+
from tqdm import tqdm
|
33 |
+
|
34 |
+
use_cuda = torch.cuda.is_available()
|
35 |
+
|
36 |
+
|
37 |
+
def tts(model, text):
|
38 |
+
"""Convert text to speech waveform given a Tacotron model.
|
39 |
+
"""
|
40 |
+
if use_cuda:
|
41 |
+
model = model.cuda()
|
42 |
+
# TODO: Turning off dropout of decoder's prenet causes serious performance
|
43 |
+
# regression, not sure why.
|
44 |
+
# model.decoder.eval()
|
45 |
+
model.encoder.eval()
|
46 |
+
model.postnet.eval()
|
47 |
+
|
48 |
+
sequence = np.array(text_to_sequence(text, [hparams.cleaners]))
|
49 |
+
sequence = Variable(torch.from_numpy(sequence)).unsqueeze(0)
|
50 |
+
if use_cuda:
|
51 |
+
sequence = sequence.cuda()
|
52 |
+
|
53 |
+
# Greedy decoding
|
54 |
+
mel_outputs, linear_outputs, alignments = model(sequence)
|
55 |
+
|
56 |
+
linear_output = linear_outputs[0].cpu().data.numpy()
|
57 |
+
spectrogram = audio._denormalize(linear_output)
|
58 |
+
alignment = alignments[0].cpu().data.numpy()
|
59 |
+
|
60 |
+
# Predicted audio signal
|
61 |
+
waveform = audio.inv_spectrogram(linear_output.T)
|
62 |
+
|
63 |
+
return waveform, alignment, spectrogram
|
64 |
+
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
args = docopt(__doc__)
|
68 |
+
print("Command line args:\n", args)
|
69 |
+
checkpoint_path = args["<checkpoint>"]
|
70 |
+
text_list_file_path = args["<text_list_file>"]
|
71 |
+
dst_dir = args["<dst_dir>"]
|
72 |
+
max_decoder_steps = int(args["--max-decoder-steps"])
|
73 |
+
file_name_suffix = args["--file-name-suffix"]
|
74 |
+
|
75 |
+
model = Tacotron(n_vocab=len(symbols),
|
76 |
+
embedding_dim=256,
|
77 |
+
mel_dim=hparams.num_mels,
|
78 |
+
linear_dim=hparams.num_freq,
|
79 |
+
r=hparams.outputs_per_step,
|
80 |
+
padding_idx=hparams.padding_idx,
|
81 |
+
use_memory_mask=hparams.use_memory_mask,
|
82 |
+
)
|
83 |
+
checkpoint = torch.load(checkpoint_path)
|
84 |
+
model.load_state_dict(checkpoint["state_dict"])
|
85 |
+
model.decoder.max_decoder_steps = max_decoder_steps
|
86 |
+
|
87 |
+
os.makedirs(dst_dir, exist_ok=True)
|
88 |
+
|
89 |
+
with open(text_list_file_path, "rb") as f:
|
90 |
+
lines = f.readlines()
|
91 |
+
for idx, line in enumerate(lines):
|
92 |
+
text = line.decode("utf-8")[:-1]
|
93 |
+
words = nltk.word_tokenize(text)
|
94 |
+
print("{}: {} ({} chars, {} words)".format(idx, text, len(text), len(words)))
|
95 |
+
waveform, alignment, _ = tts(model, text)
|
96 |
+
dst_wav_path = join(dst_dir, "{}{}.wav".format(idx, file_name_suffix))
|
97 |
+
dst_alignment_path = join(dst_dir, "{}_alignment.png".format(idx))
|
98 |
+
plot_alignment(alignment.T, dst_alignment_path,
|
99 |
+
info="tacotron, {}".format(checkpoint_path))
|
100 |
+
audio.save_wav(waveform, dst_wav_path)
|
101 |
+
|
102 |
+
print("Finished! Check out {} for generated audio samples.".format(dst_dir))
|
103 |
+
sys.exit(0)
|