alysa commited on
Commit
7629bb6
1 Parent(s): 5402140

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +319 -0
utils.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+ iteration = checkpoint_dict["iteration"]
22
+ learning_rate = checkpoint_dict["learning_rate"]
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
25
+ saved_state_dict = checkpoint_dict["model"]
26
+ if hasattr(model, "module"):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, "module"):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info(
42
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
43
+ )
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
48
+ logger.info(
49
+ "Saving model and optimizer state at iteration {} to {}".format(
50
+ iteration, checkpoint_path
51
+ )
52
+ )
53
+ if hasattr(model, "module"):
54
+ state_dict = model.module.state_dict()
55
+ else:
56
+ state_dict = model.state_dict()
57
+ torch.save(
58
+ {
59
+ "model": state_dict,
60
+ "iteration": iteration,
61
+ "optimizer": optimizer.state_dict(),
62
+ "learning_rate": learning_rate,
63
+ },
64
+ checkpoint_path,
65
+ )
66
+
67
+
68
+ def load_model(checkpoint_path, model):
69
+ assert os.path.isfile(checkpoint_path)
70
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
71
+ saved_state_dict = checkpoint_dict["model"]
72
+ if hasattr(model, "module"):
73
+ state_dict = model.module.state_dict()
74
+ else:
75
+ state_dict = model.state_dict()
76
+ new_state_dict = {}
77
+ for k, v in state_dict.items():
78
+ try:
79
+ new_state_dict[k] = saved_state_dict[k]
80
+ except:
81
+ logger.info("%s is not in the checkpoint" % k)
82
+ new_state_dict[k] = v
83
+ if hasattr(model, "module"):
84
+ model.module.load_state_dict(new_state_dict)
85
+ else:
86
+ model.load_state_dict(new_state_dict)
87
+ return model
88
+
89
+
90
+ def save_model(model, checkpoint_path):
91
+ if hasattr(model, 'module'):
92
+ state_dict = model.module.state_dict()
93
+ else:
94
+ state_dict = model.state_dict()
95
+ torch.save({'model': state_dict}, checkpoint_path)
96
+
97
+
98
+ def summarize(
99
+ writer,
100
+ global_step,
101
+ scalars={},
102
+ histograms={},
103
+ images={},
104
+ audios={},
105
+ audio_sampling_rate=22050,
106
+ ):
107
+ for k, v in scalars.items():
108
+ writer.add_scalar(k, v, global_step)
109
+ for k, v in histograms.items():
110
+ writer.add_histogram(k, v, global_step)
111
+ for k, v in images.items():
112
+ writer.add_image(k, v, global_step, dataformats="HWC")
113
+ for k, v in audios.items():
114
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
115
+
116
+
117
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
118
+ f_list = glob.glob(os.path.join(dir_path, regex))
119
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
120
+ x = f_list[-1]
121
+ print(x)
122
+ return x
123
+
124
+
125
+ def plot_spectrogram_to_numpy(spectrogram):
126
+ global MATPLOTLIB_FLAG
127
+ if not MATPLOTLIB_FLAG:
128
+ import matplotlib
129
+
130
+ matplotlib.use("Agg")
131
+ MATPLOTLIB_FLAG = True
132
+ mpl_logger = logging.getLogger("matplotlib")
133
+ mpl_logger.setLevel(logging.WARNING)
134
+ import matplotlib.pylab as plt
135
+ import numpy as np
136
+
137
+ fig, ax = plt.subplots(figsize=(10, 2))
138
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
139
+ plt.colorbar(im, ax=ax)
140
+ plt.xlabel("Frames")
141
+ plt.ylabel("Channels")
142
+ plt.tight_layout()
143
+
144
+ fig.canvas.draw()
145
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
146
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
147
+ plt.close()
148
+ return data
149
+
150
+
151
+ def plot_alignment_to_numpy(alignment, info=None):
152
+ global MATPLOTLIB_FLAG
153
+ if not MATPLOTLIB_FLAG:
154
+ import matplotlib
155
+
156
+ matplotlib.use("Agg")
157
+ MATPLOTLIB_FLAG = True
158
+ mpl_logger = logging.getLogger("matplotlib")
159
+ mpl_logger.setLevel(logging.WARNING)
160
+ import matplotlib.pylab as plt
161
+ import numpy as np
162
+
163
+ fig, ax = plt.subplots(figsize=(6, 4))
164
+ im = ax.imshow(
165
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
166
+ )
167
+ fig.colorbar(im, ax=ax)
168
+ xlabel = "Decoder timestep"
169
+ if info is not None:
170
+ xlabel += "\n\n" + info
171
+ plt.xlabel(xlabel)
172
+ plt.ylabel("Encoder timestep")
173
+ plt.tight_layout()
174
+
175
+ fig.canvas.draw()
176
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
177
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
178
+ plt.close()
179
+ return data
180
+
181
+
182
+ def load_wav_to_torch(full_path):
183
+ sampling_rate, data = read(full_path)
184
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
185
+
186
+
187
+ def load_filepaths_and_text(filename, split="|"):
188
+ with open(filename, encoding="utf-8") as f:
189
+ filepaths_and_text = []
190
+ for line in f:
191
+ path_text = line.strip().split(split)
192
+ filepaths_and_text.append(path_text)
193
+ return filepaths_and_text
194
+
195
+
196
+ def get_hparams(init=True):
197
+ parser = argparse.ArgumentParser()
198
+ parser.add_argument(
199
+ "-c",
200
+ "--config",
201
+ type=str,
202
+ default="./configs/bert_vits.json",
203
+ help="JSON file for configuration",
204
+ )
205
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
206
+
207
+ args = parser.parse_args()
208
+ model_dir = os.path.join("./logs", args.model)
209
+
210
+ if not os.path.exists(model_dir):
211
+ os.makedirs(model_dir)
212
+
213
+ config_path = args.config
214
+ config_save_path = os.path.join(model_dir, "config.json")
215
+ if init:
216
+ with open(config_path, "r") as f:
217
+ data = f.read()
218
+ with open(config_save_path, "w") as f:
219
+ f.write(data)
220
+ else:
221
+ with open(config_save_path, "r") as f:
222
+ data = f.read()
223
+ config = json.loads(data)
224
+
225
+ hparams = HParams(**config)
226
+ hparams.model_dir = model_dir
227
+ return hparams
228
+
229
+
230
+ def get_hparams_from_dir(model_dir):
231
+ config_save_path = os.path.join(model_dir, "config.json")
232
+ with open(config_save_path, "r") as f:
233
+ data = f.read()
234
+ config = json.loads(data)
235
+
236
+ hparams = HParams(**config)
237
+ hparams.model_dir = model_dir
238
+ return hparams
239
+
240
+
241
+ def get_hparams_from_file(config_path):
242
+ with open(config_path, "r") as f:
243
+ data = f.read()
244
+ config = json.loads(data)
245
+
246
+ hparams = HParams(**config)
247
+ return hparams
248
+
249
+
250
+ def check_git_hash(model_dir):
251
+ source_dir = os.path.dirname(os.path.realpath(__file__))
252
+ if not os.path.exists(os.path.join(source_dir, ".git")):
253
+ logger.warn(
254
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
255
+ source_dir
256
+ )
257
+ )
258
+ return
259
+
260
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
261
+
262
+ path = os.path.join(model_dir, "githash")
263
+ if os.path.exists(path):
264
+ saved_hash = open(path).read()
265
+ if saved_hash != cur_hash:
266
+ logger.warn(
267
+ "git hash values are different. {}(saved) != {}(current)".format(
268
+ saved_hash[:8], cur_hash[:8]
269
+ )
270
+ )
271
+ else:
272
+ open(path, "w").write(cur_hash)
273
+
274
+
275
+ def get_logger(model_dir, filename="train.log"):
276
+ global logger
277
+ logger = logging.getLogger(os.path.basename(model_dir))
278
+ logger.setLevel(logging.DEBUG)
279
+
280
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
281
+ if not os.path.exists(model_dir):
282
+ os.makedirs(model_dir)
283
+ h = logging.FileHandler(os.path.join(model_dir, filename))
284
+ h.setLevel(logging.DEBUG)
285
+ h.setFormatter(formatter)
286
+ logger.addHandler(h)
287
+ return logger
288
+
289
+
290
+ class HParams:
291
+ def __init__(self, **kwargs):
292
+ for k, v in kwargs.items():
293
+ if type(v) == dict:
294
+ v = HParams(**v)
295
+ self[k] = v
296
+
297
+ def keys(self):
298
+ return self.__dict__.keys()
299
+
300
+ def items(self):
301
+ return self.__dict__.items()
302
+
303
+ def values(self):
304
+ return self.__dict__.values()
305
+
306
+ def __len__(self):
307
+ return len(self.__dict__)
308
+
309
+ def __getitem__(self, key):
310
+ return getattr(self, key)
311
+
312
+ def __setitem__(self, key, value):
313
+ return setattr(self, key, value)
314
+
315
+ def __contains__(self, key):
316
+ return key in self.__dict__
317
+
318
+ def __repr__(self):
319
+ return self.__dict__.__repr__()