Core23 commited on
Commit
9170dfe
1 Parent(s): d926ad2

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +258 -0
utils.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
+ iteration = checkpoint_dict['iteration']
22
+ learning_rate = checkpoint_dict['learning_rate']
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
+ saved_state_dict = checkpoint_dict['model']
26
+ if hasattr(model, 'module'):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict= {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, 'module'):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info("Loaded checkpoint '{}' (iteration {})" .format(
42
+ checkpoint_path, iteration))
43
+ return model, optimizer, learning_rate, iteration
44
+
45
+
46
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
47
+ logger.info("Saving model and optimizer state at iteration {} to {}".format(
48
+ iteration, checkpoint_path))
49
+ if hasattr(model, 'module'):
50
+ state_dict = model.module.state_dict()
51
+ else:
52
+ state_dict = model.state_dict()
53
+ torch.save({'model': state_dict,
54
+ 'iteration': iteration,
55
+ 'optimizer': optimizer.state_dict(),
56
+ 'learning_rate': learning_rate}, checkpoint_path)
57
+
58
+
59
+ def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sampling_rate=22050):
60
+ for k, v in scalars.items():
61
+ writer.add_scalar(k, v, global_step)
62
+ for k, v in histograms.items():
63
+ writer.add_histogram(k, v, global_step)
64
+ for k, v in images.items():
65
+ writer.add_image(k, v, global_step, dataformats='HWC')
66
+ for k, v in audios.items():
67
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
68
+
69
+
70
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
71
+ f_list = glob.glob(os.path.join(dir_path, regex))
72
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
73
+ x = f_list[-1]
74
+ print(x)
75
+ return x
76
+
77
+
78
+ def plot_spectrogram_to_numpy(spectrogram):
79
+ global MATPLOTLIB_FLAG
80
+ if not MATPLOTLIB_FLAG:
81
+ import matplotlib
82
+ matplotlib.use("Agg")
83
+ MATPLOTLIB_FLAG = True
84
+ mpl_logger = logging.getLogger('matplotlib')
85
+ mpl_logger.setLevel(logging.WARNING)
86
+ import matplotlib.pylab as plt
87
+ import numpy as np
88
+
89
+ fig, ax = plt.subplots(figsize=(10,2))
90
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
91
+ interpolation='none')
92
+ plt.colorbar(im, ax=ax)
93
+ plt.xlabel("Frames")
94
+ plt.ylabel("Channels")
95
+ plt.tight_layout()
96
+
97
+ fig.canvas.draw()
98
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
99
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
100
+ plt.close()
101
+ return data
102
+
103
+
104
+ def plot_alignment_to_numpy(alignment, info=None):
105
+ global MATPLOTLIB_FLAG
106
+ if not MATPLOTLIB_FLAG:
107
+ import matplotlib
108
+ matplotlib.use("Agg")
109
+ MATPLOTLIB_FLAG = True
110
+ mpl_logger = logging.getLogger('matplotlib')
111
+ mpl_logger.setLevel(logging.WARNING)
112
+ import matplotlib.pylab as plt
113
+ import numpy as np
114
+
115
+ fig, ax = plt.subplots(figsize=(6, 4))
116
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
117
+ interpolation='none')
118
+ fig.colorbar(im, ax=ax)
119
+ xlabel = 'Decoder timestep'
120
+ if info is not None:
121
+ xlabel += '\n\n' + info
122
+ plt.xlabel(xlabel)
123
+ plt.ylabel('Encoder timestep')
124
+ plt.tight_layout()
125
+
126
+ fig.canvas.draw()
127
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
128
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
129
+ plt.close()
130
+ return data
131
+
132
+
133
+ def load_wav_to_torch(full_path):
134
+ sampling_rate, data = read(full_path)
135
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
136
+
137
+
138
+ def load_filepaths_and_text(filename, split="|"):
139
+ with open(filename, encoding='utf-8') as f:
140
+ filepaths_and_text = [line.strip().split(split) for line in f]
141
+ return filepaths_and_text
142
+
143
+
144
+ def get_hparams(init=True):
145
+ parser = argparse.ArgumentParser()
146
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
147
+ help='JSON file for configuration')
148
+ parser.add_argument('-m', '--model', type=str, required=True,
149
+ help='Model name')
150
+
151
+ args = parser.parse_args()
152
+ model_dir = os.path.join("./logs", args.model)
153
+
154
+ if not os.path.exists(model_dir):
155
+ os.makedirs(model_dir)
156
+
157
+ config_path = args.config
158
+ config_save_path = os.path.join(model_dir, "config.json")
159
+ if init:
160
+ with open(config_path, "r") as f:
161
+ data = f.read()
162
+ with open(config_save_path, "w") as f:
163
+ f.write(data)
164
+ else:
165
+ with open(config_save_path, "r") as f:
166
+ data = f.read()
167
+ config = json.loads(data)
168
+
169
+ hparams = HParams(**config)
170
+ hparams.model_dir = model_dir
171
+ return hparams
172
+
173
+
174
+ def get_hparams_from_dir(model_dir):
175
+ config_save_path = os.path.join(model_dir, "config.json")
176
+ with open(config_save_path, "r") as f:
177
+ data = f.read()
178
+ config = json.loads(data)
179
+
180
+ hparams =HParams(**config)
181
+ hparams.model_dir = model_dir
182
+ return hparams
183
+
184
+
185
+ def get_hparams_from_file(config_path):
186
+ with open(config_path, "r") as f:
187
+ data = f.read()
188
+ config = json.loads(data)
189
+
190
+ hparams =HParams(**config)
191
+ return hparams
192
+
193
+
194
+ def check_git_hash(model_dir):
195
+ source_dir = os.path.dirname(os.path.realpath(__file__))
196
+ if not os.path.exists(os.path.join(source_dir, ".git")):
197
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
198
+ source_dir
199
+ ))
200
+ return
201
+
202
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
203
+
204
+ path = os.path.join(model_dir, "githash")
205
+ if os.path.exists(path):
206
+ saved_hash = open(path).read()
207
+ if saved_hash != cur_hash:
208
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
209
+ saved_hash[:8], cur_hash[:8]))
210
+ else:
211
+ open(path, "w").write(cur_hash)
212
+
213
+
214
+ def get_logger(model_dir, filename="train.log"):
215
+ global logger
216
+ logger = logging.getLogger(os.path.basename(model_dir))
217
+ logger.setLevel(logging.DEBUG)
218
+
219
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
220
+ if not os.path.exists(model_dir):
221
+ os.makedirs(model_dir)
222
+ h = logging.FileHandler(os.path.join(model_dir, filename))
223
+ h.setLevel(logging.DEBUG)
224
+ h.setFormatter(formatter)
225
+ logger.addHandler(h)
226
+ return logger
227
+
228
+
229
+ class HParams():
230
+ def __init__(self, **kwargs):
231
+ for k, v in kwargs.items():
232
+ if type(v) == dict:
233
+ v = HParams(**v)
234
+ self[k] = v
235
+
236
+ def keys(self):
237
+ return self.__dict__.keys()
238
+
239
+ def items(self):
240
+ return self.__dict__.items()
241
+
242
+ def values(self):
243
+ return self.__dict__.values()
244
+
245
+ def __len__(self):
246
+ return len(self.__dict__)
247
+
248
+ def __getitem__(self, key):
249
+ return getattr(self, key)
250
+
251
+ def __setitem__(self, key, value):
252
+ return setattr(self, key, value)
253
+
254
+ def __contains__(self, key):
255
+ return key in self.__dict__
256
+
257
+ def __repr__(self):
258
+ return self.__dict__.__repr__()