SohomToom commited on
Commit
2885476
·
verified ·
1 Parent(s): 0856e34

Update MeloTTS/melo/utils.py

Browse files
Files changed (1) hide show
  1. MeloTTS/melo/utils.py +424 -424
MeloTTS/melo/utils.py CHANGED
@@ -1,424 +1,424 @@
1
- import os
2
- import glob
3
- import argparse
4
- import logging
5
- import json
6
- import subprocess
7
- import numpy as np
8
- from scipy.io.wavfile import read
9
- import torch
10
- import torchaudio
11
- import librosa
12
- from melo.text import cleaned_text_to_sequence, get_bert
13
- from melo.text.cleaner import clean_text
14
- from melo import commons
15
-
16
- MATPLOTLIB_FLAG = False
17
-
18
- logger = logging.getLogger(__name__)
19
-
20
-
21
-
22
- def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
23
- norm_text, phone, tone, word2ph = clean_text(text, language_str)
24
- phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
25
-
26
- if hps.data.add_blank:
27
- phone = commons.intersperse(phone, 0)
28
- tone = commons.intersperse(tone, 0)
29
- language = commons.intersperse(language, 0)
30
- for i in range(len(word2ph)):
31
- word2ph[i] = word2ph[i] * 2
32
- word2ph[0] += 1
33
-
34
- if getattr(hps.data, "disable_bert", False):
35
- bert = torch.zeros(1024, len(phone))
36
- ja_bert = torch.zeros(768, len(phone))
37
- else:
38
- bert = get_bert(norm_text, word2ph, language_str, device)
39
- del word2ph
40
- assert bert.shape[-1] == len(phone), phone
41
-
42
- if language_str == "ZH":
43
- bert = bert
44
- ja_bert = torch.zeros(768, len(phone))
45
- elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
46
- ja_bert = bert
47
- bert = torch.zeros(1024, len(phone))
48
- else:
49
- raise NotImplementedError()
50
-
51
- assert bert.shape[-1] == len(
52
- phone
53
- ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
54
-
55
- phone = torch.LongTensor(phone)
56
- tone = torch.LongTensor(tone)
57
- language = torch.LongTensor(language)
58
- return bert, ja_bert, phone, tone, language
59
-
60
- def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
61
- assert os.path.isfile(checkpoint_path)
62
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
63
- iteration = checkpoint_dict.get("iteration", 0)
64
- learning_rate = checkpoint_dict.get("learning_rate", 0.)
65
- if (
66
- optimizer is not None
67
- and not skip_optimizer
68
- and checkpoint_dict["optimizer"] is not None
69
- ):
70
- optimizer.load_state_dict(checkpoint_dict["optimizer"])
71
- elif optimizer is None and not skip_optimizer:
72
- # else: Disable this line if Infer and resume checkpoint,then enable the line upper
73
- new_opt_dict = optimizer.state_dict()
74
- new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
75
- new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
76
- new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
77
- optimizer.load_state_dict(new_opt_dict)
78
-
79
- saved_state_dict = checkpoint_dict["model"]
80
- if hasattr(model, "module"):
81
- state_dict = model.module.state_dict()
82
- else:
83
- state_dict = model.state_dict()
84
-
85
- new_state_dict = {}
86
- for k, v in state_dict.items():
87
- try:
88
- # assert "emb_g" not in k
89
- new_state_dict[k] = saved_state_dict[k]
90
- assert saved_state_dict[k].shape == v.shape, (
91
- saved_state_dict[k].shape,
92
- v.shape,
93
- )
94
- except Exception as e:
95
- print(e)
96
- # For upgrading from the old version
97
- if "ja_bert_proj" in k:
98
- v = torch.zeros_like(v)
99
- logger.warn(
100
- f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
101
- )
102
- else:
103
- logger.error(f"{k} is not in the checkpoint")
104
-
105
- new_state_dict[k] = v
106
-
107
- if hasattr(model, "module"):
108
- model.module.load_state_dict(new_state_dict, strict=False)
109
- else:
110
- model.load_state_dict(new_state_dict, strict=False)
111
-
112
- logger.info(
113
- "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
114
- )
115
-
116
- return model, optimizer, learning_rate, iteration
117
-
118
-
119
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
120
- logger.info(
121
- "Saving model and optimizer state at iteration {} to {}".format(
122
- iteration, checkpoint_path
123
- )
124
- )
125
- if hasattr(model, "module"):
126
- state_dict = model.module.state_dict()
127
- else:
128
- state_dict = model.state_dict()
129
- torch.save(
130
- {
131
- "model": state_dict,
132
- "iteration": iteration,
133
- "optimizer": optimizer.state_dict(),
134
- "learning_rate": learning_rate,
135
- },
136
- checkpoint_path,
137
- )
138
-
139
-
140
- def summarize(
141
- writer,
142
- global_step,
143
- scalars={},
144
- histograms={},
145
- images={},
146
- audios={},
147
- audio_sampling_rate=22050,
148
- ):
149
- for k, v in scalars.items():
150
- writer.add_scalar(k, v, global_step)
151
- for k, v in histograms.items():
152
- writer.add_histogram(k, v, global_step)
153
- for k, v in images.items():
154
- writer.add_image(k, v, global_step, dataformats="HWC")
155
- for k, v in audios.items():
156
- writer.add_audio(k, v, global_step, audio_sampling_rate)
157
-
158
-
159
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
160
- f_list = glob.glob(os.path.join(dir_path, regex))
161
- f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
162
- x = f_list[-1]
163
- return x
164
-
165
-
166
- def plot_spectrogram_to_numpy(spectrogram):
167
- global MATPLOTLIB_FLAG
168
- if not MATPLOTLIB_FLAG:
169
- import matplotlib
170
-
171
- matplotlib.use("Agg")
172
- MATPLOTLIB_FLAG = True
173
- mpl_logger = logging.getLogger("matplotlib")
174
- mpl_logger.setLevel(logging.WARNING)
175
- import matplotlib.pylab as plt
176
- import numpy as np
177
-
178
- fig, ax = plt.subplots(figsize=(10, 2))
179
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
180
- plt.colorbar(im, ax=ax)
181
- plt.xlabel("Frames")
182
- plt.ylabel("Channels")
183
- plt.tight_layout()
184
-
185
- fig.canvas.draw()
186
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
187
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
188
- plt.close()
189
- return data
190
-
191
-
192
- def plot_alignment_to_numpy(alignment, info=None):
193
- global MATPLOTLIB_FLAG
194
- if not MATPLOTLIB_FLAG:
195
- import matplotlib
196
-
197
- matplotlib.use("Agg")
198
- MATPLOTLIB_FLAG = True
199
- mpl_logger = logging.getLogger("matplotlib")
200
- mpl_logger.setLevel(logging.WARNING)
201
- import matplotlib.pylab as plt
202
- import numpy as np
203
-
204
- fig, ax = plt.subplots(figsize=(6, 4))
205
- im = ax.imshow(
206
- alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
207
- )
208
- fig.colorbar(im, ax=ax)
209
- xlabel = "Decoder timestep"
210
- if info is not None:
211
- xlabel += "\n\n" + info
212
- plt.xlabel(xlabel)
213
- plt.ylabel("Encoder timestep")
214
- plt.tight_layout()
215
-
216
- fig.canvas.draw()
217
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
218
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
219
- plt.close()
220
- return data
221
-
222
-
223
- def load_wav_to_torch(full_path):
224
- sampling_rate, data = read(full_path)
225
- return torch.FloatTensor(data.astype(np.float32)), sampling_rate
226
-
227
-
228
- def load_wav_to_torch_new(full_path):
229
- audio_norm, sampling_rate = torchaudio.load(full_path, frame_offset=0, num_frames=-1, normalize=True, channels_first=True)
230
- audio_norm = audio_norm.mean(dim=0)
231
- return audio_norm, sampling_rate
232
-
233
- def load_wav_to_torch_librosa(full_path, sr):
234
- audio_norm, sampling_rate = librosa.load(full_path, sr=sr, mono=True)
235
- return torch.FloatTensor(audio_norm.astype(np.float32)), sampling_rate
236
-
237
-
238
- def load_filepaths_and_text(filename, split="|"):
239
- with open(filename, encoding="utf-8") as f:
240
- filepaths_and_text = [line.strip().split(split) for line in f]
241
- return filepaths_and_text
242
-
243
-
244
- def get_hparams(init=True):
245
- parser = argparse.ArgumentParser()
246
- parser.add_argument(
247
- "-c",
248
- "--config",
249
- type=str,
250
- default="./configs/base.json",
251
- help="JSON file for configuration",
252
- )
253
- parser.add_argument('--local_rank', type=int, default=0)
254
- parser.add_argument('--world-size', type=int, default=1)
255
- parser.add_argument('--port', type=int, default=10000)
256
- parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
257
- parser.add_argument('--pretrain_G', type=str, default=None,
258
- help='pretrain model')
259
- parser.add_argument('--pretrain_D', type=str, default=None,
260
- help='pretrain model D')
261
- parser.add_argument('--pretrain_dur', type=str, default=None,
262
- help='pretrain model duration')
263
-
264
- args = parser.parse_args()
265
- model_dir = os.path.join("./logs", args.model)
266
-
267
- os.makedirs(model_dir, exist_ok=True)
268
-
269
- config_path = args.config
270
- config_save_path = os.path.join(model_dir, "config.json")
271
- if init:
272
- with open(config_path, "r") as f:
273
- data = f.read()
274
- with open(config_save_path, "w") as f:
275
- f.write(data)
276
- else:
277
- with open(config_save_path, "r") as f:
278
- data = f.read()
279
- config = json.loads(data)
280
-
281
- hparams = HParams(**config)
282
- hparams.model_dir = model_dir
283
- hparams.pretrain_G = args.pretrain_G
284
- hparams.pretrain_D = args.pretrain_D
285
- hparams.pretrain_dur = args.pretrain_dur
286
- hparams.port = args.port
287
- return hparams
288
-
289
-
290
- def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
291
- """Freeing up space by deleting saved ckpts
292
-
293
- Arguments:
294
- path_to_models -- Path to the model directory
295
- n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
296
- sort_by_time -- True -> chronologically delete ckpts
297
- False -> lexicographically delete ckpts
298
- """
299
- import re
300
-
301
- ckpts_files = [
302
- f
303
- for f in os.listdir(path_to_models)
304
- if os.path.isfile(os.path.join(path_to_models, f))
305
- ]
306
-
307
- def name_key(_f):
308
- return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
309
-
310
- def time_key(_f):
311
- return os.path.getmtime(os.path.join(path_to_models, _f))
312
-
313
- sort_key = time_key if sort_by_time else name_key
314
-
315
- def x_sorted(_x):
316
- return sorted(
317
- [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
318
- key=sort_key,
319
- )
320
-
321
- to_del = [
322
- os.path.join(path_to_models, fn)
323
- for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
324
- ]
325
-
326
- def del_info(fn):
327
- return logger.info(f".. Free up space by deleting ckpt {fn}")
328
-
329
- def del_routine(x):
330
- return [os.remove(x), del_info(x)]
331
-
332
- [del_routine(fn) for fn in to_del]
333
-
334
-
335
- def get_hparams_from_dir(model_dir):
336
- config_save_path = os.path.join(model_dir, "config.json")
337
- with open(config_save_path, "r", encoding="utf-8") as f:
338
- data = f.read()
339
- config = json.loads(data)
340
-
341
- hparams = HParams(**config)
342
- hparams.model_dir = model_dir
343
- return hparams
344
-
345
-
346
- def get_hparams_from_file(config_path):
347
- with open(config_path, "r", encoding="utf-8") as f:
348
- data = f.read()
349
- config = json.loads(data)
350
-
351
- hparams = HParams(**config)
352
- return hparams
353
-
354
-
355
- def check_git_hash(model_dir):
356
- source_dir = os.path.dirname(os.path.realpath(__file__))
357
- if not os.path.exists(os.path.join(source_dir, ".git")):
358
- logger.warn(
359
- "{} is not a git repository, therefore hash value comparison will be ignored.".format(
360
- source_dir
361
- )
362
- )
363
- return
364
-
365
- cur_hash = subprocess.getoutput("git rev-parse HEAD")
366
-
367
- path = os.path.join(model_dir, "githash")
368
- if os.path.exists(path):
369
- saved_hash = open(path).read()
370
- if saved_hash != cur_hash:
371
- logger.warn(
372
- "git hash values are different. {}(saved) != {}(current)".format(
373
- saved_hash[:8], cur_hash[:8]
374
- )
375
- )
376
- else:
377
- open(path, "w").write(cur_hash)
378
-
379
-
380
- def get_logger(model_dir, filename="train.log"):
381
- global logger
382
- logger = logging.getLogger(os.path.basename(model_dir))
383
- logger.setLevel(logging.DEBUG)
384
-
385
- formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
386
- if not os.path.exists(model_dir):
387
- os.makedirs(model_dir, exist_ok=True)
388
- h = logging.FileHandler(os.path.join(model_dir, filename))
389
- h.setLevel(logging.DEBUG)
390
- h.setFormatter(formatter)
391
- logger.addHandler(h)
392
- return logger
393
-
394
-
395
- class HParams:
396
- def __init__(self, **kwargs):
397
- for k, v in kwargs.items():
398
- if type(v) == dict:
399
- v = HParams(**v)
400
- self[k] = v
401
-
402
- def keys(self):
403
- return self.__dict__.keys()
404
-
405
- def items(self):
406
- return self.__dict__.items()
407
-
408
- def values(self):
409
- return self.__dict__.values()
410
-
411
- def __len__(self):
412
- return len(self.__dict__)
413
-
414
- def __getitem__(self, key):
415
- return getattr(self, key)
416
-
417
- def __setitem__(self, key, value):
418
- return setattr(self, key, value)
419
-
420
- def __contains__(self, key):
421
- return key in self.__dict__
422
-
423
- def __repr__(self):
424
- return self.__dict__.__repr__()
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import subprocess
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+ import torch
10
+ import torchaudio
11
+ import librosa
12
+ from MeloTTS.melo.text import cleaned_text_to_sequence, get_bert
13
+ from MeloTTS.melo.text.cleaner import clean_text
14
+ from MeloTTS.melo import commons
15
+
16
+ MATPLOTLIB_FLAG = False
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+
22
+ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
23
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
24
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
25
+
26
+ if hps.data.add_blank:
27
+ phone = commons.intersperse(phone, 0)
28
+ tone = commons.intersperse(tone, 0)
29
+ language = commons.intersperse(language, 0)
30
+ for i in range(len(word2ph)):
31
+ word2ph[i] = word2ph[i] * 2
32
+ word2ph[0] += 1
33
+
34
+ if getattr(hps.data, "disable_bert", False):
35
+ bert = torch.zeros(1024, len(phone))
36
+ ja_bert = torch.zeros(768, len(phone))
37
+ else:
38
+ bert = get_bert(norm_text, word2ph, language_str, device)
39
+ del word2ph
40
+ assert bert.shape[-1] == len(phone), phone
41
+
42
+ if language_str == "ZH":
43
+ bert = bert
44
+ ja_bert = torch.zeros(768, len(phone))
45
+ elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU']:
46
+ ja_bert = bert
47
+ bert = torch.zeros(1024, len(phone))
48
+ else:
49
+ raise NotImplementedError()
50
+
51
+ assert bert.shape[-1] == len(
52
+ phone
53
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
54
+
55
+ phone = torch.LongTensor(phone)
56
+ tone = torch.LongTensor(tone)
57
+ language = torch.LongTensor(language)
58
+ return bert, ja_bert, phone, tone, language
59
+
60
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
61
+ assert os.path.isfile(checkpoint_path)
62
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
63
+ iteration = checkpoint_dict.get("iteration", 0)
64
+ learning_rate = checkpoint_dict.get("learning_rate", 0.)
65
+ if (
66
+ optimizer is not None
67
+ and not skip_optimizer
68
+ and checkpoint_dict["optimizer"] is not None
69
+ ):
70
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
71
+ elif optimizer is None and not skip_optimizer:
72
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
73
+ new_opt_dict = optimizer.state_dict()
74
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
75
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
76
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
77
+ optimizer.load_state_dict(new_opt_dict)
78
+
79
+ saved_state_dict = checkpoint_dict["model"]
80
+ if hasattr(model, "module"):
81
+ state_dict = model.module.state_dict()
82
+ else:
83
+ state_dict = model.state_dict()
84
+
85
+ new_state_dict = {}
86
+ for k, v in state_dict.items():
87
+ try:
88
+ # assert "emb_g" not in k
89
+ new_state_dict[k] = saved_state_dict[k]
90
+ assert saved_state_dict[k].shape == v.shape, (
91
+ saved_state_dict[k].shape,
92
+ v.shape,
93
+ )
94
+ except Exception as e:
95
+ print(e)
96
+ # For upgrading from the old version
97
+ if "ja_bert_proj" in k:
98
+ v = torch.zeros_like(v)
99
+ logger.warn(
100
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
101
+ )
102
+ else:
103
+ logger.error(f"{k} is not in the checkpoint")
104
+
105
+ new_state_dict[k] = v
106
+
107
+ if hasattr(model, "module"):
108
+ model.module.load_state_dict(new_state_dict, strict=False)
109
+ else:
110
+ model.load_state_dict(new_state_dict, strict=False)
111
+
112
+ logger.info(
113
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
114
+ )
115
+
116
+ return model, optimizer, learning_rate, iteration
117
+
118
+
119
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
120
+ logger.info(
121
+ "Saving model and optimizer state at iteration {} to {}".format(
122
+ iteration, checkpoint_path
123
+ )
124
+ )
125
+ if hasattr(model, "module"):
126
+ state_dict = model.module.state_dict()
127
+ else:
128
+ state_dict = model.state_dict()
129
+ torch.save(
130
+ {
131
+ "model": state_dict,
132
+ "iteration": iteration,
133
+ "optimizer": optimizer.state_dict(),
134
+ "learning_rate": learning_rate,
135
+ },
136
+ checkpoint_path,
137
+ )
138
+
139
+
140
+ def summarize(
141
+ writer,
142
+ global_step,
143
+ scalars={},
144
+ histograms={},
145
+ images={},
146
+ audios={},
147
+ audio_sampling_rate=22050,
148
+ ):
149
+ for k, v in scalars.items():
150
+ writer.add_scalar(k, v, global_step)
151
+ for k, v in histograms.items():
152
+ writer.add_histogram(k, v, global_step)
153
+ for k, v in images.items():
154
+ writer.add_image(k, v, global_step, dataformats="HWC")
155
+ for k, v in audios.items():
156
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
157
+
158
+
159
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
160
+ f_list = glob.glob(os.path.join(dir_path, regex))
161
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
162
+ x = f_list[-1]
163
+ return x
164
+
165
+
166
+ def plot_spectrogram_to_numpy(spectrogram):
167
+ global MATPLOTLIB_FLAG
168
+ if not MATPLOTLIB_FLAG:
169
+ import matplotlib
170
+
171
+ matplotlib.use("Agg")
172
+ MATPLOTLIB_FLAG = True
173
+ mpl_logger = logging.getLogger("matplotlib")
174
+ mpl_logger.setLevel(logging.WARNING)
175
+ import matplotlib.pylab as plt
176
+ import numpy as np
177
+
178
+ fig, ax = plt.subplots(figsize=(10, 2))
179
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
180
+ plt.colorbar(im, ax=ax)
181
+ plt.xlabel("Frames")
182
+ plt.ylabel("Channels")
183
+ plt.tight_layout()
184
+
185
+ fig.canvas.draw()
186
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
187
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
188
+ plt.close()
189
+ return data
190
+
191
+
192
+ def plot_alignment_to_numpy(alignment, info=None):
193
+ global MATPLOTLIB_FLAG
194
+ if not MATPLOTLIB_FLAG:
195
+ import matplotlib
196
+
197
+ matplotlib.use("Agg")
198
+ MATPLOTLIB_FLAG = True
199
+ mpl_logger = logging.getLogger("matplotlib")
200
+ mpl_logger.setLevel(logging.WARNING)
201
+ import matplotlib.pylab as plt
202
+ import numpy as np
203
+
204
+ fig, ax = plt.subplots(figsize=(6, 4))
205
+ im = ax.imshow(
206
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
207
+ )
208
+ fig.colorbar(im, ax=ax)
209
+ xlabel = "Decoder timestep"
210
+ if info is not None:
211
+ xlabel += "\n\n" + info
212
+ plt.xlabel(xlabel)
213
+ plt.ylabel("Encoder timestep")
214
+ plt.tight_layout()
215
+
216
+ fig.canvas.draw()
217
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
218
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
219
+ plt.close()
220
+ return data
221
+
222
+
223
+ def load_wav_to_torch(full_path):
224
+ sampling_rate, data = read(full_path)
225
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
226
+
227
+
228
+ def load_wav_to_torch_new(full_path):
229
+ audio_norm, sampling_rate = torchaudio.load(full_path, frame_offset=0, num_frames=-1, normalize=True, channels_first=True)
230
+ audio_norm = audio_norm.mean(dim=0)
231
+ return audio_norm, sampling_rate
232
+
233
+ def load_wav_to_torch_librosa(full_path, sr):
234
+ audio_norm, sampling_rate = librosa.load(full_path, sr=sr, mono=True)
235
+ return torch.FloatTensor(audio_norm.astype(np.float32)), sampling_rate
236
+
237
+
238
+ def load_filepaths_and_text(filename, split="|"):
239
+ with open(filename, encoding="utf-8") as f:
240
+ filepaths_and_text = [line.strip().split(split) for line in f]
241
+ return filepaths_and_text
242
+
243
+
244
+ def get_hparams(init=True):
245
+ parser = argparse.ArgumentParser()
246
+ parser.add_argument(
247
+ "-c",
248
+ "--config",
249
+ type=str,
250
+ default="./configs/base.json",
251
+ help="JSON file for configuration",
252
+ )
253
+ parser.add_argument('--local_rank', type=int, default=0)
254
+ parser.add_argument('--world-size', type=int, default=1)
255
+ parser.add_argument('--port', type=int, default=10000)
256
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
257
+ parser.add_argument('--pretrain_G', type=str, default=None,
258
+ help='pretrain model')
259
+ parser.add_argument('--pretrain_D', type=str, default=None,
260
+ help='pretrain model D')
261
+ parser.add_argument('--pretrain_dur', type=str, default=None,
262
+ help='pretrain model duration')
263
+
264
+ args = parser.parse_args()
265
+ model_dir = os.path.join("./logs", args.model)
266
+
267
+ os.makedirs(model_dir, exist_ok=True)
268
+
269
+ config_path = args.config
270
+ config_save_path = os.path.join(model_dir, "config.json")
271
+ if init:
272
+ with open(config_path, "r") as f:
273
+ data = f.read()
274
+ with open(config_save_path, "w") as f:
275
+ f.write(data)
276
+ else:
277
+ with open(config_save_path, "r") as f:
278
+ data = f.read()
279
+ config = json.loads(data)
280
+
281
+ hparams = HParams(**config)
282
+ hparams.model_dir = model_dir
283
+ hparams.pretrain_G = args.pretrain_G
284
+ hparams.pretrain_D = args.pretrain_D
285
+ hparams.pretrain_dur = args.pretrain_dur
286
+ hparams.port = args.port
287
+ return hparams
288
+
289
+
290
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
291
+ """Freeing up space by deleting saved ckpts
292
+
293
+ Arguments:
294
+ path_to_models -- Path to the model directory
295
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
296
+ sort_by_time -- True -> chronologically delete ckpts
297
+ False -> lexicographically delete ckpts
298
+ """
299
+ import re
300
+
301
+ ckpts_files = [
302
+ f
303
+ for f in os.listdir(path_to_models)
304
+ if os.path.isfile(os.path.join(path_to_models, f))
305
+ ]
306
+
307
+ def name_key(_f):
308
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
309
+
310
+ def time_key(_f):
311
+ return os.path.getmtime(os.path.join(path_to_models, _f))
312
+
313
+ sort_key = time_key if sort_by_time else name_key
314
+
315
+ def x_sorted(_x):
316
+ return sorted(
317
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
318
+ key=sort_key,
319
+ )
320
+
321
+ to_del = [
322
+ os.path.join(path_to_models, fn)
323
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
324
+ ]
325
+
326
+ def del_info(fn):
327
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
328
+
329
+ def del_routine(x):
330
+ return [os.remove(x), del_info(x)]
331
+
332
+ [del_routine(fn) for fn in to_del]
333
+
334
+
335
+ def get_hparams_from_dir(model_dir):
336
+ config_save_path = os.path.join(model_dir, "config.json")
337
+ with open(config_save_path, "r", encoding="utf-8") as f:
338
+ data = f.read()
339
+ config = json.loads(data)
340
+
341
+ hparams = HParams(**config)
342
+ hparams.model_dir = model_dir
343
+ return hparams
344
+
345
+
346
+ def get_hparams_from_file(config_path):
347
+ with open(config_path, "r", encoding="utf-8") as f:
348
+ data = f.read()
349
+ config = json.loads(data)
350
+
351
+ hparams = HParams(**config)
352
+ return hparams
353
+
354
+
355
+ def check_git_hash(model_dir):
356
+ source_dir = os.path.dirname(os.path.realpath(__file__))
357
+ if not os.path.exists(os.path.join(source_dir, ".git")):
358
+ logger.warn(
359
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
360
+ source_dir
361
+ )
362
+ )
363
+ return
364
+
365
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
366
+
367
+ path = os.path.join(model_dir, "githash")
368
+ if os.path.exists(path):
369
+ saved_hash = open(path).read()
370
+ if saved_hash != cur_hash:
371
+ logger.warn(
372
+ "git hash values are different. {}(saved) != {}(current)".format(
373
+ saved_hash[:8], cur_hash[:8]
374
+ )
375
+ )
376
+ else:
377
+ open(path, "w").write(cur_hash)
378
+
379
+
380
+ def get_logger(model_dir, filename="train.log"):
381
+ global logger
382
+ logger = logging.getLogger(os.path.basename(model_dir))
383
+ logger.setLevel(logging.DEBUG)
384
+
385
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
386
+ if not os.path.exists(model_dir):
387
+ os.makedirs(model_dir, exist_ok=True)
388
+ h = logging.FileHandler(os.path.join(model_dir, filename))
389
+ h.setLevel(logging.DEBUG)
390
+ h.setFormatter(formatter)
391
+ logger.addHandler(h)
392
+ return logger
393
+
394
+
395
+ class HParams:
396
+ def __init__(self, **kwargs):
397
+ for k, v in kwargs.items():
398
+ if type(v) == dict:
399
+ v = HParams(**v)
400
+ self[k] = v
401
+
402
+ def keys(self):
403
+ return self.__dict__.keys()
404
+
405
+ def items(self):
406
+ return self.__dict__.items()
407
+
408
+ def values(self):
409
+ return self.__dict__.values()
410
+
411
+ def __len__(self):
412
+ return len(self.__dict__)
413
+
414
+ def __getitem__(self, key):
415
+ return getattr(self, key)
416
+
417
+ def __setitem__(self, key, value):
418
+ return setattr(self, key, value)
419
+
420
+ def __contains__(self, key):
421
+ return key in self.__dict__
422
+
423
+ def __repr__(self):
424
+ return self.__dict__.__repr__()