Plachta commited on
Commit
70f58f0
β€’
1 Parent(s): 8f014b7

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +217 -66
utils.py CHANGED
@@ -1,75 +1,226 @@
 
 
 
 
1
  import logging
2
- from json import loads
3
- from torch import load, FloatTensor
4
- from numpy import float32
5
- import librosa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  class HParams():
9
- def __init__(self, **kwargs):
10
- for k, v in kwargs.items():
11
- if type(v) == dict:
12
- v = HParams(**v)
13
- self[k] = v
14
-
15
- def keys(self):
16
- return self.__dict__.keys()
17
-
18
- def items(self):
19
- return self.__dict__.items()
20
-
21
- def values(self):
22
- return self.__dict__.values()
23
-
24
- def __len__(self):
25
- return len(self.__dict__)
26
-
27
- def __getitem__(self, key):
28
- return getattr(self, key)
29
-
30
- def __setitem__(self, key, value):
31
- return setattr(self, key, value)
32
-
33
- def __contains__(self, key):
34
- return key in self.__dict__
35
-
36
- def __repr__(self):
37
- return self.__dict__.__repr__()
38
-
39
-
40
- def load_checkpoint(checkpoint_path, model):
41
- checkpoint_dict = load(checkpoint_path, map_location='cpu')
42
- iteration = checkpoint_dict['iteration']
43
- saved_state_dict = checkpoint_dict['model']
44
- if hasattr(model, 'module'):
45
- state_dict = model.module.state_dict()
46
- else:
47
- state_dict = model.state_dict()
48
- new_state_dict= {}
49
- for k, v in state_dict.items():
50
- try:
51
- new_state_dict[k] = saved_state_dict[k]
52
- except:
53
- logging.info("%s is not in the checkpoint" % k)
54
- new_state_dict[k] = v
55
- if hasattr(model, 'module'):
56
- model.module.load_state_dict(new_state_dict)
57
- else:
58
- model.load_state_dict(new_state_dict)
59
- logging.info("Loaded checkpoint '{}' (iteration {})" .format(
60
- checkpoint_path, iteration))
61
- return
62
 
 
 
63
 
64
- def get_hparams_from_file(config_path):
65
- with open(config_path, "r") as f:
66
- data = f.read()
67
- config = loads(data)
 
 
 
 
 
 
 
68
 
69
- hparams = HParams(**config)
70
- return hparams
71
 
 
 
72
 
73
- def load_audio_to_torch(full_path, target_sampling_rate):
74
- audio, sampling_rate = librosa.load(full_path, sr=target_sampling_rate, mono=True)
75
- return FloatTensor(audio.astype(float32))
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
  import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.ERROR)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
21
+ iteration = checkpoint_dict['iteration']
22
+ learning_rate = checkpoint_dict['learning_rate']
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict['optimizer'])
25
+ saved_state_dict = checkpoint_dict['model']
26
+ if hasattr(model, 'module'):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, 'module'):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info("Loaded checkpoint '{}' (iteration {})".format(
42
+ checkpoint_path, iteration))
43
+ return model, optimizer, learning_rate, iteration
44
+
45
+
46
+ def plot_spectrogram_to_numpy(spectrogram):
47
+ global MATPLOTLIB_FLAG
48
+ if not MATPLOTLIB_FLAG:
49
+ import matplotlib
50
+ matplotlib.use("Agg")
51
+ MATPLOTLIB_FLAG = True
52
+ mpl_logger = logging.getLogger('matplotlib')
53
+ mpl_logger.setLevel(logging.WARNING)
54
+ import matplotlib.pylab as plt
55
+ import numpy as np
56
+
57
+ fig, ax = plt.subplots(figsize=(10, 2))
58
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
59
+ interpolation='none')
60
+ plt.colorbar(im, ax=ax)
61
+ plt.xlabel("Frames")
62
+ plt.ylabel("Channels")
63
+ plt.tight_layout()
64
+
65
+ fig.canvas.draw()
66
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
67
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
68
+ plt.close()
69
+ return data
70
+
71
+
72
+ def plot_alignment_to_numpy(alignment, info=None):
73
+ global MATPLOTLIB_FLAG
74
+ if not MATPLOTLIB_FLAG:
75
+ import matplotlib
76
+ matplotlib.use("Agg")
77
+ MATPLOTLIB_FLAG = True
78
+ mpl_logger = logging.getLogger('matplotlib')
79
+ mpl_logger.setLevel(logging.WARNING)
80
+ import matplotlib.pylab as plt
81
+ import numpy as np
82
+
83
+ fig, ax = plt.subplots(figsize=(6, 4))
84
+ im = ax.imshow(alignment.transpose(), aspect='auto', origin='lower',
85
+ interpolation='none')
86
+ fig.colorbar(im, ax=ax)
87
+ xlabel = 'Decoder timestep'
88
+ if info is not None:
89
+ xlabel += '\n\n' + info
90
+ plt.xlabel(xlabel)
91
+ plt.ylabel('Encoder timestep')
92
+ plt.tight_layout()
93
+
94
+ fig.canvas.draw()
95
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
96
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
97
+ plt.close()
98
+ return data
99
+
100
+
101
+ def load_wav_to_torch(full_path):
102
+ sampling_rate, data = read(full_path)
103
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
104
+
105
+
106
+ def load_filepaths_and_text(filename, split="|"):
107
+ with open(filename, encoding='utf-8') as f:
108
+ filepaths_and_text = [line.strip().split(split) for line in f]
109
+ return filepaths_and_text
110
+
111
+
112
+ def get_hparams(init=True):
113
+ parser = argparse.ArgumentParser()
114
+ parser.add_argument('-c', '--config', type=str, default="./configs/base.json",
115
+ help='JSON file for configuration')
116
+ parser.add_argument('-m', '--model', type=str, required=True,
117
+ help='Model name')
118
+
119
+ args = parser.parse_args()
120
+ model_dir = os.path.join("./logs", args.model)
121
+
122
+ if not os.path.exists(model_dir):
123
+ os.makedirs(model_dir)
124
+
125
+ config_path = args.config
126
+ config_save_path = os.path.join(model_dir, "config.json")
127
+ if init:
128
+ with open(config_path, "r") as f:
129
+ data = f.read()
130
+ with open(config_save_path, "w") as f:
131
+ f.write(data)
132
+ else:
133
+ with open(config_save_path, "r") as f:
134
+ data = f.read()
135
+ config = json.loads(data)
136
+
137
+ hparams = HParams(**config)
138
+ hparams.model_dir = model_dir
139
+ return hparams
140
+
141
+
142
+ def get_hparams_from_dir(model_dir):
143
+ config_save_path = os.path.join(model_dir, "config.json")
144
+ with open(config_save_path, "r") as f:
145
+ data = f.read()
146
+ config = json.loads(data)
147
+
148
+ hparams = HParams(**config)
149
+ hparams.model_dir = model_dir
150
+ return hparams
151
+
152
+
153
+ def get_hparams_from_file(config_path):
154
+ with open(config_path, "r", encoding="utf-8") as f:
155
+ data = f.read()
156
+ config = json.loads(data)
157
+
158
+ hparams = HParams(**config)
159
+ return hparams
160
+
161
+
162
+ def check_git_hash(model_dir):
163
+ source_dir = os.path.dirname(os.path.realpath(__file__))
164
+ if not os.path.exists(os.path.join(source_dir, ".git")):
165
+ logger.warn("{} is not a git repository, therefore hash value comparison will be ignored.".format(
166
+ source_dir
167
+ ))
168
+ return
169
+
170
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
171
+
172
+ path = os.path.join(model_dir, "githash")
173
+ if os.path.exists(path):
174
+ saved_hash = open(path).read()
175
+ if saved_hash != cur_hash:
176
+ logger.warn("git hash values are different. {}(saved) != {}(current)".format(
177
+ saved_hash[:8], cur_hash[:8]))
178
+ else:
179
+ open(path, "w").write(cur_hash)
180
+
181
+
182
+ def get_logger(model_dir, filename="train.log"):
183
+ global logger
184
+ logger = logging.getLogger(os.path.basename(model_dir))
185
+ logger.setLevel(logging.DEBUG)
186
+
187
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
188
+ if not os.path.exists(model_dir):
189
+ os.makedirs(model_dir)
190
+ h = logging.FileHandler(os.path.join(model_dir, filename))
191
+ h.setLevel(logging.DEBUG)
192
+ h.setFormatter(formatter)
193
+ logger.addHandler(h)
194
+ return logger
195
 
196
 
197
  class HParams():
198
+ def __init__(self, **kwargs):
199
+ for k, v in kwargs.items():
200
+ if type(v) == dict:
201
+ v = HParams(**v)
202
+ self[k] = v
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
+ def keys(self):
205
+ return self.__dict__.keys()
206
 
207
+ def items(self):
208
+ return self.__dict__.items()
209
+
210
+ def values(self):
211
+ return self.__dict__.values()
212
+
213
+ def __len__(self):
214
+ return len(self.__dict__)
215
+
216
+ def __getitem__(self, key):
217
+ return getattr(self, key)
218
 
219
+ def __setitem__(self, key, value):
220
+ return setattr(self, key, value)
221
 
222
+ def __contains__(self, key):
223
+ return key in self.__dict__
224
 
225
+ def __repr__(self):
226
+ return self.__dict__.__repr__()