NeoPy commited on
Commit
979717f
·
verified ·
1 Parent(s): 1a2120d

Update infer/modules/train/train.py

Browse files
Files changed (1) hide show
  1. infer/modules/train/train.py +155 -113
infer/modules/train/train.py CHANGED
@@ -8,6 +8,7 @@ now_dir = os.getcwd()
8
  sys.path.append(os.path.join(now_dir))
9
 
10
  import datetime
 
11
 
12
  from infer.lib.train import utils
13
 
@@ -105,6 +106,7 @@ def main():
105
  os.environ["MASTER_PORT"] = str(randint(20000, 55555))
106
  children = []
107
  logger = utils.get_logger(hps.model_dir)
 
108
  for i in range(n_gpus):
109
  subproc = mp.Process(
110
  target=run,
@@ -120,9 +122,8 @@ def main():
120
  def run(rank, n_gpus, hps, logger: logging.Logger):
121
  global global_step
122
  if rank == 0:
123
- # logger = utils.get_logger(hps.model_dir)
124
  logger.info(hps)
125
- # utils.check_git_hash(hps.model_dir)
126
  writer = SummaryWriter(log_dir=hps.model_dir)
127
  writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
128
 
@@ -140,18 +141,17 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
140
  train_sampler = DistributedBucketSampler(
141
  train_dataset,
142
  hps.train.batch_size * n_gpus,
143
- # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
144
- [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
145
  num_replicas=n_gpus,
146
  rank=rank,
147
  shuffle=True,
148
  )
149
- # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
150
- # num_workers=8 -> num_workers=4
151
  if hps.if_f0 == 1:
152
  collate_fn = TextAudioCollateMultiNSFsid()
153
  else:
154
  collate_fn = TextAudioCollate()
 
155
  train_loader = DataLoader(
156
  train_dataset,
157
  num_workers=4,
@@ -162,6 +162,11 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
162
  persistent_workers=True,
163
  prefetch_factor=8,
164
  )
 
 
 
 
 
165
  if hps.if_f0 == 1:
166
  net_g = RVC_Model_f0(
167
  hps.data.filter_length // 2 + 1,
@@ -177,11 +182,14 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
177
  **hps.model,
178
  is_half=hps.train.fp16_run,
179
  )
 
180
  if torch.cuda.is_available():
181
  net_g = net_g.cuda(rank)
 
182
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
183
  if torch.cuda.is_available():
184
  net_d = net_d.cuda(rank)
 
185
  optim_g = torch.optim.AdamW(
186
  net_g.parameters(),
187
  hps.train.learning_rate,
@@ -194,8 +202,7 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
194
  betas=hps.train.betas,
195
  eps=hps.train.eps,
196
  )
197
- # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
198
- # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
199
  if hasattr(torch, "xpu") and torch.xpu.is_available():
200
  pass
201
  elif torch.cuda.is_available():
@@ -205,52 +212,43 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
205
  net_g = DDP(net_g)
206
  net_d = DDP(net_d)
207
 
208
- try: # 如果能加载自动resume
209
  _, _, _, epoch_str = utils.load_checkpoint(
210
  utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
211
- ) # D多半加载没事
212
  if rank == 0:
213
- logger.info("loaded D")
214
- # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
215
  _, _, _, epoch_str = utils.load_checkpoint(
216
  utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
217
  )
218
  global_step = (epoch_str - 1) * len(train_loader)
219
- # epoch_str = 1
220
- # global_step = 0
221
- except: # 如果首次不能加载,加载pretrain
222
- # traceback.print_exc()
223
  epoch_str = 1
224
  global_step = 0
225
  if hps.pretrainG != "":
226
  if rank == 0:
227
- logger.info("loaded pretrained %s" % (hps.pretrainG))
228
  if hasattr(net_g, "module"):
229
- logger.info(
230
- net_g.module.load_state_dict(
231
- torch.load(hps.pretrainG, map_location="cpu")["model"]
232
- )
233
- ) ##测试不加载优化器
234
  else:
235
- logger.info(
236
- net_g.load_state_dict(
237
- torch.load(hps.pretrainG, map_location="cpu")["model"]
238
- )
239
- ) ##测试不加载优化器
240
  if hps.pretrainD != "":
241
  if rank == 0:
242
- logger.info("loaded pretrained %s" % (hps.pretrainD))
243
  if hasattr(net_d, "module"):
244
- logger.info(
245
- net_d.module.load_state_dict(
246
- torch.load(hps.pretrainD, map_location="cpu")["model"]
247
- )
248
  )
249
  else:
250
- logger.info(
251
- net_d.load_state_dict(
252
- torch.load(hps.pretrainD, map_location="cpu")["model"]
253
- )
254
  )
255
 
256
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
@@ -263,6 +261,11 @@ def run(rank, n_gpus, hps, logger: logging.Logger):
263
  scaler = GradScaler(enabled=hps.train.fp16_run)
264
 
265
  cache = []
 
 
 
 
 
266
  for epoch in range(epoch_str, hps.train.epochs + 1):
267
  if rank == 0:
268
  train_and_evaluate(
@@ -313,12 +316,16 @@ def train_and_evaluate(
313
 
314
  # Prepare data iterator
315
  if hps.if_cache_data_in_gpu == True:
316
- # Use Cache
317
- data_iterator = cache
318
  if cache == []:
319
- # Make new cache
 
 
 
 
 
 
 
320
  for batch_idx, info in enumerate(train_loader):
321
- # Unpack
322
  if hps.if_f0 == 1:
323
  (
324
  phone,
@@ -341,7 +348,7 @@ def train_and_evaluate(
341
  wave_lengths,
342
  sid,
343
  ) = info
344
- # Load on CUDA
345
  if torch.cuda.is_available():
346
  phone = phone.cuda(rank, non_blocking=True)
347
  phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
@@ -352,8 +359,7 @@ def train_and_evaluate(
352
  spec = spec.cuda(rank, non_blocking=True)
353
  spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
354
  wave = wave.cuda(rank, non_blocking=True)
355
- wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
356
- # Cache on list
357
  if hps.if_f0 == 1:
358
  cache.append(
359
  (
@@ -386,18 +392,31 @@ def train_and_evaluate(
386
  ),
387
  )
388
  )
389
- else:
390
- # Load shuffled cache
391
- shuffle(cache)
 
 
 
 
 
 
 
392
  else:
393
- # Loader
394
  data_iterator = enumerate(train_loader)
395
 
396
- # Run steps
 
 
 
 
 
 
 
397
  epoch_recorder = EpochRecorder()
 
398
  for batch_idx, info in data_iterator:
399
- # Data
400
- ## Unpack
401
  if hps.if_f0 == 1:
402
  (
403
  phone,
@@ -412,7 +431,7 @@ def train_and_evaluate(
412
  ) = info
413
  else:
414
  phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
415
- ## Load on CUDA
416
  if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
417
  phone = phone.cuda(rank, non_blocking=True)
418
  phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
@@ -423,9 +442,8 @@ def train_and_evaluate(
423
  spec = spec.cuda(rank, non_blocking=True)
424
  spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
425
  wave = wave.cuda(rank, non_blocking=True)
426
- # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
427
 
428
- # Calculate
429
  with autocast(enabled=hps.train.fp16_run):
430
  if hps.if_f0 == 1:
431
  (
@@ -443,6 +461,7 @@ def train_and_evaluate(
443
  z_mask,
444
  (z, z_p, m_p, logs_p, m_q, logs_q),
445
  ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
 
446
  mel = spec_to_mel_torch(
447
  spec,
448
  hps.data.filter_length,
@@ -454,6 +473,7 @@ def train_and_evaluate(
454
  y_mel = commons.slice_segments(
455
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
456
  )
 
457
  with autocast(enabled=False):
458
  y_hat_mel = mel_spectrogram_torch(
459
  y_hat.float().squeeze(1),
@@ -465,26 +485,30 @@ def train_and_evaluate(
465
  hps.data.mel_fmin,
466
  hps.data.mel_fmax,
467
  )
 
468
  if hps.train.fp16_run == True:
469
  y_hat_mel = y_hat_mel.half()
 
470
  wave = commons.slice_segments(
471
  wave, ids_slice * hps.data.hop_length, hps.train.segment_size
472
- ) # slice
473
 
474
- # Discriminator
475
  y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
476
  with autocast(enabled=False):
477
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
478
  y_d_hat_r, y_d_hat_g
479
  )
 
 
480
  optim_d.zero_grad()
481
  scaler.scale(loss_disc).backward()
482
  scaler.unscale_(optim_d)
483
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
484
  scaler.step(optim_d)
485
 
 
486
  with autocast(enabled=hps.train.fp16_run):
487
- # Generator
488
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
489
  with autocast(enabled=False):
490
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
@@ -492,6 +516,8 @@ def train_and_evaluate(
492
  loss_fm = feature_loss(fmap_r, fmap_g)
493
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
494
  loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
 
 
495
  optim_g.zero_grad()
496
  scaler.scale(loss_gen_all).backward()
497
  scaler.unscale_(optim_g)
@@ -499,39 +525,43 @@ def train_and_evaluate(
499
  scaler.step(optim_g)
500
  scaler.update()
501
 
 
502
  if rank == 0:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  if global_step % hps.train.log_interval == 0:
504
  lr = optim_g.param_groups[0]["lr"]
505
- logger.info(
506
- "Train Epoch: {} [{:.0f}%]".format(
507
- epoch, 100.0 * batch_idx / len(train_loader)
508
- )
509
- )
510
- # Amor For Tensorboard display
511
- if loss_mel > 75:
512
- loss_mel = 75
513
- if loss_kl > 9:
514
- loss_kl = 9
515
-
516
- logger.info([global_step, lr])
517
- logger.info(
518
- f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
519
- )
520
  scalar_dict = {
521
  "loss/g/total": loss_gen_all,
522
  "loss/d/total": loss_disc,
523
  "learning_rate": lr,
524
  "grad_norm_d": grad_norm_d,
525
  "grad_norm_g": grad_norm_g,
 
 
 
526
  }
527
- scalar_dict.update(
528
- {
529
- "loss/g/fm": loss_fm,
530
- "loss/g/mel": loss_mel,
531
- "loss/g/kl": loss_kl,
532
- }
533
- )
534
-
535
  scalar_dict.update(
536
  {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
537
  )
@@ -541,6 +571,7 @@ def train_and_evaluate(
541
  scalar_dict.update(
542
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
543
  )
 
544
  image_dict = {
545
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
546
  y_mel[0].data.cpu().numpy()
@@ -552,89 +583,100 @@ def train_and_evaluate(
552
  mel[0].data.cpu().numpy()
553
  ),
554
  }
 
555
  utils.summarize(
556
  writer=writer,
557
  global_step=global_step,
558
  images=image_dict,
559
  scalars=scalar_dict,
560
  )
 
561
  global_step += 1
562
- # /Run steps
563
-
 
 
 
 
564
  if epoch % hps.save_every_epoch == 0 and rank == 0:
565
  if hps.if_latest == 0:
 
 
566
  utils.save_checkpoint(
567
  net_g,
568
  optim_g,
569
  hps.train.learning_rate,
570
  epoch,
571
- os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
572
  )
573
  utils.save_checkpoint(
574
  net_d,
575
  optim_d,
576
  hps.train.learning_rate,
577
  epoch,
578
- os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
579
  )
 
580
  else:
 
 
581
  utils.save_checkpoint(
582
  net_g,
583
  optim_g,
584
  hps.train.learning_rate,
585
  epoch,
586
- os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
587
  )
588
  utils.save_checkpoint(
589
  net_d,
590
  optim_d,
591
  hps.train.learning_rate,
592
  epoch,
593
- os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
594
  )
 
 
595
  if rank == 0 and hps.save_every_weights == "1":
596
  if hasattr(net_g, "module"):
597
  ckpt = net_g.module.state_dict()
598
  else:
599
  ckpt = net_g.state_dict()
600
- logger.info(
601
- "saving ckpt %s_e%s:%s"
602
- % (
603
- hps.name,
604
- epoch,
605
- savee(
606
- ckpt,
607
- hps.sample_rate,
608
- hps.if_f0,
609
- hps.name + "_e%s_s%s" % (epoch, global_step),
610
- epoch,
611
- hps.version,
612
- hps,
613
- ),
614
- )
615
  )
 
616
 
 
617
  if rank == 0:
618
- logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
 
 
 
619
  if epoch >= hps.total_epoch and rank == 0:
620
- logger.info("Training is done. The program is closed.")
621
-
622
  if hasattr(net_g, "module"):
623
  ckpt = net_g.module.state_dict()
624
  else:
625
  ckpt = net_g.state_dict()
626
- logger.info(
627
- "saving final ckpt:%s"
628
- % (
629
- savee(
630
- ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
631
- )
632
- )
633
  )
634
- sleep(1)
635
- os._exit(2333333)
 
 
636
 
637
 
638
  if __name__ == "__main__":
639
  torch.multiprocessing.set_start_method("spawn")
640
- main()
 
8
  sys.path.append(os.path.join(now_dir))
9
 
10
  import datetime
11
+ from tqdm import tqdm # Added import
12
 
13
  from infer.lib.train import utils
14
 
 
106
  os.environ["MASTER_PORT"] = str(randint(20000, 55555))
107
  children = []
108
  logger = utils.get_logger(hps.model_dir)
109
+ logger.info(f"Starting training with {n_gpus} GPU(s)")
110
  for i in range(n_gpus):
111
  subproc = mp.Process(
112
  target=run,
 
122
  def run(rank, n_gpus, hps, logger: logging.Logger):
123
  global global_step
124
  if rank == 0:
125
+ logger.info(f"Process {rank}/{n_gpus-1} started")
126
  logger.info(hps)
 
127
  writer = SummaryWriter(log_dir=hps.model_dir)
128
  writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
129
 
 
141
  train_sampler = DistributedBucketSampler(
142
  train_dataset,
143
  hps.train.batch_size * n_gpus,
144
+ [100, 200, 300, 400, 500, 600, 700, 800, 900],
 
145
  num_replicas=n_gpus,
146
  rank=rank,
147
  shuffle=True,
148
  )
149
+
 
150
  if hps.if_f0 == 1:
151
  collate_fn = TextAudioCollateMultiNSFsid()
152
  else:
153
  collate_fn = TextAudioCollate()
154
+
155
  train_loader = DataLoader(
156
  train_dataset,
157
  num_workers=4,
 
162
  persistent_workers=True,
163
  prefetch_factor=8,
164
  )
165
+
166
+ if rank == 0:
167
+ logger.info(f"Training dataset size: {len(train_dataset)}")
168
+ logger.info(f"Number of batches per epoch: {len(train_loader)}")
169
+
170
  if hps.if_f0 == 1:
171
  net_g = RVC_Model_f0(
172
  hps.data.filter_length // 2 + 1,
 
182
  **hps.model,
183
  is_half=hps.train.fp16_run,
184
  )
185
+
186
  if torch.cuda.is_available():
187
  net_g = net_g.cuda(rank)
188
+
189
  net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
190
  if torch.cuda.is_available():
191
  net_d = net_d.cuda(rank)
192
+
193
  optim_g = torch.optim.AdamW(
194
  net_g.parameters(),
195
  hps.train.learning_rate,
 
202
  betas=hps.train.betas,
203
  eps=hps.train.eps,
204
  )
205
+
 
206
  if hasattr(torch, "xpu") and torch.xpu.is_available():
207
  pass
208
  elif torch.cuda.is_available():
 
212
  net_g = DDP(net_g)
213
  net_d = DDP(net_d)
214
 
215
+ try:
216
  _, _, _, epoch_str = utils.load_checkpoint(
217
  utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
218
+ )
219
  if rank == 0:
220
+ logger.info("Loaded discriminator checkpoint")
221
+
222
  _, _, _, epoch_str = utils.load_checkpoint(
223
  utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
224
  )
225
  global_step = (epoch_str - 1) * len(train_loader)
226
+ if rank == 0:
227
+ logger.info(f"Resuming from epoch {epoch_str}, global step {global_step}")
228
+ except:
 
229
  epoch_str = 1
230
  global_step = 0
231
  if hps.pretrainG != "":
232
  if rank == 0:
233
+ logger.info(f"Loading pretrained generator from {hps.pretrainG}")
234
  if hasattr(net_g, "module"):
235
+ net_g.module.load_state_dict(
236
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
237
+ )
 
 
238
  else:
239
+ net_g.load_state_dict(
240
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
241
+ )
 
 
242
  if hps.pretrainD != "":
243
  if rank == 0:
244
+ logger.info(f"Loading pretrained discriminator from {hps.pretrainD}")
245
  if hasattr(net_d, "module"):
246
+ net_d.module.load_state_dict(
247
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
 
 
248
  )
249
  else:
250
+ net_d.load_state_dict(
251
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
 
 
252
  )
253
 
254
  scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
 
261
  scaler = GradScaler(enabled=hps.train.fp16_run)
262
 
263
  cache = []
264
+
265
+ if rank == 0:
266
+ logger.info(f"Starting training from epoch {epoch_str} to {hps.train.epochs}")
267
+ logger.info(f"Total epochs to train: {hps.train.epochs - epoch_str + 1}")
268
+
269
  for epoch in range(epoch_str, hps.train.epochs + 1):
270
  if rank == 0:
271
  train_and_evaluate(
 
316
 
317
  # Prepare data iterator
318
  if hps.if_cache_data_in_gpu == True:
 
 
319
  if cache == []:
320
+ if rank == 0:
321
+ logger.info("Caching data in GPU...")
322
+ cache_progress = tqdm(total=len(train_loader),
323
+ desc="Caching",
324
+ position=0,
325
+ leave=True,
326
+ disable=(rank != 0))
327
+
328
  for batch_idx, info in enumerate(train_loader):
 
329
  if hps.if_f0 == 1:
330
  (
331
  phone,
 
348
  wave_lengths,
349
  sid,
350
  ) = info
351
+
352
  if torch.cuda.is_available():
353
  phone = phone.cuda(rank, non_blocking=True)
354
  phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
 
359
  spec = spec.cuda(rank, non_blocking=True)
360
  spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
361
  wave = wave.cuda(rank, non_blocking=True)
362
+
 
363
  if hps.if_f0 == 1:
364
  cache.append(
365
  (
 
392
  ),
393
  )
394
  )
395
+
396
+ if rank == 0:
397
+ cache_progress.update(1)
398
+
399
+ if rank == 0:
400
+ cache_progress.close()
401
+ logger.info(f"Cached {len(cache)} batches in GPU")
402
+
403
+ shuffle(cache)
404
+ data_iterator = cache
405
  else:
 
406
  data_iterator = enumerate(train_loader)
407
 
408
+ # Initialize tqdm progress bar for training
409
+ if rank == 0:
410
+ epoch_progress = tqdm(total=len(train_loader),
411
+ desc=f"Epoch {epoch}/{hps.train.epochs}",
412
+ position=0,
413
+ leave=True,
414
+ bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}')
415
+
416
  epoch_recorder = EpochRecorder()
417
+
418
  for batch_idx, info in data_iterator:
419
+ # Unpack data
 
420
  if hps.if_f0 == 1:
421
  (
422
  phone,
 
431
  ) = info
432
  else:
433
  phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
434
+
435
  if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
436
  phone = phone.cuda(rank, non_blocking=True)
437
  phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
 
442
  spec = spec.cuda(rank, non_blocking=True)
443
  spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
444
  wave = wave.cuda(rank, non_blocking=True)
 
445
 
446
+ # Forward pass
447
  with autocast(enabled=hps.train.fp16_run):
448
  if hps.if_f0 == 1:
449
  (
 
461
  z_mask,
462
  (z, z_p, m_p, logs_p, m_q, logs_q),
463
  ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
464
+
465
  mel = spec_to_mel_torch(
466
  spec,
467
  hps.data.filter_length,
 
473
  y_mel = commons.slice_segments(
474
  mel, ids_slice, hps.train.segment_size // hps.data.hop_length
475
  )
476
+
477
  with autocast(enabled=False):
478
  y_hat_mel = mel_spectrogram_torch(
479
  y_hat.float().squeeze(1),
 
485
  hps.data.mel_fmin,
486
  hps.data.mel_fmax,
487
  )
488
+
489
  if hps.train.fp16_run == True:
490
  y_hat_mel = y_hat_mel.half()
491
+
492
  wave = commons.slice_segments(
493
  wave, ids_slice * hps.data.hop_length, hps.train.segment_size
494
+ )
495
 
496
+ # Discriminator forward
497
  y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
498
  with autocast(enabled=False):
499
  loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
500
  y_d_hat_r, y_d_hat_g
501
  )
502
+
503
+ # Discriminator backward
504
  optim_d.zero_grad()
505
  scaler.scale(loss_disc).backward()
506
  scaler.unscale_(optim_d)
507
  grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
508
  scaler.step(optim_d)
509
 
510
+ # Generator forward
511
  with autocast(enabled=hps.train.fp16_run):
 
512
  y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
513
  with autocast(enabled=False):
514
  loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
 
516
  loss_fm = feature_loss(fmap_r, fmap_g)
517
  loss_gen, losses_gen = generator_loss(y_d_hat_g)
518
  loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
519
+
520
+ # Generator backward
521
  optim_g.zero_grad()
522
  scaler.scale(loss_gen_all).backward()
523
  scaler.unscale_(optim_g)
 
525
  scaler.step(optim_g)
526
  scaler.update()
527
 
528
+ # Update progress bar and logging
529
  if rank == 0:
530
+ if epoch_progress is not None:
531
+ epoch_progress.update(1)
532
+
533
+ # Update progress bar description with current losses
534
+ if batch_idx % hps.train.log_interval == 0:
535
+ postfix_dict = {
536
+ 'G': f'{loss_gen_all:.3f}',
537
+ 'D': f'{loss_disc:.3f}',
538
+ 'Mel': f'{loss_mel:.3f}',
539
+ 'KL': f'{loss_kl:.3f}',
540
+ 'Step': global_step
541
+ }
542
+ epoch_progress.set_postfix(postfix_dict)
543
+
544
  if global_step % hps.train.log_interval == 0:
545
  lr = optim_g.param_groups[0]["lr"]
546
+
547
+ logger.info(f"\nEpoch: {epoch} [{batch_idx}/{len(train_loader)}]")
548
+ logger.info(f"Global Step: {global_step}")
549
+ logger.info(f"Learning Rate: {lr:.6f}")
550
+ logger.info(f"Generator Loss: {loss_gen_all:.3f} (FM: {loss_fm:.3f}, Mel: {loss_mel:.3f}, KL: {loss_kl:.3f})")
551
+ logger.info(f"Discriminator Loss: {loss_disc:.3f}")
552
+ logger.info(f"Grad Norm - G: {grad_norm_g:.3f}, D: {grad_norm_d:.3f}")
553
+
554
+ # Tensorboard logging
 
 
 
 
 
 
555
  scalar_dict = {
556
  "loss/g/total": loss_gen_all,
557
  "loss/d/total": loss_disc,
558
  "learning_rate": lr,
559
  "grad_norm_d": grad_norm_d,
560
  "grad_norm_g": grad_norm_g,
561
+ "loss/g/fm": loss_fm,
562
+ "loss/g/mel": loss_mel,
563
+ "loss/g/kl": loss_kl,
564
  }
 
 
 
 
 
 
 
 
565
  scalar_dict.update(
566
  {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
567
  )
 
571
  scalar_dict.update(
572
  {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
573
  )
574
+
575
  image_dict = {
576
  "slice/mel_org": utils.plot_spectrogram_to_numpy(
577
  y_mel[0].data.cpu().numpy()
 
583
  mel[0].data.cpu().numpy()
584
  ),
585
  }
586
+
587
  utils.summarize(
588
  writer=writer,
589
  global_step=global_step,
590
  images=image_dict,
591
  scalars=scalar_dict,
592
  )
593
+
594
  global_step += 1
595
+
596
+ # Close progress bar
597
+ if rank == 0 and epoch_progress is not None:
598
+ epoch_progress.close()
599
+
600
+ # Save checkpoints
601
  if epoch % hps.save_every_epoch == 0 and rank == 0:
602
  if hps.if_latest == 0:
603
+ save_path_g = os.path.join(hps.model_dir, f"G_{global_step}.pth")
604
+ save_path_d = os.path.join(hps.model_dir, f"D_{global_step}.pth")
605
  utils.save_checkpoint(
606
  net_g,
607
  optim_g,
608
  hps.train.learning_rate,
609
  epoch,
610
+ save_path_g,
611
  )
612
  utils.save_checkpoint(
613
  net_d,
614
  optim_d,
615
  hps.train.learning_rate,
616
  epoch,
617
+ save_path_d,
618
  )
619
+ logger.info(f"Saved checkpoints: {save_path_g}, {save_path_d}")
620
  else:
621
+ save_path_g = os.path.join(hps.model_dir, "G_2333333.pth")
622
+ save_path_d = os.path.join(hps.model_dir, "D_2333333.pth")
623
  utils.save_checkpoint(
624
  net_g,
625
  optim_g,
626
  hps.train.learning_rate,
627
  epoch,
628
+ save_path_g,
629
  )
630
  utils.save_checkpoint(
631
  net_d,
632
  optim_d,
633
  hps.train.learning_rate,
634
  epoch,
635
+ save_path_d,
636
  )
637
+ logger.info(f"Saved latest checkpoints: {save_path_g}, {save_path_d}")
638
+
639
  if rank == 0 and hps.save_every_weights == "1":
640
  if hasattr(net_g, "module"):
641
  ckpt = net_g.module.state_dict()
642
  else:
643
  ckpt = net_g.state_dict()
644
+
645
+ model_name = hps.name + f"_e{epoch}_s{global_step}"
646
+ save_result = savee(
647
+ ckpt,
648
+ hps.sample_rate,
649
+ hps.if_f0,
650
+ model_name,
651
+ epoch,
652
+ hps.version,
653
+ hps,
 
 
 
 
 
654
  )
655
+ logger.info(f"Saved weights checkpoint: {model_name}: {save_result}")
656
 
657
+ # Log epoch completion
658
  if rank == 0:
659
+ logger.info(f"Completed Epoch {epoch} {epoch_recorder.record()}")
660
+ logger.info(f"Global Step: {global_step}")
661
+
662
+ # End training if completed
663
  if epoch >= hps.total_epoch and rank == 0:
664
+ logger.info("Training completed!")
665
+
666
  if hasattr(net_g, "module"):
667
  ckpt = net_g.module.state_dict()
668
  else:
669
  ckpt = net_g.state_dict()
670
+
671
+ final_save = savee(
672
+ ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
 
 
 
 
673
  )
674
+ logger.info(f"Saved final model: {final_save}")
675
+
676
+ sleep(2) # Give time for final logging
677
+ os._exit(0)
678
 
679
 
680
  if __name__ == "__main__":
681
  torch.multiprocessing.set_start_method("spawn")
682
+ main()