keithhon commited on
Commit
62bcf42
1 Parent(s): 6904995

Upload synthesizer/train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. synthesizer/train.py +269 -0
synthesizer/train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import optim
4
+ from torch.utils.data import DataLoader
5
+ from synthesizer import audio
6
+ from synthesizer.models.tacotron import Tacotron
7
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
8
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
9
+ from synthesizer.utils.plot import plot_spectrogram
10
+ from synthesizer.utils.symbols import symbols
11
+ from synthesizer.utils.text import sequence_to_text
12
+ from vocoder.display import *
13
+ from datetime import datetime
14
+ import numpy as np
15
+ from pathlib import Path
16
+ import sys
17
+ import time
18
+ import platform
19
+
20
+
21
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
27
+ backup_every: int, force_restart:bool, hparams):
28
+
29
+ syn_dir = Path(syn_dir)
30
+ models_dir = Path(models_dir)
31
+ models_dir.mkdir(exist_ok=True)
32
+
33
+ model_dir = models_dir.joinpath(run_id)
34
+ plot_dir = model_dir.joinpath("plots")
35
+ wav_dir = model_dir.joinpath("wavs")
36
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
37
+ meta_folder = model_dir.joinpath("metas")
38
+ model_dir.mkdir(exist_ok=True)
39
+ plot_dir.mkdir(exist_ok=True)
40
+ wav_dir.mkdir(exist_ok=True)
41
+ mel_output_dir.mkdir(exist_ok=True)
42
+ meta_folder.mkdir(exist_ok=True)
43
+
44
+ weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
45
+ metadata_fpath = syn_dir.joinpath("train.txt")
46
+
47
+ print("Checkpoint path: {}".format(weights_fpath))
48
+ print("Loading training data from: {}".format(metadata_fpath))
49
+ print("Using model: Tacotron")
50
+
51
+ # Book keeping
52
+ step = 0
53
+ time_window = ValueWindow(100)
54
+ loss_window = ValueWindow(100)
55
+
56
+
57
+ # From WaveRNN/train_tacotron.py
58
+ if torch.cuda.is_available():
59
+ device = torch.device("cuda")
60
+
61
+ for session in hparams.tts_schedule:
62
+ _, _, _, batch_size = session
63
+ if batch_size % torch.cuda.device_count() != 0:
64
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
65
+ else:
66
+ device = torch.device("cpu")
67
+ print("Using device:", device)
68
+
69
+ # Instantiate Tacotron Model
70
+ print("\nInitialising Tacotron Model...\n")
71
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
72
+ num_chars=len(symbols),
73
+ encoder_dims=hparams.tts_encoder_dims,
74
+ decoder_dims=hparams.tts_decoder_dims,
75
+ n_mels=hparams.num_mels,
76
+ fft_bins=hparams.num_mels,
77
+ postnet_dims=hparams.tts_postnet_dims,
78
+ encoder_K=hparams.tts_encoder_K,
79
+ lstm_dims=hparams.tts_lstm_dims,
80
+ postnet_K=hparams.tts_postnet_K,
81
+ num_highways=hparams.tts_num_highways,
82
+ dropout=hparams.tts_dropout,
83
+ stop_threshold=hparams.tts_stop_threshold,
84
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
85
+
86
+ # Initialize the optimizer
87
+ optimizer = optim.Adam(model.parameters())
88
+
89
+ # Load the weights
90
+ if force_restart or not weights_fpath.exists():
91
+ print("\nStarting the training of Tacotron from scratch\n")
92
+ model.save(weights_fpath)
93
+
94
+ # Embeddings metadata
95
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
96
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
97
+ for symbol in symbols:
98
+ if symbol == " ":
99
+ symbol = "\\s" # For visual purposes, swap space with \s
100
+
101
+ f.write("{}\n".format(symbol))
102
+
103
+ else:
104
+ print("\nLoading weights at %s" % weights_fpath)
105
+ model.load(weights_fpath, optimizer)
106
+ print("Tacotron weights loaded from step %d" % model.step)
107
+
108
+ # Initialize the dataset
109
+ metadata_fpath = syn_dir.joinpath("train.txt")
110
+ mel_dir = syn_dir.joinpath("mels")
111
+ embed_dir = syn_dir.joinpath("embeds")
112
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
113
+ test_loader = DataLoader(dataset,
114
+ batch_size=1,
115
+ shuffle=True,
116
+ pin_memory=True)
117
+
118
+ for i, session in enumerate(hparams.tts_schedule):
119
+ current_step = model.get_step()
120
+
121
+ r, lr, max_step, batch_size = session
122
+
123
+ training_steps = max_step - current_step
124
+
125
+ # Do we need to change to the next session?
126
+ if current_step >= max_step:
127
+ # Are there no further sessions than the current one?
128
+ if i == len(hparams.tts_schedule) - 1:
129
+ # We have completed training. Save the model and exit
130
+ model.save(weights_fpath, optimizer)
131
+ break
132
+ else:
133
+ # There is a following session, go to it
134
+ continue
135
+
136
+ model.r = r
137
+
138
+ # Begin the training
139
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
140
+ ("Batch Size", batch_size),
141
+ ("Learning Rate", lr),
142
+ ("Outputs/Step (r)", model.r)])
143
+
144
+ for p in optimizer.param_groups:
145
+ p["lr"] = lr
146
+
147
+ data_loader = DataLoader(dataset,
148
+ collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
149
+ batch_size=batch_size,
150
+ num_workers=2 if platform.system() != "Windows" else 0,
151
+ shuffle=True,
152
+ pin_memory=True)
153
+
154
+ total_iters = len(dataset)
155
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
156
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
157
+
158
+ for epoch in range(1, epochs+1):
159
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
160
+ start_time = time.time()
161
+
162
+ # Generate stop tokens for training
163
+ stop = torch.ones(mels.shape[0], mels.shape[2])
164
+ for j, k in enumerate(idx):
165
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
166
+
167
+ texts = texts.to(device)
168
+ mels = mels.to(device)
169
+ embeds = embeds.to(device)
170
+ stop = stop.to(device)
171
+
172
+ # Forward pass
173
+ # Parallelize model onto GPUS using workaround due to python bug
174
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
175
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
176
+ mels, embeds)
177
+ else:
178
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
179
+
180
+ # Backward pass
181
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
182
+ m2_loss = F.mse_loss(m2_hat, mels)
183
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
184
+
185
+ loss = m1_loss + m2_loss + stop_loss
186
+
187
+ optimizer.zero_grad()
188
+ loss.backward()
189
+
190
+ if hparams.tts_clip_grad_norm is not None:
191
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
192
+ if np.isnan(grad_norm.cpu()):
193
+ print("grad_norm was NaN!")
194
+
195
+ optimizer.step()
196
+
197
+ time_window.append(time.time() - start_time)
198
+ loss_window.append(loss.item())
199
+
200
+ step = model.get_step()
201
+ k = step // 1000
202
+
203
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
204
+ stream(msg)
205
+
206
+ # Backup or save model as appropriate
207
+ if backup_every != 0 and step % backup_every == 0 :
208
+ backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
209
+ model.save(backup_fpath, optimizer)
210
+
211
+ if save_every != 0 and step % save_every == 0 :
212
+ # Must save latest optimizer state to ensure that resuming training
213
+ # doesn't produce artifacts
214
+ model.save(weights_fpath, optimizer)
215
+
216
+ # Evaluate model to generate samples
217
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
218
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
219
+ if epoch_eval or step_eval:
220
+ for sample_idx in range(hparams.tts_eval_num_samples):
221
+ # At most, generate samples equal to number in the batch
222
+ if sample_idx + 1 <= len(texts):
223
+ # Remove padding from mels using frame length in metadata
224
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
225
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
226
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
227
+ attention_len = mel_length // model.r
228
+
229
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
230
+ mel_prediction=mel_prediction,
231
+ target_spectrogram=target_spectrogram,
232
+ input_seq=np_now(texts[sample_idx]),
233
+ step=step,
234
+ plot_dir=plot_dir,
235
+ mel_output_dir=mel_output_dir,
236
+ wav_dir=wav_dir,
237
+ sample_num=sample_idx + 1,
238
+ loss=loss,
239
+ hparams=hparams)
240
+
241
+ # Break out of loop to update training schedule
242
+ if step >= max_step:
243
+ break
244
+
245
+ # Add line break after every epoch
246
+ print("")
247
+
248
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
249
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
250
+ # Save some results for evaluation
251
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
252
+ save_attention(attention, attention_path)
253
+
254
+ # save predicted mel spectrogram to disk (debug)
255
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
256
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
257
+
258
+ # save griffin lim inverted wav for debug (mel -> wav)
259
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
260
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
261
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
262
+
263
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
264
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
265
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
266
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
267
+ target_spectrogram=target_spectrogram,
268
+ max_len=target_spectrogram.size // hparams.num_mels)
269
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))