Respair commited on
Commit
34296dd
1 Parent(s): ebb7bde

Upload 2 files

Browse files
Files changed (2) hide show
  1. accelerate_train_second.py +1001 -0
  2. train_first.py +459 -0
accelerate_train_second.py ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import traceback
15
+ import warnings
16
+
17
+ warnings.simplefilter('ignore')
18
+ from autoclip.torch import QuantileClip
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+ from accelerate import Accelerator, DistributedDataParallelKwargs
35
+ from accelerate.utils import tqdm, ProjectConfiguration
36
+
37
+ try:
38
+ import wandb
39
+ except ImportError:
40
+ wandb = None
41
+
42
+ # from Utils.fsdp_patch import replace_fsdp_state_dict_type
43
+
44
+ # replace_fsdp_state_dict_type()
45
+
46
+ import logging
47
+
48
+ from accelerate.logging import get_logger
49
+ from logging import StreamHandler
50
+
51
+ logger = get_logger(__name__)
52
+ logger.setLevel(logging.DEBUG)
53
+ # handler.setLevel(logging.DEBUG)
54
+ # logger.addHandler(handler)
55
+
56
+ @click.command()
57
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
58
+ def main(config_path):
59
+ config = yaml.safe_load(open(config_path))
60
+
61
+ log_dir = config['log_dir']
62
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
63
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
64
+
65
+ # write logs
66
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
67
+ file_handler.setLevel(logging.DEBUG)
68
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
69
+ logger.logger.addHandler(file_handler)
70
+
71
+ batch_size = config.get('batch_size', 10)
72
+
73
+ epochs = config.get('epochs_2nd', 200)
74
+ save_freq = config.get('save_freq', 2)
75
+ save_iter = 10000
76
+ log_interval = 10
77
+ saving_epoch = config.get('save_freq', 2)
78
+
79
+ data_params = config.get('data_params', None)
80
+ sr = config['preprocess_params'].get('sr', 24000)
81
+ hop = config['preprocess_params']["spect_params"].get('hop_length', 300)
82
+ win = config['preprocess_params']["spect_params"].get('win_length', 1200)
83
+ train_path = data_params['train_data']
84
+ val_path = data_params['val_data']
85
+ root_path = data_params['root_path']
86
+ min_length = data_params['min_length']
87
+ OOD_data = data_params['OOD_data']
88
+
89
+ max_len = config.get('max_len', 200)
90
+
91
+ loss_params = Munch(config['loss_params'])
92
+ diff_epoch = loss_params.diff_epoch
93
+ joint_epoch = loss_params.joint_epoch
94
+
95
+ optimizer_params = Munch(config['optimizer_params'])
96
+
97
+ train_list, val_list = get_data_path_list(train_path, val_path)
98
+
99
+ try:
100
+ tracker = 'tensorboard'
101
+ except KeyError:
102
+ tracker = "mlflow"
103
+
104
+ def log_audio(accelerator, audio, bib="", name="Validation", epoch=0, sr=24000, tracker="tensorboard"):
105
+ if tracker == "tensorboard":
106
+ ltracker = accelerator.get_tracker("tensorboard")
107
+ np_aud = np.stack([np.asarray(aud) for aud in audio])
108
+ ltracker.writer.add_audio(f"{name}-{bib}", np_aud, epoch, sample_rate=sr)
109
+ if tracker == "wandb":
110
+ try:
111
+ ltracker = accelerator.get_tracker("wandb")
112
+ ltracker.log(
113
+ {
114
+ "validation": [
115
+ wandb.Audio(audios, caption=f"{name}-{bib}", sample_rate=sr)
116
+ for i, audios in enumerate(audio)
117
+ ]
118
+ }
119
+ , step=int(bib))
120
+ except IndexError:
121
+ pass
122
+
123
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True, broadcast_buffers=False)
124
+ configAcc = ProjectConfiguration(project_dir=log_dir, logging_dir=log_dir)
125
+ accelerator = Accelerator(log_with=tracker,
126
+ project_config=configAcc,
127
+ split_batches=True,
128
+ kwargs_handlers=[ddp_kwargs],
129
+ mixed_precision='bf16')
130
+
131
+ accelerator.init_trackers(project_name="StyleTTS2-Second-Stage",
132
+ config=config if tracker == "wandb" else None)
133
+ HF = config["data_params"].get("HF", False)
134
+ name = config["data_params"].get("split", None)
135
+ split = config["data_params"].get("split", None)
136
+ val_split = config["data_params"].get("val_split", None)
137
+ ood_split = config["data_params"].get("OOD_split", None)
138
+ audcol = config["data_params"].get("audio_column", "speech")
139
+ phoncol = config["data_params"].get("phoneme_column", "phoneme")
140
+ specol = config["data_params"].get("speaker_column", "speaker ID")
141
+
142
+ if not HF:
143
+ train_list, val_list = get_data_path_list(train_path, val_path)
144
+ ds_conf = {"sr": sr, "hop": hop, "win": win}
145
+ vds_conf = {"sr": sr, "hop": hop, "win": win}
146
+ else:
147
+ train_list, val_list = train_path, val_path
148
+ ds_conf = {"sr": sr,
149
+ "hop": hop,
150
+ "split": split,
151
+ "OOD_split": ood_split,
152
+ "dataset_name": name,
153
+ "audio_column": audcol,
154
+ "phoneme_column": phoncol,
155
+ "speaker_id_column": specol,
156
+ "win": win}
157
+ vds_conf = {"sr": sr,
158
+ "hop": hop,
159
+ "split": val_split,
160
+ "OOD_split": ood_split,
161
+ "dataset_name": name,
162
+ "audio_column": audcol,
163
+ "phoneme_column": phoncol,
164
+ "speaker_id_column": specol,
165
+ "win": win}
166
+ device = accelerator.device
167
+
168
+ with accelerator.main_process_first():
169
+ train_dataloader = build_dataloader(train_list,
170
+ root_path,
171
+ OOD_data=OOD_data,
172
+ min_length=min_length,
173
+ batch_size=batch_size,
174
+ num_workers=2,
175
+ dataset_config={},
176
+ device=device)
177
+
178
+ val_dataloader = build_dataloader(val_list,
179
+ root_path,
180
+ OOD_data=OOD_data,
181
+ min_length=min_length,
182
+ batch_size=batch_size,
183
+ validation=True,
184
+ num_workers=0,
185
+ device=device,
186
+ dataset_config={})
187
+
188
+ # load pretrained ASR model
189
+ ASR_config = config.get('ASR_config', False)
190
+ ASR_path = config.get('ASR_path', False)
191
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
192
+
193
+ # load pretrained F0 model
194
+ F0_path = config.get('F0_path', False)
195
+ pitch_extractor = load_F0_models(F0_path)
196
+
197
+ # load PL-BERT model
198
+ BERT_path = config.get('PLBERT_dir', False)
199
+ plbert = load_plbert(BERT_path)
200
+
201
+ # build model
202
+ config['model_params']["sr"] = sr
203
+
204
+ model_params = recursive_munch(config['model_params'])
205
+ multispeaker = model_params.multispeaker
206
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
207
+ _ = [model[key].to(device) for key in model]
208
+
209
+ # # # DP
210
+ # for key in model:
211
+ # if key != "mpd" and key != "msd" and key != "wd":
212
+ # model[key] = accelerator.prepare(model[key])
213
+
214
+
215
+ # for k in model:
216
+ # model[k] = nn.SyncBatchNorm.convert_sync_batchnorm(model[k])
217
+
218
+ for k in model:
219
+ model[k] = accelerator.prepare(model[k])
220
+
221
+ start_epoch = 0
222
+ iters = 0
223
+
224
+ load_pretrained = config.get('pretrained_model', '') != '' and config.get('second_stage_load_pretrained', False)
225
+
226
+ if not load_pretrained:
227
+ if config.get('first_stage_path', '') != '':
228
+ first_stage_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
229
+ accelerator.print('Loading the first stage model at %s ...' % first_stage_path)
230
+ model, _, start_epoch, iters = load_checkpoint(model,
231
+ None,
232
+ first_stage_path,
233
+ load_only_params=True,
234
+ ignore_modules=['bert', 'bert_encoder', 'predictor',
235
+ 'predictor_encoder', 'msd', 'mpd', 'wd',
236
+ 'diffusion']) # keep starting epoch for tensorboard log
237
+
238
+ # these epochs should be counted from the start epoch
239
+ diff_epoch += start_epoch
240
+ joint_epoch += start_epoch
241
+ epochs += start_epoch
242
+ model.style_encoder.train()
243
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
244
+ else:
245
+ raise ValueError('You need to specify the path to the first stage model.')
246
+
247
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
248
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
249
+ wl = WavLMLoss(model_params.slm.model,
250
+ model.wd,
251
+ sr,
252
+ model_params.slm.sr).to(device)
253
+
254
+ gl = accelerator.prepare(gl)
255
+ dl = accelerator.prepare(dl)
256
+ wl = accelerator.prepare(wl)
257
+ wl = wl.eval()
258
+
259
+ sampler = DiffusionSampler(
260
+ model.diffusion.module.diffusion,
261
+ sampler=ADPM2Sampler(),
262
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
263
+ clamp=False
264
+ )
265
+
266
+ scheduler_params = {
267
+ "max_lr": optimizer_params.lr * accelerator.num_processes,
268
+ "pct_start": float(0),
269
+ "epochs": epochs,
270
+ "steps_per_epoch": len(train_dataloader),
271
+ }
272
+ scheduler_params_dict = {key: scheduler_params.copy() for key in model}
273
+ scheduler_params_dict['bert']['max_lr'] = optimizer_params.bert_lr * 2
274
+ scheduler_params_dict['decoder']['max_lr'] = optimizer_params.ft_lr * 2
275
+ scheduler_params_dict['style_encoder']['max_lr'] = optimizer_params.ft_lr * 2
276
+
277
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
278
+ scheduler_params_dict=scheduler_params_dict,
279
+ lr=optimizer_params.lr * accelerator.num_processes)
280
+
281
+ # adjust BERT learning rate
282
+ for g in optimizer.optimizers['bert'].param_groups:
283
+ g['betas'] = (0.9, 0.99)
284
+ g['lr'] = optimizer_params.bert_lr
285
+ g['initial_lr'] = optimizer_params.bert_lr
286
+ g['min_lr'] = 0
287
+ g['weight_decay'] = 0.01
288
+
289
+ # adjust acoustic module learning rate
290
+ for module in ["decoder", "style_encoder"]:
291
+ for g in optimizer.optimizers[module].param_groups:
292
+ g['betas'] = (0.0, 0.99)
293
+ g['lr'] = optimizer_params.ft_lr
294
+ g['initial_lr'] = optimizer_params.ft_lr
295
+ g['min_lr'] = 0
296
+ g['weight_decay'] = 1e-4
297
+
298
+ # load models if there is a model
299
+ if load_pretrained:
300
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
301
+ load_only_params=config.get('load_only_params', True))
302
+
303
+ n_down = model.text_aligner.module.n_down
304
+
305
+ # for k in model:
306
+ # model[k] = accelerator.prepare(model[k])
307
+
308
+ best_loss = float('inf') # best test loss
309
+ iters = 0
310
+
311
+ criterion = nn.L1Loss() # F0 loss (regression)
312
+ torch.cuda.empty_cache()
313
+
314
+ stft_loss = MultiResolutionSTFTLoss().to(device)
315
+
316
+ accelerator.print('BERT', optimizer.optimizers['bert'])
317
+ accelerator.print('decoder', optimizer.optimizers['decoder'])
318
+
319
+ start_ds = False
320
+
321
+ running_std = []
322
+
323
+ slmadv_params = Munch(config['slmadv_params'])
324
+
325
+ slmadv = SLMAdversarialLoss(model, wl, sampler,
326
+ slmadv_params.min_len,
327
+ slmadv_params.max_len,
328
+ batch_percentage=slmadv_params.batch_percentage,
329
+ skip_update=slmadv_params.iter,
330
+ sig=slmadv_params.sig
331
+ )
332
+
333
+ for k, v in optimizer.optimizers.items():
334
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
335
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
336
+
337
+ train_dataloader = accelerator.prepare(train_dataloader)
338
+
339
+ for epoch in range(start_epoch, epochs):
340
+ running_loss = 0
341
+ start_time = time.time()
342
+
343
+ _ = [model[key].eval() for key in model]
344
+
345
+ model.text_aligner.train()
346
+ model.text_encoder.train()
347
+
348
+ model.predictor.train()
349
+ model.predictor_encoder.train()
350
+ model.bert_encoder.train()
351
+ model.bert.train()
352
+ model.msd.train()
353
+ model.mpd.train()
354
+ model.wd.train()
355
+
356
+ if epoch >= diff_epoch:
357
+ start_ds = True
358
+
359
+ for i, batch in enumerate(train_dataloader):
360
+ waves = batch[0]
361
+ batch = [b.to(device) for b in batch[1:]]
362
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
363
+
364
+ with torch.no_grad():
365
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
366
+ mel_mask = length_to_mask(mel_input_length).to(device)
367
+ text_mask = length_to_mask(input_lengths).to(texts.device)
368
+
369
+ try:
370
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
371
+ s2s_attn = s2s_attn.transpose(-1, -2)
372
+ s2s_attn = s2s_attn[..., 1:]
373
+ s2s_attn = s2s_attn.transpose(-1, -2)
374
+ except:
375
+ continue
376
+
377
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
378
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
379
+
380
+ # encode
381
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
382
+ asr = (t_en @ s2s_attn_mono)
383
+
384
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
385
+
386
+ # compute reference styles
387
+ if multispeaker and epoch >= diff_epoch:
388
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
389
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
390
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
391
+
392
+ # compute the style of the entire utterance
393
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
394
+ ss = []
395
+ gs = []
396
+ for bib in range(len(mel_input_length)):
397
+ mel_length = int(mel_input_length[bib].item())
398
+ mel = mels[bib, :, :mel_input_length[bib]]
399
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
400
+ ss.append(s)
401
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
402
+ gs.append(s)
403
+
404
+ s_dur = torch.stack(ss).squeeze(1) # global prosodic styles
405
+ gs = torch.stack(gs).squeeze(1) # global acoustic styles
406
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
407
+
408
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
409
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
410
+
411
+ # denoiser training
412
+ if epoch >= diff_epoch:
413
+ num_steps = np.random.randint(3, 5)
414
+
415
+ if model_params.diffusion.dist.estimate_sigma_data:
416
+ model.diffusion.module.diffusion.sigma_data = s_trg.std(
417
+ axis=-1).mean().item() # batch-wise std estimation
418
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
419
+
420
+ if multispeaker:
421
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
422
+ embedding=bert_dur,
423
+ embedding_scale=1,
424
+ features=ref, # reference from the same speaker as the embedding
425
+ embedding_mask_proba=0.1,
426
+ num_steps=num_steps).squeeze(1)
427
+ loss_diff = model.diffusion(s_trg.unsqueeze(1), embedding=bert_dur, features=ref).mean() # EDM loss
428
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
429
+ else:
430
+ s_preds = sampler(noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
431
+ embedding=bert_dur,
432
+ embedding_scale=1,
433
+ embedding_mask_proba=0.1,
434
+ num_steps=num_steps).squeeze(1)
435
+ loss_diff = model.diffusion.module.diffusion(s_trg.unsqueeze(1),
436
+ embedding=bert_dur).mean() # EDM loss
437
+ loss_sty = F.l1_loss(s_preds, s_trg.detach()) # style reconstruction loss
438
+ # print(loss_sty)
439
+ else:
440
+ # print("here")
441
+ loss_sty = 0
442
+ loss_diff = 0
443
+
444
+ d, p = model.predictor(d_en, s_dur,
445
+ input_lengths,
446
+ s2s_attn_mono,
447
+ text_mask)
448
+
449
+ # mel_len = int(mel_input_length.min().item() / 2 - 1)
450
+
451
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
452
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
453
+
454
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
455
+ en = []
456
+ gt = []
457
+ st = []
458
+ p_en = []
459
+ wav = []
460
+
461
+ for bib in range(len(mel_input_length)):
462
+ mel_length = int(mel_input_length[bib].item() / 2)
463
+
464
+ random_start = np.random.randint(0, mel_length - mel_len)
465
+ en.append(asr[bib, :, random_start:random_start + mel_len])
466
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
467
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
468
+
469
+ y = waves[bib][(random_start * 2) * 300:((random_start + mel_len) * 2) * 300]
470
+ wav.append(torch.from_numpy(y).to(device))
471
+
472
+ # style reference (better to be different from the GT)
473
+ random_start = np.random.randint(0, mel_length - mel_len_st)
474
+ st.append(mels[bib, :, (random_start * 2):((random_start + mel_len_st) * 2)])
475
+
476
+ wav = torch.stack(wav).float().detach()
477
+
478
+ en = torch.stack(en)
479
+ p_en = torch.stack(p_en)
480
+ gt = torch.stack(gt).detach()
481
+ st = torch.stack(st).detach()
482
+
483
+ if gt.size(-1) < 80:
484
+ continue
485
+
486
+ s_dur = model.predictor_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
487
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
488
+
489
+ with torch.no_grad():
490
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
491
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2])
492
+
493
+ asr_real = model.text_aligner.module.get_feature(gt)
494
+
495
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
496
+
497
+ y_rec_gt = wav.unsqueeze(1)
498
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
499
+
500
+ if epoch >= joint_epoch:
501
+ # ground truth from recording
502
+ wav = y_rec_gt # use recording since decoder is tuned
503
+ else:
504
+ # ground truth from reconstruction
505
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
506
+
507
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
508
+
509
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
510
+
511
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
512
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
513
+
514
+ if start_ds:
515
+ optimizer.zero_grad()
516
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
517
+ accelerator.backward(d_loss)
518
+ optimizer.step('msd')
519
+ optimizer.step('mpd')
520
+ else:
521
+ d_loss = 0
522
+
523
+ # generator loss
524
+ optimizer.zero_grad()
525
+
526
+ loss_mel = stft_loss(y_rec, wav)
527
+ if start_ds:
528
+ loss_gen_all = gl(wav, y_rec).mean()
529
+ else:
530
+ loss_gen_all = 0
531
+ loss_lm = wl(wav.detach().squeeze(1), y_rec.squeeze(1)).mean()
532
+
533
+ loss_ce = 0
534
+ loss_dur = 0
535
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
536
+ _s2s_pred = _s2s_pred[:_text_length, :]
537
+ _text_input = _text_input[:_text_length].long()
538
+ _s2s_trg = torch.zeros_like(_s2s_pred)
539
+ for p in range(_s2s_trg.shape[0]):
540
+ _s2s_trg[p, :_text_input[p]] = 1
541
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
542
+
543
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
544
+ _text_input[1:_text_length - 1])
545
+ loss_ce += F.binary_cross_entropy_with_logits(_s2s_pred.flatten(), _s2s_trg.flatten())
546
+
547
+ loss_ce /= texts.size(0)
548
+ loss_dur /= texts.size(0)
549
+
550
+ g_loss = loss_params.lambda_mel * loss_mel + \
551
+ loss_params.lambda_F0 * loss_F0_rec + \
552
+ loss_params.lambda_ce * loss_ce + \
553
+ loss_params.lambda_norm * loss_norm_rec + \
554
+ loss_params.lambda_dur * loss_dur + \
555
+ loss_params.lambda_gen * loss_gen_all + \
556
+ loss_params.lambda_slm * loss_lm + \
557
+ loss_params.lambda_sty * loss_sty + \
558
+ loss_params.lambda_diff * loss_diff
559
+
560
+ running_loss += accelerator.gather(loss_mel).mean().item()
561
+ accelerator.backward(g_loss)
562
+
563
+
564
+
565
+ # clipper_bert_enc = QuantileClip(model.bert_encoder.parameters(), quantile=0.9, history_length=1000) # Adaptive clipping of gradient
566
+ # clipper_bert = QuantileClip(model.bert.parameters(), quantile=0.9, history_length=1000)
567
+ # clipper_pred = QuantileClip(model.predictor.parameters(), quantile=0.9, history_length=1000)
568
+ # clipper_pred_enc = QuantileClip(model.predictor_encoder.parameters(), quantile=0.9, history_length=1000)
569
+
570
+ # accelerator.clip_grad_norm_(model.bert_encoder.parameters(), max_norm=2.0)
571
+ # accelerator.clip_grad_norm_(model.bert.parameters(), max_norm=2.0)
572
+ # accelerator.clip_grad_norm_(model.predictor.parameters(), max_norm=2.0)
573
+ # accelerator.clip_grad_norm_(model.predictor_encoder.parameters(), max_norm=2.0)
574
+
575
+ # if iters % 10 == 0: # Monitor every 10 steps
576
+ # components = ['bert_encoder', 'bert', 'predictor', 'predictor_encoder']
577
+ # if epoch >= diff_epoch:
578
+ # components.append('diffusion')
579
+
580
+ # for key in components:
581
+ # if key in model:
582
+ # grad_norm = accelerator.clip_grad_norm_(model[key].parameters(), float('inf'))
583
+ # accelerator.print(f"key: {key} grad norm: {grad_norm:.4f}")
584
+
585
+
586
+ # if torch.isnan(g_loss):
587
+ # from IPython.core.debugger import set_trace
588
+ # set_trace()
589
+
590
+
591
+
592
+ # clipper_bert_enc.step()
593
+ # clipper_bert.step()
594
+ # clipper_pred.step()
595
+ # clipper_pred_enc.step()
596
+
597
+ optimizer.step('bert_encoder')
598
+ optimizer.step('bert')
599
+ optimizer.step('predictor')
600
+ optimizer.step('predictor_encoder')
601
+
602
+ if epoch >= diff_epoch:
603
+ # accelerator.clip_grad_norm_(model.diffusion.parameters(), max_norm=1.0)
604
+ optimizer.step('diffusion')
605
+
606
+ if epoch >= joint_epoch:
607
+
608
+ optimizer.step('style_encoder')
609
+ optimizer.step('decoder')
610
+
611
+ d_loss_slm, loss_gen_lm = 0, 0
612
+
613
+ # # randomly pick whether to use in-distribution text
614
+ # if np.random.rand() < 0.5:
615
+ # use_ind = True
616
+ # else:
617
+ # use_ind = False
618
+
619
+ # if use_ind:
620
+ # ref_lengths = input_lengths
621
+ # ref_texts = texts
622
+
623
+ # slm_out = slmadv(i,
624
+ # y_rec_gt,
625
+ # y_rec_gt_pred,
626
+ # waves,
627
+ # mel_input_length,
628
+ # ref_texts,
629
+ # ref_lengths, use_ind, s_trg.detach(), ref if multispeaker else None)
630
+
631
+
632
+
633
+ # if slm_out is None:
634
+ # continue
635
+ # # if slm_out is not None:
636
+ # # d_loss_slm, loss_gen_lm, y_pred = slm_out
637
+ # # optimizer.zero_grad()
638
+ # # # accelerator.clip_grad_norm_(model.decoder.parameters(), 1)
639
+ # # # print("here")
640
+ # # accelerator.backward(loss_gen_lm)
641
+ # # # print("here2")
642
+ # # # SLM discriminator loss
643
+
644
+
645
+ # # # compute the gradient norm
646
+
647
+
648
+ # # total_norm = {}
649
+ # # for key in model.keys():
650
+ # # total_norm[key] = 0
651
+ # # parameters = [p for p in model[key].parameters() if p.grad is not None and p.requires_grad]
652
+ # # for p in parameters:
653
+ # # param_norm = p.grad.detach().data.norm(2)
654
+ # # total_norm[key] += param_norm.item() ** 2
655
+ # # total_norm[key] = total_norm[key] ** 0.5
656
+
657
+ # # # gradient scaling
658
+ # # if total_norm['predictor'] > slmadv_params.thresh:
659
+ # # for key in model.keys():
660
+ # # for p in model[key].parameters():
661
+ # # if p.grad is not None:
662
+ # # p.grad *= (1 / total_norm['predictor'])
663
+
664
+ # # for p in model.predictor.module.duration_proj.parameters():
665
+ # # if p.grad is not None:
666
+ # # p.grad *= slmadv_params.scale
667
+
668
+ # # for p in model.predictor.module.lstm.parameters():
669
+ # # if p.grad is not None:
670
+ # # p.grad *= slmadv_params.scale
671
+
672
+ # # for p in model.diffusion.module.parameters():
673
+ # # if p.grad is not None:
674
+ # # p.grad *= slmadv_params.scale
675
+
676
+ # # optimizer.step('bert_encoder')
677
+ # # optimizer.step('bert')
678
+ # # optimizer.step('predictor')
679
+ # # optimizer.step('diffusion')
680
+
681
+ # # # SLM discriminator loss
682
+ # # if d_loss_slm != 0:
683
+ # # optimizer.zero_grad()
684
+ # # # print("hey1")
685
+ # # accelerator.backward(d_loss_slm, retain_graph=True)
686
+ # # optimizer.step('wd')
687
+ # # # print("hey2")
688
+
689
+ else:
690
+ d_loss_slm, loss_gen_lm = 0, 0
691
+
692
+ iters = iters + 1
693
+ if (i + 1) % log_interval == 0:
694
+ logger.info(
695
+ 'Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
696
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
697
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
698
+ d_loss_slm, loss_gen_lm), main_process_only=True)
699
+ if accelerator.is_main_process:
700
+ print('Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f'
701
+ % (epoch + 1, epochs, i + 1, len(train_list) // batch_size, running_loss / log_interval, d_loss,
702
+ loss_dur, loss_ce, loss_norm_rec, loss_F0_rec, loss_lm, loss_gen_all, loss_sty, loss_diff,
703
+ d_loss_slm, loss_gen_lm))
704
+ accelerator.log({'train/mel_loss': float(running_loss / log_interval),
705
+ 'train/gen_loss': float(loss_gen_all),
706
+ 'train/d_loss': float(d_loss),
707
+ 'train/ce_loss': float(loss_ce),
708
+ 'train/dur_loss': float(loss_dur),
709
+ 'train/slm_loss': float(loss_lm),
710
+ 'train/norm_loss': float(loss_norm_rec),
711
+ 'train/F0_loss': float(loss_F0_rec),
712
+ 'train/sty_loss': float(loss_sty),
713
+ 'train/diff_loss': float(loss_diff),
714
+ 'train/d_loss_slm': float(d_loss_slm),
715
+ 'train/gen_loss_slm': float(loss_gen_lm),
716
+ 'epoch': int(epoch) + 1}, step=iters)
717
+
718
+ running_loss = 0
719
+
720
+ accelerator.print('Time elasped:', time.time() - start_time)
721
+
722
+ loss_test = 0
723
+ loss_align = 0
724
+ loss_f = 0
725
+
726
+ _ = [model[key].eval() for key in model]
727
+
728
+ with torch.no_grad():
729
+ iters_test = 0
730
+ for batch_idx, batch in enumerate(val_dataloader):
731
+ optimizer.zero_grad()
732
+
733
+ try:
734
+ waves = batch[0]
735
+ batch = [b.to(device) for b in batch[1:]]
736
+ texts, input_lengths, ref_texts, ref_lengths, mels, mel_input_length, ref_mels = batch
737
+ with torch.no_grad():
738
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to(device)
739
+ text_mask = length_to_mask(input_lengths).to(texts.device)
740
+
741
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
742
+ s2s_attn = s2s_attn.transpose(-1, -2)
743
+ s2s_attn = s2s_attn[..., 1:]
744
+ s2s_attn = s2s_attn.transpose(-1, -2)
745
+
746
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
747
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
748
+
749
+ # encode
750
+ # print("t_en", t_en.shape, t_en)
751
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
752
+ asr = (t_en @ s2s_attn_mono)
753
+
754
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
755
+
756
+ ss = []
757
+ gs = []
758
+
759
+ for bib in range(len(mel_input_length)):
760
+ mel_length = int(mel_input_length[bib].item())
761
+ mel = mels[bib, :, :mel_input_length[bib]]
762
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
763
+ ss.append(s)
764
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
765
+ gs.append(s)
766
+
767
+ s = torch.stack(ss).squeeze(1)
768
+ gs = torch.stack(gs).squeeze(1)
769
+ s_trg = torch.cat([s, gs], dim=-1).detach()
770
+ # print("texts", texts.shape, texts)
771
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
772
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
773
+ d, p = model.predictor(d_en, s,
774
+ input_lengths,
775
+ s2s_attn_mono,
776
+ text_mask)
777
+ # get clips
778
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
779
+ en = []
780
+ gt = []
781
+ p_en = []
782
+ wav = []
783
+
784
+ for bib in range(len(mel_input_length)):
785
+ mel_length = int(mel_input_length[bib].item() / 2)
786
+
787
+ random_start = np.random.randint(0, mel_length - mel_len)
788
+ en.append(asr[bib, :, random_start:random_start + mel_len])
789
+ p_en.append(p[bib, :, random_start:random_start + mel_len])
790
+
791
+ gt.append(mels[bib, :, (random_start * 2):((random_start + mel_len) * 2)])
792
+
793
+ y = waves[bib][(random_start * 2) * 300:((random_start + mel_len) * 2) * 300]
794
+ wav.append(torch.from_numpy(y).to(device))
795
+
796
+ wav = torch.stack(wav).float().detach()
797
+
798
+ en = torch.stack(en)
799
+ p_en = torch.stack(p_en)
800
+ gt = torch.stack(gt).detach()
801
+
802
+ s = model.predictor_encoder(gt.unsqueeze(1))
803
+
804
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s, f0=True)
805
+
806
+ loss_dur = 0
807
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
808
+ _s2s_pred = _s2s_pred[:_text_length, :]
809
+ _text_input = _text_input[:_text_length].long()
810
+ _s2s_trg = torch.zeros_like(_s2s_pred)
811
+ for bib in range(_s2s_trg.shape[0]):
812
+ _s2s_trg[bib, :_text_input[bib]] = 1
813
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
814
+ loss_dur += F.l1_loss(_dur_pred[1:_text_length - 1],
815
+ _text_input[1:_text_length - 1])
816
+
817
+ loss_dur /= texts.size(0)
818
+
819
+ s = model.style_encoder(gt.unsqueeze(1))
820
+
821
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
822
+ loss_mel = stft_loss(y_rec.squeeze(1), wav.detach())
823
+
824
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
825
+
826
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
827
+
828
+ loss_test += accelerator.gather(loss_mel).mean()
829
+ loss_align += accelerator.gather(loss_dur).mean()
830
+ loss_f += accelerator.gather(loss_F0).mean()
831
+
832
+ iters_test += 1
833
+ except Exception as e:
834
+ accelerator.print(f"Eval errored with: \n {str(e)}")
835
+ continue
836
+
837
+ accelerator.print('Epochs:', epoch + 1)
838
+ try:
839
+ logger.info('Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f' % (
840
+ loss_test / iters_test, loss_align / iters_test, loss_f / iters_test) + '\n', main_process_only=True)
841
+
842
+
843
+ accelerator.log({'eval/mel_loss': float(loss_test / iters_test),
844
+ 'eval/dur_loss': float(loss_test / iters_test),
845
+ 'eval/F0_loss': float(loss_f / iters_test)},
846
+ step=(i + 1) * (epoch + 1))
847
+ except ZeroDivisionError:
848
+ accelerator.print("Eval loss was divided by zero... skipping eval cycle")
849
+
850
+ if epoch < diff_epoch:
851
+ # generating reconstruction examples with GT duration
852
+
853
+ with torch.no_grad():
854
+ for bib in range(len(asr)):
855
+ mel_length = int(mel_input_length[bib].item())
856
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
857
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
858
+
859
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
860
+ F0_real = F0_real.unsqueeze(0)
861
+ s = model.style_encoder(gt.unsqueeze(1))
862
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
863
+
864
+ try:
865
+ y_rec = model.decoder(en, F0_real.squeeze(0), real_norm, s)
866
+ except Exception as e:
867
+ accelerator.print(str(e))
868
+ accelerator.print(F0_real.size())
869
+ accelerator.print(F0_real.squeeze(0).size())
870
+
871
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
872
+ p_en = p[bib, :, :mel_length // 2].unsqueeze(0)
873
+ F0_fake, N_fake = model.predictor(texts=p_en, style=s_dur, f0=True)
874
+
875
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
876
+
877
+ # writer.add_audio('pred/y' + str(bib), y_pred.cpu().numpy().squeeze(), epoch, sample_rate=sr)
878
+ if accelerator.is_main_process:
879
+ log_audio(accelerator, y_pred.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
880
+
881
+ if epoch == 0:
882
+ # writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
883
+ if accelerator.is_main_process:
884
+ log_audio(accelerator, waves[bib].squeeze(), bib, "gt/y", epoch, sr, tracker=tracker)
885
+
886
+ if bib >= 10:
887
+ break
888
+ else:
889
+
890
+ try:
891
+ # generating sampled speech from text directly
892
+ with torch.no_grad():
893
+ # compute reference styles
894
+ if multispeaker and epoch >= diff_epoch:
895
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
896
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
897
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
898
+
899
+ for bib in range(len(d_en)):
900
+ if multispeaker:
901
+ s_pred = sampler(noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
902
+ embedding=bert_dur[bib].unsqueeze(0),
903
+ embedding_scale=1,
904
+ features=ref_s[bib].unsqueeze(0),
905
+ # reference from the same speaker as the embedding
906
+ num_steps=5).squeeze(1)
907
+ else:
908
+ s_pred = sampler(noise=torch.ones((1, 1, 256)).to(texts.device)*0.5,
909
+ embedding=bert_dur[bib].unsqueeze(0),
910
+ embedding_scale=1,
911
+ num_steps=5).squeeze(1)
912
+
913
+ s = s_pred[:, 128:]
914
+ ref = s_pred[:, :128]
915
+ # print(model.predictor)
916
+ # print(d_en[bib, :, :input_lengths[bib]])
917
+ d = model.predictor.module.text_encoder(d_en[bib, :, :input_lengths[bib]].unsqueeze(0),
918
+ s, input_lengths[bib, ...].unsqueeze(0),
919
+ text_mask[bib, :input_lengths[bib]].unsqueeze(0))
920
+
921
+ x = model.predictor.module.lstm(d)
922
+ x_mod = model.predictor.module.prepare_projection(x) # 640 -> 512
923
+ duration = model.predictor.module.duration_proj(x_mod)
924
+
925
+ duration = torch.sigmoid(duration).sum(axis=-1)
926
+ pred_dur = torch.round(duration.squeeze(0)).clamp(min=1)
927
+
928
+ pred_dur[-1] += 5
929
+
930
+ pred_aln_trg = torch.zeros(input_lengths[bib], int(pred_dur.sum().data))
931
+ c_frame = 0
932
+ for i in range(pred_aln_trg.size(0)):
933
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
934
+ c_frame += int(pred_dur[i].data)
935
+
936
+ # encode prosody
937
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(texts.device))
938
+ F0_pred, N_pred = model.predictor(texts=en, style=s, f0=True)
939
+ out = model.decoder(
940
+ (t_en[bib, :, :input_lengths[bib]].unsqueeze(0) @ pred_aln_trg.unsqueeze(0).to(texts.device)),
941
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
942
+
943
+ # writer.add_audio('pred/y' + str(bib), out.cpu().numpy().squeeze(), epoch, sample_rate=sr)
944
+ if accelerator.is_main_process:
945
+ log_audio(accelerator, out.detach().cpu().numpy().squeeze(), bib, "pred/y", epoch, sr, tracker=tracker)
946
+
947
+ if bib >= 5:
948
+ break
949
+ except Exception as e:
950
+ accelerator.print('error -> ', e)
951
+ accelerator.print("some of the samples couldn't be evaluated, skipping those.")
952
+
953
+ if epoch % saving_epoch == 0:
954
+ if (loss_test / iters_test) < best_loss:
955
+ best_loss = loss_test / iters_test
956
+ try:
957
+ accelerator.print('Saving..')
958
+ state = {
959
+ 'net': {key: model[key].state_dict() for key in model},
960
+ 'optimizer': optimizer.state_dict(),
961
+ 'iters': iters,
962
+ 'val_loss': loss_test / iters_test,
963
+ 'epoch': epoch,
964
+ }
965
+ except ZeroDivisionError:
966
+ accelerator.print('No iter test, Re-Saving..')
967
+ state = {
968
+ 'net': {key: model[key].state_dict() for key in model},
969
+ 'optimizer': optimizer.state_dict(),
970
+ 'iters': iters,
971
+ 'val_loss': 0.1, # not zero just in case
972
+ 'epoch': epoch,
973
+ }
974
+
975
+ if accelerator.is_main_process:
976
+ save_path = osp.join(log_dir, 'epoch_2nd_%05d.pth' % epoch)
977
+ torch.save(state, save_path)
978
+
979
+ # if estimate sigma, save the estimated simga
980
+ if model_params.diffusion.dist.estimate_sigma_data:
981
+ config['model_params']['diffusion']['dist']['sigma_data'] = float(np.mean(running_std))
982
+
983
+ with open(osp.join(log_dir, osp.basename(config_path)), 'w') as outfile:
984
+ yaml.dump(config, outfile, default_flow_style=True)
985
+ if accelerator.is_main_process:
986
+ print('Saving last pth..')
987
+ state = {
988
+ 'net': {key: model[key].state_dict() for key in model},
989
+ 'optimizer': optimizer.state_dict(),
990
+ 'iters': iters,
991
+ 'val_loss': loss_test / iters_test,
992
+ 'epoch': epoch,
993
+ }
994
+ save_path = osp.join(log_dir, '2nd_phase_last.pth')
995
+ torch.save(state, save_path)
996
+
997
+ accelerator.end_training()
998
+
999
+
1000
+ if __name__ == "__main__":
1001
+ main()
train_first.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import re
4
+ import sys
5
+ import yaml
6
+ import shutil
7
+ import numpy as np
8
+ import torch
9
+ import click
10
+ import warnings
11
+ warnings.simplefilter('ignore')
12
+
13
+ # load packages
14
+ import random
15
+ import yaml
16
+ from munch import Munch
17
+ import numpy as np
18
+ import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
21
+ import torchaudio
22
+ import librosa
23
+
24
+ from models import *
25
+ from meldataset import build_dataloader
26
+ from utils import *
27
+ from losses import *
28
+ from optimizers import build_optimizer
29
+ import time
30
+
31
+ from accelerate import Accelerator
32
+ from accelerate.utils import LoggerType
33
+ from accelerate import DistributedDataParallelKwargs
34
+
35
+ from torch.utils.tensorboard import SummaryWriter
36
+
37
+ import logging
38
+ from accelerate.logging import get_logger
39
+ logger = get_logger(__name__, log_level="DEBUG")
40
+
41
+ @click.command()
42
+ @click.option('-p', '--config_path', default='Configs/config.yml', type=str)
43
+ def main(config_path):
44
+ config = yaml.safe_load(open(config_path))
45
+
46
+ save_iter = 10500
47
+
48
+ log_dir = config['log_dir']
49
+ if not osp.exists(log_dir): os.makedirs(log_dir, exist_ok=True)
50
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
51
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
52
+ accelerator = Accelerator(project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs], mixed_precision='bf16')
53
+ if accelerator.is_main_process:
54
+ writer = SummaryWriter(log_dir + "/tensorboard")
55
+
56
+ # write logs
57
+ file_handler = logging.FileHandler(osp.join(log_dir, 'train.log'))
58
+ file_handler.setLevel(logging.DEBUG)
59
+ file_handler.setFormatter(logging.Formatter('%(levelname)s:%(asctime)s: %(message)s'))
60
+ logger.logger.addHandler(file_handler)
61
+
62
+ batch_size = config.get('batch_size', 10)
63
+ device = accelerator.device
64
+
65
+ epochs = config.get('epochs_1st', 200)
66
+ save_freq = config.get('save_freq', 2)
67
+ log_interval = config.get('log_interval', 10)
68
+ saving_epoch = config.get('save_freq', 2)
69
+
70
+ data_params = config.get('data_params', None)
71
+ sr = config['preprocess_params'].get('sr', 24000)
72
+ train_path = data_params['train_data']
73
+ val_path = data_params['val_data']
74
+ root_path = data_params['root_path']
75
+ min_length = data_params['min_length']
76
+ OOD_data = data_params['OOD_data']
77
+
78
+ max_len = config.get('max_len', 200)
79
+
80
+ # load data
81
+ train_list, val_list = get_data_path_list(train_path, val_path)
82
+
83
+ train_dataloader = build_dataloader(train_list,
84
+ root_path,
85
+ OOD_data=OOD_data,
86
+ min_length=min_length,
87
+ batch_size=batch_size,
88
+ num_workers=2,
89
+ dataset_config={},
90
+ device=device)
91
+
92
+ val_dataloader = build_dataloader(val_list,
93
+ root_path,
94
+ OOD_data=OOD_data,
95
+ min_length=min_length,
96
+ batch_size=batch_size,
97
+ validation=True,
98
+ num_workers=0,
99
+ device=device,
100
+ dataset_config={})
101
+
102
+ with accelerator.main_process_first():
103
+ # load pretrained ASR model
104
+ ASR_config = config.get('ASR_config', False)
105
+ ASR_path = config.get('ASR_path', False)
106
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
107
+
108
+ # load pretrained F0 model
109
+ F0_path = config.get('F0_path', False)
110
+ pitch_extractor = load_F0_models(F0_path)
111
+
112
+ # load BERT model
113
+ from Utils.PLBERT.util import load_plbert
114
+ BERT_path = config.get('PLBERT_dir', False)
115
+ plbert = load_plbert(BERT_path)
116
+
117
+ scheduler_params = {
118
+ "max_lr": float(config['optimizer_params'].get('lr', 1e-4)),
119
+ "pct_start": float(config['optimizer_params'].get('pct_start', 0.0)),
120
+ "epochs": epochs,
121
+ "steps_per_epoch": len(train_dataloader),
122
+ }
123
+
124
+ model_params = recursive_munch(config['model_params'])
125
+ multispeaker = model_params.multispeaker
126
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
127
+
128
+ best_loss = float('inf') # best test loss
129
+ loss_train_record = list([])
130
+ loss_test_record = list([])
131
+
132
+ loss_params = Munch(config['loss_params'])
133
+ TMA_epoch = loss_params.TMA_epoch
134
+
135
+ for k in model:
136
+ model[k] = accelerator.prepare(model[k])
137
+
138
+ train_dataloader, val_dataloader = accelerator.prepare(
139
+ train_dataloader, val_dataloader
140
+ )
141
+
142
+ _ = [model[key].to(device) for key in model]
143
+
144
+ # initialize optimizers after preparing models for compatibility with FSDP
145
+ optimizer = build_optimizer({key: model[key].parameters() for key in model},
146
+ scheduler_params_dict= {key: scheduler_params.copy() for key in model},
147
+ lr=float(config['optimizer_params'].get('lr', 1e-4)))
148
+
149
+ for k, v in optimizer.optimizers.items():
150
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
151
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
152
+
153
+ with accelerator.main_process_first():
154
+ if config.get('pretrained_model', '') != '':
155
+ model, optimizer, start_epoch, iters = load_checkpoint(model, optimizer, config['pretrained_model'],
156
+ load_only_params=config.get('load_only_params', True))
157
+ else:
158
+ start_epoch = 0
159
+ iters = 0
160
+
161
+ # in case not distributed
162
+ try:
163
+ n_down = model.text_aligner.module.n_down
164
+ except:
165
+ n_down = model.text_aligner.n_down
166
+
167
+ # wrapped losses for compatibility with mixed precision
168
+ stft_loss = MultiResolutionSTFTLoss().to(device)
169
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
170
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
171
+ wl = WavLMLoss(model_params.slm.model,
172
+ model.wd,
173
+ sr,
174
+ model_params.slm.sr).to(device)
175
+
176
+ for epoch in range(start_epoch, epochs):
177
+ running_loss = 0
178
+ start_time = time.time()
179
+
180
+ _ = [model[key].train() for key in model]
181
+
182
+ for i, batch in enumerate(train_dataloader):
183
+ waves = batch[0]
184
+ batch = [b.to(device) for b in batch[1:]]
185
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
186
+
187
+ with torch.no_grad():
188
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
189
+ text_mask = length_to_mask(input_lengths).to(texts.device)
190
+
191
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
192
+
193
+ s2s_attn = s2s_attn.transpose(-1, -2)
194
+ s2s_attn = s2s_attn[..., 1:]
195
+ s2s_attn = s2s_attn.transpose(-1, -2)
196
+
197
+ with torch.no_grad():
198
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
199
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
200
+ attn_mask = (attn_mask < 1)
201
+
202
+ s2s_attn.masked_fill_(attn_mask, 0.0)
203
+
204
+ with torch.no_grad():
205
+ mask_ST = mask_from_lens(s2s_attn, input_lengths, mel_input_length // (2 ** n_down))
206
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
207
+
208
+ # encode
209
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
210
+
211
+ # 50% of chance of using monotonic version
212
+ if bool(random.getrandbits(1)):
213
+ asr = (t_en @ s2s_attn)
214
+ else:
215
+ asr = (t_en @ s2s_attn_mono)
216
+
217
+ # get clips
218
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
219
+ mel_len = min([int(mel_input_length_all.min().item() / 2 - 1), max_len // 2])
220
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
221
+
222
+ en = []
223
+ gt = []
224
+ wav = []
225
+ st = []
226
+
227
+ for bib in range(len(mel_input_length)):
228
+ mel_length = int(mel_input_length[bib].item() / 2)
229
+
230
+ random_start = np.random.randint(0, mel_length - mel_len)
231
+ en.append(asr[bib, :, random_start:random_start+mel_len])
232
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
233
+
234
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
235
+ wav.append(torch.from_numpy(y).to(device))
236
+
237
+ # style reference (better to be different from the GT)
238
+ random_start = np.random.randint(0, mel_length - mel_len_st)
239
+ st.append(mels[bib, :, (random_start * 2):((random_start+mel_len_st) * 2)])
240
+
241
+ en = torch.stack(en)
242
+ gt = torch.stack(gt).detach()
243
+ st = torch.stack(st).detach()
244
+
245
+ wav = torch.stack(wav).float().detach()
246
+
247
+ # clip too short to be used by the style encoder
248
+ if gt.shape[-1] < 80:
249
+ continue
250
+
251
+ with torch.no_grad():
252
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
253
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
254
+
255
+ s = model.style_encoder(st.unsqueeze(1) if multispeaker else gt.unsqueeze(1))
256
+
257
+ y_rec = model.decoder(en, F0_real, real_norm, s)
258
+
259
+ # discriminator loss
260
+
261
+ if epoch >= TMA_epoch:
262
+ optimizer.zero_grad()
263
+ d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
264
+ accelerator.backward(d_loss)
265
+ optimizer.step('msd')
266
+ optimizer.step('mpd')
267
+ else:
268
+ d_loss = 0
269
+
270
+ # generator loss
271
+ optimizer.zero_grad()
272
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
273
+
274
+ if epoch >= TMA_epoch: # start TMA training
275
+ loss_s2s = 0
276
+ for _s2s_pred, _text_input, _text_length in zip(s2s_pred, texts, input_lengths):
277
+ loss_s2s += F.cross_entropy(_s2s_pred[:_text_length], _text_input[:_text_length])
278
+ loss_s2s /= texts.size(0)
279
+
280
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
281
+
282
+ loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
283
+ loss_slm = wl(wav.detach(), y_rec).mean()
284
+
285
+ g_loss = loss_params.lambda_mel * loss_mel + \
286
+ loss_params.lambda_mono * loss_mono + \
287
+ loss_params.lambda_s2s * loss_s2s + \
288
+ loss_params.lambda_gen * loss_gen_all + \
289
+ loss_params.lambda_slm * loss_slm
290
+
291
+ else:
292
+ loss_s2s = 0
293
+ loss_mono = 0
294
+ loss_gen_all = 0
295
+ loss_slm = 0
296
+ g_loss = loss_mel
297
+
298
+ running_loss += accelerator.gather(loss_mel).mean().item()
299
+
300
+ accelerator.backward(g_loss)
301
+
302
+ optimizer.step('text_encoder')
303
+ optimizer.step('style_encoder')
304
+ optimizer.step('decoder')
305
+
306
+ if epoch >= TMA_epoch:
307
+ optimizer.step('text_aligner')
308
+ optimizer.step('pitch_extractor')
309
+
310
+ iters = iters + 1
311
+
312
+ if (i+1)%log_interval == 0 and accelerator.is_main_process:
313
+ log_print ('Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f'
314
+ %(epoch+1, epochs, i+1, len(train_list)//batch_size, running_loss / log_interval, loss_gen_all, d_loss, loss_mono, loss_s2s, loss_slm), logger)
315
+
316
+ writer.add_scalar('train/mel_loss', running_loss / log_interval, iters)
317
+ writer.add_scalar('train/gen_loss', loss_gen_all, iters)
318
+ writer.add_scalar('train/d_loss', d_loss, iters)
319
+ writer.add_scalar('train/mono_loss', loss_mono, iters)
320
+ writer.add_scalar('train/s2s_loss', loss_s2s, iters)
321
+ writer.add_scalar('train/slm_loss', loss_slm, iters)
322
+
323
+ running_loss = 0
324
+
325
+ print('Time elasped:', time.time()-start_time)
326
+
327
+ if (i+1)%save_iter == 0 and accelerator.is_main_process:
328
+
329
+ print(f'Saving on step {epoch*len(train_dataloader)+i}...')
330
+ state = {
331
+ 'net': {key: model[key].state_dict() for key in model},
332
+ 'optimizer': optimizer.state_dict(),
333
+ 'iters': iters,
334
+ 'epoch': epoch,
335
+ }
336
+ save_path = osp.join(log_dir, f'2nd_phase_{epoch*len(train_dataloader)+i}.pth')
337
+ torch.save(state, save_path)
338
+
339
+ loss_test = 0
340
+
341
+ _ = [model[key].eval() for key in model]
342
+
343
+ with torch.no_grad():
344
+ iters_test = 0
345
+ for batch_idx, batch in enumerate(val_dataloader):
346
+ optimizer.zero_grad()
347
+
348
+ waves = batch[0]
349
+ batch = [b.to(device) for b in batch[1:]]
350
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
351
+
352
+ with torch.no_grad():
353
+ mask = length_to_mask(mel_input_length // (2 ** n_down)).to('cuda')
354
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
355
+
356
+ s2s_attn = s2s_attn.transpose(-1, -2)
357
+ s2s_attn = s2s_attn[..., 1:]
358
+ s2s_attn = s2s_attn.transpose(-1, -2)
359
+
360
+ text_mask = length_to_mask(input_lengths).to(texts.device)
361
+ attn_mask = (~mask).unsqueeze(-1).expand(mask.shape[0], mask.shape[1], text_mask.shape[-1]).float().transpose(-1, -2)
362
+ attn_mask = attn_mask.float() * (~text_mask).unsqueeze(-1).expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1]).float()
363
+ attn_mask = (attn_mask < 1)
364
+ s2s_attn.masked_fill_(attn_mask, 0.0)
365
+
366
+ # encode
367
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
368
+
369
+ asr = (t_en @ s2s_attn)
370
+
371
+ # get clips
372
+ mel_input_length_all = accelerator.gather(mel_input_length) # for balanced load
373
+ mel_len = min([int(mel_input_length.min().item() / 2 - 1), max_len // 2])
374
+
375
+ en = []
376
+ gt = []
377
+ wav = []
378
+ for bib in range(len(mel_input_length)):
379
+ mel_length = int(mel_input_length[bib].item() / 2)
380
+
381
+ random_start = np.random.randint(0, mel_length - mel_len)
382
+ en.append(asr[bib, :, random_start:random_start+mel_len])
383
+ gt.append(mels[bib, :, (random_start * 2):((random_start+mel_len) * 2)])
384
+ y = waves[bib][(random_start * 2) * 300:((random_start+mel_len) * 2) * 300]
385
+ wav.append(torch.from_numpy(y).to('cuda'))
386
+
387
+ wav = torch.stack(wav).float().detach()
388
+
389
+ en = torch.stack(en)
390
+ gt = torch.stack(gt).detach()
391
+
392
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
393
+ s = model.style_encoder(gt.unsqueeze(1))
394
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
395
+ y_rec = model.decoder(en, F0_real, real_norm, s)
396
+
397
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
398
+
399
+ loss_test += accelerator.gather(loss_mel).mean().item()
400
+ iters_test += 1
401
+
402
+ if accelerator.is_main_process:
403
+ print('Epochs:', epoch + 1)
404
+ log_print('Validation loss: %.3f' % (loss_test / iters_test) + '\n\n\n\n', logger)
405
+ print('\n\n\n')
406
+ writer.add_scalar('eval/mel_loss', loss_test / iters_test, epoch + 1)
407
+ attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
408
+ writer.add_figure('eval/attn', attn_image, epoch)
409
+
410
+ with torch.no_grad():
411
+ for bib in range(len(asr)):
412
+ mel_length = int(mel_input_length[bib].item())
413
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
414
+ en = asr[bib, :, :mel_length // 2].unsqueeze(0)
415
+
416
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
417
+ F0_real = F0_real.unsqueeze(0)
418
+ s = model.style_encoder(gt.unsqueeze(1))
419
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
420
+
421
+ y_rec = model.decoder(en, F0_real, real_norm, s)
422
+
423
+ writer.add_audio('eval/y' + str(bib), y_rec.cpu().numpy().squeeze(), epoch, sample_rate=sr)
424
+ if epoch == 0:
425
+ writer.add_audio('gt/y' + str(bib), waves[bib].squeeze(), epoch, sample_rate=sr)
426
+
427
+ if bib >= 15:
428
+ break
429
+
430
+ if epoch % saving_epoch == 0:
431
+ if (loss_test / iters_test) < best_loss:
432
+ best_loss = loss_test / iters_test
433
+ print('Saving..')
434
+ state = {
435
+ 'net': {key: model[key].state_dict() for key in model},
436
+ 'optimizer': optimizer.state_dict(),
437
+ 'iters': iters,
438
+ 'val_loss': loss_test / iters_test,
439
+ 'epoch': epoch,
440
+ }
441
+ save_path = osp.join(log_dir, 'epoch_1st_%05d.pth' % epoch)
442
+ torch.save(state, save_path)
443
+
444
+ if accelerator.is_main_process:
445
+ print('Saving..')
446
+ state = {
447
+ 'net': {key: model[key].state_dict() for key in model},
448
+ 'optimizer': optimizer.state_dict(),
449
+ 'iters': iters,
450
+ 'val_loss': loss_test / iters_test,
451
+ 'epoch': epoch,
452
+ }
453
+ save_path = osp.join(log_dir, config.get('first_stage_path', 'first_stage.pth'))
454
+ torch.save(state, save_path)
455
+
456
+
457
+
458
+ if __name__=="__main__":
459
+ main()