smjain commited on
Commit
d66b5b3
1 Parent(s): 01fa2a4

Upload train_nsf_sim_cache_sid_load_pretrain.py

Browse files
train_nsf_sim_cache_sid_load_pretrain.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os
2
+
3
+ now_dir = os.getcwd()
4
+ sys.path.append(os.path.join(now_dir, "train"))
5
+ import lib.utils
6
+
7
+ hps = utils.get_hparams()
8
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
9
+ n_gpus = len(hps.gpus.split("-"))
10
+ from random import shuffle
11
+ import traceback, json, argparse, itertools, math, torch, pdb
12
+
13
+ torch.backends.cudnn.deterministic = False
14
+ torch.backends.cudnn.benchmark = False
15
+ from torch import nn, optim
16
+ from torch.nn import functional as F
17
+ from torch.utils.data import DataLoader
18
+ from torch.utils.tensorboard import SummaryWriter
19
+ import torch.multiprocessing as mp
20
+ import torch.distributed as dist
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.cuda.amp import autocast, GradScaler
23
+ from infer_pack import commons
24
+ from time import sleep
25
+ from time import time as ttime
26
+ from lib.data_utils import (
27
+ TextAudioLoaderMultiNSFsid,
28
+ TextAudioLoader,
29
+ TextAudioCollateMultiNSFsid,
30
+ TextAudioCollate,
31
+ DistributedBucketSampler,
32
+ )
33
+ from infer_pack.models import (
34
+ SynthesizerTrnMs256NSFsid,
35
+ SynthesizerTrnMs256NSFsid_nono,
36
+ MultiPeriodDiscriminator,
37
+ )
38
+ from losses import generator_loss, discriminator_loss, feature_loss, kl_loss
39
+ from mel_processing import mel_spectrogram_torch, spec_to_mel_torch
40
+
41
+
42
+ global_step = 0
43
+
44
+
45
+ def main():
46
+ # n_gpus = torch.cuda.device_count()
47
+ os.environ["MASTER_ADDR"] = "localhost"
48
+ os.environ["MASTER_PORT"] = "51545"
49
+
50
+ mp.spawn(
51
+ run,
52
+ nprocs=n_gpus,
53
+ args=(
54
+ n_gpus,
55
+ hps,
56
+ ),
57
+ )
58
+
59
+
60
+ def run(rank, n_gpus, hps):
61
+ global global_step
62
+ if rank == 0:
63
+ logger = utils.get_logger(hps.model_dir)
64
+ logger.info(hps)
65
+ utils.check_git_hash(hps.model_dir)
66
+ writer = SummaryWriter(log_dir=hps.model_dir)
67
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
68
+
69
+ dist.init_process_group(
70
+ backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
71
+ )
72
+ torch.manual_seed(hps.train.seed)
73
+ if torch.cuda.is_available():
74
+ torch.cuda.set_device(rank)
75
+
76
+ if hps.if_f0 == 1:
77
+ train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
78
+ else:
79
+ train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
80
+ train_sampler = DistributedBucketSampler(
81
+ train_dataset,
82
+ hps.train.batch_size * n_gpus,
83
+ # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
84
+ [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
85
+ num_replicas=n_gpus,
86
+ rank=rank,
87
+ shuffle=True,
88
+ )
89
+ # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
90
+ # num_workers=8 -> num_workers=4
91
+ if hps.if_f0 == 1:
92
+ collate_fn = TextAudioCollateMultiNSFsid()
93
+ else:
94
+ collate_fn = TextAudioCollate()
95
+ train_loader = DataLoader(
96
+ train_dataset,
97
+ num_workers=4,
98
+ shuffle=False,
99
+ pin_memory=True,
100
+ collate_fn=collate_fn,
101
+ batch_sampler=train_sampler,
102
+ persistent_workers=True,
103
+ prefetch_factor=8,
104
+ )
105
+ if hps.if_f0 == 1:
106
+ net_g = SynthesizerTrnMs256NSFsid(
107
+ hps.data.filter_length // 2 + 1,
108
+ hps.train.segment_size // hps.data.hop_length,
109
+ **hps.model,
110
+ is_half=hps.train.fp16_run,
111
+ sr=hps.sample_rate,
112
+ )
113
+ else:
114
+ net_g = SynthesizerTrnMs256NSFsid_nono(
115
+ hps.data.filter_length // 2 + 1,
116
+ hps.train.segment_size // hps.data.hop_length,
117
+ **hps.model,
118
+ is_half=hps.train.fp16_run,
119
+ )
120
+ if torch.cuda.is_available():
121
+ net_g = net_g.cuda(rank)
122
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
123
+ if torch.cuda.is_available():
124
+ net_d = net_d.cuda(rank)
125
+ optim_g = torch.optim.AdamW(
126
+ net_g.parameters(),
127
+ hps.train.learning_rate,
128
+ betas=hps.train.betas,
129
+ eps=hps.train.eps,
130
+ )
131
+ optim_d = torch.optim.AdamW(
132
+ net_d.parameters(),
133
+ hps.train.learning_rate,
134
+ betas=hps.train.betas,
135
+ eps=hps.train.eps,
136
+ )
137
+ # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
138
+ # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
139
+ if torch.cuda.is_available():
140
+ net_g = DDP(net_g, device_ids=[rank])
141
+ net_d = DDP(net_d, device_ids=[rank])
142
+ else:
143
+ net_g = DDP(net_g)
144
+ net_d = DDP(net_d)
145
+
146
+ try: # 如果能加载自动resume
147
+ _, _, _, epoch_str = utils.load_checkpoint(
148
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
149
+ ) # D多半加载没事
150
+ if rank == 0:
151
+ logger.info("loaded D")
152
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
153
+ _, _, _, epoch_str = utils.load_checkpoint(
154
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
155
+ )
156
+ global_step = (epoch_str - 1) * len(train_loader)
157
+ # epoch_str = 1
158
+ # global_step = 0
159
+ except: # 如果首次不能加载,加载pretrain
160
+ # traceback.print_exc()
161
+ epoch_str = 1
162
+ global_step = 0
163
+ if rank == 0:
164
+ logger.info("loaded pretrained %s %s" % (hps.pretrainG, hps.pretrainD))
165
+ print(
166
+ net_g.module.load_state_dict(
167
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
168
+ )
169
+ ) ##测试不加载优化器
170
+ print(
171
+ net_d.module.load_state_dict(
172
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
173
+ )
174
+ )
175
+
176
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
177
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
178
+ )
179
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
180
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
181
+ )
182
+
183
+ scaler = GradScaler(enabled=hps.train.fp16_run)
184
+
185
+ cache = []
186
+ for epoch in range(epoch_str, hps.train.epochs + 1):
187
+ if rank == 0:
188
+ train_and_evaluate(
189
+ rank,
190
+ epoch,
191
+ hps,
192
+ [net_g, net_d],
193
+ [optim_g, optim_d],
194
+ [scheduler_g, scheduler_d],
195
+ scaler,
196
+ [train_loader, None],
197
+ logger,
198
+ [writer, writer_eval],
199
+ cache,
200
+ )
201
+ else:
202
+ train_and_evaluate(
203
+ rank,
204
+ epoch,
205
+ hps,
206
+ [net_g, net_d],
207
+ [optim_g, optim_d],
208
+ [scheduler_g, scheduler_d],
209
+ scaler,
210
+ [train_loader, None],
211
+ None,
212
+ None,
213
+ cache,
214
+ )
215
+ scheduler_g.step()
216
+ scheduler_d.step()
217
+
218
+
219
+ def train_and_evaluate(
220
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
221
+ ):
222
+ net_g, net_d = nets
223
+ optim_g, optim_d = optims
224
+ train_loader, eval_loader = loaders
225
+ if writers is not None:
226
+ writer, writer_eval = writers
227
+
228
+ train_loader.batch_sampler.set_epoch(epoch)
229
+ global global_step
230
+
231
+ net_g.train()
232
+ net_d.train()
233
+
234
+ # Prepare data iterator
235
+ if hps.if_cache_data_in_gpu == True:
236
+ # Use Cache
237
+ data_iterator = cache
238
+ if cache == []:
239
+ # Make new cache
240
+ for batch_idx, info in enumerate(train_loader):
241
+ # Unpack
242
+ if hps.if_f0 == 1:
243
+ (
244
+ phone,
245
+ phone_lengths,
246
+ pitch,
247
+ pitchf,
248
+ spec,
249
+ spec_lengths,
250
+ wave,
251
+ wave_lengths,
252
+ sid,
253
+ ) = info
254
+ else:
255
+ (
256
+ phone,
257
+ phone_lengths,
258
+ spec,
259
+ spec_lengths,
260
+ wave,
261
+ wave_lengths,
262
+ sid,
263
+ ) = info
264
+ # Load on CUDA
265
+ if torch.cuda.is_available():
266
+ phone = phone.cuda(rank, non_blocking=True)
267
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
268
+ if hps.if_f0 == 1:
269
+ pitch = pitch.cuda(rank, non_blocking=True)
270
+ pitchf = pitchf.cuda(rank, non_blocking=True)
271
+ sid = sid.cuda(rank, non_blocking=True)
272
+ spec = spec.cuda(rank, non_blocking=True)
273
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
274
+ wave = wave.cuda(rank, non_blocking=True)
275
+ wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
276
+ # Cache on list
277
+ if hps.if_f0 == 1:
278
+ cache.append(
279
+ (
280
+ batch_idx,
281
+ (
282
+ phone,
283
+ phone_lengths,
284
+ pitch,
285
+ pitchf,
286
+ spec,
287
+ spec_lengths,
288
+ wave,
289
+ wave_lengths,
290
+ sid,
291
+ ),
292
+ )
293
+ )
294
+ else:
295
+ cache.append(
296
+ (
297
+ batch_idx,
298
+ (
299
+ phone,
300
+ phone_lengths,
301
+ spec,
302
+ spec_lengths,
303
+ wave,
304
+ wave_lengths,
305
+ sid,
306
+ ),
307
+ )
308
+ )
309
+ else:
310
+ # Load shuffled cache
311
+ shuffle(cache)
312
+ else:
313
+ # Loader
314
+ data_iterator = enumerate(train_loader)
315
+
316
+ # Run steps
317
+ for batch_idx, info in data_iterator:
318
+ # Data
319
+ ## Unpack
320
+ if hps.if_f0 == 1:
321
+ (
322
+ phone,
323
+ phone_lengths,
324
+ pitch,
325
+ pitchf,
326
+ spec,
327
+ spec_lengths,
328
+ wave,
329
+ wave_lengths,
330
+ sid,
331
+ ) = info
332
+ else:
333
+ phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
334
+ ## Load on CUDA
335
+ if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
336
+ phone = phone.cuda(rank, non_blocking=True)
337
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
338
+ if hps.if_f0 == 1:
339
+ pitch = pitch.cuda(rank, non_blocking=True)
340
+ pitchf = pitchf.cuda(rank, non_blocking=True)
341
+ sid = sid.cuda(rank, non_blocking=True)
342
+ spec = spec.cuda(rank, non_blocking=True)
343
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
344
+ wave = wave.cuda(rank, non_blocking=True)
345
+ wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
346
+
347
+ # Calculate
348
+ with autocast(enabled=hps.train.fp16_run):
349
+ if hps.if_f0 == 1:
350
+ (
351
+ y_hat,
352
+ ids_slice,
353
+ x_mask,
354
+ z_mask,
355
+ (z, z_p, m_p, logs_p, m_q, logs_q),
356
+ ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
357
+ else:
358
+ (
359
+ y_hat,
360
+ ids_slice,
361
+ x_mask,
362
+ z_mask,
363
+ (z, z_p, m_p, logs_p, m_q, logs_q),
364
+ ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
365
+ mel = spec_to_mel_torch(
366
+ spec,
367
+ hps.data.filter_length,
368
+ hps.data.n_mel_channels,
369
+ hps.data.sampling_rate,
370
+ hps.data.mel_fmin,
371
+ hps.data.mel_fmax,
372
+ )
373
+ y_mel = commons.slice_segments(
374
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
375
+ )
376
+ with autocast(enabled=False):
377
+ y_hat_mel = mel_spectrogram_torch(
378
+ y_hat.float().squeeze(1),
379
+ hps.data.filter_length,
380
+ hps.data.n_mel_channels,
381
+ hps.data.sampling_rate,
382
+ hps.data.hop_length,
383
+ hps.data.win_length,
384
+ hps.data.mel_fmin,
385
+ hps.data.mel_fmax,
386
+ )
387
+ if hps.train.fp16_run == True:
388
+ y_hat_mel = y_hat_mel.half()
389
+ wave = commons.slice_segments(
390
+ wave, ids_slice * hps.data.hop_length, hps.train.segment_size
391
+ ) # slice
392
+
393
+ # Discriminator
394
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
395
+ with autocast(enabled=False):
396
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
397
+ y_d_hat_r, y_d_hat_g
398
+ )
399
+ optim_d.zero_grad()
400
+ scaler.scale(loss_disc).backward()
401
+ scaler.unscale_(optim_d)
402
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
403
+ scaler.step(optim_d)
404
+
405
+ with autocast(enabled=hps.train.fp16_run):
406
+ # Generator
407
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
408
+ with autocast(enabled=False):
409
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
410
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
411
+ loss_fm = feature_loss(fmap_r, fmap_g)
412
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
413
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
414
+ optim_g.zero_grad()
415
+ scaler.scale(loss_gen_all).backward()
416
+ scaler.unscale_(optim_g)
417
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
418
+ scaler.step(optim_g)
419
+ scaler.update()
420
+
421
+ if rank == 0:
422
+ if global_step % hps.train.log_interval == 0:
423
+ lr = optim_g.param_groups[0]["lr"]
424
+ logger.info(
425
+ "Train Epoch: {} [{:.0f}%]".format(
426
+ epoch, 100.0 * batch_idx / len(train_loader)
427
+ )
428
+ )
429
+ # Amor For Tensorboard display
430
+ if loss_mel > 50:
431
+ loss_mel = 50
432
+ if loss_kl > 5:
433
+ loss_kl = 5
434
+
435
+ logger.info([global_step, lr])
436
+ logger.info(
437
+ f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
438
+ )
439
+ scalar_dict = {
440
+ "loss/g/total": loss_gen_all,
441
+ "loss/d/total": loss_disc,
442
+ "learning_rate": lr,
443
+ "grad_norm_d": grad_norm_d,
444
+ "grad_norm_g": grad_norm_g,
445
+ }
446
+ scalar_dict.update(
447
+ {
448
+ "loss/g/fm": loss_fm,
449
+ "loss/g/mel": loss_mel,
450
+ "loss/g/kl": loss_kl,
451
+ }
452
+ )
453
+
454
+ scalar_dict.update(
455
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
456
+ )
457
+ scalar_dict.update(
458
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
459
+ )
460
+ scalar_dict.update(
461
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
462
+ )
463
+ image_dict = {
464
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
465
+ y_mel[0].data.cpu().numpy()
466
+ ),
467
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
468
+ y_hat_mel[0].data.cpu().numpy()
469
+ ),
470
+ "all/mel": utils.plot_spectrogram_to_numpy(
471
+ mel[0].data.cpu().numpy()
472
+ ),
473
+ }
474
+ utils.summarize(
475
+ writer=writer,
476
+ global_step=global_step,
477
+ images=image_dict,
478
+ scalars=scalar_dict,
479
+ )
480
+ global_step += 1
481
+ # /Run steps
482
+
483
+ if epoch % hps.save_every_epoch == 0 and rank == 0:
484
+ if hps.if_latest == 0:
485
+ utils.save_checkpoint(
486
+ net_g,
487
+ optim_g,
488
+ hps.train.learning_rate,
489
+ epoch,
490
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
491
+ )
492
+ utils.save_checkpoint(
493
+ net_d,
494
+ optim_d,
495
+ hps.train.learning_rate,
496
+ epoch,
497
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
498
+ )
499
+ else:
500
+ utils.save_checkpoint(
501
+ net_g,
502
+ optim_g,
503
+ hps.train.learning_rate,
504
+ epoch,
505
+ os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
506
+ )
507
+ utils.save_checkpoint(
508
+ net_d,
509
+ optim_d,
510
+ hps.train.learning_rate,
511
+ epoch,
512
+ os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
513
+ )
514
+
515
+ if rank == 0:
516
+ logger.info("====> Epoch: {}".format(epoch))
517
+ if epoch >= hps.total_epoch and rank == 0:
518
+ logger.info("Training is done. The program is closed.")
519
+ from process_ckpt import savee # def savee(ckpt,sr,if_f0,name,epoch):
520
+
521
+ if hasattr(net_g, "module"):
522
+ ckpt = net_g.module.state_dict()
523
+ else:
524
+ ckpt = net_g.state_dict()
525
+ logger.info(
526
+ "saving final ckpt:%s"
527
+ % (savee(ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch))
528
+ )
529
+ sleep(1)
530
+ os._exit(2333333)
531
+
532
+
533
+ if __name__ == "__main__":
534
+ main()