smjain commited on
Commit
150902c
·
verified ·
1 Parent(s): 3a68337

Upload 5 files

Browse files
Files changed (5) hide show
  1. lib/data_utils.py +517 -0
  2. lib/losses.py +58 -0
  3. lib/mel_processing.py +132 -0
  4. lib/process_ckpt.py +261 -0
  5. lib/utils.py +478 -0
lib/data_utils.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+
11
+ from infer.lib.train.mel_processing import spectrogram_torch
12
+ from infer.lib.train.utils import load_filepaths_and_text, load_wav_to_torch
13
+
14
+
15
+ class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
16
+ """
17
+ 1) loads audio, text pairs
18
+ 2) normalizes text and converts them to sequences of integers
19
+ 3) computes spectrograms from audio files.
20
+ """
21
+
22
+ def __init__(self, audiopaths_and_text, hparams):
23
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
24
+ self.max_wav_value = hparams.max_wav_value
25
+ self.sampling_rate = hparams.sampling_rate
26
+ self.filter_length = hparams.filter_length
27
+ self.hop_length = hparams.hop_length
28
+ self.win_length = hparams.win_length
29
+ self.sampling_rate = hparams.sampling_rate
30
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
31
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
32
+ self._filter()
33
+
34
+ def _filter(self):
35
+ """
36
+ Filter text & store spec lengths
37
+ """
38
+ # Store spectrogram lengths for Bucketing
39
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
40
+ # spec_length = wav_length // hop_length
41
+ audiopaths_and_text_new = []
42
+ lengths = []
43
+ for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
44
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
45
+ audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
46
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
47
+ self.audiopaths_and_text = audiopaths_and_text_new
48
+ self.lengths = lengths
49
+
50
+ def get_sid(self, sid):
51
+ sid = torch.LongTensor([int(sid)])
52
+ return sid
53
+
54
+ def get_audio_text_pair(self, audiopath_and_text):
55
+ # separate filename and text
56
+ file = audiopath_and_text[0]
57
+ phone = audiopath_and_text[1]
58
+ pitch = audiopath_and_text[2]
59
+ pitchf = audiopath_and_text[3]
60
+ dv = audiopath_and_text[4]
61
+
62
+ phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
63
+ spec, wav = self.get_audio(file)
64
+ dv = self.get_sid(dv)
65
+
66
+ len_phone = phone.size()[0]
67
+ len_spec = spec.size()[-1]
68
+ # print(123,phone.shape,pitch.shape,spec.shape)
69
+ if len_phone != len_spec:
70
+ len_min = min(len_phone, len_spec)
71
+ # amor
72
+ len_wav = len_min * self.hop_length
73
+
74
+ spec = spec[:, :len_min]
75
+ wav = wav[:, :len_wav]
76
+
77
+ phone = phone[:len_min, :]
78
+ pitch = pitch[:len_min]
79
+ pitchf = pitchf[:len_min]
80
+
81
+ return (spec, wav, phone, pitch, pitchf, dv)
82
+
83
+ def get_labels(self, phone, pitch, pitchf):
84
+ phone = np.load(phone)
85
+ phone = np.repeat(phone, 2, axis=0)
86
+ pitch = np.load(pitch)
87
+ pitchf = np.load(pitchf)
88
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
89
+ # print(234,phone.shape,pitch.shape)
90
+ phone = phone[:n_num, :]
91
+ pitch = pitch[:n_num]
92
+ pitchf = pitchf[:n_num]
93
+ phone = torch.FloatTensor(phone)
94
+ pitch = torch.LongTensor(pitch)
95
+ pitchf = torch.FloatTensor(pitchf)
96
+ return phone, pitch, pitchf
97
+
98
+ def get_audio(self, filename):
99
+ audio, sampling_rate = load_wav_to_torch(filename)
100
+ if sampling_rate != self.sampling_rate:
101
+ raise ValueError(
102
+ "{} SR doesn't match target {} SR".format(
103
+ sampling_rate, self.sampling_rate
104
+ )
105
+ )
106
+ audio_norm = audio
107
+ # audio_norm = audio / self.max_wav_value
108
+ # audio_norm = audio / np.abs(audio).max()
109
+
110
+ audio_norm = audio_norm.unsqueeze(0)
111
+ spec_filename = filename.replace(".wav", ".spec.pt")
112
+ if os.path.exists(spec_filename):
113
+ try:
114
+ spec = torch.load(spec_filename)
115
+ except:
116
+ logger.warning("%s %s", spec_filename, traceback.format_exc())
117
+ spec = spectrogram_torch(
118
+ audio_norm,
119
+ self.filter_length,
120
+ self.sampling_rate,
121
+ self.hop_length,
122
+ self.win_length,
123
+ center=False,
124
+ )
125
+ spec = torch.squeeze(spec, 0)
126
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
127
+ else:
128
+ spec = spectrogram_torch(
129
+ audio_norm,
130
+ self.filter_length,
131
+ self.sampling_rate,
132
+ self.hop_length,
133
+ self.win_length,
134
+ center=False,
135
+ )
136
+ spec = torch.squeeze(spec, 0)
137
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
138
+ return spec, audio_norm
139
+
140
+ def __getitem__(self, index):
141
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
142
+
143
+ def __len__(self):
144
+ return len(self.audiopaths_and_text)
145
+
146
+
147
+ class TextAudioCollateMultiNSFsid:
148
+ """Zero-pads model inputs and targets"""
149
+
150
+ def __init__(self, return_ids=False):
151
+ self.return_ids = return_ids
152
+
153
+ def __call__(self, batch):
154
+ """Collate's training batch from normalized text and aduio
155
+ PARAMS
156
+ ------
157
+ batch: [text_normalized, spec_normalized, wav_normalized]
158
+ """
159
+ # Right zero-pad all one-hot text sequences to max input length
160
+ _, ids_sorted_decreasing = torch.sort(
161
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
162
+ )
163
+
164
+ max_spec_len = max([x[0].size(1) for x in batch])
165
+ max_wave_len = max([x[1].size(1) for x in batch])
166
+ spec_lengths = torch.LongTensor(len(batch))
167
+ wave_lengths = torch.LongTensor(len(batch))
168
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
169
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
170
+ spec_padded.zero_()
171
+ wave_padded.zero_()
172
+
173
+ max_phone_len = max([x[2].size(0) for x in batch])
174
+ phone_lengths = torch.LongTensor(len(batch))
175
+ phone_padded = torch.FloatTensor(
176
+ len(batch), max_phone_len, batch[0][2].shape[1]
177
+ ) # (spec, wav, phone, pitch)
178
+ pitch_padded = torch.LongTensor(len(batch), max_phone_len)
179
+ pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
180
+ phone_padded.zero_()
181
+ pitch_padded.zero_()
182
+ pitchf_padded.zero_()
183
+ # dv = torch.FloatTensor(len(batch), 256)#gin=256
184
+ sid = torch.LongTensor(len(batch))
185
+
186
+ for i in range(len(ids_sorted_decreasing)):
187
+ row = batch[ids_sorted_decreasing[i]]
188
+
189
+ spec = row[0]
190
+ spec_padded[i, :, : spec.size(1)] = spec
191
+ spec_lengths[i] = spec.size(1)
192
+
193
+ wave = row[1]
194
+ wave_padded[i, :, : wave.size(1)] = wave
195
+ wave_lengths[i] = wave.size(1)
196
+
197
+ phone = row[2]
198
+ phone_padded[i, : phone.size(0), :] = phone
199
+ phone_lengths[i] = phone.size(0)
200
+
201
+ pitch = row[3]
202
+ pitch_padded[i, : pitch.size(0)] = pitch
203
+ pitchf = row[4]
204
+ pitchf_padded[i, : pitchf.size(0)] = pitchf
205
+
206
+ # dv[i] = row[5]
207
+ sid[i] = row[5]
208
+
209
+ return (
210
+ phone_padded,
211
+ phone_lengths,
212
+ pitch_padded,
213
+ pitchf_padded,
214
+ spec_padded,
215
+ spec_lengths,
216
+ wave_padded,
217
+ wave_lengths,
218
+ # dv
219
+ sid,
220
+ )
221
+
222
+
223
+ class TextAudioLoader(torch.utils.data.Dataset):
224
+ """
225
+ 1) loads audio, text pairs
226
+ 2) normalizes text and converts them to sequences of integers
227
+ 3) computes spectrograms from audio files.
228
+ """
229
+
230
+ def __init__(self, audiopaths_and_text, hparams):
231
+ self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
232
+ self.max_wav_value = hparams.max_wav_value
233
+ self.sampling_rate = hparams.sampling_rate
234
+ self.filter_length = hparams.filter_length
235
+ self.hop_length = hparams.hop_length
236
+ self.win_length = hparams.win_length
237
+ self.sampling_rate = hparams.sampling_rate
238
+ self.min_text_len = getattr(hparams, "min_text_len", 1)
239
+ self.max_text_len = getattr(hparams, "max_text_len", 5000)
240
+ self._filter()
241
+
242
+ def _filter(self):
243
+ """
244
+ Filter text & store spec lengths
245
+ """
246
+ # Store spectrogram lengths for Bucketing
247
+ # wav_length ~= file_size / (wav_channels * Bytes per dim) = file_size / (1 * 2)
248
+ # spec_length = wav_length // hop_length
249
+ audiopaths_and_text_new = []
250
+ lengths = []
251
+ for audiopath, text, dv in self.audiopaths_and_text:
252
+ if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
253
+ audiopaths_and_text_new.append([audiopath, text, dv])
254
+ lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
255
+ self.audiopaths_and_text = audiopaths_and_text_new
256
+ self.lengths = lengths
257
+
258
+ def get_sid(self, sid):
259
+ sid = torch.LongTensor([int(sid)])
260
+ return sid
261
+
262
+ def get_audio_text_pair(self, audiopath_and_text):
263
+ # separate filename and text
264
+ file = audiopath_and_text[0]
265
+ phone = audiopath_and_text[1]
266
+ dv = audiopath_and_text[2]
267
+
268
+ phone = self.get_labels(phone)
269
+ spec, wav = self.get_audio(file)
270
+ dv = self.get_sid(dv)
271
+
272
+ len_phone = phone.size()[0]
273
+ len_spec = spec.size()[-1]
274
+ if len_phone != len_spec:
275
+ len_min = min(len_phone, len_spec)
276
+ len_wav = len_min * self.hop_length
277
+ spec = spec[:, :len_min]
278
+ wav = wav[:, :len_wav]
279
+ phone = phone[:len_min, :]
280
+ return (spec, wav, phone, dv)
281
+
282
+ def get_labels(self, phone):
283
+ phone = np.load(phone)
284
+ phone = np.repeat(phone, 2, axis=0)
285
+ n_num = min(phone.shape[0], 900) # DistributedBucketSampler
286
+ phone = phone[:n_num, :]
287
+ phone = torch.FloatTensor(phone)
288
+ return phone
289
+
290
+ def get_audio(self, filename):
291
+ audio, sampling_rate = load_wav_to_torch(filename)
292
+ if sampling_rate != self.sampling_rate:
293
+ raise ValueError(
294
+ "{} SR doesn't match target {} SR".format(
295
+ sampling_rate, self.sampling_rate
296
+ )
297
+ )
298
+ audio_norm = audio
299
+ # audio_norm = audio / self.max_wav_value
300
+ # audio_norm = audio / np.abs(audio).max()
301
+
302
+ audio_norm = audio_norm.unsqueeze(0)
303
+ spec_filename = filename.replace(".wav", ".spec.pt")
304
+ if os.path.exists(spec_filename):
305
+ try:
306
+ spec = torch.load(spec_filename)
307
+ except:
308
+ logger.warning("%s %s", spec_filename, traceback.format_exc())
309
+ spec = spectrogram_torch(
310
+ audio_norm,
311
+ self.filter_length,
312
+ self.sampling_rate,
313
+ self.hop_length,
314
+ self.win_length,
315
+ center=False,
316
+ )
317
+ spec = torch.squeeze(spec, 0)
318
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
319
+ else:
320
+ spec = spectrogram_torch(
321
+ audio_norm,
322
+ self.filter_length,
323
+ self.sampling_rate,
324
+ self.hop_length,
325
+ self.win_length,
326
+ center=False,
327
+ )
328
+ spec = torch.squeeze(spec, 0)
329
+ torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
330
+ return spec, audio_norm
331
+
332
+ def __getitem__(self, index):
333
+ return self.get_audio_text_pair(self.audiopaths_and_text[index])
334
+
335
+ def __len__(self):
336
+ return len(self.audiopaths_and_text)
337
+
338
+
339
+ class TextAudioCollate:
340
+ """Zero-pads model inputs and targets"""
341
+
342
+ def __init__(self, return_ids=False):
343
+ self.return_ids = return_ids
344
+
345
+ def __call__(self, batch):
346
+ """Collate's training batch from normalized text and aduio
347
+ PARAMS
348
+ ------
349
+ batch: [text_normalized, spec_normalized, wav_normalized]
350
+ """
351
+ # Right zero-pad all one-hot text sequences to max input length
352
+ _, ids_sorted_decreasing = torch.sort(
353
+ torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True
354
+ )
355
+
356
+ max_spec_len = max([x[0].size(1) for x in batch])
357
+ max_wave_len = max([x[1].size(1) for x in batch])
358
+ spec_lengths = torch.LongTensor(len(batch))
359
+ wave_lengths = torch.LongTensor(len(batch))
360
+ spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
361
+ wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
362
+ spec_padded.zero_()
363
+ wave_padded.zero_()
364
+
365
+ max_phone_len = max([x[2].size(0) for x in batch])
366
+ phone_lengths = torch.LongTensor(len(batch))
367
+ phone_padded = torch.FloatTensor(
368
+ len(batch), max_phone_len, batch[0][2].shape[1]
369
+ )
370
+ phone_padded.zero_()
371
+ sid = torch.LongTensor(len(batch))
372
+
373
+ for i in range(len(ids_sorted_decreasing)):
374
+ row = batch[ids_sorted_decreasing[i]]
375
+
376
+ spec = row[0]
377
+ spec_padded[i, :, : spec.size(1)] = spec
378
+ spec_lengths[i] = spec.size(1)
379
+
380
+ wave = row[1]
381
+ wave_padded[i, :, : wave.size(1)] = wave
382
+ wave_lengths[i] = wave.size(1)
383
+
384
+ phone = row[2]
385
+ phone_padded[i, : phone.size(0), :] = phone
386
+ phone_lengths[i] = phone.size(0)
387
+
388
+ sid[i] = row[3]
389
+
390
+ return (
391
+ phone_padded,
392
+ phone_lengths,
393
+ spec_padded,
394
+ spec_lengths,
395
+ wave_padded,
396
+ wave_lengths,
397
+ sid,
398
+ )
399
+
400
+
401
+ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
402
+ """
403
+ Maintain similar input lengths in a batch.
404
+ Length groups are specified by boundaries.
405
+ Ex) boundaries = [b1, b2, b3] -> any batch is included either {x | b1 < length(x) <=b2} or {x | b2 < length(x) <= b3}.
406
+
407
+ It removes samples which are not included in the boundaries.
408
+ Ex) boundaries = [b1, b2, b3] -> any x s.t. length(x) <= b1 or length(x) > b3 are discarded.
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ dataset,
414
+ batch_size,
415
+ boundaries,
416
+ num_replicas=None,
417
+ rank=None,
418
+ shuffle=True,
419
+ ):
420
+ super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
421
+ self.lengths = dataset.lengths
422
+ self.batch_size = batch_size
423
+ self.boundaries = boundaries
424
+
425
+ self.buckets, self.num_samples_per_bucket = self._create_buckets()
426
+ self.total_size = sum(self.num_samples_per_bucket)
427
+ self.num_samples = self.total_size // self.num_replicas
428
+
429
+ def _create_buckets(self):
430
+ buckets = [[] for _ in range(len(self.boundaries) - 1)]
431
+ for i in range(len(self.lengths)):
432
+ length = self.lengths[i]
433
+ idx_bucket = self._bisect(length)
434
+ if idx_bucket != -1:
435
+ buckets[idx_bucket].append(i)
436
+
437
+ for i in range(len(buckets) - 1, -1, -1): #
438
+ if len(buckets[i]) == 0:
439
+ buckets.pop(i)
440
+ self.boundaries.pop(i + 1)
441
+
442
+ num_samples_per_bucket = []
443
+ for i in range(len(buckets)):
444
+ len_bucket = len(buckets[i])
445
+ total_batch_size = self.num_replicas * self.batch_size
446
+ rem = (
447
+ total_batch_size - (len_bucket % total_batch_size)
448
+ ) % total_batch_size
449
+ num_samples_per_bucket.append(len_bucket + rem)
450
+ return buckets, num_samples_per_bucket
451
+
452
+ def __iter__(self):
453
+ # deterministically shuffle based on epoch
454
+ g = torch.Generator()
455
+ g.manual_seed(self.epoch)
456
+
457
+ indices = []
458
+ if self.shuffle:
459
+ for bucket in self.buckets:
460
+ indices.append(torch.randperm(len(bucket), generator=g).tolist())
461
+ else:
462
+ for bucket in self.buckets:
463
+ indices.append(list(range(len(bucket))))
464
+
465
+ batches = []
466
+ for i in range(len(self.buckets)):
467
+ bucket = self.buckets[i]
468
+ len_bucket = len(bucket)
469
+ ids_bucket = indices[i]
470
+ num_samples_bucket = self.num_samples_per_bucket[i]
471
+
472
+ # add extra samples to make it evenly divisible
473
+ rem = num_samples_bucket - len_bucket
474
+ ids_bucket = (
475
+ ids_bucket
476
+ + ids_bucket * (rem // len_bucket)
477
+ + ids_bucket[: (rem % len_bucket)]
478
+ )
479
+
480
+ # subsample
481
+ ids_bucket = ids_bucket[self.rank :: self.num_replicas]
482
+
483
+ # batching
484
+ for j in range(len(ids_bucket) // self.batch_size):
485
+ batch = [
486
+ bucket[idx]
487
+ for idx in ids_bucket[
488
+ j * self.batch_size : (j + 1) * self.batch_size
489
+ ]
490
+ ]
491
+ batches.append(batch)
492
+
493
+ if self.shuffle:
494
+ batch_ids = torch.randperm(len(batches), generator=g).tolist()
495
+ batches = [batches[i] for i in batch_ids]
496
+ self.batches = batches
497
+
498
+ assert len(self.batches) * self.batch_size == self.num_samples
499
+ return iter(self.batches)
500
+
501
+ def _bisect(self, x, lo=0, hi=None):
502
+ if hi is None:
503
+ hi = len(self.boundaries) - 1
504
+
505
+ if hi > lo:
506
+ mid = (hi + lo) // 2
507
+ if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]:
508
+ return mid
509
+ elif x <= self.boundaries[mid]:
510
+ return self._bisect(x, lo, mid)
511
+ else:
512
+ return self._bisect(x, mid + 1, hi)
513
+ else:
514
+ return -1
515
+
516
+ def __len__(self):
517
+ return self.num_samples // self.batch_size
lib/losses.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def feature_loss(fmap_r, fmap_g):
5
+ loss = 0
6
+ for dr, dg in zip(fmap_r, fmap_g):
7
+ for rl, gl in zip(dr, dg):
8
+ rl = rl.float().detach()
9
+ gl = gl.float()
10
+ loss += torch.mean(torch.abs(rl - gl))
11
+
12
+ return loss * 2
13
+
14
+
15
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
16
+ loss = 0
17
+ r_losses = []
18
+ g_losses = []
19
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
20
+ dr = dr.float()
21
+ dg = dg.float()
22
+ r_loss = torch.mean((1 - dr) ** 2)
23
+ g_loss = torch.mean(dg**2)
24
+ loss += r_loss + g_loss
25
+ r_losses.append(r_loss.item())
26
+ g_losses.append(g_loss.item())
27
+
28
+ return loss, r_losses, g_losses
29
+
30
+
31
+ def generator_loss(disc_outputs):
32
+ loss = 0
33
+ gen_losses = []
34
+ for dg in disc_outputs:
35
+ dg = dg.float()
36
+ l = torch.mean((1 - dg) ** 2)
37
+ gen_losses.append(l)
38
+ loss += l
39
+
40
+ return loss, gen_losses
41
+
42
+
43
+ def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
44
+ """
45
+ z_p, logs_q: [b, h, t_t]
46
+ m_p, logs_p: [b, h, t_t]
47
+ """
48
+ z_p = z_p.float()
49
+ logs_q = logs_q.float()
50
+ m_p = m_p.float()
51
+ logs_p = logs_p.float()
52
+ z_mask = z_mask.float()
53
+
54
+ kl = logs_p - logs_q - 0.5
55
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
56
+ kl = torch.sum(kl * z_mask)
57
+ l = kl / torch.sum(z_mask)
58
+ return l
lib/mel_processing.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.data
3
+ from librosa.filters import mel as librosa_mel_fn
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ MAX_WAV_VALUE = 32768.0
9
+
10
+
11
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
12
+ """
13
+ PARAMS
14
+ ------
15
+ C: compression factor
16
+ """
17
+ return torch.log(torch.clamp(x, min=clip_val) * C)
18
+
19
+
20
+ def dynamic_range_decompression_torch(x, C=1):
21
+ """
22
+ PARAMS
23
+ ------
24
+ C: compression factor used to compress
25
+ """
26
+ return torch.exp(x) / C
27
+
28
+
29
+ def spectral_normalize_torch(magnitudes):
30
+ return dynamic_range_compression_torch(magnitudes)
31
+
32
+
33
+ def spectral_de_normalize_torch(magnitudes):
34
+ return dynamic_range_decompression_torch(magnitudes)
35
+
36
+
37
+ # Reusable banks
38
+ mel_basis = {}
39
+ hann_window = {}
40
+
41
+
42
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
43
+ """Convert waveform into Linear-frequency Linear-amplitude spectrogram.
44
+
45
+ Args:
46
+ y :: (B, T) - Audio waveforms
47
+ n_fft
48
+ sampling_rate
49
+ hop_size
50
+ win_size
51
+ center
52
+ Returns:
53
+ :: (B, Freq, Frame) - Linear-frequency Linear-amplitude spectrogram
54
+ """
55
+ # Validation
56
+ if torch.min(y) < -1.07:
57
+ logger.debug("min value is %s", str(torch.min(y)))
58
+ if torch.max(y) > 1.07:
59
+ logger.debug("max value is %s", str(torch.max(y)))
60
+
61
+ # Window - Cache if needed
62
+ global hann_window
63
+ dtype_device = str(y.dtype) + "_" + str(y.device)
64
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
65
+ if wnsize_dtype_device not in hann_window:
66
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
67
+ dtype=y.dtype, device=y.device
68
+ )
69
+
70
+ # Padding
71
+ y = torch.nn.functional.pad(
72
+ y.unsqueeze(1),
73
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
74
+ mode="reflect",
75
+ )
76
+ y = y.squeeze(1)
77
+
78
+ # Complex Spectrogram :: (B, T) -> (B, Freq, Frame, RealComplex=2)
79
+ spec = torch.stft(
80
+ y,
81
+ n_fft,
82
+ hop_length=hop_size,
83
+ win_length=win_size,
84
+ window=hann_window[wnsize_dtype_device],
85
+ center=center,
86
+ pad_mode="reflect",
87
+ normalized=False,
88
+ onesided=True,
89
+ return_complex=False,
90
+ )
91
+
92
+ # Linear-frequency Linear-amplitude spectrogram :: (B, Freq, Frame, RealComplex=2) -> (B, Freq, Frame)
93
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
94
+ return spec
95
+
96
+
97
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
98
+ # MelBasis - Cache if needed
99
+ global mel_basis
100
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
101
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
102
+ if fmax_dtype_device not in mel_basis:
103
+ mel = librosa_mel_fn(
104
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
105
+ )
106
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
107
+ dtype=spec.dtype, device=spec.device
108
+ )
109
+
110
+ # Mel-frequency Log-amplitude spectrogram :: (B, Freq=num_mels, Frame)
111
+ melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
112
+ melspec = spectral_normalize_torch(melspec)
113
+ return melspec
114
+
115
+
116
+ def mel_spectrogram_torch(
117
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
118
+ ):
119
+ """Convert waveform into Mel-frequency Log-amplitude spectrogram.
120
+
121
+ Args:
122
+ y :: (B, T) - Waveforms
123
+ Returns:
124
+ melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
125
+ """
126
+ # Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
127
+ spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
128
+
129
+ # Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
130
+ melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
131
+
132
+ return melspec
lib/process_ckpt.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+ from collections import OrderedDict
5
+
6
+ import torch
7
+
8
+ from i18n.i18n import I18nAuto
9
+
10
+ i18n = I18nAuto()
11
+
12
+
13
+ def savee(ckpt, sr, if_f0, name, epoch, version, hps):
14
+ try:
15
+ opt = OrderedDict()
16
+ opt["weight"] = {}
17
+ for key in ckpt.keys():
18
+ if "enc_q" in key:
19
+ continue
20
+ opt["weight"][key] = ckpt[key].half()
21
+ opt["config"] = [
22
+ hps.data.filter_length // 2 + 1,
23
+ 32,
24
+ hps.model.inter_channels,
25
+ hps.model.hidden_channels,
26
+ hps.model.filter_channels,
27
+ hps.model.n_heads,
28
+ hps.model.n_layers,
29
+ hps.model.kernel_size,
30
+ hps.model.p_dropout,
31
+ hps.model.resblock,
32
+ hps.model.resblock_kernel_sizes,
33
+ hps.model.resblock_dilation_sizes,
34
+ hps.model.upsample_rates,
35
+ hps.model.upsample_initial_channel,
36
+ hps.model.upsample_kernel_sizes,
37
+ hps.model.spk_embed_dim,
38
+ hps.model.gin_channels,
39
+ hps.data.sampling_rate,
40
+ ]
41
+ opt["info"] = "%sepoch" % epoch
42
+ opt["sr"] = sr
43
+ opt["f0"] = if_f0
44
+ opt["version"] = version
45
+ torch.save(opt, "assets/weights/%s.pth" % name)
46
+ return "Success."
47
+ except:
48
+ return traceback.format_exc()
49
+
50
+
51
+ def show_info(path):
52
+ try:
53
+ a = torch.load(path, map_location="cpu")
54
+ return "模型信息:%s\n采样率:%s\n模型是否输入音高引导:%s\n版本:%s" % (
55
+ a.get("info", "None"),
56
+ a.get("sr", "None"),
57
+ a.get("f0", "None"),
58
+ a.get("version", "None"),
59
+ )
60
+ except:
61
+ return traceback.format_exc()
62
+
63
+
64
+ def extract_small_model(path, name, sr, if_f0, info, version):
65
+ try:
66
+ ckpt = torch.load(path, map_location="cpu")
67
+ if "model" in ckpt:
68
+ ckpt = ckpt["model"]
69
+ opt = OrderedDict()
70
+ opt["weight"] = {}
71
+ for key in ckpt.keys():
72
+ if "enc_q" in key:
73
+ continue
74
+ opt["weight"][key] = ckpt[key].half()
75
+ if sr == "40k":
76
+ opt["config"] = [
77
+ 1025,
78
+ 32,
79
+ 192,
80
+ 192,
81
+ 768,
82
+ 2,
83
+ 6,
84
+ 3,
85
+ 0,
86
+ "1",
87
+ [3, 7, 11],
88
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
89
+ [10, 10, 2, 2],
90
+ 512,
91
+ [16, 16, 4, 4],
92
+ 109,
93
+ 256,
94
+ 40000,
95
+ ]
96
+ elif sr == "48k":
97
+ if version == "v1":
98
+ opt["config"] = [
99
+ 1025,
100
+ 32,
101
+ 192,
102
+ 192,
103
+ 768,
104
+ 2,
105
+ 6,
106
+ 3,
107
+ 0,
108
+ "1",
109
+ [3, 7, 11],
110
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
111
+ [10, 6, 2, 2, 2],
112
+ 512,
113
+ [16, 16, 4, 4, 4],
114
+ 109,
115
+ 256,
116
+ 48000,
117
+ ]
118
+ else:
119
+ opt["config"] = [
120
+ 1025,
121
+ 32,
122
+ 192,
123
+ 192,
124
+ 768,
125
+ 2,
126
+ 6,
127
+ 3,
128
+ 0,
129
+ "1",
130
+ [3, 7, 11],
131
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
132
+ [12, 10, 2, 2],
133
+ 512,
134
+ [24, 20, 4, 4],
135
+ 109,
136
+ 256,
137
+ 48000,
138
+ ]
139
+ elif sr == "32k":
140
+ if version == "v1":
141
+ opt["config"] = [
142
+ 513,
143
+ 32,
144
+ 192,
145
+ 192,
146
+ 768,
147
+ 2,
148
+ 6,
149
+ 3,
150
+ 0,
151
+ "1",
152
+ [3, 7, 11],
153
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
154
+ [10, 4, 2, 2, 2],
155
+ 512,
156
+ [16, 16, 4, 4, 4],
157
+ 109,
158
+ 256,
159
+ 32000,
160
+ ]
161
+ else:
162
+ opt["config"] = [
163
+ 513,
164
+ 32,
165
+ 192,
166
+ 192,
167
+ 768,
168
+ 2,
169
+ 6,
170
+ 3,
171
+ 0,
172
+ "1",
173
+ [3, 7, 11],
174
+ [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
175
+ [10, 8, 2, 2],
176
+ 512,
177
+ [20, 16, 4, 4],
178
+ 109,
179
+ 256,
180
+ 32000,
181
+ ]
182
+ if info == "":
183
+ info = "Extracted model."
184
+ opt["info"] = info
185
+ opt["version"] = version
186
+ opt["sr"] = sr
187
+ opt["f0"] = int(if_f0)
188
+ torch.save(opt, "assets/weights/%s.pth" % name)
189
+ return "Success."
190
+ except:
191
+ return traceback.format_exc()
192
+
193
+
194
+ def change_info(path, info, name):
195
+ try:
196
+ ckpt = torch.load(path, map_location="cpu")
197
+ ckpt["info"] = info
198
+ if name == "":
199
+ name = os.path.basename(path)
200
+ torch.save(ckpt, "assets/weights/%s" % name)
201
+ return "Success."
202
+ except:
203
+ return traceback.format_exc()
204
+
205
+
206
+ def merge(path1, path2, alpha1, sr, f0, info, name, version):
207
+ try:
208
+
209
+ def extract(ckpt):
210
+ a = ckpt["model"]
211
+ opt = OrderedDict()
212
+ opt["weight"] = {}
213
+ for key in a.keys():
214
+ if "enc_q" in key:
215
+ continue
216
+ opt["weight"][key] = a[key]
217
+ return opt
218
+
219
+ ckpt1 = torch.load(path1, map_location="cpu")
220
+ ckpt2 = torch.load(path2, map_location="cpu")
221
+ cfg = ckpt1["config"]
222
+ if "model" in ckpt1:
223
+ ckpt1 = extract(ckpt1)
224
+ else:
225
+ ckpt1 = ckpt1["weight"]
226
+ if "model" in ckpt2:
227
+ ckpt2 = extract(ckpt2)
228
+ else:
229
+ ckpt2 = ckpt2["weight"]
230
+ if sorted(list(ckpt1.keys())) != sorted(list(ckpt2.keys())):
231
+ return "Fail to merge the models. The model architectures are not the same."
232
+ opt = OrderedDict()
233
+ opt["weight"] = {}
234
+ for key in ckpt1.keys():
235
+ # try:
236
+ if key == "emb_g.weight" and ckpt1[key].shape != ckpt2[key].shape:
237
+ min_shape0 = min(ckpt1[key].shape[0], ckpt2[key].shape[0])
238
+ opt["weight"][key] = (
239
+ alpha1 * (ckpt1[key][:min_shape0].float())
240
+ + (1 - alpha1) * (ckpt2[key][:min_shape0].float())
241
+ ).half()
242
+ else:
243
+ opt["weight"][key] = (
244
+ alpha1 * (ckpt1[key].float()) + (1 - alpha1) * (ckpt2[key].float())
245
+ ).half()
246
+ # except:
247
+ # pdb.set_trace()
248
+ opt["config"] = cfg
249
+ """
250
+ if(sr=="40k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 10, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 40000]
251
+ elif(sr=="48k"):opt["config"] = [1025, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10,6,2,2,2], 512, [16, 16, 4, 4], 109, 256, 48000]
252
+ elif(sr=="32k"):opt["config"] = [513, 32, 192, 192, 768, 2, 6, 3, 0, "1", [3, 7, 11], [[1, 3, 5], [1, 3, 5], [1, 3, 5]], [10, 4, 2, 2, 2], 512, [16, 16, 4, 4,4], 109, 256, 32000]
253
+ """
254
+ opt["sr"] = sr
255
+ opt["f0"] = 1 if f0 == i18n("是") else 0
256
+ opt["version"] = version
257
+ opt["info"] = info
258
+ torch.save(opt, "assets/weights/%s.pth" % name)
259
+ return "Success."
260
+ except:
261
+ return traceback.format_exc()
lib/utils.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import json
4
+ import logging
5
+ import os
6
+ import subprocess
7
+ import sys
8
+ import shutil
9
+
10
+ import numpy as np
11
+ import torch
12
+ from scipy.io.wavfile import read
13
+
14
+ MATPLOTLIB_FLAG = False
15
+
16
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
17
+ logger = logging
18
+
19
+
20
+ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
21
+ assert os.path.isfile(checkpoint_path)
22
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
23
+
24
+ ##################
25
+ def go(model, bkey):
26
+ saved_state_dict = checkpoint_dict[bkey]
27
+ if hasattr(model, "module"):
28
+ state_dict = model.module.state_dict()
29
+ else:
30
+ state_dict = model.state_dict()
31
+ new_state_dict = {}
32
+ for k, v in state_dict.items(): # 模型需要的shape
33
+ try:
34
+ new_state_dict[k] = saved_state_dict[k]
35
+ if saved_state_dict[k].shape != state_dict[k].shape:
36
+ logger.warning(
37
+ "shape-%s-mismatch. need: %s, get: %s",
38
+ k,
39
+ state_dict[k].shape,
40
+ saved_state_dict[k].shape,
41
+ ) #
42
+ raise KeyError
43
+ except:
44
+ # logger.info(traceback.format_exc())
45
+ logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46
+ new_state_dict[k] = v # 模型自带的随机值
47
+ if hasattr(model, "module"):
48
+ model.module.load_state_dict(new_state_dict, strict=False)
49
+ else:
50
+ model.load_state_dict(new_state_dict, strict=False)
51
+ return model
52
+
53
+ go(combd, "combd")
54
+ model = go(sbd, "sbd")
55
+ #############
56
+ logger.info("Loaded model weights")
57
+
58
+ iteration = checkpoint_dict["iteration"]
59
+ learning_rate = checkpoint_dict["learning_rate"]
60
+ if (
61
+ optimizer is not None and load_opt == 1
62
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
63
+ # try:
64
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
65
+ # except:
66
+ # traceback.print_exc()
67
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
68
+ return model, optimizer, learning_rate, iteration
69
+
70
+
71
+ # def load_checkpoint(checkpoint_path, model, optimizer=None):
72
+ # assert os.path.isfile(checkpoint_path)
73
+ # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
74
+ # iteration = checkpoint_dict['iteration']
75
+ # learning_rate = checkpoint_dict['learning_rate']
76
+ # if optimizer is not None:
77
+ # optimizer.load_state_dict(checkpoint_dict['optimizer'])
78
+ # # print(1111)
79
+ # saved_state_dict = checkpoint_dict['model']
80
+ # # print(1111)
81
+ #
82
+ # if hasattr(model, 'module'):
83
+ # state_dict = model.module.state_dict()
84
+ # else:
85
+ # state_dict = model.state_dict()
86
+ # new_state_dict= {}
87
+ # for k, v in state_dict.items():
88
+ # try:
89
+ # new_state_dict[k] = saved_state_dict[k]
90
+ # except:
91
+ # logger.info("%s is not in the checkpoint" % k)
92
+ # new_state_dict[k] = v
93
+ # if hasattr(model, 'module'):
94
+ # model.module.load_state_dict(new_state_dict)
95
+ # else:
96
+ # model.load_state_dict(new_state_dict)
97
+ # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
98
+ # checkpoint_path, iteration))
99
+ # return model, optimizer, learning_rate, iteration
100
+ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
101
+ assert os.path.isfile(checkpoint_path)
102
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
103
+
104
+ saved_state_dict = checkpoint_dict["model"]
105
+ if hasattr(model, "module"):
106
+ state_dict = model.module.state_dict()
107
+ else:
108
+ state_dict = model.state_dict()
109
+ new_state_dict = {}
110
+ for k, v in state_dict.items(): # 模型需要的shape
111
+ try:
112
+ new_state_dict[k] = saved_state_dict[k]
113
+ if saved_state_dict[k].shape != state_dict[k].shape:
114
+ logger.warning(
115
+ "shape-%s-mismatch|need-%s|get-%s",
116
+ k,
117
+ state_dict[k].shape,
118
+ saved_state_dict[k].shape,
119
+ ) #
120
+ raise KeyError
121
+ except:
122
+ # logger.info(traceback.format_exc())
123
+ logger.info("%s is not in the checkpoint", k) # pretrain缺失的
124
+ new_state_dict[k] = v # 模型自带的随机值
125
+ if hasattr(model, "module"):
126
+ model.module.load_state_dict(new_state_dict, strict=False)
127
+ else:
128
+ model.load_state_dict(new_state_dict, strict=False)
129
+ logger.info("Loaded model weights")
130
+
131
+ iteration = checkpoint_dict["iteration"]
132
+ learning_rate = checkpoint_dict["learning_rate"]
133
+ if (
134
+ optimizer is not None and load_opt == 1
135
+ ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
136
+ # try:
137
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
138
+ # except:
139
+ # traceback.print_exc()
140
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
141
+ return model, optimizer, learning_rate, iteration
142
+
143
+
144
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
145
+ logger.info(
146
+ "Saving model and optimizer state at epoch {} to {}".format(
147
+ iteration, checkpoint_path
148
+ )
149
+ )
150
+ if hasattr(model, "module"):
151
+ state_dict = model.module.state_dict()
152
+ else:
153
+ state_dict = model.state_dict()
154
+ torch.save(
155
+ {
156
+ "model": state_dict,
157
+ "iteration": iteration,
158
+ "optimizer": optimizer.state_dict(),
159
+ "learning_rate": learning_rate,
160
+ },
161
+ checkpoint_path,
162
+ )
163
+
164
+
165
+ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
166
+ logger.info(
167
+ "Saving model and optimizer state at epoch {} to {}".format(
168
+ iteration, checkpoint_path
169
+ )
170
+ )
171
+ if hasattr(combd, "module"):
172
+ state_dict_combd = combd.module.state_dict()
173
+ else:
174
+ state_dict_combd = combd.state_dict()
175
+ if hasattr(sbd, "module"):
176
+ state_dict_sbd = sbd.module.state_dict()
177
+ else:
178
+ state_dict_sbd = sbd.state_dict()
179
+ torch.save(
180
+ {
181
+ "combd": state_dict_combd,
182
+ "sbd": state_dict_sbd,
183
+ "iteration": iteration,
184
+ "optimizer": optimizer.state_dict(),
185
+ "learning_rate": learning_rate,
186
+ },
187
+ checkpoint_path,
188
+ )
189
+
190
+
191
+ def summarize(
192
+ writer,
193
+ global_step,
194
+ scalars={},
195
+ histograms={},
196
+ images={},
197
+ audios={},
198
+ audio_sampling_rate=22050,
199
+ ):
200
+ for k, v in scalars.items():
201
+ writer.add_scalar(k, v, global_step)
202
+ for k, v in histograms.items():
203
+ writer.add_histogram(k, v, global_step)
204
+ for k, v in images.items():
205
+ writer.add_image(k, v, global_step, dataformats="HWC")
206
+ for k, v in audios.items():
207
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
208
+
209
+
210
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
211
+ f_list = glob.glob(os.path.join(dir_path, regex))
212
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
213
+ x = f_list[-1]
214
+ logger.debug(x)
215
+ return x
216
+
217
+
218
+ def plot_spectrogram_to_numpy(spectrogram):
219
+ global MATPLOTLIB_FLAG
220
+ if not MATPLOTLIB_FLAG:
221
+ import matplotlib
222
+
223
+ matplotlib.use("Agg")
224
+ MATPLOTLIB_FLAG = True
225
+ mpl_logger = logging.getLogger("matplotlib")
226
+ mpl_logger.setLevel(logging.WARNING)
227
+ import matplotlib.pylab as plt
228
+ import numpy as np
229
+
230
+ fig, ax = plt.subplots(figsize=(10, 2))
231
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
232
+ plt.colorbar(im, ax=ax)
233
+ plt.xlabel("Frames")
234
+ plt.ylabel("Channels")
235
+ plt.tight_layout()
236
+
237
+ fig.canvas.draw()
238
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
239
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
240
+ plt.close()
241
+ return data
242
+
243
+
244
+ def plot_alignment_to_numpy(alignment, info=None):
245
+ global MATPLOTLIB_FLAG
246
+ if not MATPLOTLIB_FLAG:
247
+ import matplotlib
248
+
249
+ matplotlib.use("Agg")
250
+ MATPLOTLIB_FLAG = True
251
+ mpl_logger = logging.getLogger("matplotlib")
252
+ mpl_logger.setLevel(logging.WARNING)
253
+ import matplotlib.pylab as plt
254
+ import numpy as np
255
+
256
+ fig, ax = plt.subplots(figsize=(6, 4))
257
+ im = ax.imshow(
258
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
259
+ )
260
+ fig.colorbar(im, ax=ax)
261
+ xlabel = "Decoder timestep"
262
+ if info is not None:
263
+ xlabel += "\n\n" + info
264
+ plt.xlabel(xlabel)
265
+ plt.ylabel("Encoder timestep")
266
+ plt.tight_layout()
267
+
268
+ fig.canvas.draw()
269
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
270
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
271
+ plt.close()
272
+ return data
273
+
274
+
275
+ def load_wav_to_torch(full_path):
276
+ sampling_rate, data = read(full_path)
277
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
278
+
279
+
280
+ def load_filepaths_and_text(filename, split="|"):
281
+ with open(filename, encoding="utf-8") as f:
282
+ filepaths_and_text = [line.strip().split(split) for line in f]
283
+ return filepaths_and_text
284
+
285
+
286
+ def get_hparams(init=True):
287
+ """
288
+ todo:
289
+ 结尾七人组:
290
+ 保存频率、总epoch done
291
+ bs done
292
+ pretrainG、pretrainD done
293
+ 卡号:os.en["CUDA_VISIBLE_DEVICES"] done
294
+ if_latest done
295
+ 模型:if_f0 done
296
+ 采样率:自动选择config done
297
+ 是否缓存数据集进GPU:if_cache_data_in_gpu done
298
+
299
+ -m:
300
+ 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
301
+ -c不要了
302
+ """
303
+ parser = argparse.ArgumentParser()
304
+ parser.add_argument(
305
+ "-se",
306
+ "--save_every_epoch",
307
+ type=int,
308
+ required=True,
309
+ help="checkpoint save frequency (epoch)",
310
+ )
311
+ parser.add_argument(
312
+ "-te", "--total_epoch", type=int, required=True, help="total_epoch"
313
+ )
314
+ parser.add_argument(
315
+ "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
316
+ )
317
+ parser.add_argument(
318
+ "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
319
+ )
320
+ parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
321
+ parser.add_argument(
322
+ "-bs", "--batch_size", type=int, required=True, help="batch size"
323
+ )
324
+ parser.add_argument(
325
+ "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
326
+ ) # -m
327
+ parser.add_argument(
328
+ "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
329
+ )
330
+ parser.add_argument(
331
+ "-sw",
332
+ "--save_every_weights",
333
+ type=str,
334
+ default="0",
335
+ help="save the extracted model in weights directory when saving checkpoints",
336
+ )
337
+ parser.add_argument(
338
+ "-v", "--version", type=str, required=True, help="model version"
339
+ )
340
+ parser.add_argument(
341
+ "-f0",
342
+ "--if_f0",
343
+ type=int,
344
+ required=True,
345
+ help="use f0 as one of the inputs of the model, 1 or 0",
346
+ )
347
+ parser.add_argument(
348
+ "-l",
349
+ "--if_latest",
350
+ type=int,
351
+ required=True,
352
+ help="if only save the latest G/D pth file, 1 or 0",
353
+ )
354
+ parser.add_argument(
355
+ "-c",
356
+ "--if_cache_data_in_gpu",
357
+ type=int,
358
+ required=True,
359
+ help="if caching the dataset in GPU memory, 1 or 0",
360
+ )
361
+
362
+ args = parser.parse_args()
363
+ name = args.experiment_dir
364
+ experiment_dir = os.path.join("./logs", args.experiment_dir)
365
+
366
+ config_save_path = os.path.join(experiment_dir, "config.json")
367
+ with open(config_save_path, "r") as f:
368
+ config = json.load(f)
369
+
370
+ hparams = HParams(**config)
371
+ hparams.model_dir = hparams.experiment_dir = experiment_dir
372
+ hparams.save_every_epoch = args.save_every_epoch
373
+ hparams.name = name
374
+ hparams.total_epoch = args.total_epoch
375
+ hparams.pretrainG = args.pretrainG
376
+ hparams.pretrainD = args.pretrainD
377
+ hparams.version = args.version
378
+ hparams.gpus = args.gpus
379
+ hparams.train.batch_size = args.batch_size
380
+ hparams.sample_rate = args.sample_rate
381
+ hparams.if_f0 = args.if_f0
382
+ hparams.if_latest = args.if_latest
383
+ hparams.save_every_weights = args.save_every_weights
384
+ hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
385
+ hparams.data.training_files = "%s/filelist.txt" % experiment_dir
386
+ return hparams
387
+
388
+
389
+ def get_hparams_from_dir(model_dir):
390
+ config_save_path = os.path.join(model_dir, "config.json")
391
+ with open(config_save_path, "r") as f:
392
+ data = f.read()
393
+ config = json.loads(data)
394
+
395
+ hparams = HParams(**config)
396
+ hparams.model_dir = model_dir
397
+ return hparams
398
+
399
+
400
+ def get_hparams_from_file(config_path):
401
+ with open(config_path, "r") as f:
402
+ data = f.read()
403
+ config = json.loads(data)
404
+
405
+ hparams = HParams(**config)
406
+ return hparams
407
+
408
+
409
+ def check_git_hash(model_dir):
410
+ source_dir = os.path.dirname(os.path.realpath(__file__))
411
+ if not os.path.exists(os.path.join(source_dir, ".git")):
412
+ logger.warning(
413
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
414
+ source_dir
415
+ )
416
+ )
417
+ return
418
+
419
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
420
+
421
+ path = os.path.join(model_dir, "githash")
422
+ if os.path.exists(path):
423
+ saved_hash = open(path).read()
424
+ if saved_hash != cur_hash:
425
+ logger.warning(
426
+ "git hash values are different. {}(saved) != {}(current)".format(
427
+ saved_hash[:8], cur_hash[:8]
428
+ )
429
+ )
430
+ else:
431
+ open(path, "w").write(cur_hash)
432
+
433
+
434
+ def get_logger(model_dir, filename="train.log"):
435
+ global logger
436
+ logger = logging.getLogger(os.path.basename(model_dir))
437
+ logger.setLevel(logging.DEBUG)
438
+
439
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
440
+ if not os.path.exists(model_dir):
441
+ os.makedirs(model_dir)
442
+ h = logging.FileHandler(os.path.join(model_dir, filename))
443
+ h.setLevel(logging.DEBUG)
444
+ h.setFormatter(formatter)
445
+ logger.addHandler(h)
446
+ return logger
447
+
448
+
449
+ class HParams:
450
+ def __init__(self, **kwargs):
451
+ for k, v in kwargs.items():
452
+ if type(v) == dict:
453
+ v = HParams(**v)
454
+ self[k] = v
455
+
456
+ def keys(self):
457
+ return self.__dict__.keys()
458
+
459
+ def items(self):
460
+ return self.__dict__.items()
461
+
462
+ def values(self):
463
+ return self.__dict__.values()
464
+
465
+ def __len__(self):
466
+ return len(self.__dict__)
467
+
468
+ def __getitem__(self, key):
469
+ return getattr(self, key)
470
+
471
+ def __setitem__(self, key, value):
472
+ return setattr(self, key, value)
473
+
474
+ def __contains__(self, key):
475
+ return key in self.__dict__
476
+
477
+ def __repr__(self):
478
+ return self.__dict__.__repr__()