PhoenixStormJr commited on
Commit
c5ca0a1
·
verified ·
1 Parent(s): 8b8bd90

Update train/utils.py

Browse files
Files changed (1) hide show
  1. train/utils.py +486 -486
train/utils.py CHANGED
@@ -1,486 +1,486 @@
1
- import os, traceback
2
- import glob
3
- import sys
4
- import argparse
5
- import logging
6
- import json
7
- import subprocess
8
- import numpy as np
9
- from scipy.io.wavfile import read
10
- import torch
11
-
12
- MATPLOTLIB_FLAG = False
13
-
14
- logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
- logger = logging
16
-
17
-
18
- def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
19
- assert os.path.isfile(checkpoint_path)
20
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
-
22
- ##################
23
- def go(model, bkey):
24
- saved_state_dict = checkpoint_dict[bkey]
25
- if hasattr(model, "module"):
26
- state_dict = model.module.state_dict()
27
- else:
28
- state_dict = model.state_dict()
29
- new_state_dict = {}
30
- for k, v in state_dict.items(): # 模型需要的shape
31
- try:
32
- new_state_dict[k] = saved_state_dict[k]
33
- if saved_state_dict[k].shape != state_dict[k].shape:
34
- print(
35
- "shape-%s-mismatch|need-%s|get-%s"
36
- % (k, state_dict[k].shape, saved_state_dict[k].shape)
37
- ) #
38
- raise KeyError
39
- except:
40
- # logger.info(traceback.format_exc())
41
- logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
42
- new_state_dict[k] = v # 模型自带的随机值
43
- if hasattr(model, "module"):
44
- model.module.load_state_dict(new_state_dict, strict=False)
45
- else:
46
- model.load_state_dict(new_state_dict, strict=False)
47
-
48
- go(combd, "combd")
49
- go(sbd, "sbd")
50
- #############
51
- logger.info("Loaded model weights")
52
-
53
- iteration = checkpoint_dict["iteration"]
54
- learning_rate = checkpoint_dict["learning_rate"]
55
- if (
56
- optimizer is not None and load_opt == 1
57
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
58
- # try:
59
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
60
- # except:
61
- # traceback.print_exc()
62
- logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
63
- return model, optimizer, learning_rate, iteration
64
-
65
-
66
- # def load_checkpoint(checkpoint_path, model, optimizer=None):
67
- # assert os.path.isfile(checkpoint_path)
68
- # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
69
- # iteration = checkpoint_dict['iteration']
70
- # learning_rate = checkpoint_dict['learning_rate']
71
- # if optimizer is not None:
72
- # optimizer.load_state_dict(checkpoint_dict['optimizer'])
73
- # # print(1111)
74
- # saved_state_dict = checkpoint_dict['model']
75
- # # print(1111)
76
- #
77
- # if hasattr(model, 'module'):
78
- # state_dict = model.module.state_dict()
79
- # else:
80
- # state_dict = model.state_dict()
81
- # new_state_dict= {}
82
- # for k, v in state_dict.items():
83
- # try:
84
- # new_state_dict[k] = saved_state_dict[k]
85
- # except:
86
- # logger.info("%s is not in the checkpoint" % k)
87
- # new_state_dict[k] = v
88
- # if hasattr(model, 'module'):
89
- # model.module.load_state_dict(new_state_dict)
90
- # else:
91
- # model.load_state_dict(new_state_dict)
92
- # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
93
- # checkpoint_path, iteration))
94
- # return model, optimizer, learning_rate, iteration
95
- def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
96
- assert os.path.isfile(checkpoint_path)
97
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
98
-
99
- saved_state_dict = checkpoint_dict["model"]
100
- if hasattr(model, "module"):
101
- state_dict = model.module.state_dict()
102
- else:
103
- state_dict = model.state_dict()
104
- new_state_dict = {}
105
- for k, v in state_dict.items(): # 模型需要的shape
106
- try:
107
- new_state_dict[k] = saved_state_dict[k]
108
- if saved_state_dict[k].shape != state_dict[k].shape:
109
- print(
110
- "shape-%s-mismatch|need-%s|get-%s"
111
- % (k, state_dict[k].shape, saved_state_dict[k].shape)
112
- ) #
113
- raise KeyError
114
- except:
115
- # logger.info(traceback.format_exc())
116
- logger.info("%s is not in the checkpoint" % k) # pretrain缺失的
117
- new_state_dict[k] = v # 模型自带的随机值
118
- if hasattr(model, "module"):
119
- model.module.load_state_dict(new_state_dict, strict=False)
120
- else:
121
- model.load_state_dict(new_state_dict, strict=False)
122
- logger.info("Loaded model weights")
123
-
124
- iteration = checkpoint_dict["iteration"]
125
- learning_rate = checkpoint_dict["learning_rate"]
126
- if (
127
- optimizer is not None and load_opt == 1
128
- ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
129
- # try:
130
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
131
- # except:
132
- # traceback.print_exc()
133
- logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
134
- return model, optimizer, learning_rate, iteration
135
-
136
-
137
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
138
- logger.info(
139
- "Saving model and optimizer state at epoch {} to {}".format(
140
- iteration, checkpoint_path
141
- )
142
- )
143
- if hasattr(model, "module"):
144
- state_dict = model.module.state_dict()
145
- else:
146
- state_dict = model.state_dict()
147
- torch.save(
148
- {
149
- "model": state_dict,
150
- "iteration": iteration,
151
- "optimizer": optimizer.state_dict(),
152
- "learning_rate": learning_rate,
153
- },
154
- checkpoint_path,
155
- )
156
-
157
-
158
- def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
159
- logger.info(
160
- "Saving model and optimizer state at epoch {} to {}".format(
161
- iteration, checkpoint_path
162
- )
163
- )
164
- if hasattr(combd, "module"):
165
- state_dict_combd = combd.module.state_dict()
166
- else:
167
- state_dict_combd = combd.state_dict()
168
- if hasattr(sbd, "module"):
169
- state_dict_sbd = sbd.module.state_dict()
170
- else:
171
- state_dict_sbd = sbd.state_dict()
172
- torch.save(
173
- {
174
- "combd": state_dict_combd,
175
- "sbd": state_dict_sbd,
176
- "iteration": iteration,
177
- "optimizer": optimizer.state_dict(),
178
- "learning_rate": learning_rate,
179
- },
180
- checkpoint_path,
181
- )
182
-
183
-
184
- def summarize(
185
- writer,
186
- global_step,
187
- scalars={},
188
- histograms={},
189
- images={},
190
- audios={},
191
- audio_sampling_rate=22050,
192
- ):
193
- for k, v in scalars.items():
194
- writer.add_scalar(k, v, global_step)
195
- for k, v in histograms.items():
196
- writer.add_histogram(k, v, global_step)
197
- for k, v in images.items():
198
- writer.add_image(k, v, global_step, dataformats="HWC")
199
- for k, v in audios.items():
200
- writer.add_audio(k, v, global_step, audio_sampling_rate)
201
-
202
-
203
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204
- f_list = glob.glob(os.path.join(dir_path, regex))
205
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206
- x = f_list[-1]
207
- print(x)
208
- return x
209
-
210
-
211
- def plot_spectrogram_to_numpy(spectrogram):
212
- global MATPLOTLIB_FLAG
213
- if not MATPLOTLIB_FLAG:
214
- import matplotlib
215
-
216
- matplotlib.use("Agg")
217
- MATPLOTLIB_FLAG = True
218
- mpl_logger = logging.getLogger("matplotlib")
219
- mpl_logger.setLevel(logging.WARNING)
220
- import matplotlib.pylab as plt
221
- import numpy as np
222
-
223
- fig, ax = plt.subplots(figsize=(10, 2))
224
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
225
- plt.colorbar(im, ax=ax)
226
- plt.xlabel("Frames")
227
- plt.ylabel("Channels")
228
- plt.tight_layout()
229
-
230
- fig.canvas.draw()
231
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
232
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
233
- plt.close()
234
- return data
235
-
236
-
237
- def plot_alignment_to_numpy(alignment, info=None):
238
- global MATPLOTLIB_FLAG
239
- if not MATPLOTLIB_FLAG:
240
- import matplotlib
241
-
242
- matplotlib.use("Agg")
243
- MATPLOTLIB_FLAG = True
244
- mpl_logger = logging.getLogger("matplotlib")
245
- mpl_logger.setLevel(logging.WARNING)
246
- import matplotlib.pylab as plt
247
- import numpy as np
248
-
249
- fig, ax = plt.subplots(figsize=(6, 4))
250
- im = ax.imshow(
251
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
252
- )
253
- fig.colorbar(im, ax=ax)
254
- xlabel = "Decoder timestep"
255
- if info is not None:
256
- xlabel += "\n\n" + info
257
- plt.xlabel(xlabel)
258
- plt.ylabel("Encoder timestep")
259
- plt.tight_layout()
260
-
261
- fig.canvas.draw()
262
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
263
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
264
- plt.close()
265
- return data
266
-
267
-
268
- def load_wav_to_torch(full_path):
269
- sampling_rate, data = read(full_path)
270
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
271
-
272
-
273
- def load_filepaths_and_text(filename, split="|"):
274
- with open(filename, encoding="utf-8") as f:
275
- filepaths_and_text = [line.strip().split(split) for line in f]
276
- return filepaths_and_text
277
-
278
-
279
- def get_hparams(init=True):
280
- """
281
- todo:
282
- 结尾七人组:
283
- 保存频率、总epoch done
284
- bs done
285
- pretrainGpretrainD done
286
- 卡号:os.en["CUDA_VISIBLE_DEVICES"] done
287
- if_latest done
288
- 模型:if_f0 done
289
- 采样率:自动选��config done
290
- 是否缓存数据集进GPU:if_cache_data_in_gpu done
291
-
292
- -m:
293
- 自动决定training_files路径,改掉train_nsf_load_pretrain.py里的hps.data.training_files done
294
- -c不要了
295
- """
296
- parser = argparse.ArgumentParser()
297
- # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
298
- parser.add_argument(
299
- "-se",
300
- "--save_every_epoch",
301
- type=int,
302
- required=True,
303
- help="checkpoint save frequency (epoch)",
304
- )
305
- parser.add_argument(
306
- "-te", "--total_epoch", type=int, required=True, help="total_epoch"
307
- )
308
- parser.add_argument(
309
- "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
310
- )
311
- parser.add_argument(
312
- "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
313
- )
314
- parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
315
- parser.add_argument(
316
- "-bs", "--batch_size", type=int, required=True, help="batch size"
317
- )
318
- parser.add_argument(
319
- "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
320
- ) # -m
321
- parser.add_argument(
322
- "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323
- )
324
- parser.add_argument(
325
- "-sw",
326
- "--save_every_weights",
327
- type=str,
328
- default="0",
329
- help="save the extracted model in weights directory when saving checkpoints",
330
- )
331
- parser.add_argument(
332
- "-v", "--version", type=str, required=True, help="model version"
333
- )
334
- parser.add_argument(
335
- "-f0",
336
- "--if_f0",
337
- type=int,
338
- required=True,
339
- help="use f0 as one of the inputs of the model, 1 or 0",
340
- )
341
- parser.add_argument(
342
- "-l",
343
- "--if_latest",
344
- type=int,
345
- required=True,
346
- help="if only save the latest G/D pth file, 1 or 0",
347
- )
348
- parser.add_argument(
349
- "-c",
350
- "--if_cache_data_in_gpu",
351
- type=int,
352
- required=True,
353
- help="if caching the dataset in GPU memory, 1 or 0",
354
- )
355
-
356
- args = parser.parse_args()
357
- name = args.experiment_dir
358
- experiment_dir = os.path.join("./logs", args.experiment_dir)
359
-
360
- if not os.path.exists(experiment_dir):
361
- os.makedirs(experiment_dir)
362
-
363
- if args.version == "v1" or args.sample_rate == "40k":
364
- config_path = "configs/%s.json" % args.sample_rate
365
- else:
366
- config_path = "configs/%s_v2.json" % args.sample_rate
367
- config_save_path = os.path.join(experiment_dir, "config.json")
368
- if init:
369
- with open(config_path, "r") as f:
370
- data = f.read()
371
- with open(config_save_path, "w") as f:
372
- f.write(data)
373
- else:
374
- with open(config_save_path, "r") as f:
375
- data = f.read()
376
- config = json.loads(data)
377
-
378
- hparams = HParams(**config)
379
- hparams.model_dir = hparams.experiment_dir = experiment_dir
380
- hparams.save_every_epoch = args.save_every_epoch
381
- hparams.name = name
382
- hparams.total_epoch = args.total_epoch
383
- hparams.pretrainG = args.pretrainG
384
- hparams.pretrainD = args.pretrainD
385
- hparams.version = args.version
386
- hparams.gpus = args.gpus
387
- hparams.train.batch_size = args.batch_size
388
- hparams.sample_rate = args.sample_rate
389
- hparams.if_f0 = args.if_f0
390
- hparams.if_latest = args.if_latest
391
- hparams.save_every_weights = args.save_every_weights
392
- hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
393
- hparams.data.training_files = "%s/filelist.txt" % experiment_dir
394
- return hparams
395
-
396
-
397
- def get_hparams_from_dir(model_dir):
398
- config_save_path = os.path.join(model_dir, "config.json")
399
- with open(config_save_path, "r") as f:
400
- data = f.read()
401
- config = json.loads(data)
402
-
403
- hparams = HParams(**config)
404
- hparams.model_dir = model_dir
405
- return hparams
406
-
407
-
408
- def get_hparams_from_file(config_path):
409
- with open(config_path, "r") as f:
410
- data = f.read()
411
- config = json.loads(data)
412
-
413
- hparams = HParams(**config)
414
- return hparams
415
-
416
-
417
- def check_git_hash(model_dir):
418
- source_dir = os.path.dirname(os.path.realpath(__file__))
419
- if not os.path.exists(os.path.join(source_dir, ".git")):
420
- logger.warn(
421
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
422
- source_dir
423
- )
424
- )
425
- return
426
-
427
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
428
-
429
- path = os.path.join(model_dir, "githash")
430
- if os.path.exists(path):
431
- saved_hash = open(path).read()
432
- if saved_hash != cur_hash:
433
- logger.warn(
434
- "git hash values are different. {}(saved) != {}(current)".format(
435
- saved_hash[:8], cur_hash[:8]
436
- )
437
- )
438
- else:
439
- open(path, "w").write(cur_hash)
440
-
441
-
442
- def get_logger(model_dir, filename="train.log"):
443
- global logger
444
- logger = logging.getLogger(os.path.basename(model_dir))
445
- logger.setLevel(logging.DEBUG)
446
-
447
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
448
- if not os.path.exists(model_dir):
449
- os.makedirs(model_dir)
450
- h = logging.FileHandler(os.path.join(model_dir, filename))
451
- h.setLevel(logging.DEBUG)
452
- h.setFormatter(formatter)
453
- logger.addHandler(h)
454
- return logger
455
-
456
-
457
- class HParams:
458
- def __init__(self, **kwargs):
459
- for k, v in kwargs.items():
460
- if type(v) == dict:
461
- v = HParams(**v)
462
- self[k] = v
463
-
464
- def keys(self):
465
- return self.__dict__.keys()
466
-
467
- def items(self):
468
- return self.__dict__.items()
469
-
470
- def values(self):
471
- return self.__dict__.values()
472
-
473
- def __len__(self):
474
- return len(self.__dict__)
475
-
476
- def __getitem__(self, key):
477
- return getattr(self, key)
478
-
479
- def __setitem__(self, key, value):
480
- return setattr(self, key, value)
481
-
482
- def __contains__(self, key):
483
- return key in self.__dict__
484
-
485
- def __repr__(self):
486
- return self.__dict__.__repr__()
 
1
+ import os, traceback
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+
22
+ ##################
23
+ def go(model, bkey):
24
+ saved_state_dict = checkpoint_dict[bkey]
25
+ if hasattr(model, "module"):
26
+ state_dict = model.module.state_dict()
27
+ else:
28
+ state_dict = model.state_dict()
29
+ new_state_dict = {}
30
+ for k, v in state_dict.items(): # The shape required by the model
31
+ try:
32
+ new_state_dict[k] = saved_state_dict[k]
33
+ if saved_state_dict[k].shape != state_dict[k].shape:
34
+ print(
35
+ "shape-%s-mismatch|need-%s|get-%s"
36
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
37
+ ) #
38
+ raise KeyError
39
+ except:
40
+ # logger.info(traceback.format_exc())
41
+ logger.info("%s is not in the checkpoint" % k) # pretrain is missing
42
+ new_state_dict[k] = v # Random values that come with the model
43
+ if hasattr(model, "module"):
44
+ model.module.load_state_dict(new_state_dict, strict=False)
45
+ else:
46
+ model.load_state_dict(new_state_dict, strict=False)
47
+
48
+ go(combd, "combd")
49
+ go(sbd, "sbd")
50
+ #############
51
+ logger.info("Loaded model weights")
52
+
53
+ iteration = checkpoint_dict["iteration"]
54
+ learning_rate = checkpoint_dict["learning_rate"]
55
+ if (
56
+ optimizer is not None and load_opt == 1
57
+ ): ###Unable to load. If it is empty, reinitialize it. It may also affect the update of the lr schedule. Therefore, catch it at the outermost edge of the train file.
58
+ # try:
59
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
60
+ # except:
61
+ # traceback.print_exc()
62
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
63
+ return model, optimizer, learning_rate, iteration
64
+
65
+
66
+ # def load_checkpoint(checkpoint_path, model, optimizer=None):
67
+ # assert os.path.isfile(checkpoint_path)
68
+ # checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
69
+ # iteration = checkpoint_dict['iteration']
70
+ # learning_rate = checkpoint_dict['learning_rate']
71
+ # if optimizer is not None:
72
+ # optimizer.load_state_dict(checkpoint_dict['optimizer'])
73
+ # # print(1111)
74
+ # saved_state_dict = checkpoint_dict['model']
75
+ # # print(1111)
76
+ #
77
+ # if hasattr(model, 'module'):
78
+ # state_dict = model.module.state_dict()
79
+ # else:
80
+ # state_dict = model.state_dict()
81
+ # new_state_dict= {}
82
+ # for k, v in state_dict.items():
83
+ # try:
84
+ # new_state_dict[k] = saved_state_dict[k]
85
+ # except:
86
+ # logger.info("%s is not in the checkpoint" % k)
87
+ # new_state_dict[k] = v
88
+ # if hasattr(model, 'module'):
89
+ # model.module.load_state_dict(new_state_dict)
90
+ # else:
91
+ # model.load_state_dict(new_state_dict)
92
+ # logger.info("Loaded checkpoint '{}' (epoch {})" .format(
93
+ # checkpoint_path, iteration))
94
+ # return model, optimizer, learning_rate, iteration
95
+ def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
96
+ assert os.path.isfile(checkpoint_path)
97
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
98
+
99
+ saved_state_dict = checkpoint_dict["model"]
100
+ if hasattr(model, "module"):
101
+ state_dict = model.module.state_dict()
102
+ else:
103
+ state_dict = model.state_dict()
104
+ new_state_dict = {}
105
+ for k, v in state_dict.items(): # The shape required by the model
106
+ try:
107
+ new_state_dict[k] = saved_state_dict[k]
108
+ if saved_state_dict[k].shape != state_dict[k].shape:
109
+ print(
110
+ "shape-%s-mismatch|need-%s|get-%s"
111
+ % (k, state_dict[k].shape, saved_state_dict[k].shape)
112
+ ) #
113
+ raise KeyError
114
+ except:
115
+ # logger.info(traceback.format_exc())
116
+ logger.info("%s is not in the checkpoint" % k) # pretrain is missing
117
+ new_state_dict[k] = v # Random values ​​that come with the model
118
+ if hasattr(model, "module"):
119
+ model.module.load_state_dict(new_state_dict, strict=False)
120
+ else:
121
+ model.load_state_dict(new_state_dict, strict=False)
122
+ logger.info("Loaded model weights")
123
+
124
+ iteration = checkpoint_dict["iteration"]
125
+ learning_rate = checkpoint_dict["learning_rate"]
126
+ if (
127
+ optimizer is not None and load_opt == 1
128
+ ): ###Cannot load, if it is empty, reinitialize, it may also affect the update of lr schedule, so catch at the outermost edge of train file
129
+ # try:
130
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
131
+ # except:
132
+ # traceback.print_exc()
133
+ logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
134
+ return model, optimizer, learning_rate, iteration
135
+
136
+
137
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
138
+ logger.info(
139
+ "Saving model and optimizer state at epoch {} to {}".format(
140
+ iteration, checkpoint_path
141
+ )
142
+ )
143
+ if hasattr(model, "module"):
144
+ state_dict = model.module.state_dict()
145
+ else:
146
+ state_dict = model.state_dict()
147
+ torch.save(
148
+ {
149
+ "model": state_dict,
150
+ "iteration": iteration,
151
+ "optimizer": optimizer.state_dict(),
152
+ "learning_rate": learning_rate,
153
+ },
154
+ checkpoint_path,
155
+ )
156
+
157
+
158
+ def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
159
+ logger.info(
160
+ "Saving model and optimizer state at epoch {} to {}".format(
161
+ iteration, checkpoint_path
162
+ )
163
+ )
164
+ if hasattr(combd, "module"):
165
+ state_dict_combd = combd.module.state_dict()
166
+ else:
167
+ state_dict_combd = combd.state_dict()
168
+ if hasattr(sbd, "module"):
169
+ state_dict_sbd = sbd.module.state_dict()
170
+ else:
171
+ state_dict_sbd = sbd.state_dict()
172
+ torch.save(
173
+ {
174
+ "combd": state_dict_combd,
175
+ "sbd": state_dict_sbd,
176
+ "iteration": iteration,
177
+ "optimizer": optimizer.state_dict(),
178
+ "learning_rate": learning_rate,
179
+ },
180
+ checkpoint_path,
181
+ )
182
+
183
+
184
+ def summarize(
185
+ writer,
186
+ global_step,
187
+ scalars={},
188
+ histograms={},
189
+ images={},
190
+ audios={},
191
+ audio_sampling_rate=22050,
192
+ ):
193
+ for k, v in scalars.items():
194
+ writer.add_scalar(k, v, global_step)
195
+ for k, v in histograms.items():
196
+ writer.add_histogram(k, v, global_step)
197
+ for k, v in images.items():
198
+ writer.add_image(k, v, global_step, dataformats="HWC")
199
+ for k, v in audios.items():
200
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
201
+
202
+
203
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
204
+ f_list = glob.glob(os.path.join(dir_path, regex))
205
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
206
+ x = f_list[-1]
207
+ print(x)
208
+ return x
209
+
210
+
211
+ def plot_spectrogram_to_numpy(spectrogram):
212
+ global MATPLOTLIB_FLAG
213
+ if not MATPLOTLIB_FLAG:
214
+ import matplotlib
215
+
216
+ matplotlib.use("Agg")
217
+ MATPLOTLIB_FLAG = True
218
+ mpl_logger = logging.getLogger("matplotlib")
219
+ mpl_logger.setLevel(logging.WARNING)
220
+ import matplotlib.pylab as plt
221
+ import numpy as np
222
+
223
+ fig, ax = plt.subplots(figsize=(10, 2))
224
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
225
+ plt.colorbar(im, ax=ax)
226
+ plt.xlabel("Frames")
227
+ plt.ylabel("Channels")
228
+ plt.tight_layout()
229
+
230
+ fig.canvas.draw()
231
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
232
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
233
+ plt.close()
234
+ return data
235
+
236
+
237
+ def plot_alignment_to_numpy(alignment, info=None):
238
+ global MATPLOTLIB_FLAG
239
+ if not MATPLOTLIB_FLAG:
240
+ import matplotlib
241
+
242
+ matplotlib.use("Agg")
243
+ MATPLOTLIB_FLAG = True
244
+ mpl_logger = logging.getLogger("matplotlib")
245
+ mpl_logger.setLevel(logging.WARNING)
246
+ import matplotlib.pylab as plt
247
+ import numpy as np
248
+
249
+ fig, ax = plt.subplots(figsize=(6, 4))
250
+ im = ax.imshow(
251
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
252
+ )
253
+ fig.colorbar(im, ax=ax)
254
+ xlabel = "Decoder timestep"
255
+ if info is not None:
256
+ xlabel += "\n\n" + info
257
+ plt.xlabel(xlabel)
258
+ plt.ylabel("Encoder timestep")
259
+ plt.tight_layout()
260
+
261
+ fig.canvas.draw()
262
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
263
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
264
+ plt.close()
265
+ return data
266
+
267
+
268
+ def load_wav_to_torch(full_path):
269
+ sampling_rate, data = read(full_path)
270
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
271
+
272
+
273
+ def load_filepaths_and_text(filename, split="|"):
274
+ with open(filename, encoding="utf-8") as f:
275
+ filepaths_and_text = [line.strip().split(split) for line in f]
276
+ return filepaths_and_text
277
+
278
+
279
+ def get_hparams(init=True):
280
+ """
281
+ todo:
282
+ Ending group of seven:
283
+ Save frequency, total epoch done
284
+ bs done
285
+ pretrainG, pretrainD done
286
+ Card number: os.en["CUDA_VISIBLE_DEVICES"] done
287
+ if_latest done
288
+ Model: if_f0 done
289
+ Sampling rate: Automatically select config done
290
+ Whether to cache the data set into the GPU: if_cache_data_in_gpu done
291
+
292
+ -m:
293
+ Automatically determine the training_files path, change the hps.data.training_files in train_nsf_load_pretrain.py done
294
+ -c no longer needed
295
+ """
296
+ parser = argparse.ArgumentParser()
297
+ # parser.add_argument('-c', '--config', type=str, default="configs/40k.json",help='JSON file for configuration')
298
+ parser.add_argument(
299
+ "-se",
300
+ "--save_every_epoch",
301
+ type=int,
302
+ required=True,
303
+ help="checkpoint save frequency (epoch)",
304
+ )
305
+ parser.add_argument(
306
+ "-te", "--total_epoch", type=int, required=True, help="total_epoch"
307
+ )
308
+ parser.add_argument(
309
+ "-pg", "--pretrainG", type=str, default="", help="Pretrained Discriminator path"
310
+ )
311
+ parser.add_argument(
312
+ "-pd", "--pretrainD", type=str, default="", help="Pretrained Generator path"
313
+ )
314
+ parser.add_argument("-g", "--gpus", type=str, default="0", help="split by -")
315
+ parser.add_argument(
316
+ "-bs", "--batch_size", type=int, required=True, help="batch size"
317
+ )
318
+ parser.add_argument(
319
+ "-e", "--experiment_dir", type=str, required=True, help="experiment dir"
320
+ ) # -m
321
+ parser.add_argument(
322
+ "-sr", "--sample_rate", type=str, required=True, help="sample rate, 32k/40k/48k"
323
+ )
324
+ parser.add_argument(
325
+ "-sw",
326
+ "--save_every_weights",
327
+ type=str,
328
+ default="0",
329
+ help="save the extracted model in weights directory when saving checkpoints",
330
+ )
331
+ parser.add_argument(
332
+ "-v", "--version", type=str, required=True, help="model version"
333
+ )
334
+ parser.add_argument(
335
+ "-f0",
336
+ "--if_f0",
337
+ type=int,
338
+ required=True,
339
+ help="use f0 as one of the inputs of the model, 1 or 0",
340
+ )
341
+ parser.add_argument(
342
+ "-l",
343
+ "--if_latest",
344
+ type=int,
345
+ required=True,
346
+ help="if only save the latest G/D pth file, 1 or 0",
347
+ )
348
+ parser.add_argument(
349
+ "-c",
350
+ "--if_cache_data_in_gpu",
351
+ type=int,
352
+ required=True,
353
+ help="if caching the dataset in GPU memory, 1 or 0",
354
+ )
355
+
356
+ args = parser.parse_args()
357
+ name = args.experiment_dir
358
+ experiment_dir = os.path.join("./logs", args.experiment_dir)
359
+
360
+ if not os.path.exists(experiment_dir):
361
+ os.makedirs(experiment_dir)
362
+
363
+ if args.version == "v1" or args.sample_rate == "40k":
364
+ config_path = "configs/%s.json" % args.sample_rate
365
+ else:
366
+ config_path = "configs/%s_v2.json" % args.sample_rate
367
+ config_save_path = os.path.join(experiment_dir, "config.json")
368
+ if init:
369
+ with open(config_path, "r") as f:
370
+ data = f.read()
371
+ with open(config_save_path, "w") as f:
372
+ f.write(data)
373
+ else:
374
+ with open(config_save_path, "r") as f:
375
+ data = f.read()
376
+ config = json.loads(data)
377
+
378
+ hparams = HParams(**config)
379
+ hparams.model_dir = hparams.experiment_dir = experiment_dir
380
+ hparams.save_every_epoch = args.save_every_epoch
381
+ hparams.name = name
382
+ hparams.total_epoch = args.total_epoch
383
+ hparams.pretrainG = args.pretrainG
384
+ hparams.pretrainD = args.pretrainD
385
+ hparams.version = args.version
386
+ hparams.gpus = args.gpus
387
+ hparams.train.batch_size = args.batch_size
388
+ hparams.sample_rate = args.sample_rate
389
+ hparams.if_f0 = args.if_f0
390
+ hparams.if_latest = args.if_latest
391
+ hparams.save_every_weights = args.save_every_weights
392
+ hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu
393
+ hparams.data.training_files = "%s/filelist.txt" % experiment_dir
394
+ return hparams
395
+
396
+
397
+ def get_hparams_from_dir(model_dir):
398
+ config_save_path = os.path.join(model_dir, "config.json")
399
+ with open(config_save_path, "r") as f:
400
+ data = f.read()
401
+ config = json.loads(data)
402
+
403
+ hparams = HParams(**config)
404
+ hparams.model_dir = model_dir
405
+ return hparams
406
+
407
+
408
+ def get_hparams_from_file(config_path):
409
+ with open(config_path, "r") as f:
410
+ data = f.read()
411
+ config = json.loads(data)
412
+
413
+ hparams = HParams(**config)
414
+ return hparams
415
+
416
+
417
+ def check_git_hash(model_dir):
418
+ source_dir = os.path.dirname(os.path.realpath(__file__))
419
+ if not os.path.exists(os.path.join(source_dir, ".git")):
420
+ logger.warn(
421
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
422
+ source_dir
423
+ )
424
+ )
425
+ return
426
+
427
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
428
+
429
+ path = os.path.join(model_dir, "githash")
430
+ if os.path.exists(path):
431
+ saved_hash = open(path).read()
432
+ if saved_hash != cur_hash:
433
+ logger.warn(
434
+ "git hash values are different. {}(saved) != {}(current)".format(
435
+ saved_hash[:8], cur_hash[:8]
436
+ )
437
+ )
438
+ else:
439
+ open(path, "w").write(cur_hash)
440
+
441
+
442
+ def get_logger(model_dir, filename="train.log"):
443
+ global logger
444
+ logger = logging.getLogger(os.path.basename(model_dir))
445
+ logger.setLevel(logging.DEBUG)
446
+
447
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
448
+ if not os.path.exists(model_dir):
449
+ os.makedirs(model_dir)
450
+ h = logging.FileHandler(os.path.join(model_dir, filename))
451
+ h.setLevel(logging.DEBUG)
452
+ h.setFormatter(formatter)
453
+ logger.addHandler(h)
454
+ return logger
455
+
456
+
457
+ class HParams:
458
+ def __init__(self, **kwargs):
459
+ for k, v in kwargs.items():
460
+ if type(v) == dict:
461
+ v = HParams(**v)
462
+ self[k] = v
463
+
464
+ def keys(self):
465
+ return self.__dict__.keys()
466
+
467
+ def items(self):
468
+ return self.__dict__.items()
469
+
470
+ def values(self):
471
+ return self.__dict__.values()
472
+
473
+ def __len__(self):
474
+ return len(self.__dict__)
475
+
476
+ def __getitem__(self, key):
477
+ return getattr(self, key)
478
+
479
+ def __setitem__(self, key, value):
480
+ return setattr(self, key, value)
481
+
482
+ def __contains__(self, key):
483
+ return key in self.__dict__
484
+
485
+ def __repr__(self):
486
+ return self.__dict__.__repr__()