smjain commited on
Commit
aea7d07
1 Parent(s): d9dd983

Upload train.py

Browse files
Files changed (1) hide show
  1. lib/train.py +643 -0
lib/train.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ now_dir = os.getcwd()
8
+ sys.path.append(os.path.join(now_dir))
9
+
10
+ import datetime
11
+
12
+ from lib import utils
13
+
14
+ hps = utils.get_hparams()
15
+ os.environ["CUDA_VISIBLE_DEVICES"] = hps.gpus.replace("-", ",")
16
+ n_gpus = len(hps.gpus.split("-"))
17
+ from random import randint, shuffle
18
+
19
+ import torch
20
+
21
+ try:
22
+ import intel_extension_for_pytorch as ipex # pylint: disable=import-error, unused-import
23
+
24
+ if torch.xpu.is_available():
25
+ from infer.modules.ipex import ipex_init
26
+ from infer.modules.ipex.gradscaler import gradscaler_init
27
+ from torch.xpu.amp import autocast
28
+
29
+ GradScaler = gradscaler_init()
30
+ ipex_init()
31
+ else:
32
+ from torch.cuda.amp import GradScaler, autocast
33
+ except Exception:
34
+ from torch.cuda.amp import GradScaler, autocast
35
+
36
+ torch.backends.cudnn.deterministic = False
37
+ torch.backends.cudnn.benchmark = False
38
+ from time import sleep
39
+ from time import time as ttime
40
+
41
+ import torch.distributed as dist
42
+ import torch.multiprocessing as mp
43
+ from torch.nn import functional as F
44
+ from torch.nn.parallel import DistributedDataParallel as DDP
45
+ from torch.utils.data import DataLoader
46
+ from torch.utils.tensorboard import SummaryWriter
47
+
48
+ from lib import commons
49
+ from lib.data_utils import (
50
+ DistributedBucketSampler,
51
+ TextAudioCollate,
52
+ TextAudioCollateMultiNSFsid,
53
+ TextAudioLoader,
54
+ TextAudioLoaderMultiNSFsid,
55
+ )
56
+
57
+ if hps.version == "v1":
58
+ from infer.lib.infer_pack.models import MultiPeriodDiscriminator
59
+ from infer.lib.infer_pack.models import SynthesizerTrnMs256NSFsid as RVC_Model_f0
60
+ from infer.lib.infer_pack.models import (
61
+ SynthesizerTrnMs256NSFsid_nono as RVC_Model_nof0,
62
+ )
63
+ else:
64
+ from infer.lib.infer_pack.models import (
65
+ SynthesizerTrnMs768NSFsid as RVC_Model_f0,
66
+ SynthesizerTrnMs768NSFsid_nono as RVC_Model_nof0,
67
+ MultiPeriodDiscriminatorV2 as MultiPeriodDiscriminator,
68
+ )
69
+
70
+ from infer.lib.train.losses import (
71
+ discriminator_loss,
72
+ feature_loss,
73
+ generator_loss,
74
+ kl_loss,
75
+ )
76
+ from infer.lib.train.mel_processing import mel_spectrogram_torch, spec_to_mel_torch
77
+ from infer.lib.train.process_ckpt import savee
78
+
79
+ global_step = 0
80
+
81
+
82
+ class EpochRecorder:
83
+ def __init__(self):
84
+ self.last_time = ttime()
85
+
86
+ def record(self):
87
+ now_time = ttime()
88
+ elapsed_time = now_time - self.last_time
89
+ self.last_time = now_time
90
+ elapsed_time_str = str(datetime.timedelta(seconds=elapsed_time))
91
+ current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
92
+ return f"[{current_time}] | ({elapsed_time_str})"
93
+
94
+
95
+ def main():
96
+ n_gpus = torch.cuda.device_count()
97
+
98
+ if torch.cuda.is_available() == False and torch.backends.mps.is_available() == True:
99
+ n_gpus = 1
100
+ if n_gpus < 1:
101
+ # patch to unblock people without gpus. there is probably a better way.
102
+ print("NO GPU DETECTED: falling back to CPU - this may take a while")
103
+ n_gpus = 1
104
+ os.environ["MASTER_ADDR"] = "localhost"
105
+ os.environ["MASTER_PORT"] = str(randint(20000, 55555))
106
+ children = []
107
+ for i in range(n_gpus):
108
+ subproc = mp.Process(
109
+ target=run,
110
+ args=(i, n_gpus, hps),
111
+ )
112
+ children.append(subproc)
113
+ subproc.start()
114
+
115
+ for i in range(n_gpus):
116
+ children[i].join()
117
+
118
+
119
+ def run(
120
+ rank,
121
+ n_gpus,
122
+ hps,
123
+ ):
124
+ global global_step
125
+ if rank == 0:
126
+ logger = utils.get_logger(hps.model_dir)
127
+ logger.info(hps)
128
+ # utils.check_git_hash(hps.model_dir)
129
+ writer = SummaryWriter(log_dir=hps.model_dir)
130
+ writer_eval = SummaryWriter(log_dir=os.path.join(hps.model_dir, "eval"))
131
+
132
+ dist.init_process_group(
133
+ backend="gloo", init_method="env://", world_size=n_gpus, rank=rank
134
+ )
135
+ torch.manual_seed(hps.train.seed)
136
+ if torch.cuda.is_available():
137
+ torch.cuda.set_device(rank)
138
+
139
+ if hps.if_f0 == 1:
140
+ train_dataset = TextAudioLoaderMultiNSFsid(hps.data.training_files, hps.data)
141
+ else:
142
+ train_dataset = TextAudioLoader(hps.data.training_files, hps.data)
143
+ train_sampler = DistributedBucketSampler(
144
+ train_dataset,
145
+ hps.train.batch_size * n_gpus,
146
+ # [100, 200, 300, 400, 500, 600, 700, 800, 900, 1000, 1200,1400], # 16s
147
+ [100, 200, 300, 400, 500, 600, 700, 800, 900], # 16s
148
+ num_replicas=n_gpus,
149
+ rank=rank,
150
+ shuffle=True,
151
+ )
152
+ # It is possible that dataloader's workers are out of shared memory. Please try to raise your shared memory limit.
153
+ # num_workers=8 -> num_workers=4
154
+ if hps.if_f0 == 1:
155
+ collate_fn = TextAudioCollateMultiNSFsid()
156
+ else:
157
+ collate_fn = TextAudioCollate()
158
+ train_loader = DataLoader(
159
+ train_dataset,
160
+ num_workers=4,
161
+ shuffle=False,
162
+ pin_memory=True,
163
+ collate_fn=collate_fn,
164
+ batch_sampler=train_sampler,
165
+ persistent_workers=True,
166
+ prefetch_factor=8,
167
+ )
168
+ if hps.if_f0 == 1:
169
+ net_g = RVC_Model_f0(
170
+ hps.data.filter_length // 2 + 1,
171
+ hps.train.segment_size // hps.data.hop_length,
172
+ **hps.model,
173
+ is_half=hps.train.fp16_run,
174
+ sr=hps.sample_rate,
175
+ )
176
+ else:
177
+ net_g = RVC_Model_nof0(
178
+ hps.data.filter_length // 2 + 1,
179
+ hps.train.segment_size // hps.data.hop_length,
180
+ **hps.model,
181
+ is_half=hps.train.fp16_run,
182
+ )
183
+ if torch.cuda.is_available():
184
+ net_g = net_g.cuda(rank)
185
+ net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm)
186
+ if torch.cuda.is_available():
187
+ net_d = net_d.cuda(rank)
188
+ optim_g = torch.optim.AdamW(
189
+ net_g.parameters(),
190
+ hps.train.learning_rate,
191
+ betas=hps.train.betas,
192
+ eps=hps.train.eps,
193
+ )
194
+ optim_d = torch.optim.AdamW(
195
+ net_d.parameters(),
196
+ hps.train.learning_rate,
197
+ betas=hps.train.betas,
198
+ eps=hps.train.eps,
199
+ )
200
+ # net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
201
+ # net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
202
+ if hasattr(torch, "xpu") and torch.xpu.is_available():
203
+ pass
204
+ elif torch.cuda.is_available():
205
+ net_g = DDP(net_g, device_ids=[rank])
206
+ net_d = DDP(net_d, device_ids=[rank])
207
+ else:
208
+ net_g = DDP(net_g)
209
+ net_d = DDP(net_d)
210
+
211
+ try: # 如果能加载自动resume
212
+ _, _, _, epoch_str = utils.load_checkpoint(
213
+ utils.latest_checkpoint_path(hps.model_dir, "D_*.pth"), net_d, optim_d
214
+ ) # D多半加载没事
215
+ if rank == 0:
216
+ logger.info("loaded D")
217
+ # _, _, _, epoch_str = utils.load_checkpoint(utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g,load_opt=0)
218
+ _, _, _, epoch_str = utils.load_checkpoint(
219
+ utils.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g, optim_g
220
+ )
221
+ global_step = (epoch_str - 1) * len(train_loader)
222
+ # epoch_str = 1
223
+ # global_step = 0
224
+ except: # 如果首次不能加载,加载pretrain
225
+ # traceback.print_exc()
226
+ epoch_str = 1
227
+ global_step = 0
228
+ if hps.pretrainG != "":
229
+ if rank == 0:
230
+ logger.info("loaded pretrained %s" % (hps.pretrainG))
231
+ if hasattr(net_g, "module"):
232
+ logger.info(
233
+ net_g.module.load_state_dict(
234
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
235
+ )
236
+ ) ##测试不加载优化器
237
+ else:
238
+ logger.info(
239
+ net_g.load_state_dict(
240
+ torch.load(hps.pretrainG, map_location="cpu")["model"]
241
+ )
242
+ ) ##测试不加载优化器
243
+ if hps.pretrainD != "":
244
+ if rank == 0:
245
+ logger.info("loaded pretrained %s" % (hps.pretrainD))
246
+ if hasattr(net_d, "module"):
247
+ logger.info(
248
+ net_d.module.load_state_dict(
249
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
250
+ )
251
+ )
252
+ else:
253
+ logger.info(
254
+ net_d.load_state_dict(
255
+ torch.load(hps.pretrainD, map_location="cpu")["model"]
256
+ )
257
+ )
258
+
259
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
260
+ optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
261
+ )
262
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
263
+ optim_d, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2
264
+ )
265
+
266
+ scaler = GradScaler(enabled=hps.train.fp16_run)
267
+
268
+ cache = []
269
+ for epoch in range(epoch_str, hps.train.epochs + 1):
270
+ if rank == 0:
271
+ train_and_evaluate(
272
+ rank,
273
+ epoch,
274
+ hps,
275
+ [net_g, net_d],
276
+ [optim_g, optim_d],
277
+ [scheduler_g, scheduler_d],
278
+ scaler,
279
+ [train_loader, None],
280
+ logger,
281
+ [writer, writer_eval],
282
+ cache,
283
+ )
284
+ else:
285
+ train_and_evaluate(
286
+ rank,
287
+ epoch,
288
+ hps,
289
+ [net_g, net_d],
290
+ [optim_g, optim_d],
291
+ [scheduler_g, scheduler_d],
292
+ scaler,
293
+ [train_loader, None],
294
+ None,
295
+ None,
296
+ cache,
297
+ )
298
+ scheduler_g.step()
299
+ scheduler_d.step()
300
+
301
+
302
+ def train_and_evaluate(
303
+ rank, epoch, hps, nets, optims, schedulers, scaler, loaders, logger, writers, cache
304
+ ):
305
+ net_g, net_d = nets
306
+ optim_g, optim_d = optims
307
+ train_loader, eval_loader = loaders
308
+ if writers is not None:
309
+ writer, writer_eval = writers
310
+
311
+ train_loader.batch_sampler.set_epoch(epoch)
312
+ global global_step
313
+
314
+ net_g.train()
315
+ net_d.train()
316
+
317
+ # Prepare data iterator
318
+ if hps.if_cache_data_in_gpu == True:
319
+ # Use Cache
320
+ data_iterator = cache
321
+ if cache == []:
322
+ # Make new cache
323
+ for batch_idx, info in enumerate(train_loader):
324
+ # Unpack
325
+ if hps.if_f0 == 1:
326
+ (
327
+ phone,
328
+ phone_lengths,
329
+ pitch,
330
+ pitchf,
331
+ spec,
332
+ spec_lengths,
333
+ wave,
334
+ wave_lengths,
335
+ sid,
336
+ ) = info
337
+ else:
338
+ (
339
+ phone,
340
+ phone_lengths,
341
+ spec,
342
+ spec_lengths,
343
+ wave,
344
+ wave_lengths,
345
+ sid,
346
+ ) = info
347
+ # Load on CUDA
348
+ if torch.cuda.is_available():
349
+ phone = phone.cuda(rank, non_blocking=True)
350
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
351
+ if hps.if_f0 == 1:
352
+ pitch = pitch.cuda(rank, non_blocking=True)
353
+ pitchf = pitchf.cuda(rank, non_blocking=True)
354
+ sid = sid.cuda(rank, non_blocking=True)
355
+ spec = spec.cuda(rank, non_blocking=True)
356
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
357
+ wave = wave.cuda(rank, non_blocking=True)
358
+ wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
359
+ # Cache on list
360
+ if hps.if_f0 == 1:
361
+ cache.append(
362
+ (
363
+ batch_idx,
364
+ (
365
+ phone,
366
+ phone_lengths,
367
+ pitch,
368
+ pitchf,
369
+ spec,
370
+ spec_lengths,
371
+ wave,
372
+ wave_lengths,
373
+ sid,
374
+ ),
375
+ )
376
+ )
377
+ else:
378
+ cache.append(
379
+ (
380
+ batch_idx,
381
+ (
382
+ phone,
383
+ phone_lengths,
384
+ spec,
385
+ spec_lengths,
386
+ wave,
387
+ wave_lengths,
388
+ sid,
389
+ ),
390
+ )
391
+ )
392
+ else:
393
+ # Load shuffled cache
394
+ shuffle(cache)
395
+ else:
396
+ # Loader
397
+ data_iterator = enumerate(train_loader)
398
+
399
+ # Run steps
400
+ epoch_recorder = EpochRecorder()
401
+ for batch_idx, info in data_iterator:
402
+ # Data
403
+ ## Unpack
404
+ if hps.if_f0 == 1:
405
+ (
406
+ phone,
407
+ phone_lengths,
408
+ pitch,
409
+ pitchf,
410
+ spec,
411
+ spec_lengths,
412
+ wave,
413
+ wave_lengths,
414
+ sid,
415
+ ) = info
416
+ else:
417
+ phone, phone_lengths, spec, spec_lengths, wave, wave_lengths, sid = info
418
+ ## Load on CUDA
419
+ if (hps.if_cache_data_in_gpu == False) and torch.cuda.is_available():
420
+ phone = phone.cuda(rank, non_blocking=True)
421
+ phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
422
+ if hps.if_f0 == 1:
423
+ pitch = pitch.cuda(rank, non_blocking=True)
424
+ pitchf = pitchf.cuda(rank, non_blocking=True)
425
+ sid = sid.cuda(rank, non_blocking=True)
426
+ spec = spec.cuda(rank, non_blocking=True)
427
+ spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
428
+ wave = wave.cuda(rank, non_blocking=True)
429
+ # wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
430
+
431
+ # Calculate
432
+ with autocast(enabled=hps.train.fp16_run):
433
+ if hps.if_f0 == 1:
434
+ (
435
+ y_hat,
436
+ ids_slice,
437
+ x_mask,
438
+ z_mask,
439
+ (z, z_p, m_p, logs_p, m_q, logs_q),
440
+ ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
441
+ else:
442
+ (
443
+ y_hat,
444
+ ids_slice,
445
+ x_mask,
446
+ z_mask,
447
+ (z, z_p, m_p, logs_p, m_q, logs_q),
448
+ ) = net_g(phone, phone_lengths, spec, spec_lengths, sid)
449
+ mel = spec_to_mel_torch(
450
+ spec,
451
+ hps.data.filter_length,
452
+ hps.data.n_mel_channels,
453
+ hps.data.sampling_rate,
454
+ hps.data.mel_fmin,
455
+ hps.data.mel_fmax,
456
+ )
457
+ y_mel = commons.slice_segments(
458
+ mel, ids_slice, hps.train.segment_size // hps.data.hop_length
459
+ )
460
+ with autocast(enabled=False):
461
+ y_hat_mel = mel_spectrogram_torch(
462
+ y_hat.float().squeeze(1),
463
+ hps.data.filter_length,
464
+ hps.data.n_mel_channels,
465
+ hps.data.sampling_rate,
466
+ hps.data.hop_length,
467
+ hps.data.win_length,
468
+ hps.data.mel_fmin,
469
+ hps.data.mel_fmax,
470
+ )
471
+ if hps.train.fp16_run == True:
472
+ y_hat_mel = y_hat_mel.half()
473
+ wave = commons.slice_segments(
474
+ wave, ids_slice * hps.data.hop_length, hps.train.segment_size
475
+ ) # slice
476
+
477
+ # Discriminator
478
+ y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
479
+ with autocast(enabled=False):
480
+ loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(
481
+ y_d_hat_r, y_d_hat_g
482
+ )
483
+ optim_d.zero_grad()
484
+ scaler.scale(loss_disc).backward()
485
+ scaler.unscale_(optim_d)
486
+ grad_norm_d = commons.clip_grad_value_(net_d.parameters(), None)
487
+ scaler.step(optim_d)
488
+
489
+ with autocast(enabled=hps.train.fp16_run):
490
+ # Generator
491
+ y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
492
+ with autocast(enabled=False):
493
+ loss_mel = F.l1_loss(y_mel, y_hat_mel) * hps.train.c_mel
494
+ loss_kl = kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * hps.train.c_kl
495
+ loss_fm = feature_loss(fmap_r, fmap_g)
496
+ loss_gen, losses_gen = generator_loss(y_d_hat_g)
497
+ loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
498
+ optim_g.zero_grad()
499
+ scaler.scale(loss_gen_all).backward()
500
+ scaler.unscale_(optim_g)
501
+ grad_norm_g = commons.clip_grad_value_(net_g.parameters(), None)
502
+ scaler.step(optim_g)
503
+ scaler.update()
504
+
505
+ if rank == 0:
506
+ if global_step % hps.train.log_interval == 0:
507
+ lr = optim_g.param_groups[0]["lr"]
508
+ logger.info(
509
+ "Train Epoch: {} [{:.0f}%]".format(
510
+ epoch, 100.0 * batch_idx / len(train_loader)
511
+ )
512
+ )
513
+ # Amor For Tensorboard display
514
+ if loss_mel > 75:
515
+ loss_mel = 75
516
+ if loss_kl > 9:
517
+ loss_kl = 9
518
+
519
+ logger.info([global_step, lr])
520
+ logger.info(
521
+ f"loss_disc={loss_disc:.3f}, loss_gen={loss_gen:.3f}, loss_fm={loss_fm:.3f},loss_mel={loss_mel:.3f}, loss_kl={loss_kl:.3f}"
522
+ )
523
+ scalar_dict = {
524
+ "loss/g/total": loss_gen_all,
525
+ "loss/d/total": loss_disc,
526
+ "learning_rate": lr,
527
+ "grad_norm_d": grad_norm_d,
528
+ "grad_norm_g": grad_norm_g,
529
+ }
530
+ scalar_dict.update(
531
+ {
532
+ "loss/g/fm": loss_fm,
533
+ "loss/g/mel": loss_mel,
534
+ "loss/g/kl": loss_kl,
535
+ }
536
+ )
537
+
538
+ scalar_dict.update(
539
+ {"loss/g/{}".format(i): v for i, v in enumerate(losses_gen)}
540
+ )
541
+ scalar_dict.update(
542
+ {"loss/d_r/{}".format(i): v for i, v in enumerate(losses_disc_r)}
543
+ )
544
+ scalar_dict.update(
545
+ {"loss/d_g/{}".format(i): v for i, v in enumerate(losses_disc_g)}
546
+ )
547
+ image_dict = {
548
+ "slice/mel_org": utils.plot_spectrogram_to_numpy(
549
+ y_mel[0].data.cpu().numpy()
550
+ ),
551
+ "slice/mel_gen": utils.plot_spectrogram_to_numpy(
552
+ y_hat_mel[0].data.cpu().numpy()
553
+ ),
554
+ "all/mel": utils.plot_spectrogram_to_numpy(
555
+ mel[0].data.cpu().numpy()
556
+ ),
557
+ }
558
+ utils.summarize(
559
+ writer=writer,
560
+ global_step=global_step,
561
+ images=image_dict,
562
+ scalars=scalar_dict,
563
+ )
564
+ global_step += 1
565
+ # /Run steps
566
+
567
+ if epoch % hps.save_every_epoch == 0 and rank == 0:
568
+ if hps.if_latest == 0:
569
+ utils.save_checkpoint(
570
+ net_g,
571
+ optim_g,
572
+ hps.train.learning_rate,
573
+ epoch,
574
+ os.path.join(hps.model_dir, "G_{}.pth".format(global_step)),
575
+ )
576
+ utils.save_checkpoint(
577
+ net_d,
578
+ optim_d,
579
+ hps.train.learning_rate,
580
+ epoch,
581
+ os.path.join(hps.model_dir, "D_{}.pth".format(global_step)),
582
+ )
583
+ else:
584
+ utils.save_checkpoint(
585
+ net_g,
586
+ optim_g,
587
+ hps.train.learning_rate,
588
+ epoch,
589
+ os.path.join(hps.model_dir, "G_{}.pth".format(2333333)),
590
+ )
591
+ utils.save_checkpoint(
592
+ net_d,
593
+ optim_d,
594
+ hps.train.learning_rate,
595
+ epoch,
596
+ os.path.join(hps.model_dir, "D_{}.pth".format(2333333)),
597
+ )
598
+ if rank == 0 and hps.save_every_weights == "1":
599
+ if hasattr(net_g, "module"):
600
+ ckpt = net_g.module.state_dict()
601
+ else:
602
+ ckpt = net_g.state_dict()
603
+ logger.info(
604
+ "saving ckpt %s_e%s:%s"
605
+ % (
606
+ hps.name,
607
+ epoch,
608
+ savee(
609
+ ckpt,
610
+ hps.sample_rate,
611
+ hps.if_f0,
612
+ hps.name + "_e%s_s%s" % (epoch, global_step),
613
+ epoch,
614
+ hps.version,
615
+ hps,
616
+ ),
617
+ )
618
+ )
619
+
620
+ if rank == 0:
621
+ logger.info("====> Epoch: {} {}".format(epoch, epoch_recorder.record()))
622
+ if epoch >= hps.total_epoch and rank == 0:
623
+ logger.info("Training is done. The program is closed.")
624
+
625
+ if hasattr(net_g, "module"):
626
+ ckpt = net_g.module.state_dict()
627
+ else:
628
+ ckpt = net_g.state_dict()
629
+ logger.info(
630
+ "saving final ckpt:%s"
631
+ % (
632
+ savee(
633
+ ckpt, hps.sample_rate, hps.if_f0, hps.name, epoch, hps.version, hps
634
+ )
635
+ )
636
+ )
637
+ sleep(1)
638
+ os._exit(2333333)
639
+
640
+
641
+ if __name__ == "__main__":
642
+ torch.multiprocessing.set_start_method("spawn")
643
+ main()