Spaces:
Runtime error
Runtime error
Upload synthesizer/train.py with huggingface_hub
Browse files- synthesizer/train.py +269 -0
synthesizer/train.py
ADDED
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import optim
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from synthesizer import audio
|
6 |
+
from synthesizer.models.tacotron import Tacotron
|
7 |
+
from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
|
8 |
+
from synthesizer.utils import ValueWindow, data_parallel_workaround
|
9 |
+
from synthesizer.utils.plot import plot_spectrogram
|
10 |
+
from synthesizer.utils.symbols import symbols
|
11 |
+
from synthesizer.utils.text import sequence_to_text
|
12 |
+
from vocoder.display import *
|
13 |
+
from datetime import datetime
|
14 |
+
import numpy as np
|
15 |
+
from pathlib import Path
|
16 |
+
import sys
|
17 |
+
import time
|
18 |
+
import platform
|
19 |
+
|
20 |
+
|
21 |
+
def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
|
22 |
+
|
23 |
+
def time_string():
|
24 |
+
return datetime.now().strftime("%Y-%m-%d %H:%M")
|
25 |
+
|
26 |
+
def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
|
27 |
+
backup_every: int, force_restart:bool, hparams):
|
28 |
+
|
29 |
+
syn_dir = Path(syn_dir)
|
30 |
+
models_dir = Path(models_dir)
|
31 |
+
models_dir.mkdir(exist_ok=True)
|
32 |
+
|
33 |
+
model_dir = models_dir.joinpath(run_id)
|
34 |
+
plot_dir = model_dir.joinpath("plots")
|
35 |
+
wav_dir = model_dir.joinpath("wavs")
|
36 |
+
mel_output_dir = model_dir.joinpath("mel-spectrograms")
|
37 |
+
meta_folder = model_dir.joinpath("metas")
|
38 |
+
model_dir.mkdir(exist_ok=True)
|
39 |
+
plot_dir.mkdir(exist_ok=True)
|
40 |
+
wav_dir.mkdir(exist_ok=True)
|
41 |
+
mel_output_dir.mkdir(exist_ok=True)
|
42 |
+
meta_folder.mkdir(exist_ok=True)
|
43 |
+
|
44 |
+
weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
|
45 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
46 |
+
|
47 |
+
print("Checkpoint path: {}".format(weights_fpath))
|
48 |
+
print("Loading training data from: {}".format(metadata_fpath))
|
49 |
+
print("Using model: Tacotron")
|
50 |
+
|
51 |
+
# Book keeping
|
52 |
+
step = 0
|
53 |
+
time_window = ValueWindow(100)
|
54 |
+
loss_window = ValueWindow(100)
|
55 |
+
|
56 |
+
|
57 |
+
# From WaveRNN/train_tacotron.py
|
58 |
+
if torch.cuda.is_available():
|
59 |
+
device = torch.device("cuda")
|
60 |
+
|
61 |
+
for session in hparams.tts_schedule:
|
62 |
+
_, _, _, batch_size = session
|
63 |
+
if batch_size % torch.cuda.device_count() != 0:
|
64 |
+
raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
|
65 |
+
else:
|
66 |
+
device = torch.device("cpu")
|
67 |
+
print("Using device:", device)
|
68 |
+
|
69 |
+
# Instantiate Tacotron Model
|
70 |
+
print("\nInitialising Tacotron Model...\n")
|
71 |
+
model = Tacotron(embed_dims=hparams.tts_embed_dims,
|
72 |
+
num_chars=len(symbols),
|
73 |
+
encoder_dims=hparams.tts_encoder_dims,
|
74 |
+
decoder_dims=hparams.tts_decoder_dims,
|
75 |
+
n_mels=hparams.num_mels,
|
76 |
+
fft_bins=hparams.num_mels,
|
77 |
+
postnet_dims=hparams.tts_postnet_dims,
|
78 |
+
encoder_K=hparams.tts_encoder_K,
|
79 |
+
lstm_dims=hparams.tts_lstm_dims,
|
80 |
+
postnet_K=hparams.tts_postnet_K,
|
81 |
+
num_highways=hparams.tts_num_highways,
|
82 |
+
dropout=hparams.tts_dropout,
|
83 |
+
stop_threshold=hparams.tts_stop_threshold,
|
84 |
+
speaker_embedding_size=hparams.speaker_embedding_size).to(device)
|
85 |
+
|
86 |
+
# Initialize the optimizer
|
87 |
+
optimizer = optim.Adam(model.parameters())
|
88 |
+
|
89 |
+
# Load the weights
|
90 |
+
if force_restart or not weights_fpath.exists():
|
91 |
+
print("\nStarting the training of Tacotron from scratch\n")
|
92 |
+
model.save(weights_fpath)
|
93 |
+
|
94 |
+
# Embeddings metadata
|
95 |
+
char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
|
96 |
+
with open(char_embedding_fpath, "w", encoding="utf-8") as f:
|
97 |
+
for symbol in symbols:
|
98 |
+
if symbol == " ":
|
99 |
+
symbol = "\\s" # For visual purposes, swap space with \s
|
100 |
+
|
101 |
+
f.write("{}\n".format(symbol))
|
102 |
+
|
103 |
+
else:
|
104 |
+
print("\nLoading weights at %s" % weights_fpath)
|
105 |
+
model.load(weights_fpath, optimizer)
|
106 |
+
print("Tacotron weights loaded from step %d" % model.step)
|
107 |
+
|
108 |
+
# Initialize the dataset
|
109 |
+
metadata_fpath = syn_dir.joinpath("train.txt")
|
110 |
+
mel_dir = syn_dir.joinpath("mels")
|
111 |
+
embed_dir = syn_dir.joinpath("embeds")
|
112 |
+
dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
|
113 |
+
test_loader = DataLoader(dataset,
|
114 |
+
batch_size=1,
|
115 |
+
shuffle=True,
|
116 |
+
pin_memory=True)
|
117 |
+
|
118 |
+
for i, session in enumerate(hparams.tts_schedule):
|
119 |
+
current_step = model.get_step()
|
120 |
+
|
121 |
+
r, lr, max_step, batch_size = session
|
122 |
+
|
123 |
+
training_steps = max_step - current_step
|
124 |
+
|
125 |
+
# Do we need to change to the next session?
|
126 |
+
if current_step >= max_step:
|
127 |
+
# Are there no further sessions than the current one?
|
128 |
+
if i == len(hparams.tts_schedule) - 1:
|
129 |
+
# We have completed training. Save the model and exit
|
130 |
+
model.save(weights_fpath, optimizer)
|
131 |
+
break
|
132 |
+
else:
|
133 |
+
# There is a following session, go to it
|
134 |
+
continue
|
135 |
+
|
136 |
+
model.r = r
|
137 |
+
|
138 |
+
# Begin the training
|
139 |
+
simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
|
140 |
+
("Batch Size", batch_size),
|
141 |
+
("Learning Rate", lr),
|
142 |
+
("Outputs/Step (r)", model.r)])
|
143 |
+
|
144 |
+
for p in optimizer.param_groups:
|
145 |
+
p["lr"] = lr
|
146 |
+
|
147 |
+
data_loader = DataLoader(dataset,
|
148 |
+
collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
|
149 |
+
batch_size=batch_size,
|
150 |
+
num_workers=2 if platform.system() != "Windows" else 0,
|
151 |
+
shuffle=True,
|
152 |
+
pin_memory=True)
|
153 |
+
|
154 |
+
total_iters = len(dataset)
|
155 |
+
steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
|
156 |
+
epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
|
157 |
+
|
158 |
+
for epoch in range(1, epochs+1):
|
159 |
+
for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
|
160 |
+
start_time = time.time()
|
161 |
+
|
162 |
+
# Generate stop tokens for training
|
163 |
+
stop = torch.ones(mels.shape[0], mels.shape[2])
|
164 |
+
for j, k in enumerate(idx):
|
165 |
+
stop[j, :int(dataset.metadata[k][4])-1] = 0
|
166 |
+
|
167 |
+
texts = texts.to(device)
|
168 |
+
mels = mels.to(device)
|
169 |
+
embeds = embeds.to(device)
|
170 |
+
stop = stop.to(device)
|
171 |
+
|
172 |
+
# Forward pass
|
173 |
+
# Parallelize model onto GPUS using workaround due to python bug
|
174 |
+
if device.type == "cuda" and torch.cuda.device_count() > 1:
|
175 |
+
m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
|
176 |
+
mels, embeds)
|
177 |
+
else:
|
178 |
+
m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
|
179 |
+
|
180 |
+
# Backward pass
|
181 |
+
m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
|
182 |
+
m2_loss = F.mse_loss(m2_hat, mels)
|
183 |
+
stop_loss = F.binary_cross_entropy(stop_pred, stop)
|
184 |
+
|
185 |
+
loss = m1_loss + m2_loss + stop_loss
|
186 |
+
|
187 |
+
optimizer.zero_grad()
|
188 |
+
loss.backward()
|
189 |
+
|
190 |
+
if hparams.tts_clip_grad_norm is not None:
|
191 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
|
192 |
+
if np.isnan(grad_norm.cpu()):
|
193 |
+
print("grad_norm was NaN!")
|
194 |
+
|
195 |
+
optimizer.step()
|
196 |
+
|
197 |
+
time_window.append(time.time() - start_time)
|
198 |
+
loss_window.append(loss.item())
|
199 |
+
|
200 |
+
step = model.get_step()
|
201 |
+
k = step // 1000
|
202 |
+
|
203 |
+
msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
|
204 |
+
stream(msg)
|
205 |
+
|
206 |
+
# Backup or save model as appropriate
|
207 |
+
if backup_every != 0 and step % backup_every == 0 :
|
208 |
+
backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
|
209 |
+
model.save(backup_fpath, optimizer)
|
210 |
+
|
211 |
+
if save_every != 0 and step % save_every == 0 :
|
212 |
+
# Must save latest optimizer state to ensure that resuming training
|
213 |
+
# doesn't produce artifacts
|
214 |
+
model.save(weights_fpath, optimizer)
|
215 |
+
|
216 |
+
# Evaluate model to generate samples
|
217 |
+
epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
|
218 |
+
step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
|
219 |
+
if epoch_eval or step_eval:
|
220 |
+
for sample_idx in range(hparams.tts_eval_num_samples):
|
221 |
+
# At most, generate samples equal to number in the batch
|
222 |
+
if sample_idx + 1 <= len(texts):
|
223 |
+
# Remove padding from mels using frame length in metadata
|
224 |
+
mel_length = int(dataset.metadata[idx[sample_idx]][4])
|
225 |
+
mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
|
226 |
+
target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
|
227 |
+
attention_len = mel_length // model.r
|
228 |
+
|
229 |
+
eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
|
230 |
+
mel_prediction=mel_prediction,
|
231 |
+
target_spectrogram=target_spectrogram,
|
232 |
+
input_seq=np_now(texts[sample_idx]),
|
233 |
+
step=step,
|
234 |
+
plot_dir=plot_dir,
|
235 |
+
mel_output_dir=mel_output_dir,
|
236 |
+
wav_dir=wav_dir,
|
237 |
+
sample_num=sample_idx + 1,
|
238 |
+
loss=loss,
|
239 |
+
hparams=hparams)
|
240 |
+
|
241 |
+
# Break out of loop to update training schedule
|
242 |
+
if step >= max_step:
|
243 |
+
break
|
244 |
+
|
245 |
+
# Add line break after every epoch
|
246 |
+
print("")
|
247 |
+
|
248 |
+
def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
|
249 |
+
plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
|
250 |
+
# Save some results for evaluation
|
251 |
+
attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
|
252 |
+
save_attention(attention, attention_path)
|
253 |
+
|
254 |
+
# save predicted mel spectrogram to disk (debug)
|
255 |
+
mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
|
256 |
+
np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
|
257 |
+
|
258 |
+
# save griffin lim inverted wav for debug (mel -> wav)
|
259 |
+
wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
|
260 |
+
wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
|
261 |
+
audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
|
262 |
+
|
263 |
+
# save real and predicted mel-spectrogram plot to disk (control purposes)
|
264 |
+
spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
|
265 |
+
title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
|
266 |
+
plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
|
267 |
+
target_spectrogram=target_spectrogram,
|
268 |
+
max_len=target_spectrogram.size // hparams.num_mels)
|
269 |
+
print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
|