saeki
commited on
Commit
·
7b918f7
1
Parent(s):
e6364e9
fix
Browse files- aet.py +368 -0
- dataset.py +344 -0
- eval.py +67 -0
- lightning_module.py +875 -0
- model.py +854 -0
- preprocess.py +152 -0
- train.py +106 -0
- utils.py +147 -0
aet.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import pathlib
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import numpy as np
|
8 |
+
import random
|
9 |
+
import librosa
|
10 |
+
from dataset import Dataset
|
11 |
+
import pickle
|
12 |
+
from lightning_module import (
|
13 |
+
SSLStepLightningModule,
|
14 |
+
SSLDualLightningModule,
|
15 |
+
)
|
16 |
+
from utils import plot_and_save_mels
|
17 |
+
import os
|
18 |
+
import tqdm
|
19 |
+
|
20 |
+
|
21 |
+
class AETDataset(Dataset):
|
22 |
+
def __init__(self, filetxt, src_config, tar_config):
|
23 |
+
self.config = src_config
|
24 |
+
|
25 |
+
self.preprocessed_dir_src = pathlib.Path(
|
26 |
+
src_config["general"]["preprocessed_path"]
|
27 |
+
)
|
28 |
+
self.preprocessed_dir_tar = pathlib.Path(
|
29 |
+
tar_config["general"]["preprocessed_path"]
|
30 |
+
)
|
31 |
+
for item in [
|
32 |
+
"sampling_rate",
|
33 |
+
"fft_length",
|
34 |
+
"frame_length",
|
35 |
+
"frame_shift",
|
36 |
+
"fmin",
|
37 |
+
"fmax",
|
38 |
+
"n_mels",
|
39 |
+
]:
|
40 |
+
assert src_config["preprocess"][item] == tar_config["preprocess"][item]
|
41 |
+
|
42 |
+
self.spec_module = torchaudio.transforms.MelSpectrogram(
|
43 |
+
sample_rate=src_config["preprocess"]["sampling_rate"],
|
44 |
+
n_fft=src_config["preprocess"]["fft_length"],
|
45 |
+
win_length=src_config["preprocess"]["frame_length"],
|
46 |
+
hop_length=src_config["preprocess"]["frame_shift"],
|
47 |
+
f_min=src_config["preprocess"]["fmin"],
|
48 |
+
f_max=src_config["preprocess"]["fmax"],
|
49 |
+
n_mels=src_config["preprocess"]["n_mels"],
|
50 |
+
power=1,
|
51 |
+
center=True,
|
52 |
+
norm="slaney",
|
53 |
+
mel_scale="slaney",
|
54 |
+
)
|
55 |
+
|
56 |
+
with open(self.preprocessed_dir_src / filetxt, "r") as fr:
|
57 |
+
self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr]
|
58 |
+
with open(self.preprocessed_dir_tar / filetxt, "r") as fr:
|
59 |
+
self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr]
|
60 |
+
|
61 |
+
self.d_out = {"src": {}, "tar": {}}
|
62 |
+
for item in ["wavs", "wavsaux"]:
|
63 |
+
self.d_out["src"][item] = []
|
64 |
+
self.d_out["tar"][item] = []
|
65 |
+
|
66 |
+
for swp in self.filelist_src:
|
67 |
+
if src_config["general"]["corpus_type"] == "single":
|
68 |
+
basename = str(swp.stem)
|
69 |
+
else:
|
70 |
+
basename = str(swp.parent.name) + "-" + str(swp.stem)
|
71 |
+
with open(
|
72 |
+
self.preprocessed_dir_src / "{}.pickle".format(basename), "rb"
|
73 |
+
) as fw:
|
74 |
+
d_preprocessed = pickle.load(fw)
|
75 |
+
for item in ["wavs", "wavsaux"]:
|
76 |
+
try:
|
77 |
+
self.d_out["src"][item].extend(d_preprocessed[item])
|
78 |
+
except:
|
79 |
+
pass
|
80 |
+
|
81 |
+
for twp in self.filelist_tar:
|
82 |
+
if tar_config["general"]["corpus_type"] == "single":
|
83 |
+
basename = str(twp.stem)
|
84 |
+
else:
|
85 |
+
basename = str(twp.parent.name) + "-" + str(twp.stem)
|
86 |
+
with open(
|
87 |
+
self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb"
|
88 |
+
) as fw:
|
89 |
+
d_preprocessed = pickle.load(fw)
|
90 |
+
for item in ["wavs", "wavsaux"]:
|
91 |
+
try:
|
92 |
+
self.d_out["tar"][item].extend(d_preprocessed[item])
|
93 |
+
except:
|
94 |
+
pass
|
95 |
+
|
96 |
+
min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"]))
|
97 |
+
for spk in ["src", "tar"]:
|
98 |
+
for item in ["wavs", "wavsaux"]:
|
99 |
+
if self.d_out[spk][item] != None:
|
100 |
+
self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len])
|
101 |
+
|
102 |
+
def __len__(self):
|
103 |
+
return len(self.d_out["src"]["wavs"])
|
104 |
+
|
105 |
+
def __getitem__(self, idx):
|
106 |
+
d_batch = {}
|
107 |
+
|
108 |
+
for spk in ["src", "tar"]:
|
109 |
+
for item in ["wavs", "wavsaux"]:
|
110 |
+
if self.d_out[spk][item].size > 0:
|
111 |
+
d_batch["{}_{}".format(item, spk)] = torch.from_numpy(
|
112 |
+
self.d_out[spk][item][idx]
|
113 |
+
)
|
114 |
+
d_batch["{}_{}".format(item, spk)] = self.normalize_waveform(
|
115 |
+
d_batch["{}_{}".format(item, spk)], db=-3
|
116 |
+
)
|
117 |
+
|
118 |
+
d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"])
|
119 |
+
return d_batch
|
120 |
+
|
121 |
+
|
122 |
+
class AETModule(torch.nn.Module):
|
123 |
+
"""
|
124 |
+
src: Dataset from which we extract the channel features
|
125 |
+
tar: Dataset to which the src channel features are added
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(self, args, chmatch_config, src_config, tar_config):
|
129 |
+
super().__init__()
|
130 |
+
if args.stage == "ssl-step":
|
131 |
+
LModule = SSLStepLightningModule
|
132 |
+
elif args.stage == "ssl-dual":
|
133 |
+
LModule = SSLDualLightningModule
|
134 |
+
else:
|
135 |
+
raise NotImplementedError()
|
136 |
+
|
137 |
+
src_model = LModule(src_config).load_from_checkpoint(
|
138 |
+
checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"],
|
139 |
+
config=src_config,
|
140 |
+
)
|
141 |
+
self.src_config = src_config
|
142 |
+
|
143 |
+
self.encoder_src = src_model.encoder
|
144 |
+
if src_config["general"]["use_gst"]:
|
145 |
+
self.gst_src = src_model.gst
|
146 |
+
else:
|
147 |
+
self.channelfeats_src = src_model.channelfeats
|
148 |
+
self.channel_src = src_model.channel
|
149 |
+
|
150 |
+
def forward(self, melspecs_src, wavsaux_tar):
|
151 |
+
if self.src_config["general"]["use_gst"]:
|
152 |
+
chfeats_src = self.gst_src(melspecs_src.transpose(1, 2))
|
153 |
+
else:
|
154 |
+
_, enc_hidden_src = self.encoder_src(
|
155 |
+
melspecs_src.unsqueeze(1).transpose(2, 3)
|
156 |
+
)
|
157 |
+
chfeats_src = self.channelfeats_src(enc_hidden_src)
|
158 |
+
wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src)
|
159 |
+
return wavschmatch_tar
|
160 |
+
|
161 |
+
|
162 |
+
def get_arg():
|
163 |
+
parser = argparse.ArgumentParser()
|
164 |
+
parser.add_argument("--stage", required=True, type=str)
|
165 |
+
parser.add_argument("--config_path", required=True, type=pathlib.Path)
|
166 |
+
parser.add_argument("--exist_src_aux", action="store_true")
|
167 |
+
parser.add_argument("--run_name", required=True, type=str)
|
168 |
+
return parser.parse_args()
|
169 |
+
|
170 |
+
|
171 |
+
def main(args, chmatch_config, device):
|
172 |
+
src_config = yaml.load(
|
173 |
+
open(chmatch_config["general"]["source"]["config_path"], "r"),
|
174 |
+
Loader=yaml.FullLoader,
|
175 |
+
)
|
176 |
+
tar_config = yaml.load(
|
177 |
+
open(chmatch_config["general"]["target"]["config_path"], "r"),
|
178 |
+
Loader=yaml.FullLoader,
|
179 |
+
)
|
180 |
+
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
|
181 |
+
dataset = AETDataset("test.txt", src_config, tar_config)
|
182 |
+
loader = DataLoader(dataset, batch_size=1, shuffle=False)
|
183 |
+
chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device)
|
184 |
+
|
185 |
+
if args.exist_src_aux:
|
186 |
+
char_vector = calc_deg_charactaristics(chmatch_config)
|
187 |
+
|
188 |
+
for idx, batch in enumerate(tqdm.tqdm(loader)):
|
189 |
+
melspecs_src = batch["melspecs_src"].to(device)
|
190 |
+
wavsdeg_src = batch["wavs_src"].to(device)
|
191 |
+
wavsaux_tar = batch["wavsaux_tar"].to(device)
|
192 |
+
if args.exist_src_aux:
|
193 |
+
wavsdegbaseline_tar = calc_deg_baseline(
|
194 |
+
batch["wavsaux_tar"], char_vector, tar_config
|
195 |
+
)
|
196 |
+
wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config)
|
197 |
+
wavsdeg_tar = batch["wavs_tar"].to(device)
|
198 |
+
wavsmatch_tar = normalize_waveform(
|
199 |
+
chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config
|
200 |
+
)
|
201 |
+
torchaudio.save(
|
202 |
+
output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx),
|
203 |
+
wavsdeg_src.cpu(),
|
204 |
+
src_config["preprocess"]["sampling_rate"],
|
205 |
+
)
|
206 |
+
torchaudio.save(
|
207 |
+
output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx),
|
208 |
+
wavsaux_tar.cpu(),
|
209 |
+
tar_config["preprocess"]["sampling_rate"],
|
210 |
+
)
|
211 |
+
if args.exist_src_aux:
|
212 |
+
torchaudio.save(
|
213 |
+
output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx),
|
214 |
+
wavsdegbaseline_tar.cpu(),
|
215 |
+
tar_config["preprocess"]["sampling_rate"],
|
216 |
+
)
|
217 |
+
torchaudio.save(
|
218 |
+
output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx),
|
219 |
+
wavsdeg_tar.cpu(),
|
220 |
+
tar_config["preprocess"]["sampling_rate"],
|
221 |
+
)
|
222 |
+
torchaudio.save(
|
223 |
+
output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx),
|
224 |
+
wavsmatch_tar.cpu(),
|
225 |
+
tar_config["preprocess"]["sampling_rate"],
|
226 |
+
)
|
227 |
+
plot_and_save_mels(
|
228 |
+
wavsdeg_src[0, ...].cpu().detach(),
|
229 |
+
output_path / "test_mels" / "{}-src_melsdeg.png".format(idx),
|
230 |
+
src_config,
|
231 |
+
)
|
232 |
+
plot_and_save_mels(
|
233 |
+
wavsaux_tar[0, ...].cpu().detach(),
|
234 |
+
output_path / "test_mels" / "{}-tar_melsaux.png".format(idx),
|
235 |
+
tar_config,
|
236 |
+
)
|
237 |
+
if args.exist_src_aux:
|
238 |
+
plot_and_save_mels(
|
239 |
+
wavsdegbaseline_tar[0, ...].cpu().detach(),
|
240 |
+
output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx),
|
241 |
+
tar_config,
|
242 |
+
)
|
243 |
+
plot_and_save_mels(
|
244 |
+
wavsdeg_tar[0, ...].cpu().detach(),
|
245 |
+
output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx),
|
246 |
+
tar_config,
|
247 |
+
)
|
248 |
+
plot_and_save_mels(
|
249 |
+
wavsmatch_tar[0, ...].cpu().detach(),
|
250 |
+
output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx),
|
251 |
+
tar_config,
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def calc_deg_baseline(wav, char_vector, tar_config):
|
256 |
+
wav = wav[0, ...].cpu().detach().numpy()
|
257 |
+
spec = librosa.stft(
|
258 |
+
wav,
|
259 |
+
n_fft=tar_config["preprocess"]["fft_length"],
|
260 |
+
hop_length=tar_config["preprocess"]["frame_shift"],
|
261 |
+
win_length=tar_config["preprocess"]["frame_length"],
|
262 |
+
)
|
263 |
+
spec_converted = spec * char_vector.reshape(-1, 1)
|
264 |
+
wav_converted = librosa.istft(
|
265 |
+
spec_converted,
|
266 |
+
hop_length=tar_config["preprocess"]["frame_shift"],
|
267 |
+
win_length=tar_config["preprocess"]["frame_length"],
|
268 |
+
)
|
269 |
+
wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0)
|
270 |
+
return wav_converted
|
271 |
+
|
272 |
+
|
273 |
+
def calc_deg_charactaristics(chmatch_config):
|
274 |
+
src_config = yaml.load(
|
275 |
+
open(chmatch_config["general"]["source"]["config_path"], "r"),
|
276 |
+
Loader=yaml.FullLoader,
|
277 |
+
)
|
278 |
+
tar_config = yaml.load(
|
279 |
+
open(chmatch_config["general"]["target"]["config_path"], "r"),
|
280 |
+
Loader=yaml.FullLoader,
|
281 |
+
)
|
282 |
+
# configs
|
283 |
+
preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"])
|
284 |
+
n_train = src_config["preprocess"]["n_train"]
|
285 |
+
SR = src_config["preprocess"]["sampling_rate"]
|
286 |
+
|
287 |
+
os.makedirs(preprocessed_dir, exist_ok=True)
|
288 |
+
|
289 |
+
sourcepath = pathlib.Path(src_config["general"]["source_path"])
|
290 |
+
|
291 |
+
if src_config["general"]["corpus_type"] == "single":
|
292 |
+
fulllist = list(sourcepath.glob("*.wav"))
|
293 |
+
random.seed(0)
|
294 |
+
random.shuffle(fulllist)
|
295 |
+
train_filelist = fulllist[:n_train]
|
296 |
+
elif src_config["general"]["corpus_type"] == "multi-seen":
|
297 |
+
fulllist = list(sourcepath.glob("*/*.wav"))
|
298 |
+
random.seed(0)
|
299 |
+
random.shuffle(fulllist)
|
300 |
+
train_filelist = fulllist[:n_train]
|
301 |
+
elif src_config["general"]["corpus_type"] == "multi-unseen":
|
302 |
+
spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")]))
|
303 |
+
train_filelist = []
|
304 |
+
random.seed(0)
|
305 |
+
random.shuffle(spk_list)
|
306 |
+
for i, spk in enumerate(spk_list):
|
307 |
+
sourcespkpath = sourcepath / spk
|
308 |
+
if i < n_train:
|
309 |
+
train_filelist.extend(list(sourcespkpath.glob("*.wav")))
|
310 |
+
else:
|
311 |
+
raise NotImplementedError(
|
312 |
+
"corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}"
|
313 |
+
)
|
314 |
+
|
315 |
+
specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1))
|
316 |
+
|
317 |
+
for wp in tqdm.tqdm(train_filelist):
|
318 |
+
wav, _ = librosa.load(wp, sr=SR)
|
319 |
+
spec = np.abs(
|
320 |
+
librosa.stft(
|
321 |
+
wav,
|
322 |
+
n_fft=src_config["preprocess"]["fft_length"],
|
323 |
+
hop_length=src_config["preprocess"]["frame_shift"],
|
324 |
+
win_length=src_config["preprocess"]["frame_length"],
|
325 |
+
)
|
326 |
+
)
|
327 |
+
|
328 |
+
auxpath = pathlib.Path(src_config["general"]["aux_path"])
|
329 |
+
if src_config["general"]["corpus_type"] == "single":
|
330 |
+
wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR)
|
331 |
+
else:
|
332 |
+
wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR)
|
333 |
+
spec_aux = np.abs(
|
334 |
+
librosa.stft(
|
335 |
+
wav_aux,
|
336 |
+
n_fft=src_config["preprocess"]["fft_length"],
|
337 |
+
hop_length=src_config["preprocess"]["frame_shift"],
|
338 |
+
win_length=src_config["preprocess"]["frame_length"],
|
339 |
+
)
|
340 |
+
)
|
341 |
+
min_len = min(spec.shape[1], spec_aux.shape[1])
|
342 |
+
spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10)
|
343 |
+
specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)])
|
344 |
+
|
345 |
+
char_vector = np.mean(specs_all, axis=1)
|
346 |
+
char_vector = char_vector / (np.sum(char_vector) + 1e-10)
|
347 |
+
return char_vector
|
348 |
+
|
349 |
+
|
350 |
+
def normalize_waveform(wav, tar_config, db=-3):
|
351 |
+
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
352 |
+
wav,
|
353 |
+
tar_config["preprocess"]["sampling_rate"],
|
354 |
+
[["norm", "{}".format(db)]],
|
355 |
+
)
|
356 |
+
return wav
|
357 |
+
|
358 |
+
|
359 |
+
if __name__ == "__main__":
|
360 |
+
args = get_arg()
|
361 |
+
chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
|
362 |
+
output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
|
363 |
+
os.makedirs(output_path, exist_ok=True)
|
364 |
+
os.makedirs(output_path / "test_wavs", exist_ok=True)
|
365 |
+
os.makedirs(output_path / "test_mels", exist_ok=True)
|
366 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
367 |
+
|
368 |
+
main(args, chmatch_config, device)
|
dataset.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
import pathlib
|
3 |
+
import torch
|
4 |
+
from torch.utils.data.dataloader import DataLoader
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
import numpy as np
|
7 |
+
import yaml
|
8 |
+
import torchaudio
|
9 |
+
import pyworld
|
10 |
+
import pysptk
|
11 |
+
import random
|
12 |
+
|
13 |
+
|
14 |
+
class DataModule(pl.LightningDataModule):
|
15 |
+
def __init__(self, config):
|
16 |
+
super().__init__()
|
17 |
+
self.config = config
|
18 |
+
self.batchsize = config["train"]["batchsize"]
|
19 |
+
self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
|
20 |
+
|
21 |
+
def setup(self, stage):
|
22 |
+
|
23 |
+
if not self.preprocessed_dir.exists():
|
24 |
+
raise RuntimeError("Preprocessed directory was not be found")
|
25 |
+
|
26 |
+
if "dual" in self.config:
|
27 |
+
if self.config["dual"]["enable"]:
|
28 |
+
task_config = yaml.load(
|
29 |
+
open(self.config["dual"]["config_path"], "r"),
|
30 |
+
Loader=yaml.FullLoader,
|
31 |
+
)
|
32 |
+
task_preprocessed_dir = (
|
33 |
+
self.preprocessed_dir.parent
|
34 |
+
/ pathlib.Path(task_config["general"]["preprocessed_path"]).name
|
35 |
+
)
|
36 |
+
if not task_preprocessed_dir.exists():
|
37 |
+
raise RuntimeError(
|
38 |
+
"Preprocessed directory for multi-task learning was not be found"
|
39 |
+
)
|
40 |
+
|
41 |
+
self.flnames = {
|
42 |
+
"train": "train.txt",
|
43 |
+
"val": "val.txt",
|
44 |
+
"test": "test.txt",
|
45 |
+
}
|
46 |
+
|
47 |
+
def get_ds(self, phase):
|
48 |
+
ds = Dataset(self.flnames[phase], self.config)
|
49 |
+
return ds
|
50 |
+
|
51 |
+
def get_loader(self, phase):
|
52 |
+
ds = self.get_ds(phase)
|
53 |
+
dl = DataLoader(
|
54 |
+
ds,
|
55 |
+
self.batchsize,
|
56 |
+
shuffle=True if phase == "train" else False,
|
57 |
+
num_workers=self.config["train"]["num_workers"],
|
58 |
+
drop_last=True,
|
59 |
+
)
|
60 |
+
return dl
|
61 |
+
|
62 |
+
def train_dataloader(self):
|
63 |
+
return self.get_loader(phase="train")
|
64 |
+
|
65 |
+
def val_dataloader(self):
|
66 |
+
return self.get_loader(phase="val")
|
67 |
+
|
68 |
+
def test_dataloader(self):
|
69 |
+
return self.get_loader(phase="test")
|
70 |
+
|
71 |
+
|
72 |
+
class Dataset(torch.utils.data.Dataset):
|
73 |
+
def __init__(self, filetxt, config):
|
74 |
+
|
75 |
+
self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
|
76 |
+
self.config = config
|
77 |
+
self.spec_module = torchaudio.transforms.MelSpectrogram(
|
78 |
+
sample_rate=config["preprocess"]["sampling_rate"],
|
79 |
+
n_fft=config["preprocess"]["fft_length"],
|
80 |
+
win_length=config["preprocess"]["frame_length"],
|
81 |
+
hop_length=config["preprocess"]["frame_shift"],
|
82 |
+
f_min=config["preprocess"]["fmin"],
|
83 |
+
f_max=config["preprocess"]["fmax"],
|
84 |
+
n_mels=config["preprocess"]["n_mels"],
|
85 |
+
power=1,
|
86 |
+
center=True,
|
87 |
+
norm="slaney",
|
88 |
+
mel_scale="slaney",
|
89 |
+
)
|
90 |
+
self.resample_candidate = [8000, 11025, 12000, 16000]
|
91 |
+
self.quantization_candidate = range(2 ** 6, 2 ** 10 + 2, 2)
|
92 |
+
self.segment_length = config["preprocess"]["segment_length"]
|
93 |
+
|
94 |
+
with open(self.preprocessed_dir / filetxt, "r") as fr:
|
95 |
+
self.filelist = [pathlib.Path(path.strip("\n")) for path in fr]
|
96 |
+
|
97 |
+
self.d_out = dict()
|
98 |
+
for item in ["wavs", "wavsaux"]:
|
99 |
+
self.d_out[item] = []
|
100 |
+
|
101 |
+
for wp in self.filelist:
|
102 |
+
|
103 |
+
if config["general"]["corpus_type"] == "single":
|
104 |
+
basename = str(wp.stem)
|
105 |
+
else:
|
106 |
+
basename = str(wp.parent.name) + "-" + str(wp.stem)
|
107 |
+
|
108 |
+
with open(self.preprocessed_dir / "{}.pickle".format(basename), "rb") as fw:
|
109 |
+
d_preprocessed = pickle.load(fw)
|
110 |
+
|
111 |
+
for item in ["wavs", "wavsaux"]:
|
112 |
+
try:
|
113 |
+
self.d_out[item].extend(d_preprocessed[item])
|
114 |
+
except:
|
115 |
+
pass
|
116 |
+
|
117 |
+
for item in ["wavs", "wavsaux"]:
|
118 |
+
if self.d_out[item] != None:
|
119 |
+
self.d_out[item] = np.asarray(self.d_out[item])
|
120 |
+
|
121 |
+
if "dual" in self.config:
|
122 |
+
if self.config["dual"]["enable"]:
|
123 |
+
task_config = yaml.load(
|
124 |
+
open(config["dual"]["config_path"], "r"),
|
125 |
+
Loader=yaml.FullLoader,
|
126 |
+
)
|
127 |
+
task_preprocessed_dir = (
|
128 |
+
self.preprocessed_dir.parent
|
129 |
+
/ pathlib.Path(task_config["general"]["preprocessed_path"]).name
|
130 |
+
)
|
131 |
+
with open(task_preprocessed_dir / filetxt, "r") as fr:
|
132 |
+
task_filelist = [pathlib.Path(path.strip("\n")) for path in fr]
|
133 |
+
self.d_out["wavstask"] = []
|
134 |
+
for wp in task_filelist:
|
135 |
+
if task_config["general"]["corpus_type"] == "single":
|
136 |
+
basename = str(wp.stem)
|
137 |
+
else:
|
138 |
+
basename = str(wp.parent.name) + "-" + str(wp.stem)
|
139 |
+
with open(
|
140 |
+
task_preprocessed_dir / "{}.pickle".format(basename), "rb"
|
141 |
+
) as fw:
|
142 |
+
d_preprocessed = pickle.load(fw)
|
143 |
+
self.d_out["wavstask"].extend(d_preprocessed["wavs"])
|
144 |
+
self.d_out["wavstask"] = np.asarray(self.d_out["wavstask"])
|
145 |
+
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.d_out["wavs"])
|
148 |
+
|
149 |
+
def __getitem__(self, idx):
|
150 |
+
|
151 |
+
d_batch = {}
|
152 |
+
|
153 |
+
if self.d_out["wavs"].size > 0:
|
154 |
+
d_batch["wavs"] = torch.from_numpy(self.d_out["wavs"][idx])
|
155 |
+
if self.segment_length > 0:
|
156 |
+
d_batch["wavs"] = self.get_segment(d_batch["wavs"], self.segment_length)
|
157 |
+
|
158 |
+
if self.d_out["wavsaux"].size > 0:
|
159 |
+
d_batch["wavsaux"] = torch.from_numpy(self.d_out["wavsaux"][idx])
|
160 |
+
if self.segment_length > 0:
|
161 |
+
d_batch["wavsaux"] = self.get_segment(
|
162 |
+
d_batch["wavsaux"], self.segment_length
|
163 |
+
)
|
164 |
+
|
165 |
+
if self.config["general"]["stage"] == "pretrain":
|
166 |
+
if self.config["train"]["augment"]:
|
167 |
+
d_batch["wavs"] = self.augmentation(d_batch["wavsaux"])
|
168 |
+
d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
|
169 |
+
d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
|
170 |
+
if len(d_batch["wavs"]) != len(d_batch["wavsaux"]):
|
171 |
+
min_seq_len = min(len(d_batch["wavs"]), len(d_batch["wavsaux"]))
|
172 |
+
d_batch["wavs"] = d_batch["wavs"][:min_seq_len]
|
173 |
+
d_batch["wavsaux"] = d_batch["wavsaux"][:min_seq_len]
|
174 |
+
d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
|
175 |
+
if self.config["general"]["feature_type"] == "melspec":
|
176 |
+
d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
|
177 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
178 |
+
d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
|
179 |
+
d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
|
180 |
+
d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
|
181 |
+
else:
|
182 |
+
raise NotImplementedError()
|
183 |
+
|
184 |
+
elif self.config["general"]["stage"].startswith("ssl"):
|
185 |
+
d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
|
186 |
+
d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
|
187 |
+
if self.config["general"]["feature_type"] == "vocfeats":
|
188 |
+
d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
|
189 |
+
d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
|
190 |
+
if self.d_out["wavsaux"].size > 0:
|
191 |
+
d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
|
192 |
+
if self.config["general"]["feature_type"] == "melspec":
|
193 |
+
d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
|
194 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
195 |
+
d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
|
196 |
+
if "dual" in self.config:
|
197 |
+
if self.config["dual"]["enable"]:
|
198 |
+
d_batch["wavstask"] = torch.from_numpy(self.d_out["wavstask"][idx])
|
199 |
+
d_batch["wavstask"] = self.get_segment(
|
200 |
+
d_batch["wavstask"], self.segment_length
|
201 |
+
)
|
202 |
+
d_batch["wavstask"] = self.normalize_waveform(
|
203 |
+
d_batch["wavstask"], db=-3
|
204 |
+
)
|
205 |
+
if self.config["general"]["feature_type"] == "melspec":
|
206 |
+
d_batch["melspecstask"] = self.calc_spectrogram(
|
207 |
+
d_batch["wavstask"]
|
208 |
+
)
|
209 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
210 |
+
d_batch["melcepstask"] = self.calc_melcep(d_batch["wavstask"])
|
211 |
+
else:
|
212 |
+
raise NotImplementedError()
|
213 |
+
else:
|
214 |
+
raise NotImplementedError()
|
215 |
+
|
216 |
+
return d_batch
|
217 |
+
|
218 |
+
def calc_spectrogram(self, wav):
|
219 |
+
specs = self.spec_module(wav)
|
220 |
+
log_spec = torch.log(
|
221 |
+
torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
|
222 |
+
* self.config["preprocess"]["comp_factor"]
|
223 |
+
).to(torch.float32)
|
224 |
+
return log_spec
|
225 |
+
|
226 |
+
def calc_melcep(self, wav):
|
227 |
+
wav = wav.numpy()
|
228 |
+
_, sp, _ = pyworld.wav2world(
|
229 |
+
wav.astype(np.float64),
|
230 |
+
self.config["preprocess"]["sampling_rate"],
|
231 |
+
fft_size=self.config["preprocess"]["fft_length"],
|
232 |
+
frame_period=(
|
233 |
+
self.config["preprocess"]["frame_shift"]
|
234 |
+
/ self.config["preprocess"]["sampling_rate"]
|
235 |
+
* 1000
|
236 |
+
),
|
237 |
+
)
|
238 |
+
melcep = pysptk.sp2mc(
|
239 |
+
sp,
|
240 |
+
order=self.config["preprocess"]["cep_order"],
|
241 |
+
alpha=pysptk.util.mcepalpha(self.config["preprocess"]["sampling_rate"]),
|
242 |
+
).transpose(1, 0)
|
243 |
+
melcep = torch.from_numpy(melcep).to(torch.float32)
|
244 |
+
return melcep
|
245 |
+
|
246 |
+
def calc_f0(self, wav):
|
247 |
+
if self.config["preprocess"]["f0_extractor"] == "dio":
|
248 |
+
return self.calc_f0_dio(wav)
|
249 |
+
elif self.config["preprocess"]["f0_extractor"] == "harvest":
|
250 |
+
return self.calc_f0_harvest(wav)
|
251 |
+
elif self.config["preprocess"]["f0_extractor"] == "swipe":
|
252 |
+
return self.calc_f0_swipe(wav)
|
253 |
+
else:
|
254 |
+
raise NotImplementedError()
|
255 |
+
|
256 |
+
def calc_f0_dio(self, wav):
|
257 |
+
wav = wav.numpy()
|
258 |
+
_f0, _t = pyworld.dio(
|
259 |
+
wav.astype(np.float64),
|
260 |
+
self.config["preprocess"]["sampling_rate"],
|
261 |
+
frame_period=(
|
262 |
+
self.config["preprocess"]["frame_shift"]
|
263 |
+
/ self.config["preprocess"]["sampling_rate"]
|
264 |
+
* 1000
|
265 |
+
),
|
266 |
+
)
|
267 |
+
f0 = pyworld.stonemask(
|
268 |
+
wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
|
269 |
+
)
|
270 |
+
f0 = torch.from_numpy(f0).to(torch.float32)
|
271 |
+
return f0
|
272 |
+
|
273 |
+
def calc_f0_harvest(self, wav):
|
274 |
+
wav = wav.numpy()
|
275 |
+
_f0, _t = pyworld.harvest(
|
276 |
+
wav.astype(np.float64),
|
277 |
+
self.config["preprocess"]["sampling_rate"],
|
278 |
+
frame_period=(
|
279 |
+
self.config["preprocess"]["frame_shift"]
|
280 |
+
/ self.config["preprocess"]["sampling_rate"]
|
281 |
+
* 1000
|
282 |
+
),
|
283 |
+
)
|
284 |
+
f0 = pyworld.stonemask(
|
285 |
+
wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
|
286 |
+
)
|
287 |
+
f0 = torch.from_numpy(f0).to(torch.float32)
|
288 |
+
return f0
|
289 |
+
|
290 |
+
def calc_f0_swipe(self, wav):
|
291 |
+
wav = wav.numpy()
|
292 |
+
f0 = pysptk.sptk.swipe(
|
293 |
+
wav.astype(np.float64),
|
294 |
+
fs=self.config["preprocess"]["sampling_rate"],
|
295 |
+
min=71,
|
296 |
+
max=800,
|
297 |
+
hopsize=self.config["preprocess"]["frame_shift"],
|
298 |
+
otype="f0",
|
299 |
+
)
|
300 |
+
f0 = torch.from_numpy(f0).to(torch.float32)
|
301 |
+
return f0
|
302 |
+
|
303 |
+
def augmentation(self, wav):
|
304 |
+
wav /= torch.max(torch.abs(wav))
|
305 |
+
new_freq = random.choice(self.resample_candidate)
|
306 |
+
new_quantization = random.choice(self.quantization_candidate)
|
307 |
+
mulaw_encoder = torchaudio.transforms.MuLawEncoding(
|
308 |
+
quantization_channels=new_quantization
|
309 |
+
)
|
310 |
+
wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0
|
311 |
+
downsampler = torchaudio.transforms.Resample(
|
312 |
+
orig_freq=self.config["preprocess"]["sampling_rate"],
|
313 |
+
new_freq=new_freq,
|
314 |
+
resampling_method="sinc_interpolation",
|
315 |
+
lowpass_filter_width=6,
|
316 |
+
dtype=torch.float32,
|
317 |
+
)
|
318 |
+
upsampler = torchaudio.transforms.Resample(
|
319 |
+
orig_freq=new_freq,
|
320 |
+
new_freq=self.config["preprocess"]["sampling_rate"],
|
321 |
+
resampling_method="sinc_interpolation",
|
322 |
+
lowpass_filter_width=6,
|
323 |
+
dtype=torch.float32,
|
324 |
+
)
|
325 |
+
wav_processed = upsampler(downsampler(wav_quantized))
|
326 |
+
return wav_processed
|
327 |
+
|
328 |
+
def normalize_waveform(self, wav, db=-3):
|
329 |
+
wav, _ = torchaudio.sox_effects.apply_effects_tensor(
|
330 |
+
wav.unsqueeze(0),
|
331 |
+
self.config["preprocess"]["sampling_rate"],
|
332 |
+
[["norm", "{}".format(db)]],
|
333 |
+
)
|
334 |
+
return wav.squeeze(0)
|
335 |
+
|
336 |
+
def get_segment(self, wav, segment_length):
|
337 |
+
seg_size = self.config["preprocess"]["sampling_rate"] * segment_length
|
338 |
+
if len(wav) >= seg_size:
|
339 |
+
max_wav_start = len(wav) - seg_size
|
340 |
+
wav_start = random.randint(0, max_wav_start)
|
341 |
+
wav = wav[wav_start : wav_start + seg_size]
|
342 |
+
else:
|
343 |
+
wav = torch.nn.functional.pad(wav, (0, seg_size - len(wav)), "constant")
|
344 |
+
return wav
|
eval.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
from pytorch_lightning import Trainer
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
9 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
10 |
+
|
11 |
+
from dataset import DataModule
|
12 |
+
from lightning_module import (
|
13 |
+
PretrainLightningModule,
|
14 |
+
SSLStepLightningModule,
|
15 |
+
SSLDualLightningModule,
|
16 |
+
)
|
17 |
+
|
18 |
+
|
19 |
+
def get_arg():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--config_path", required=True, type=pathlib.Path)
|
22 |
+
parser.add_argument("--ckpt_path", required=True, type=pathlib.Path)
|
23 |
+
parser.add_argument(
|
24 |
+
"--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"]
|
25 |
+
)
|
26 |
+
parser.add_argument("--run_name", required=True, type=str)
|
27 |
+
return parser.parse_args()
|
28 |
+
|
29 |
+
|
30 |
+
def eval(args, config, output_path):
|
31 |
+
|
32 |
+
csvlogger = CSVLogger(save_dir=output_path, name="test_log")
|
33 |
+
trainer = Trainer(
|
34 |
+
gpus=-1,
|
35 |
+
deterministic=False,
|
36 |
+
auto_select_gpus=True,
|
37 |
+
benchmark=True,
|
38 |
+
logger=[csvlogger],
|
39 |
+
default_root_dir=os.getcwd(),
|
40 |
+
)
|
41 |
+
|
42 |
+
if config["general"]["stage"] == "pretrain":
|
43 |
+
model = PretrainLightningModule(config).load_from_checkpoint(
|
44 |
+
checkpoint_path=args.ckpt_path, config=config
|
45 |
+
)
|
46 |
+
elif config["general"]["stage"] == "ssl-step":
|
47 |
+
model = SSLStepLightningModule(config).load_from_checkpoint(
|
48 |
+
checkpoint_path=args.ckpt_path, config=config
|
49 |
+
)
|
50 |
+
elif config["general"]["stage"] == "ssl-dual":
|
51 |
+
model = SSLDualLightningModule(config).load_from_checkpoint(
|
52 |
+
checkpoint_path=args.ckpt_path, config=config
|
53 |
+
)
|
54 |
+
else:
|
55 |
+
raise NotImplementedError()
|
56 |
+
|
57 |
+
datamodule = DataModule(config)
|
58 |
+
trainer.test(model=model, verbose=True, datamodule=datamodule)
|
59 |
+
|
60 |
+
|
61 |
+
if __name__ == "__main__":
|
62 |
+
args = get_arg()
|
63 |
+
config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
|
64 |
+
output_path = str(pathlib.Path(config["general"]["output_path"]) / args.run_name)
|
65 |
+
config["general"]["stage"] = str(getattr(args, "stage"))
|
66 |
+
|
67 |
+
eval(args, config, output_path)
|
lightning_module.py
ADDED
@@ -0,0 +1,875 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import pytorch_lightning as pl
|
3 |
+
import torchaudio
|
4 |
+
import os
|
5 |
+
import pathlib
|
6 |
+
import tqdm
|
7 |
+
from model import (
|
8 |
+
EncoderModule,
|
9 |
+
ChannelFeatureModule,
|
10 |
+
ChannelModule,
|
11 |
+
MultiScaleSpectralLoss,
|
12 |
+
GSTModule,
|
13 |
+
)
|
14 |
+
from utils import (
|
15 |
+
manual_logging,
|
16 |
+
load_vocoder,
|
17 |
+
plot_and_save_mels,
|
18 |
+
plot_and_save_mels_all,
|
19 |
+
)
|
20 |
+
|
21 |
+
|
22 |
+
class PretrainLightningModule(pl.LightningModule):
|
23 |
+
def __init__(self, config):
|
24 |
+
super().__init__()
|
25 |
+
self.save_hyperparameters()
|
26 |
+
self.config = config
|
27 |
+
if config["general"]["use_gst"]:
|
28 |
+
self.encoder = EncoderModule(config)
|
29 |
+
self.gst = GSTModule(config)
|
30 |
+
else:
|
31 |
+
self.encoder = EncoderModule(config, use_channel=True)
|
32 |
+
self.channelfeats = ChannelFeatureModule(config)
|
33 |
+
|
34 |
+
self.channel = ChannelModule(config)
|
35 |
+
self.vocoder = load_vocoder(config)
|
36 |
+
|
37 |
+
self.criteria_a = MultiScaleSpectralLoss(config)
|
38 |
+
if "feature_loss" in config["train"]:
|
39 |
+
if config["train"]["feature_loss"]["type"] == "mae":
|
40 |
+
self.criteria_b = torch.nn.L1Loss()
|
41 |
+
else:
|
42 |
+
self.criteria_b = torch.nn.MSELoss()
|
43 |
+
else:
|
44 |
+
self.criteria = torch.nn.L1Loss()
|
45 |
+
self.alpha = config["train"]["alpha"]
|
46 |
+
|
47 |
+
def forward(self, melspecs, wavsaux):
|
48 |
+
if self.config["general"]["use_gst"]:
|
49 |
+
enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
|
50 |
+
chfeats = self.gst(melspecs.transpose(1, 2))
|
51 |
+
else:
|
52 |
+
enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
|
53 |
+
chfeats = self.channelfeats(enc_hidden)
|
54 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
55 |
+
wavsdeg = self.channel(wavsaux, chfeats)
|
56 |
+
return enc_out, wavsdeg
|
57 |
+
|
58 |
+
def training_step(self, batch, batch_idx):
|
59 |
+
if self.config["general"]["use_gst"]:
|
60 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
61 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
62 |
+
else:
|
63 |
+
enc_out, enc_hidden = self.encoder(
|
64 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
65 |
+
)
|
66 |
+
chfeats = self.channelfeats(enc_hidden)
|
67 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
68 |
+
wavsdeg = self.channel(batch["wavsaux"], chfeats)
|
69 |
+
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
|
70 |
+
if self.config["general"]["feature_type"] == "melspec":
|
71 |
+
loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"])
|
72 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
73 |
+
loss_encoder = self.criteria_b(enc_out, batch["melceps"])
|
74 |
+
loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
|
75 |
+
self.log(
|
76 |
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
77 |
+
)
|
78 |
+
self.log(
|
79 |
+
"train_loss_recons",
|
80 |
+
loss_recons,
|
81 |
+
on_step=True,
|
82 |
+
on_epoch=True,
|
83 |
+
prog_bar=True,
|
84 |
+
logger=True,
|
85 |
+
)
|
86 |
+
self.log(
|
87 |
+
"train_loss_encoder",
|
88 |
+
loss_encoder,
|
89 |
+
on_step=True,
|
90 |
+
on_epoch=True,
|
91 |
+
prog_bar=True,
|
92 |
+
logger=True,
|
93 |
+
)
|
94 |
+
return loss
|
95 |
+
|
96 |
+
def validation_step(self, batch, batch_idx):
|
97 |
+
if self.config["general"]["use_gst"]:
|
98 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
99 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
100 |
+
else:
|
101 |
+
enc_out, enc_hidden = self.encoder(
|
102 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
103 |
+
)
|
104 |
+
chfeats = self.channelfeats(enc_hidden)
|
105 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
106 |
+
wavsdeg = self.channel(batch["wavsaux"], chfeats)
|
107 |
+
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
|
108 |
+
if self.config["general"]["feature_type"] == "melspec":
|
109 |
+
val_aux_feats = batch["melspecsaux"]
|
110 |
+
feats_name = "melspec"
|
111 |
+
loss_encoder = self.criteria_b(enc_out, val_aux_feats)
|
112 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
113 |
+
val_aux_feats = batch["melceps"]
|
114 |
+
feats_name = "melcep"
|
115 |
+
loss_encoder = self.criteria_b(enc_out, val_aux_feats)
|
116 |
+
loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
|
117 |
+
logger_img_dict = {
|
118 |
+
"val_src_melspec": batch["melspecs"],
|
119 |
+
"val_pred_{}".format(feats_name): enc_out,
|
120 |
+
"val_aux_{}".format(feats_name): val_aux_feats,
|
121 |
+
}
|
122 |
+
logger_wav_dict = {
|
123 |
+
"val_src_wav": batch["wavs"],
|
124 |
+
"val_pred_wav": wavsdeg,
|
125 |
+
"val_aux_wav": batch["wavsaux"],
|
126 |
+
}
|
127 |
+
return {
|
128 |
+
"val_loss": loss,
|
129 |
+
"val_loss_recons": loss_recons,
|
130 |
+
"val_loss_encoder": loss_encoder,
|
131 |
+
"logger_dict": [logger_img_dict, logger_wav_dict],
|
132 |
+
}
|
133 |
+
|
134 |
+
def validation_epoch_end(self, outputs):
|
135 |
+
val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item()
|
136 |
+
val_loss_recons = (
|
137 |
+
torch.stack([out["val_loss_recons"] for out in outputs]).mean().item()
|
138 |
+
)
|
139 |
+
val_loss_encoder = (
|
140 |
+
torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item()
|
141 |
+
)
|
142 |
+
self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True)
|
143 |
+
self.log(
|
144 |
+
"val_loss_recons",
|
145 |
+
val_loss_recons,
|
146 |
+
on_epoch=True,
|
147 |
+
prog_bar=True,
|
148 |
+
logger=True,
|
149 |
+
)
|
150 |
+
self.log(
|
151 |
+
"val_loss_encoder",
|
152 |
+
val_loss_encoder,
|
153 |
+
on_epoch=True,
|
154 |
+
prog_bar=True,
|
155 |
+
logger=True,
|
156 |
+
)
|
157 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
158 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
159 |
+
|
160 |
+
def test_step(self, batch, batch_idx):
|
161 |
+
if self.config["general"]["use_gst"]:
|
162 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
163 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
164 |
+
else:
|
165 |
+
enc_out, enc_hidden = self.encoder(
|
166 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
167 |
+
)
|
168 |
+
chfeats = self.channelfeats(enc_hidden)
|
169 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
170 |
+
wavsdeg = self.channel(batch["wavsaux"], chfeats)
|
171 |
+
if self.config["general"]["feature_type"] == "melspec":
|
172 |
+
enc_feats = enc_out
|
173 |
+
enc_feats_aux = batch["melspecsaux"]
|
174 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
175 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
176 |
+
enc_feats_aux = torch.cat(
|
177 |
+
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
|
178 |
+
)
|
179 |
+
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
|
180 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
181 |
+
if self.config["general"]["feature_type"] == "melspec":
|
182 |
+
enc_feats_input = batch["melspecs"]
|
183 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
184 |
+
enc_feats_input = torch.cat(
|
185 |
+
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
|
186 |
+
)
|
187 |
+
input_recons = self.vocoder(enc_feats_input).squeeze(1)
|
188 |
+
if "wavsaux" in batch:
|
189 |
+
gt_wav = batch["wavsaux"]
|
190 |
+
else:
|
191 |
+
gt_wav = None
|
192 |
+
return {
|
193 |
+
"reconstructed": recons_wav,
|
194 |
+
"remastered": remas,
|
195 |
+
"channeled": wavsdeg,
|
196 |
+
"groundtruth": gt_wav,
|
197 |
+
"input": batch["wavs"],
|
198 |
+
"input_recons": input_recons,
|
199 |
+
}
|
200 |
+
|
201 |
+
def test_epoch_end(self, outputs):
|
202 |
+
wav_dir = (
|
203 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
|
204 |
+
)
|
205 |
+
os.makedirs(wav_dir, exist_ok=True)
|
206 |
+
mel_dir = (
|
207 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
|
208 |
+
)
|
209 |
+
os.makedirs(mel_dir, exist_ok=True)
|
210 |
+
print("Saving mel spectrogram plots ...")
|
211 |
+
for idx, out in enumerate(tqdm.tqdm(outputs)):
|
212 |
+
for key in [
|
213 |
+
"reconstructed",
|
214 |
+
"remastered",
|
215 |
+
"channeled",
|
216 |
+
"input",
|
217 |
+
"input_recons",
|
218 |
+
"groundtruth",
|
219 |
+
]:
|
220 |
+
if out[key] != None:
|
221 |
+
torchaudio.save(
|
222 |
+
wav_dir / "{}-{}.wav".format(idx, key),
|
223 |
+
out[key][0, ...].unsqueeze(0).cpu(),
|
224 |
+
sample_rate=self.config["preprocess"]["sampling_rate"],
|
225 |
+
channels_first=True,
|
226 |
+
)
|
227 |
+
plot_and_save_mels(
|
228 |
+
out[key][0, ...].cpu(),
|
229 |
+
mel_dir / "{}-{}.png".format(idx, key),
|
230 |
+
self.config,
|
231 |
+
)
|
232 |
+
plot_and_save_mels_all(
|
233 |
+
out,
|
234 |
+
[
|
235 |
+
"reconstructed",
|
236 |
+
"remastered",
|
237 |
+
"channeled",
|
238 |
+
"input",
|
239 |
+
"input_recons",
|
240 |
+
"groundtruth",
|
241 |
+
],
|
242 |
+
mel_dir / "{}-all.png".format(idx),
|
243 |
+
self.config,
|
244 |
+
)
|
245 |
+
|
246 |
+
def configure_optimizers(self):
|
247 |
+
optimizer = torch.optim.Adam(
|
248 |
+
self.parameters(), lr=self.config["train"]["learning_rate"]
|
249 |
+
)
|
250 |
+
lr_scheduler_config = {
|
251 |
+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
252 |
+
optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
|
253 |
+
),
|
254 |
+
"interval": "epoch",
|
255 |
+
"frequency": 3,
|
256 |
+
"monitor": "val_loss",
|
257 |
+
}
|
258 |
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
259 |
+
|
260 |
+
def tflogger(self, logger_dict, data_type):
|
261 |
+
for lg in self.logger.experiment:
|
262 |
+
if type(lg).__name__ == "SummaryWriter":
|
263 |
+
tensorboard = lg
|
264 |
+
for key in logger_dict.keys():
|
265 |
+
manual_logging(
|
266 |
+
logger=tensorboard,
|
267 |
+
item=logger_dict[key],
|
268 |
+
idx=0,
|
269 |
+
tag=key,
|
270 |
+
global_step=self.global_step,
|
271 |
+
data_type=data_type,
|
272 |
+
config=self.config,
|
273 |
+
)
|
274 |
+
|
275 |
+
|
276 |
+
class SSLBaseModule(pl.LightningModule):
|
277 |
+
def __init__(self, config):
|
278 |
+
super().__init__()
|
279 |
+
self.save_hyperparameters()
|
280 |
+
self.config = config
|
281 |
+
if config["general"]["use_gst"]:
|
282 |
+
self.encoder = EncoderModule(config)
|
283 |
+
self.gst = GSTModule(config)
|
284 |
+
else:
|
285 |
+
self.encoder = EncoderModule(config, use_channel=True)
|
286 |
+
self.channelfeats = ChannelFeatureModule(config)
|
287 |
+
self.channel = ChannelModule(config)
|
288 |
+
|
289 |
+
if config["train"]["load_pretrained"]:
|
290 |
+
pre_model = PretrainLightningModule.load_from_checkpoint(
|
291 |
+
checkpoint_path=config["train"]["pretrained_path"]
|
292 |
+
)
|
293 |
+
self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False)
|
294 |
+
self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False)
|
295 |
+
if config["general"]["use_gst"]:
|
296 |
+
self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False)
|
297 |
+
else:
|
298 |
+
self.channelfeats.load_state_dict(
|
299 |
+
pre_model.channelfeats.state_dict(), strict=False
|
300 |
+
)
|
301 |
+
|
302 |
+
self.vocoder = load_vocoder(config)
|
303 |
+
self.criteria = self.get_loss_function(config)
|
304 |
+
|
305 |
+
def training_step(self, batch, batch_idx):
|
306 |
+
raise NotImplementedError()
|
307 |
+
|
308 |
+
def validation_step(self, batch, batch_idx):
|
309 |
+
raise NotImplementedError()
|
310 |
+
|
311 |
+
def validation_epoch_end(self, outputs):
|
312 |
+
raise NotImplementedError()
|
313 |
+
|
314 |
+
def configure_optimizers(self):
|
315 |
+
raise NotImplementedError()
|
316 |
+
|
317 |
+
def get_loss_function(self, config):
|
318 |
+
raise NotImplementedError()
|
319 |
+
|
320 |
+
def forward(self, melspecs, f0s=None):
|
321 |
+
if self.config["general"]["use_gst"]:
|
322 |
+
enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
|
323 |
+
chfeats = self.gst(melspecs.transpose(1, 2))
|
324 |
+
else:
|
325 |
+
enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
|
326 |
+
chfeats = self.channelfeats(enc_hidden)
|
327 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
328 |
+
if self.config["general"]["feature_type"] == "melspec":
|
329 |
+
enc_feats = enc_out
|
330 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
331 |
+
enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1)
|
332 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
333 |
+
wavsdeg = self.channel(remas, chfeats)
|
334 |
+
return remas, wavsdeg
|
335 |
+
|
336 |
+
def test_step(self, batch, batch_idx):
|
337 |
+
if self.config["general"]["use_gst"]:
|
338 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
339 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
340 |
+
else:
|
341 |
+
enc_out, enc_hidden = self.encoder(
|
342 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
343 |
+
)
|
344 |
+
chfeats = self.channelfeats(enc_hidden)
|
345 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
346 |
+
if self.config["general"]["feature_type"] == "melspec":
|
347 |
+
enc_feats = enc_out
|
348 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
349 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
350 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
351 |
+
wavsdeg = self.channel(remas, chfeats)
|
352 |
+
if self.config["general"]["feature_type"] == "melspec":
|
353 |
+
enc_feats_input = batch["melspecs"]
|
354 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
355 |
+
enc_feats_input = torch.cat(
|
356 |
+
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
|
357 |
+
)
|
358 |
+
input_recons = self.vocoder(enc_feats_input).squeeze(1)
|
359 |
+
if "wavsaux" in batch:
|
360 |
+
gt_wav = batch["wavsaux"]
|
361 |
+
if self.config["general"]["feature_type"] == "melspec":
|
362 |
+
enc_feats_aux = batch["melspecsaux"]
|
363 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
364 |
+
enc_feats_aux = torch.cat(
|
365 |
+
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
|
366 |
+
)
|
367 |
+
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
|
368 |
+
else:
|
369 |
+
gt_wav = None
|
370 |
+
recons_wav = None
|
371 |
+
return {
|
372 |
+
"reconstructed": recons_wav,
|
373 |
+
"remastered": remas,
|
374 |
+
"channeled": wavsdeg,
|
375 |
+
"input": batch["wavs"],
|
376 |
+
"input_recons": input_recons,
|
377 |
+
"groundtruth": gt_wav,
|
378 |
+
}
|
379 |
+
|
380 |
+
def test_epoch_end(self, outputs):
|
381 |
+
wav_dir = (
|
382 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
|
383 |
+
)
|
384 |
+
os.makedirs(wav_dir, exist_ok=True)
|
385 |
+
mel_dir = (
|
386 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
|
387 |
+
)
|
388 |
+
os.makedirs(mel_dir, exist_ok=True)
|
389 |
+
print("Saving mel spectrogram plots ...")
|
390 |
+
for idx, out in enumerate(tqdm.tqdm(outputs)):
|
391 |
+
plot_keys = []
|
392 |
+
for key in [
|
393 |
+
"reconstructed",
|
394 |
+
"remastered",
|
395 |
+
"channeled",
|
396 |
+
"input",
|
397 |
+
"input_recons",
|
398 |
+
"groundtruth",
|
399 |
+
]:
|
400 |
+
if out[key] != None:
|
401 |
+
plot_keys.append(key)
|
402 |
+
torchaudio.save(
|
403 |
+
wav_dir / "{}-{}.wav".format(idx, key),
|
404 |
+
out[key][0, ...].unsqueeze(0).cpu(),
|
405 |
+
sample_rate=self.config["preprocess"]["sampling_rate"],
|
406 |
+
channels_first=True,
|
407 |
+
)
|
408 |
+
plot_and_save_mels(
|
409 |
+
out[key][0, ...].cpu(),
|
410 |
+
mel_dir / "{}-{}.png".format(idx, key),
|
411 |
+
self.config,
|
412 |
+
)
|
413 |
+
plot_and_save_mels_all(
|
414 |
+
out,
|
415 |
+
plot_keys,
|
416 |
+
mel_dir / "{}-all.png".format(idx),
|
417 |
+
self.config,
|
418 |
+
)
|
419 |
+
|
420 |
+
def tflogger(self, logger_dict, data_type):
|
421 |
+
for lg in self.logger.experiment:
|
422 |
+
if type(lg).__name__ == "SummaryWriter":
|
423 |
+
tensorboard = lg
|
424 |
+
for key in logger_dict.keys():
|
425 |
+
manual_logging(
|
426 |
+
logger=tensorboard,
|
427 |
+
item=logger_dict[key],
|
428 |
+
idx=0,
|
429 |
+
tag=key,
|
430 |
+
global_step=self.global_step,
|
431 |
+
data_type=data_type,
|
432 |
+
config=self.config,
|
433 |
+
)
|
434 |
+
|
435 |
+
|
436 |
+
class SSLStepLightningModule(SSLBaseModule):
|
437 |
+
def __init__(self, config):
|
438 |
+
super().__init__(config)
|
439 |
+
if config["train"]["fix_channel"]:
|
440 |
+
for param in self.channel.parameters():
|
441 |
+
param.requires_grad = False
|
442 |
+
|
443 |
+
def training_step(self, batch, batch_idx, optimizer_idx):
|
444 |
+
if self.config["general"]["use_gst"]:
|
445 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
446 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
447 |
+
else:
|
448 |
+
enc_out, enc_hidden = self.encoder(
|
449 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
450 |
+
)
|
451 |
+
chfeats = self.channelfeats(enc_hidden)
|
452 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
453 |
+
if self.config["general"]["feature_type"] == "melspec":
|
454 |
+
enc_feats = enc_out
|
455 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
456 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
457 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
458 |
+
wavsdeg = self.channel(remas, chfeats)
|
459 |
+
loss = self.criteria(wavsdeg, batch["wavs"])
|
460 |
+
self.log(
|
461 |
+
"train_loss",
|
462 |
+
loss,
|
463 |
+
on_step=True,
|
464 |
+
on_epoch=True,
|
465 |
+
prog_bar=True,
|
466 |
+
logger=True,
|
467 |
+
)
|
468 |
+
return loss
|
469 |
+
|
470 |
+
def validation_step(self, batch, batch_idx):
|
471 |
+
if self.config["general"]["use_gst"]:
|
472 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
473 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
474 |
+
else:
|
475 |
+
enc_out, enc_hidden = self.encoder(
|
476 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
477 |
+
)
|
478 |
+
chfeats = self.channelfeats(enc_hidden)
|
479 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
480 |
+
if self.config["general"]["feature_type"] == "melspec":
|
481 |
+
enc_feats = enc_out
|
482 |
+
feats_name = "melspec"
|
483 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
484 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
485 |
+
feats_name = "melcep"
|
486 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
487 |
+
wavsdeg = self.channel(remas, chfeats)
|
488 |
+
loss = self.criteria(wavsdeg, batch["wavs"])
|
489 |
+
logger_img_dict = {
|
490 |
+
"val_src_melspec": batch["melspecs"],
|
491 |
+
"val_pred_{}".format(feats_name): enc_out,
|
492 |
+
}
|
493 |
+
for auxfeats in ["melceps", "melspecsaux"]:
|
494 |
+
if auxfeats in batch:
|
495 |
+
logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
|
496 |
+
logger_wav_dict = {
|
497 |
+
"val_src_wav": batch["wavs"],
|
498 |
+
"val_remastered_wav": remas,
|
499 |
+
"val_pred_wav": wavsdeg,
|
500 |
+
}
|
501 |
+
if "wavsaux" in batch:
|
502 |
+
logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
|
503 |
+
d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]}
|
504 |
+
return d_out
|
505 |
+
|
506 |
+
def validation_epoch_end(self, outputs):
|
507 |
+
self.log(
|
508 |
+
"val_loss",
|
509 |
+
torch.stack([out["val_loss"] for out in outputs]).mean().item(),
|
510 |
+
on_epoch=True,
|
511 |
+
prog_bar=True,
|
512 |
+
logger=True,
|
513 |
+
)
|
514 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
515 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
516 |
+
|
517 |
+
def optimizer_step(
|
518 |
+
self,
|
519 |
+
epoch,
|
520 |
+
batch_idx,
|
521 |
+
optimizer,
|
522 |
+
optimizer_idx,
|
523 |
+
optimizer_closure,
|
524 |
+
on_tpu=False,
|
525 |
+
using_native_amp=False,
|
526 |
+
using_lbfgs=False,
|
527 |
+
):
|
528 |
+
if epoch < self.config["train"]["epoch_channel"]:
|
529 |
+
if optimizer_idx == 0:
|
530 |
+
optimizer.step(closure=optimizer_closure)
|
531 |
+
elif optimizer_idx == 1:
|
532 |
+
optimizer_closure()
|
533 |
+
else:
|
534 |
+
if optimizer_idx == 0:
|
535 |
+
optimizer_closure()
|
536 |
+
elif optimizer_idx == 1:
|
537 |
+
optimizer.step(closure=optimizer_closure)
|
538 |
+
|
539 |
+
def configure_optimizers(self):
|
540 |
+
if self.config["train"]["fix_channel"]:
|
541 |
+
if self.config["general"]["use_gst"]:
|
542 |
+
optimizer_channel = torch.optim.Adam(
|
543 |
+
self.gst.parameters(), lr=self.config["train"]["learning_rate"]
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
optimizer_channel = torch.optim.Adam(
|
547 |
+
self.channelfeats.parameters(),
|
548 |
+
lr=self.config["train"]["learning_rate"],
|
549 |
+
)
|
550 |
+
optimizer_encoder = torch.optim.Adam(
|
551 |
+
self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
|
552 |
+
)
|
553 |
+
else:
|
554 |
+
if self.config["general"]["use_gst"]:
|
555 |
+
optimizer_channel = torch.optim.Adam(
|
556 |
+
[
|
557 |
+
{"params": self.channel.parameters()},
|
558 |
+
{"params": self.gst.parameters()},
|
559 |
+
],
|
560 |
+
lr=self.config["train"]["learning_rate"],
|
561 |
+
)
|
562 |
+
else:
|
563 |
+
optimizer_channel = torch.optim.Adam(
|
564 |
+
[
|
565 |
+
{"params": self.channel.parameters()},
|
566 |
+
{"params": self.channelfeats.parameters()},
|
567 |
+
],
|
568 |
+
lr=self.config["train"]["learning_rate"],
|
569 |
+
)
|
570 |
+
optimizer_encoder = torch.optim.Adam(
|
571 |
+
self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
|
572 |
+
)
|
573 |
+
optimizers = [optimizer_channel, optimizer_encoder]
|
574 |
+
schedulers = [
|
575 |
+
{
|
576 |
+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
577 |
+
optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True
|
578 |
+
),
|
579 |
+
"interval": "epoch",
|
580 |
+
"frequency": 3,
|
581 |
+
"monitor": "val_loss",
|
582 |
+
},
|
583 |
+
{
|
584 |
+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
585 |
+
optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True
|
586 |
+
),
|
587 |
+
"interval": "epoch",
|
588 |
+
"frequency": 3,
|
589 |
+
"monitor": "val_loss",
|
590 |
+
},
|
591 |
+
]
|
592 |
+
return optimizers, schedulers
|
593 |
+
|
594 |
+
def get_loss_function(self, config):
|
595 |
+
return MultiScaleSpectralLoss(config)
|
596 |
+
|
597 |
+
|
598 |
+
class SSLDualLightningModule(SSLBaseModule):
|
599 |
+
def __init__(self, config):
|
600 |
+
super().__init__(config)
|
601 |
+
if config["train"]["fix_channel"]:
|
602 |
+
for param in self.channel.parameters():
|
603 |
+
param.requires_grad = False
|
604 |
+
self.spec_module = torchaudio.transforms.MelSpectrogram(
|
605 |
+
sample_rate=config["preprocess"]["sampling_rate"],
|
606 |
+
n_fft=config["preprocess"]["fft_length"],
|
607 |
+
win_length=config["preprocess"]["frame_length"],
|
608 |
+
hop_length=config["preprocess"]["frame_shift"],
|
609 |
+
f_min=config["preprocess"]["fmin"],
|
610 |
+
f_max=config["preprocess"]["fmax"],
|
611 |
+
n_mels=config["preprocess"]["n_mels"],
|
612 |
+
power=1,
|
613 |
+
center=True,
|
614 |
+
norm="slaney",
|
615 |
+
mel_scale="slaney",
|
616 |
+
)
|
617 |
+
self.beta = config["train"]["beta"]
|
618 |
+
self.criteria_a, self.criteria_b = self.get_loss_function(config)
|
619 |
+
|
620 |
+
def training_step(self, batch, batch_idx):
|
621 |
+
if self.config["general"]["use_gst"]:
|
622 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
623 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
624 |
+
else:
|
625 |
+
enc_out, enc_hidden = self.encoder(
|
626 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
627 |
+
)
|
628 |
+
chfeats = self.channelfeats(enc_hidden)
|
629 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
630 |
+
if self.config["general"]["feature_type"] == "melspec":
|
631 |
+
enc_feats = enc_out
|
632 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
633 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
634 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
635 |
+
wavsdeg = self.channel(remas, chfeats)
|
636 |
+
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
|
637 |
+
|
638 |
+
with torch.no_grad():
|
639 |
+
wavsdegtask = self.channel(batch["wavstask"], chfeats)
|
640 |
+
melspecstask = self.calc_spectrogram(wavsdegtask)
|
641 |
+
if self.config["general"]["use_gst"]:
|
642 |
+
enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
|
643 |
+
else:
|
644 |
+
enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
|
645 |
+
enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
|
646 |
+
if self.config["general"]["feature_type"] == "melspec":
|
647 |
+
loss_task = self.criteria_b(enc_out_task, batch["melspecstask"])
|
648 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
649 |
+
loss_task = self.criteria_b(enc_out_task, batch["melcepstask"])
|
650 |
+
loss = self.beta * loss_recons + (1 - self.beta) * loss_task
|
651 |
+
|
652 |
+
self.log(
|
653 |
+
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
|
654 |
+
)
|
655 |
+
self.log(
|
656 |
+
"train_loss_recons",
|
657 |
+
loss_recons,
|
658 |
+
on_step=True,
|
659 |
+
on_epoch=True,
|
660 |
+
prog_bar=True,
|
661 |
+
logger=True,
|
662 |
+
)
|
663 |
+
self.log(
|
664 |
+
"train_loss_task",
|
665 |
+
loss_task,
|
666 |
+
on_step=True,
|
667 |
+
on_epoch=True,
|
668 |
+
prog_bar=True,
|
669 |
+
logger=True,
|
670 |
+
)
|
671 |
+
return loss
|
672 |
+
|
673 |
+
def validation_step(self, batch, batch_idx):
|
674 |
+
if self.config["general"]["use_gst"]:
|
675 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
676 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
677 |
+
else:
|
678 |
+
enc_out, enc_hidden = self.encoder(
|
679 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
680 |
+
)
|
681 |
+
chfeats = self.channelfeats(enc_hidden)
|
682 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
683 |
+
if self.config["general"]["feature_type"] == "melspec":
|
684 |
+
enc_feats = enc_out
|
685 |
+
feats_name = "melspec"
|
686 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
687 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
688 |
+
feats_name = "melcep"
|
689 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
690 |
+
wavsdeg = self.channel(remas, chfeats)
|
691 |
+
loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
|
692 |
+
|
693 |
+
wavsdegtask = self.channel(batch["wavstask"], chfeats)
|
694 |
+
melspecstask = self.calc_spectrogram(wavsdegtask)
|
695 |
+
if self.config["general"]["use_gst"]:
|
696 |
+
enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
|
697 |
+
else:
|
698 |
+
enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
|
699 |
+
enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
|
700 |
+
if self.config["general"]["feature_type"] == "melspec":
|
701 |
+
enc_out_task_truth = batch["melspecstask"]
|
702 |
+
loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
|
703 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
704 |
+
enc_out_task_truth = batch["melcepstask"]
|
705 |
+
loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
|
706 |
+
loss = self.beta * loss_recons + (1 - self.beta) * loss_task
|
707 |
+
|
708 |
+
logger_img_dict = {
|
709 |
+
"val_src_melspec": batch["melspecs"],
|
710 |
+
"val_pred_{}".format(feats_name): enc_out,
|
711 |
+
"val_truth_{}_task".format(feats_name): enc_out_task_truth,
|
712 |
+
"val_pred_{}_task".format(feats_name): enc_out_task,
|
713 |
+
}
|
714 |
+
for auxfeats in ["melceps", "melspecsaux"]:
|
715 |
+
if auxfeats in batch:
|
716 |
+
logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
|
717 |
+
logger_wav_dict = {
|
718 |
+
"val_src_wav": batch["wavs"],
|
719 |
+
"val_remastered_wav": remas,
|
720 |
+
"val_pred_wav": wavsdeg,
|
721 |
+
"val_truth_wavtask": batch["wavstask"],
|
722 |
+
"val_deg_wavtask": wavsdegtask,
|
723 |
+
}
|
724 |
+
if "wavsaux" in batch:
|
725 |
+
logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
|
726 |
+
|
727 |
+
d_out = {
|
728 |
+
"val_loss": loss,
|
729 |
+
"val_loss_recons": loss_recons,
|
730 |
+
"val_loss_task": loss_task,
|
731 |
+
"logger_dict": [logger_img_dict, logger_wav_dict],
|
732 |
+
}
|
733 |
+
return d_out
|
734 |
+
|
735 |
+
def validation_epoch_end(self, outputs):
|
736 |
+
self.log(
|
737 |
+
"val_loss",
|
738 |
+
torch.stack([out["val_loss"] for out in outputs]).mean().item(),
|
739 |
+
on_epoch=True,
|
740 |
+
prog_bar=True,
|
741 |
+
logger=True,
|
742 |
+
)
|
743 |
+
self.log(
|
744 |
+
"val_loss_recons",
|
745 |
+
torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(),
|
746 |
+
on_epoch=True,
|
747 |
+
prog_bar=True,
|
748 |
+
logger=True,
|
749 |
+
)
|
750 |
+
self.log(
|
751 |
+
"val_loss_task",
|
752 |
+
torch.stack([out["val_loss_task"] for out in outputs]).mean().item(),
|
753 |
+
on_epoch=True,
|
754 |
+
prog_bar=True,
|
755 |
+
logger=True,
|
756 |
+
)
|
757 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
|
758 |
+
self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
|
759 |
+
|
760 |
+
def test_step(self, batch, batch_idx):
|
761 |
+
if self.config["general"]["use_gst"]:
|
762 |
+
enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
|
763 |
+
chfeats = self.gst(batch["melspecs"].transpose(1, 2))
|
764 |
+
else:
|
765 |
+
enc_out, enc_hidden = self.encoder(
|
766 |
+
batch["melspecs"].unsqueeze(1).transpose(2, 3)
|
767 |
+
)
|
768 |
+
chfeats = self.channelfeats(enc_hidden)
|
769 |
+
enc_out = enc_out.squeeze(1).transpose(1, 2)
|
770 |
+
if self.config["general"]["feature_type"] == "melspec":
|
771 |
+
enc_feats = enc_out
|
772 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
773 |
+
enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
|
774 |
+
remas = self.vocoder(enc_feats).squeeze(1)
|
775 |
+
wavsdeg = self.channel(remas, chfeats)
|
776 |
+
if self.config["general"]["feature_type"] == "melspec":
|
777 |
+
enc_feats_input = batch["melspecs"]
|
778 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
779 |
+
enc_feats_input = torch.cat(
|
780 |
+
(batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
|
781 |
+
)
|
782 |
+
input_recons = self.vocoder(enc_feats_input).squeeze(1)
|
783 |
+
|
784 |
+
wavsdegtask = self.channel(batch["wavstask"], chfeats)
|
785 |
+
if "wavsaux" in batch:
|
786 |
+
gt_wav = batch["wavsaux"]
|
787 |
+
if self.config["general"]["feature_type"] == "melspec":
|
788 |
+
enc_feats_aux = batch["melspecsaux"]
|
789 |
+
elif self.config["general"]["feature_type"] == "vocfeats":
|
790 |
+
enc_feats_aux = torch.cat(
|
791 |
+
(batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
|
792 |
+
)
|
793 |
+
recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
|
794 |
+
else:
|
795 |
+
gt_wav = None
|
796 |
+
recons_wav = None
|
797 |
+
return {
|
798 |
+
"reconstructed": recons_wav,
|
799 |
+
"remastered": remas,
|
800 |
+
"channeled": wavsdeg,
|
801 |
+
"channeled_task": wavsdegtask,
|
802 |
+
"input": batch["wavs"],
|
803 |
+
"input_recons": input_recons,
|
804 |
+
"groundtruth": gt_wav,
|
805 |
+
}
|
806 |
+
|
807 |
+
def test_epoch_end(self, outputs):
|
808 |
+
wav_dir = (
|
809 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
|
810 |
+
)
|
811 |
+
os.makedirs(wav_dir, exist_ok=True)
|
812 |
+
mel_dir = (
|
813 |
+
pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
|
814 |
+
)
|
815 |
+
os.makedirs(mel_dir, exist_ok=True)
|
816 |
+
print("Saving mel spectrogram plots ...")
|
817 |
+
for idx, out in enumerate(tqdm.tqdm(outputs)):
|
818 |
+
plot_keys = []
|
819 |
+
for key in [
|
820 |
+
"reconstructed",
|
821 |
+
"remastered",
|
822 |
+
"channeled",
|
823 |
+
"channeled_task",
|
824 |
+
"input",
|
825 |
+
"input_recons",
|
826 |
+
"groundtruth",
|
827 |
+
]:
|
828 |
+
if out[key] != None:
|
829 |
+
plot_keys.append(key)
|
830 |
+
torchaudio.save(
|
831 |
+
wav_dir / "{}-{}.wav".format(idx, key),
|
832 |
+
out[key][0, ...].unsqueeze(0).cpu(),
|
833 |
+
sample_rate=self.config["preprocess"]["sampling_rate"],
|
834 |
+
channels_first=True,
|
835 |
+
)
|
836 |
+
plot_and_save_mels(
|
837 |
+
out[key][0, ...].cpu(),
|
838 |
+
mel_dir / "{}-{}.png".format(idx, key),
|
839 |
+
self.config,
|
840 |
+
)
|
841 |
+
plot_and_save_mels_all(
|
842 |
+
out,
|
843 |
+
plot_keys,
|
844 |
+
mel_dir / "{}-all.png".format(idx),
|
845 |
+
self.config,
|
846 |
+
)
|
847 |
+
|
848 |
+
def configure_optimizers(self):
|
849 |
+
optimizer = torch.optim.Adam(
|
850 |
+
self.parameters(), lr=self.config["train"]["learning_rate"]
|
851 |
+
)
|
852 |
+
lr_scheduler_config = {
|
853 |
+
"scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
|
854 |
+
optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
|
855 |
+
),
|
856 |
+
"interval": "epoch",
|
857 |
+
"frequency": 3,
|
858 |
+
"monitor": "val_loss",
|
859 |
+
}
|
860 |
+
return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
|
861 |
+
|
862 |
+
def calc_spectrogram(self, wav):
|
863 |
+
specs = self.spec_module(wav)
|
864 |
+
log_spec = torch.log(
|
865 |
+
torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
|
866 |
+
* self.config["preprocess"]["comp_factor"]
|
867 |
+
).to(torch.float32)
|
868 |
+
return log_spec
|
869 |
+
|
870 |
+
def get_loss_function(self, config):
|
871 |
+
if config["train"]["feature_loss"]["type"] == "mae":
|
872 |
+
feature_loss = torch.nn.L1Loss()
|
873 |
+
else:
|
874 |
+
feature_loss = torch.nn.MSELoss()
|
875 |
+
return MultiScaleSpectralLoss(config), feature_loss
|
model.py
ADDED
@@ -0,0 +1,854 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torchaudio
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn.init as init
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
|
9 |
+
class EncoderModule(nn.Module):
|
10 |
+
"""
|
11 |
+
Analysis module based on 2D conv U-Net
|
12 |
+
Inspired by https://github.com/haoheliu/voicefixer
|
13 |
+
|
14 |
+
Args:
|
15 |
+
config (dict): config
|
16 |
+
use_channel (bool): output channel feature or not
|
17 |
+
"""
|
18 |
+
def __init__(self, config, use_channel=False):
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.channels = 1
|
22 |
+
self.use_channel = use_channel
|
23 |
+
self.downsample_ratio = 2 ** 4
|
24 |
+
|
25 |
+
self.down_block1 = DownBlockRes2D(
|
26 |
+
in_channels=self.channels,
|
27 |
+
out_channels=32,
|
28 |
+
downsample=(2, 2),
|
29 |
+
activation="relu",
|
30 |
+
momentum=0.01,
|
31 |
+
)
|
32 |
+
self.down_block2 = DownBlockRes2D(
|
33 |
+
in_channels=32,
|
34 |
+
out_channels=64,
|
35 |
+
downsample=(2, 2),
|
36 |
+
activation="relu",
|
37 |
+
momentum=0.01,
|
38 |
+
)
|
39 |
+
self.down_block3 = DownBlockRes2D(
|
40 |
+
in_channels=64,
|
41 |
+
out_channels=128,
|
42 |
+
downsample=(2, 2),
|
43 |
+
activation="relu",
|
44 |
+
momentum=0.01,
|
45 |
+
)
|
46 |
+
self.down_block4 = DownBlockRes2D(
|
47 |
+
in_channels=128,
|
48 |
+
out_channels=256,
|
49 |
+
downsample=(2, 2),
|
50 |
+
activation="relu",
|
51 |
+
momentum=0.01,
|
52 |
+
)
|
53 |
+
self.conv_block5 = ConvBlockRes2D(
|
54 |
+
in_channels=256,
|
55 |
+
out_channels=256,
|
56 |
+
size=3,
|
57 |
+
activation="relu",
|
58 |
+
momentum=0.01,
|
59 |
+
)
|
60 |
+
self.up_block1 = UpBlockRes2D(
|
61 |
+
in_channels=256,
|
62 |
+
out_channels=256,
|
63 |
+
stride=(2, 2),
|
64 |
+
activation="relu",
|
65 |
+
momentum=0.01,
|
66 |
+
)
|
67 |
+
self.up_block2 = UpBlockRes2D(
|
68 |
+
in_channels=256,
|
69 |
+
out_channels=128,
|
70 |
+
stride=(2, 2),
|
71 |
+
activation="relu",
|
72 |
+
momentum=0.01,
|
73 |
+
)
|
74 |
+
self.up_block3 = UpBlockRes2D(
|
75 |
+
in_channels=128,
|
76 |
+
out_channels=64,
|
77 |
+
stride=(2, 2),
|
78 |
+
activation="relu",
|
79 |
+
momentum=0.01,
|
80 |
+
)
|
81 |
+
self.up_block4 = UpBlockRes2D(
|
82 |
+
in_channels=64,
|
83 |
+
out_channels=32,
|
84 |
+
stride=(2, 2),
|
85 |
+
activation="relu",
|
86 |
+
momentum=0.01,
|
87 |
+
)
|
88 |
+
|
89 |
+
self.after_conv_block1 = ConvBlockRes2D(
|
90 |
+
in_channels=32,
|
91 |
+
out_channels=32,
|
92 |
+
size=3,
|
93 |
+
activation="relu",
|
94 |
+
momentum=0.01,
|
95 |
+
)
|
96 |
+
|
97 |
+
self.after_conv2 = nn.Conv2d(
|
98 |
+
in_channels=32,
|
99 |
+
out_channels=1,
|
100 |
+
kernel_size=(1, 1),
|
101 |
+
stride=(1, 1),
|
102 |
+
padding=(0, 0),
|
103 |
+
bias=True,
|
104 |
+
)
|
105 |
+
|
106 |
+
if config["general"]["feature_type"] == "melspec":
|
107 |
+
out_dim = config["preprocess"]["n_mels"]
|
108 |
+
elif config["general"]["feature_type"] == "vocfeats":
|
109 |
+
out_dim = config["preprocess"]["cep_order"] + 1
|
110 |
+
else:
|
111 |
+
raise NotImplementedError()
|
112 |
+
|
113 |
+
self.after_linear = nn.Linear(
|
114 |
+
in_features=80,
|
115 |
+
out_features=out_dim,
|
116 |
+
bias=True,
|
117 |
+
)
|
118 |
+
|
119 |
+
if self.use_channel:
|
120 |
+
self.conv_channel = ConvBlockRes2D(
|
121 |
+
in_channels=256,
|
122 |
+
out_channels=256,
|
123 |
+
size=3,
|
124 |
+
activation="relu",
|
125 |
+
momentum=0.01,
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(self, x):
|
129 |
+
"""
|
130 |
+
Forward
|
131 |
+
|
132 |
+
Args:
|
133 |
+
mel spectrogram: (batch, 1, time, freq)
|
134 |
+
|
135 |
+
Return:
|
136 |
+
speech feature (mel spectrogram or mel cepstrum): (batch, 1, time, freq)
|
137 |
+
input of channel feature module (batch, 256, time, freq)
|
138 |
+
"""
|
139 |
+
|
140 |
+
origin_len = x.shape[2]
|
141 |
+
pad_len = (
|
142 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
143 |
+
- origin_len
|
144 |
+
)
|
145 |
+
x = F.pad(x, pad=(0, 0, 0, pad_len))
|
146 |
+
x = x[..., 0 : x.shape[-1] - 1]
|
147 |
+
|
148 |
+
(x1_pool, x1) = self.down_block1(x)
|
149 |
+
(x2_pool, x2) = self.down_block2(x1_pool)
|
150 |
+
(x3_pool, x3) = self.down_block3(x2_pool)
|
151 |
+
(x4_pool, x4) = self.down_block4(x3_pool)
|
152 |
+
x_center = self.conv_block5(x4_pool)
|
153 |
+
x5 = self.up_block1(x_center, x4)
|
154 |
+
x6 = self.up_block2(x5, x3)
|
155 |
+
x7 = self.up_block3(x6, x2)
|
156 |
+
x8 = self.up_block4(x7, x1)
|
157 |
+
x = self.after_conv_block1(x8)
|
158 |
+
x = self.after_conv2(x)
|
159 |
+
|
160 |
+
x = F.pad(x, pad=(0, 1))
|
161 |
+
x = x[:, :, 0:origin_len, :]
|
162 |
+
|
163 |
+
x = self.after_linear(x)
|
164 |
+
|
165 |
+
if self.use_channel:
|
166 |
+
x_channel = self.conv_channel(x4_pool)
|
167 |
+
return x, x_channel
|
168 |
+
else:
|
169 |
+
return x
|
170 |
+
|
171 |
+
|
172 |
+
class ChannelModule(nn.Module):
|
173 |
+
"""
|
174 |
+
Channel module based on 1D conv U-Net
|
175 |
+
|
176 |
+
Args:
|
177 |
+
config (dict): config
|
178 |
+
"""
|
179 |
+
def __init__(self, config):
|
180 |
+
super().__init__()
|
181 |
+
|
182 |
+
self.channels = 1
|
183 |
+
self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
|
184 |
+
|
185 |
+
self.down_block1 = DownBlockRes1D(
|
186 |
+
in_channels=self.channels,
|
187 |
+
out_channels=32,
|
188 |
+
downsample=2,
|
189 |
+
activation="relu",
|
190 |
+
momentum=0.01,
|
191 |
+
)
|
192 |
+
self.down_block2 = DownBlockRes1D(
|
193 |
+
in_channels=32,
|
194 |
+
out_channels=64,
|
195 |
+
downsample=2,
|
196 |
+
activation="relu",
|
197 |
+
momentum=0.01,
|
198 |
+
)
|
199 |
+
self.down_block3 = DownBlockRes1D(
|
200 |
+
in_channels=64,
|
201 |
+
out_channels=128,
|
202 |
+
downsample=2,
|
203 |
+
activation="relu",
|
204 |
+
momentum=0.01,
|
205 |
+
)
|
206 |
+
self.down_block4 = DownBlockRes1D(
|
207 |
+
in_channels=128,
|
208 |
+
out_channels=256,
|
209 |
+
downsample=2,
|
210 |
+
activation="relu",
|
211 |
+
momentum=0.01,
|
212 |
+
)
|
213 |
+
self.down_block5 = DownBlockRes1D(
|
214 |
+
in_channels=256,
|
215 |
+
out_channels=512,
|
216 |
+
downsample=2,
|
217 |
+
activation="relu",
|
218 |
+
momentum=0.01,
|
219 |
+
)
|
220 |
+
self.conv_block6 = ConvBlockRes1D(
|
221 |
+
in_channels=512,
|
222 |
+
out_channels=384,
|
223 |
+
size=3,
|
224 |
+
activation="relu",
|
225 |
+
momentum=0.01,
|
226 |
+
)
|
227 |
+
self.up_block1 = UpBlockRes1D(
|
228 |
+
in_channels=512,
|
229 |
+
out_channels=512,
|
230 |
+
stride=2,
|
231 |
+
activation="relu",
|
232 |
+
momentum=0.01,
|
233 |
+
)
|
234 |
+
self.up_block2 = UpBlockRes1D(
|
235 |
+
in_channels=512,
|
236 |
+
out_channels=256,
|
237 |
+
stride=2,
|
238 |
+
activation="relu",
|
239 |
+
momentum=0.01,
|
240 |
+
)
|
241 |
+
self.up_block3 = UpBlockRes1D(
|
242 |
+
in_channels=256,
|
243 |
+
out_channels=128,
|
244 |
+
stride=2,
|
245 |
+
activation="relu",
|
246 |
+
momentum=0.01,
|
247 |
+
)
|
248 |
+
self.up_block4 = UpBlockRes1D(
|
249 |
+
in_channels=128,
|
250 |
+
out_channels=64,
|
251 |
+
stride=2,
|
252 |
+
activation="relu",
|
253 |
+
momentum=0.01,
|
254 |
+
)
|
255 |
+
self.up_block5 = UpBlockRes1D(
|
256 |
+
in_channels=64,
|
257 |
+
out_channels=32,
|
258 |
+
stride=2,
|
259 |
+
activation="relu",
|
260 |
+
momentum=0.01,
|
261 |
+
)
|
262 |
+
|
263 |
+
self.after_conv_block1 = ConvBlockRes1D(
|
264 |
+
in_channels=32,
|
265 |
+
out_channels=32,
|
266 |
+
size=3,
|
267 |
+
activation="relu",
|
268 |
+
momentum=0.01,
|
269 |
+
)
|
270 |
+
|
271 |
+
self.after_conv2 = nn.Conv1d(
|
272 |
+
in_channels=32,
|
273 |
+
out_channels=1,
|
274 |
+
kernel_size=1,
|
275 |
+
stride=1,
|
276 |
+
padding=0,
|
277 |
+
bias=True,
|
278 |
+
)
|
279 |
+
|
280 |
+
def forward(self, x, h):
|
281 |
+
"""
|
282 |
+
Forward
|
283 |
+
|
284 |
+
Args:
|
285 |
+
clean waveform: (batch, n_channel (1), time)
|
286 |
+
channel feature: (batch, feature_dim)
|
287 |
+
Outputs:
|
288 |
+
degraded waveform: (batch, n_channel (1), time)
|
289 |
+
"""
|
290 |
+
x = x.unsqueeze(1)
|
291 |
+
|
292 |
+
origin_len = x.shape[2]
|
293 |
+
pad_len = (
|
294 |
+
int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
|
295 |
+
- origin_len
|
296 |
+
)
|
297 |
+
x = F.pad(x, pad=(0, pad_len))
|
298 |
+
x = x[..., 0 : x.shape[-1] - 1]
|
299 |
+
|
300 |
+
(x1_pool, x1) = self.down_block1(x)
|
301 |
+
(x2_pool, x2) = self.down_block2(x1_pool)
|
302 |
+
(x3_pool, x3) = self.down_block3(x2_pool)
|
303 |
+
(x4_pool, x4) = self.down_block4(x3_pool)
|
304 |
+
(x5_pool, x5) = self.down_block5(x4_pool)
|
305 |
+
x_center = self.conv_block6(x5_pool)
|
306 |
+
x_concat = torch.cat(
|
307 |
+
(x_center, h.unsqueeze(2).expand(-1, -1, x_center.size(2))), dim=1
|
308 |
+
)
|
309 |
+
x6 = self.up_block1(x_concat, x5)
|
310 |
+
x7 = self.up_block2(x6, x4)
|
311 |
+
x8 = self.up_block3(x7, x3)
|
312 |
+
x9 = self.up_block4(x8, x2)
|
313 |
+
x10 = self.up_block5(x9, x1)
|
314 |
+
x = self.after_conv_block1(x10)
|
315 |
+
x = self.after_conv2(x)
|
316 |
+
|
317 |
+
x = F.pad(x, pad=(0, 1))
|
318 |
+
x = x[..., 0:origin_len]
|
319 |
+
|
320 |
+
return x.squeeze(1)
|
321 |
+
|
322 |
+
|
323 |
+
class ChannelFeatureModule(nn.Module):
|
324 |
+
"""
|
325 |
+
Channel feature module based on 2D convolution layers
|
326 |
+
|
327 |
+
Args:
|
328 |
+
config (dict): config
|
329 |
+
"""
|
330 |
+
def __init__(self, config):
|
331 |
+
super().__init__()
|
332 |
+
self.conv_blocks_in = ConvBlockRes2D(
|
333 |
+
in_channels=256,
|
334 |
+
out_channels=512,
|
335 |
+
size=3,
|
336 |
+
activation="relu",
|
337 |
+
momentum=0.01,
|
338 |
+
)
|
339 |
+
self.down_block1 = DownBlockRes2D(
|
340 |
+
in_channels=512,
|
341 |
+
out_channels=256,
|
342 |
+
downsample=(2, 2),
|
343 |
+
activation="relu",
|
344 |
+
momentum=0.01,
|
345 |
+
)
|
346 |
+
self.down_block2 = DownBlockRes2D(
|
347 |
+
in_channels=256,
|
348 |
+
out_channels=256,
|
349 |
+
downsample=(2, 2),
|
350 |
+
activation="relu",
|
351 |
+
momentum=0.01,
|
352 |
+
)
|
353 |
+
self.conv_block_out = ConvBlockRes2D(
|
354 |
+
in_channels=256,
|
355 |
+
out_channels=128,
|
356 |
+
size=3,
|
357 |
+
activation="relu",
|
358 |
+
momentum=0.01,
|
359 |
+
)
|
360 |
+
self.avgpool2d = torch.nn.AdaptiveAvgPool2d(1)
|
361 |
+
|
362 |
+
def forward(self, x):
|
363 |
+
"""
|
364 |
+
Forward
|
365 |
+
|
366 |
+
Args:
|
367 |
+
output of analysis module: (batch, 256, time, freq)
|
368 |
+
|
369 |
+
Return:
|
370 |
+
channel feature: (batch, feature_dim)
|
371 |
+
"""
|
372 |
+
x = self.conv_blocks_in(x)
|
373 |
+
x, _ = self.down_block1(x)
|
374 |
+
x, _ = self.down_block2(x)
|
375 |
+
x = self.conv_block_out(x)
|
376 |
+
x = self.avgpool2d(x)
|
377 |
+
x = x.squeeze(3).squeeze(2)
|
378 |
+
return x
|
379 |
+
|
380 |
+
|
381 |
+
class ConvBlockRes2D(nn.Module):
|
382 |
+
def __init__(self, in_channels, out_channels, size, activation, momentum):
|
383 |
+
super().__init__()
|
384 |
+
|
385 |
+
self.activation = activation
|
386 |
+
if type(size) == type((3, 4)):
|
387 |
+
pad = size[0] // 2
|
388 |
+
size = size[0]
|
389 |
+
else:
|
390 |
+
pad = size // 2
|
391 |
+
size = size
|
392 |
+
|
393 |
+
self.conv1 = nn.Conv2d(
|
394 |
+
in_channels=in_channels,
|
395 |
+
out_channels=out_channels,
|
396 |
+
kernel_size=(size, size),
|
397 |
+
stride=(1, 1),
|
398 |
+
dilation=(1, 1),
|
399 |
+
padding=(pad, pad),
|
400 |
+
bias=False,
|
401 |
+
)
|
402 |
+
|
403 |
+
self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
|
404 |
+
|
405 |
+
self.conv2 = nn.Conv2d(
|
406 |
+
in_channels=out_channels,
|
407 |
+
out_channels=out_channels,
|
408 |
+
kernel_size=(size, size),
|
409 |
+
stride=(1, 1),
|
410 |
+
dilation=(1, 1),
|
411 |
+
padding=(pad, pad),
|
412 |
+
bias=False,
|
413 |
+
)
|
414 |
+
|
415 |
+
self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
|
416 |
+
|
417 |
+
if in_channels != out_channels:
|
418 |
+
self.shortcut = nn.Conv2d(
|
419 |
+
in_channels=in_channels,
|
420 |
+
out_channels=out_channels,
|
421 |
+
kernel_size=(1, 1),
|
422 |
+
stride=(1, 1),
|
423 |
+
padding=(0, 0),
|
424 |
+
)
|
425 |
+
self.is_shortcut = True
|
426 |
+
else:
|
427 |
+
self.is_shortcut = False
|
428 |
+
|
429 |
+
def forward(self, x):
|
430 |
+
origin = x
|
431 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
432 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
433 |
+
|
434 |
+
if self.is_shortcut:
|
435 |
+
return self.shortcut(origin) + x
|
436 |
+
else:
|
437 |
+
return origin + x
|
438 |
+
|
439 |
+
|
440 |
+
class ConvBlockRes1D(nn.Module):
|
441 |
+
def __init__(self, in_channels, out_channels, size, activation, momentum):
|
442 |
+
super().__init__()
|
443 |
+
|
444 |
+
self.activation = activation
|
445 |
+
pad = size // 2
|
446 |
+
|
447 |
+
self.conv1 = nn.Conv1d(
|
448 |
+
in_channels=in_channels,
|
449 |
+
out_channels=out_channels,
|
450 |
+
kernel_size=size,
|
451 |
+
stride=1,
|
452 |
+
dilation=1,
|
453 |
+
padding=pad,
|
454 |
+
bias=False,
|
455 |
+
)
|
456 |
+
|
457 |
+
self.bn1 = nn.BatchNorm1d(in_channels, momentum=momentum)
|
458 |
+
|
459 |
+
self.conv2 = nn.Conv1d(
|
460 |
+
in_channels=out_channels,
|
461 |
+
out_channels=out_channels,
|
462 |
+
kernel_size=size,
|
463 |
+
stride=1,
|
464 |
+
dilation=1,
|
465 |
+
padding=pad,
|
466 |
+
bias=False,
|
467 |
+
)
|
468 |
+
|
469 |
+
self.bn2 = nn.BatchNorm1d(out_channels, momentum=momentum)
|
470 |
+
|
471 |
+
if in_channels != out_channels:
|
472 |
+
self.shortcut = nn.Conv1d(
|
473 |
+
in_channels=in_channels,
|
474 |
+
out_channels=out_channels,
|
475 |
+
kernel_size=1,
|
476 |
+
stride=1,
|
477 |
+
padding=0,
|
478 |
+
)
|
479 |
+
self.is_shortcut = True
|
480 |
+
else:
|
481 |
+
self.is_shortcut = False
|
482 |
+
|
483 |
+
def forward(self, x):
|
484 |
+
origin = x
|
485 |
+
x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
|
486 |
+
x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
|
487 |
+
|
488 |
+
if self.is_shortcut:
|
489 |
+
return self.shortcut(origin) + x
|
490 |
+
else:
|
491 |
+
return origin + x
|
492 |
+
|
493 |
+
|
494 |
+
class DownBlockRes2D(nn.Module):
|
495 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
496 |
+
super().__init__()
|
497 |
+
size = 3
|
498 |
+
|
499 |
+
self.conv_block1 = ConvBlockRes2D(
|
500 |
+
in_channels, out_channels, size, activation, momentum
|
501 |
+
)
|
502 |
+
self.conv_block2 = ConvBlockRes2D(
|
503 |
+
out_channels, out_channels, size, activation, momentum
|
504 |
+
)
|
505 |
+
self.conv_block3 = ConvBlockRes2D(
|
506 |
+
out_channels, out_channels, size, activation, momentum
|
507 |
+
)
|
508 |
+
self.conv_block4 = ConvBlockRes2D(
|
509 |
+
out_channels, out_channels, size, activation, momentum
|
510 |
+
)
|
511 |
+
self.avg_pool2d = torch.nn.AvgPool2d(downsample)
|
512 |
+
|
513 |
+
def forward(self, x):
|
514 |
+
encoder = self.conv_block1(x)
|
515 |
+
encoder = self.conv_block2(encoder)
|
516 |
+
encoder = self.conv_block3(encoder)
|
517 |
+
encoder = self.conv_block4(encoder)
|
518 |
+
encoder_pool = self.avg_pool2d(encoder)
|
519 |
+
return encoder_pool, encoder
|
520 |
+
|
521 |
+
|
522 |
+
class DownBlockRes1D(nn.Module):
|
523 |
+
def __init__(self, in_channels, out_channels, downsample, activation, momentum):
|
524 |
+
super().__init__()
|
525 |
+
size = 3
|
526 |
+
|
527 |
+
self.conv_block1 = ConvBlockRes1D(
|
528 |
+
in_channels, out_channels, size, activation, momentum
|
529 |
+
)
|
530 |
+
self.conv_block2 = ConvBlockRes1D(
|
531 |
+
out_channels, out_channels, size, activation, momentum
|
532 |
+
)
|
533 |
+
self.conv_block3 = ConvBlockRes1D(
|
534 |
+
out_channels, out_channels, size, activation, momentum
|
535 |
+
)
|
536 |
+
self.conv_block4 = ConvBlockRes1D(
|
537 |
+
out_channels, out_channels, size, activation, momentum
|
538 |
+
)
|
539 |
+
self.avg_pool1d = torch.nn.AvgPool1d(downsample)
|
540 |
+
|
541 |
+
def forward(self, x):
|
542 |
+
encoder = self.conv_block1(x)
|
543 |
+
encoder = self.conv_block2(encoder)
|
544 |
+
encoder = self.conv_block3(encoder)
|
545 |
+
encoder = self.conv_block4(encoder)
|
546 |
+
encoder_pool = self.avg_pool1d(encoder)
|
547 |
+
return encoder_pool, encoder
|
548 |
+
|
549 |
+
|
550 |
+
class UpBlockRes2D(nn.Module):
|
551 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
552 |
+
super().__init__()
|
553 |
+
size = 3
|
554 |
+
self.activation = activation
|
555 |
+
|
556 |
+
self.conv1 = torch.nn.ConvTranspose2d(
|
557 |
+
in_channels=in_channels,
|
558 |
+
out_channels=out_channels,
|
559 |
+
kernel_size=(size, size),
|
560 |
+
stride=stride,
|
561 |
+
padding=(0, 0),
|
562 |
+
output_padding=(0, 0),
|
563 |
+
bias=False,
|
564 |
+
dilation=(1, 1),
|
565 |
+
)
|
566 |
+
|
567 |
+
self.bn1 = nn.BatchNorm2d(in_channels)
|
568 |
+
self.conv_block2 = ConvBlockRes2D(
|
569 |
+
out_channels * 2, out_channels, size, activation, momentum
|
570 |
+
)
|
571 |
+
self.conv_block3 = ConvBlockRes2D(
|
572 |
+
out_channels, out_channels, size, activation, momentum
|
573 |
+
)
|
574 |
+
self.conv_block4 = ConvBlockRes2D(
|
575 |
+
out_channels, out_channels, size, activation, momentum
|
576 |
+
)
|
577 |
+
self.conv_block5 = ConvBlockRes2D(
|
578 |
+
out_channels, out_channels, size, activation, momentum
|
579 |
+
)
|
580 |
+
|
581 |
+
def prune(self, x, both=False):
|
582 |
+
"""Prune the shape of x after transpose convolution."""
|
583 |
+
if both:
|
584 |
+
x = x[:, :, 0:-1, 0:-1]
|
585 |
+
else:
|
586 |
+
x = x[:, :, 0:-1, :]
|
587 |
+
return x
|
588 |
+
|
589 |
+
def forward(self, input_tensor, concat_tensor, both=False):
|
590 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
591 |
+
x = self.prune(x, both=both)
|
592 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
593 |
+
x = self.conv_block2(x)
|
594 |
+
x = self.conv_block3(x)
|
595 |
+
x = self.conv_block4(x)
|
596 |
+
x = self.conv_block5(x)
|
597 |
+
return x
|
598 |
+
|
599 |
+
|
600 |
+
class UpBlockRes1D(nn.Module):
|
601 |
+
def __init__(self, in_channels, out_channels, stride, activation, momentum):
|
602 |
+
super().__init__()
|
603 |
+
size = 3
|
604 |
+
self.activation = activation
|
605 |
+
|
606 |
+
self.conv1 = torch.nn.ConvTranspose1d(
|
607 |
+
in_channels=in_channels,
|
608 |
+
out_channels=out_channels,
|
609 |
+
kernel_size=size,
|
610 |
+
stride=stride,
|
611 |
+
padding=0,
|
612 |
+
output_padding=0,
|
613 |
+
bias=False,
|
614 |
+
dilation=1,
|
615 |
+
)
|
616 |
+
|
617 |
+
self.bn1 = nn.BatchNorm1d(in_channels)
|
618 |
+
self.conv_block2 = ConvBlockRes1D(
|
619 |
+
out_channels * 2, out_channels, size, activation, momentum
|
620 |
+
)
|
621 |
+
self.conv_block3 = ConvBlockRes1D(
|
622 |
+
out_channels, out_channels, size, activation, momentum
|
623 |
+
)
|
624 |
+
self.conv_block4 = ConvBlockRes1D(
|
625 |
+
out_channels, out_channels, size, activation, momentum
|
626 |
+
)
|
627 |
+
self.conv_block5 = ConvBlockRes1D(
|
628 |
+
out_channels, out_channels, size, activation, momentum
|
629 |
+
)
|
630 |
+
|
631 |
+
def prune(self, x):
|
632 |
+
"""Prune the shape of x after transpose convolution."""
|
633 |
+
print(x.shape)
|
634 |
+
x = x[:, 0:-1, :]
|
635 |
+
print(x.shape)
|
636 |
+
return x
|
637 |
+
|
638 |
+
def forward(self, input_tensor, concat_tensor):
|
639 |
+
x = self.conv1(F.relu_(self.bn1(input_tensor)))
|
640 |
+
# x = self.prune(x)
|
641 |
+
x = torch.cat((x, concat_tensor), dim=1)
|
642 |
+
x = self.conv_block2(x)
|
643 |
+
x = self.conv_block3(x)
|
644 |
+
x = self.conv_block4(x)
|
645 |
+
x = self.conv_block5(x)
|
646 |
+
return x
|
647 |
+
|
648 |
+
|
649 |
+
class MultiScaleSpectralLoss(nn.Module):
|
650 |
+
"""
|
651 |
+
Multi scale spectral loss
|
652 |
+
https://openreview.net/forum?id=B1x1ma4tDr
|
653 |
+
|
654 |
+
Args:
|
655 |
+
config (dict): config
|
656 |
+
"""
|
657 |
+
def __init__(self, config):
|
658 |
+
super().__init__()
|
659 |
+
try:
|
660 |
+
self.use_linear = config["train"]["multi_scale_loss"]["use_linear"]
|
661 |
+
self.gamma = config["train"]["multi_scale_loss"]["gamma"]
|
662 |
+
except KeyError:
|
663 |
+
self.use_linear = False
|
664 |
+
|
665 |
+
self.fft_sizes = [2048, 512, 256, 128, 64]
|
666 |
+
self.spectrograms = []
|
667 |
+
for fftsize in self.fft_sizes:
|
668 |
+
self.spectrograms.append(
|
669 |
+
torchaudio.transforms.Spectrogram(
|
670 |
+
n_fft=fftsize, hop_length=fftsize // 4, power=2
|
671 |
+
)
|
672 |
+
)
|
673 |
+
self.spectrograms = nn.ModuleList(self.spectrograms)
|
674 |
+
self.criteria = nn.L1Loss()
|
675 |
+
self.eps = 1e-10
|
676 |
+
|
677 |
+
def forward(self, wav_out, wav_target):
|
678 |
+
"""
|
679 |
+
Forward
|
680 |
+
|
681 |
+
Args:
|
682 |
+
wav_out: output of channel module (batch, time)
|
683 |
+
wav_target: input degraded waveform (batch, time)
|
684 |
+
|
685 |
+
Return:
|
686 |
+
loss
|
687 |
+
"""
|
688 |
+
loss = 0.0
|
689 |
+
length = min(wav_out.size(1), wav_target.size(1))
|
690 |
+
for spectrogram in self.spectrograms:
|
691 |
+
S_out = spectrogram(wav_out[..., :length])
|
692 |
+
S_target = spectrogram(wav_target[..., :length])
|
693 |
+
log_S_out = torch.log(S_out + self.eps)
|
694 |
+
log_S_target = torch.log(S_target + self.eps)
|
695 |
+
if self.use_linear:
|
696 |
+
loss += self.criteria(S_out, S_target) + self.gamma * self.criteria(
|
697 |
+
log_S_out, log_S_target
|
698 |
+
)
|
699 |
+
else:
|
700 |
+
loss += self.criteria(log_S_out, log_S_target)
|
701 |
+
return loss
|
702 |
+
|
703 |
+
|
704 |
+
class ReferenceEncoder(nn.Module):
|
705 |
+
def __init__(
|
706 |
+
self, idim=80, ref_enc_filters=[32, 32, 64, 64, 128, 128], ref_dim=128
|
707 |
+
):
|
708 |
+
super().__init__()
|
709 |
+
K = len(ref_enc_filters)
|
710 |
+
filters = [1] + ref_enc_filters
|
711 |
+
|
712 |
+
convs = [
|
713 |
+
nn.Conv2d(
|
714 |
+
in_channels=filters[i],
|
715 |
+
out_channels=filters[i + 1],
|
716 |
+
kernel_size=(3, 3),
|
717 |
+
stride=(2, 2),
|
718 |
+
padding=(1, 1),
|
719 |
+
)
|
720 |
+
for i in range(K)
|
721 |
+
]
|
722 |
+
self.convs = nn.ModuleList(convs)
|
723 |
+
self.bns = nn.ModuleList(
|
724 |
+
[nn.BatchNorm2d(num_features=ref_enc_filters[i]) for i in range(K)]
|
725 |
+
)
|
726 |
+
|
727 |
+
out_channels = self.calculate_channels(idim, 3, 2, 1, K)
|
728 |
+
|
729 |
+
self.gru = nn.GRU(
|
730 |
+
input_size=ref_enc_filters[-1] * out_channels,
|
731 |
+
hidden_size=ref_dim,
|
732 |
+
batch_first=True,
|
733 |
+
)
|
734 |
+
self.n_mel_channels = idim
|
735 |
+
|
736 |
+
def forward(self, inputs):
|
737 |
+
|
738 |
+
out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels)
|
739 |
+
for conv, bn in zip(self.convs, self.bns):
|
740 |
+
out = conv(out)
|
741 |
+
out = bn(out)
|
742 |
+
out = F.relu(out)
|
743 |
+
|
744 |
+
out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
|
745 |
+
N, T = out.size(0), out.size(1)
|
746 |
+
out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
|
747 |
+
|
748 |
+
self.gru.flatten_parameters()
|
749 |
+
|
750 |
+
_, out = self.gru(out)
|
751 |
+
|
752 |
+
return out.squeeze(0)
|
753 |
+
|
754 |
+
def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
|
755 |
+
for _ in range(n_convs):
|
756 |
+
L = (L - kernel_size + 2 * pad) // stride + 1
|
757 |
+
return L
|
758 |
+
|
759 |
+
|
760 |
+
class STL(nn.Module):
|
761 |
+
def __init__(self, ref_dim=128, num_heads=4, token_num=10, token_dim=128):
|
762 |
+
super().__init__()
|
763 |
+
self.embed = nn.Parameter(torch.FloatTensor(token_num, token_dim // num_heads))
|
764 |
+
d_q = ref_dim
|
765 |
+
d_k = token_dim // num_heads
|
766 |
+
self.attention = MultiHeadAttention(
|
767 |
+
query_dim=d_q, key_dim=d_k, num_units=token_dim, num_heads=num_heads
|
768 |
+
)
|
769 |
+
init.normal_(self.embed, mean=0, std=0.5)
|
770 |
+
|
771 |
+
def forward(self, inputs):
|
772 |
+
N = inputs.size(0)
|
773 |
+
query = inputs.unsqueeze(1)
|
774 |
+
keys = (
|
775 |
+
torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
|
776 |
+
) # [N, token_num, token_embedding_size // num_heads]
|
777 |
+
style_embed = self.attention(query, keys)
|
778 |
+
return style_embed
|
779 |
+
|
780 |
+
|
781 |
+
class MultiHeadAttention(nn.Module):
|
782 |
+
"""
|
783 |
+
Multi head attention
|
784 |
+
https://github.com/KinglittleQ/GST-Tacotron
|
785 |
+
|
786 |
+
"""
|
787 |
+
def __init__(self, query_dim, key_dim, num_units, num_heads):
|
788 |
+
super().__init__()
|
789 |
+
self.num_units = num_units
|
790 |
+
self.num_heads = num_heads
|
791 |
+
self.key_dim = key_dim
|
792 |
+
|
793 |
+
self.W_query = nn.Linear(
|
794 |
+
in_features=query_dim, out_features=num_units, bias=False
|
795 |
+
)
|
796 |
+
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
|
797 |
+
self.W_value = nn.Linear(
|
798 |
+
in_features=key_dim, out_features=num_units, bias=False
|
799 |
+
)
|
800 |
+
|
801 |
+
def forward(self, query, key):
|
802 |
+
"""
|
803 |
+
Forward
|
804 |
+
|
805 |
+
Args:
|
806 |
+
query: (batch, T_q, query_dim)
|
807 |
+
key: (batch, T_k, key_dim)
|
808 |
+
|
809 |
+
Return:
|
810 |
+
out: (N, T_q, num_units)
|
811 |
+
"""
|
812 |
+
querys = self.W_query(query) # [N, T_q, num_units]
|
813 |
+
|
814 |
+
keys = self.W_key(key) # [N, T_k, num_units]
|
815 |
+
values = self.W_value(key)
|
816 |
+
|
817 |
+
split_size = self.num_units // self.num_heads
|
818 |
+
querys = torch.stack(
|
819 |
+
torch.split(querys, split_size, dim=2), dim=0
|
820 |
+
) # [h, N, T_q, num_units/h]
|
821 |
+
keys = torch.stack(
|
822 |
+
torch.split(keys, split_size, dim=2), dim=0
|
823 |
+
) # [h, N, T_k, num_units/h]
|
824 |
+
values = torch.stack(
|
825 |
+
torch.split(values, split_size, dim=2), dim=0
|
826 |
+
) # [h, N, T_k, num_units/h]
|
827 |
+
|
828 |
+
# score = softmax(QK^T / (d_k ** 0.5))
|
829 |
+
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
|
830 |
+
scores = scores / (self.key_dim ** 0.5)
|
831 |
+
scores = F.softmax(scores, dim=3)
|
832 |
+
|
833 |
+
# out = score * V
|
834 |
+
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
|
835 |
+
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
|
836 |
+
0
|
837 |
+
) # [N, T_q, num_units]
|
838 |
+
|
839 |
+
return out
|
840 |
+
|
841 |
+
|
842 |
+
class GSTModule(nn.Module):
|
843 |
+
def __init__(self, config):
|
844 |
+
super().__init__()
|
845 |
+
self.encoder_post = ReferenceEncoder(
|
846 |
+
idim=config["preprocess"]["n_mels"],
|
847 |
+
ref_dim=256,
|
848 |
+
)
|
849 |
+
self.stl = STL(ref_dim=256, num_heads=8, token_num=10, token_dim=128)
|
850 |
+
|
851 |
+
def forward(self, inputs):
|
852 |
+
acoustic_embed = self.encoder_post(inputs)
|
853 |
+
style_embed = self.stl(acoustic_embed)
|
854 |
+
return style_embed.squeeze(1)
|
preprocess.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import librosa
|
4 |
+
import tqdm
|
5 |
+
import pickle
|
6 |
+
import random
|
7 |
+
import argparse
|
8 |
+
import yaml
|
9 |
+
import pathlib
|
10 |
+
|
11 |
+
|
12 |
+
def get_arg():
|
13 |
+
parser = argparse.ArgumentParser()
|
14 |
+
parser.add_argument("--config_path", required=True, type=pathlib.Path)
|
15 |
+
parser.add_argument("--corpus_type", default=None, type=str)
|
16 |
+
parser.add_argument("--source_path", default=None, type=pathlib.Path)
|
17 |
+
parser.add_argument("--source_path_task", default=None, type=pathlib.Path)
|
18 |
+
parser.add_argument("--aux_path", default=None, type=pathlib.Path)
|
19 |
+
parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path)
|
20 |
+
parser.add_argument("--n_train", default=None, type=int)
|
21 |
+
parser.add_argument("--n_val", default=None, type=int)
|
22 |
+
parser.add_argument("--n_test", default=None, type=int)
|
23 |
+
return parser.parse_args()
|
24 |
+
|
25 |
+
|
26 |
+
def preprocess(config):
|
27 |
+
|
28 |
+
# configs
|
29 |
+
preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
|
30 |
+
n_train = config["preprocess"]["n_train"]
|
31 |
+
n_val = config["preprocess"]["n_val"]
|
32 |
+
n_test = config["preprocess"]["n_test"]
|
33 |
+
SR = config["preprocess"]["sampling_rate"]
|
34 |
+
|
35 |
+
os.makedirs(preprocessed_dir, exist_ok=True)
|
36 |
+
|
37 |
+
sourcepath = pathlib.Path(config["general"]["source_path"])
|
38 |
+
|
39 |
+
if config["general"]["corpus_type"] == "single":
|
40 |
+
fulllist = list(sourcepath.glob("*.wav"))
|
41 |
+
random.seed(0)
|
42 |
+
random.shuffle(fulllist)
|
43 |
+
train_filelist = fulllist[:n_train]
|
44 |
+
val_filelist = fulllist[n_train : n_train + n_val]
|
45 |
+
test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test]
|
46 |
+
filelist = train_filelist + val_filelist + test_filelist
|
47 |
+
elif config["general"]["corpus_type"] == "multi-seen":
|
48 |
+
fulllist = list(sourcepath.glob("*/*.wav"))
|
49 |
+
random.seed(0)
|
50 |
+
random.shuffle(fulllist)
|
51 |
+
train_filelist = fulllist[:n_train]
|
52 |
+
val_filelist = fulllist[n_train : n_train + n_val]
|
53 |
+
test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test]
|
54 |
+
filelist = train_filelist + val_filelist + test_filelist
|
55 |
+
elif config["general"]["corpus_type"] == "multi-unseen":
|
56 |
+
spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")]))
|
57 |
+
train_filelist = []
|
58 |
+
val_filelist = []
|
59 |
+
test_filelist = []
|
60 |
+
random.seed(0)
|
61 |
+
random.shuffle(spk_list)
|
62 |
+
for i, spk in enumerate(spk_list):
|
63 |
+
sourcespkpath = sourcepath / spk
|
64 |
+
if i < n_train:
|
65 |
+
train_filelist.extend(list(sourcespkpath.glob("*.wav")))
|
66 |
+
elif i < n_train + n_val:
|
67 |
+
val_filelist.extend(list(sourcespkpath.glob("*.wav")))
|
68 |
+
elif i < n_train + n_val + n_test:
|
69 |
+
test_filelist.extend(list(sourcespkpath.glob("*.wav")))
|
70 |
+
filelist = train_filelist + val_filelist + test_filelist
|
71 |
+
else:
|
72 |
+
raise NotImplementedError(
|
73 |
+
"corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}"
|
74 |
+
)
|
75 |
+
|
76 |
+
with open(preprocessed_dir / "train.txt", "w", encoding="utf-8") as f:
|
77 |
+
for m in train_filelist:
|
78 |
+
f.write(str(m) + "\n")
|
79 |
+
with open(preprocessed_dir / "val.txt", "w", encoding="utf-8") as f:
|
80 |
+
for m in val_filelist:
|
81 |
+
f.write(str(m) + "\n")
|
82 |
+
with open(preprocessed_dir / "test.txt", "w", encoding="utf-8") as f:
|
83 |
+
for m in test_filelist:
|
84 |
+
f.write(str(m) + "\n")
|
85 |
+
|
86 |
+
for wp in tqdm.tqdm(filelist):
|
87 |
+
|
88 |
+
if config["general"]["corpus_type"] == "single":
|
89 |
+
basename = str(wp.stem)
|
90 |
+
else:
|
91 |
+
basename = str(wp.parent.name) + "-" + str(wp.stem)
|
92 |
+
|
93 |
+
wav, _ = librosa.load(wp, sr=SR)
|
94 |
+
wavsegs = []
|
95 |
+
|
96 |
+
if config["general"]["aux_path"] != None:
|
97 |
+
auxpath = pathlib.Path(config["general"]["aux_path"])
|
98 |
+
if config["general"]["corpus_type"] == "single":
|
99 |
+
wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR)
|
100 |
+
else:
|
101 |
+
wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR)
|
102 |
+
wavauxsegs = []
|
103 |
+
|
104 |
+
if config["general"]["aux_path"] == None:
|
105 |
+
wavsegs.append(wav)
|
106 |
+
else:
|
107 |
+
min_seq_len = min(len(wav), len(wav_aux))
|
108 |
+
wav = wav[:min_seq_len]
|
109 |
+
wav_aux = wav_aux[:min_seq_len]
|
110 |
+
wavsegs.append(wav)
|
111 |
+
wavauxsegs.append(wav_aux)
|
112 |
+
|
113 |
+
wavsegs = np.asarray(wavsegs).astype(np.float32)
|
114 |
+
if config["general"]["aux_path"] != None:
|
115 |
+
wavauxsegs = np.asarray(wavauxsegs).astype(np.float32)
|
116 |
+
else:
|
117 |
+
wavauxsegs = None
|
118 |
+
|
119 |
+
d_preprocessed = {"wavs": wavsegs, "wavsaux": wavauxsegs}
|
120 |
+
|
121 |
+
with open(preprocessed_dir / "{}.pickle".format(basename), "wb") as fw:
|
122 |
+
pickle.dump(d_preprocessed, fw)
|
123 |
+
|
124 |
+
|
125 |
+
if __name__ == "__main__":
|
126 |
+
args = get_arg()
|
127 |
+
|
128 |
+
config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
|
129 |
+
for key in ["corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
130 |
+
if getattr(args, key) != None:
|
131 |
+
config["general"][key] = str(getattr(args, key))
|
132 |
+
for key in ["n_train", "n_val", "n_test"]:
|
133 |
+
if getattr(args, key) != None:
|
134 |
+
config["preprocess"][key] = getattr(args, key)
|
135 |
+
|
136 |
+
print("Performing preprocessing ...")
|
137 |
+
preprocess(config)
|
138 |
+
|
139 |
+
if "dual" in config:
|
140 |
+
if config["dual"]["enable"]:
|
141 |
+
task_config = yaml.load(
|
142 |
+
open(config["dual"]["config_path"], "r"), Loader=yaml.FullLoader
|
143 |
+
)
|
144 |
+
task_preprocessed_dir = (
|
145 |
+
pathlib.Path(config["general"]["preprocessed_path"]).parent
|
146 |
+
/ pathlib.Path(task_config["general"]["preprocessed_path"]).name
|
147 |
+
)
|
148 |
+
task_config["general"]["preprocessed_path"] = task_preprocessed_dir
|
149 |
+
if args.source_path_task != None:
|
150 |
+
task_config["general"]["source_path"] = args.source_path_task
|
151 |
+
print("Performing preprocessing for multi-task learning ...")
|
152 |
+
preprocess(task_config)
|
train.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import pathlib
|
4 |
+
import yaml
|
5 |
+
from dataset import DataModule
|
6 |
+
from pytorch_lightning import Trainer
|
7 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
8 |
+
from pytorch_lightning.loggers.csv_logs import CSVLogger
|
9 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
10 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
11 |
+
from lightning_module import (
|
12 |
+
PretrainLightningModule,
|
13 |
+
SSLStepLightningModule,
|
14 |
+
SSLDualLightningModule,
|
15 |
+
)
|
16 |
+
from utils import configure_args
|
17 |
+
|
18 |
+
|
19 |
+
def get_arg():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--config_path", required=True, type=pathlib.Path)
|
22 |
+
parser.add_argument(
|
23 |
+
"--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"]
|
24 |
+
)
|
25 |
+
parser.add_argument("--run_name", required=True, type=str)
|
26 |
+
parser.add_argument("--corpus_type", default=None, type=str)
|
27 |
+
parser.add_argument("--source_path", default=None, type=pathlib.Path)
|
28 |
+
parser.add_argument("--aux_path", default=None, type=pathlib.Path)
|
29 |
+
parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path)
|
30 |
+
parser.add_argument("--n_train", default=None, type=int)
|
31 |
+
parser.add_argument("--n_val", default=None, type=int)
|
32 |
+
parser.add_argument("--n_test", default=None, type=int)
|
33 |
+
parser.add_argument("--epoch", default=None, type=int)
|
34 |
+
parser.add_argument("--load_pretrained", action="store_true")
|
35 |
+
parser.add_argument("--pretrained_path", default=None, type=pathlib.Path)
|
36 |
+
parser.add_argument("--early_stopping", action="store_true")
|
37 |
+
parser.add_argument("--alpha", default=None, type=float)
|
38 |
+
parser.add_argument("--beta", default=None, type=float)
|
39 |
+
parser.add_argument("--learning_rate", default=None, type=float)
|
40 |
+
parser.add_argument(
|
41 |
+
"--feature_loss_type", default=None, type=str, choices=["mae", "mse"]
|
42 |
+
)
|
43 |
+
parser.add_argument("--debug", action="store_true")
|
44 |
+
return parser.parse_args()
|
45 |
+
|
46 |
+
|
47 |
+
def train(args, config, output_path):
|
48 |
+
debug = args.debug
|
49 |
+
|
50 |
+
csvlogger = CSVLogger(save_dir=str(output_path), name="train_log")
|
51 |
+
tblogger = TensorBoardLogger(save_dir=str(output_path), name="tf_log")
|
52 |
+
|
53 |
+
checkpoint_callback = ModelCheckpoint(
|
54 |
+
dirpath=str(output_path),
|
55 |
+
save_weights_only=True,
|
56 |
+
save_top_k=-1,
|
57 |
+
every_n_epochs=1,
|
58 |
+
monitor="val_loss",
|
59 |
+
)
|
60 |
+
callbacks = [checkpoint_callback]
|
61 |
+
if config["train"]["early_stopping"]:
|
62 |
+
earlystop_callback = EarlyStopping(
|
63 |
+
monitor="val_loss", min_delta=0.0, patience=15, mode="min"
|
64 |
+
)
|
65 |
+
callbacks.append(earlystop_callback)
|
66 |
+
|
67 |
+
trainer = Trainer(
|
68 |
+
max_epochs=1 if debug else config["train"]["epoch"],
|
69 |
+
gpus=-1,
|
70 |
+
deterministic=False,
|
71 |
+
auto_select_gpus=True,
|
72 |
+
benchmark=True,
|
73 |
+
default_root_dir=os.getcwd(),
|
74 |
+
limit_train_batches=0.01 if debug else 1.0,
|
75 |
+
limit_val_batches=0.5 if debug else 1.0,
|
76 |
+
callbacks=callbacks,
|
77 |
+
logger=[csvlogger, tblogger],
|
78 |
+
gradient_clip_val=config["train"]["grad_clip_thresh"],
|
79 |
+
flush_logs_every_n_steps=config["train"]["logger_step"],
|
80 |
+
val_check_interval=0.5,
|
81 |
+
)
|
82 |
+
|
83 |
+
if config["general"]["stage"] == "pretrain":
|
84 |
+
model = PretrainLightningModule(config)
|
85 |
+
elif config["general"]["stage"] == "ssl-step":
|
86 |
+
model = SSLStepLightningModule(config)
|
87 |
+
elif config["general"]["stage"] == "ssl-dual":
|
88 |
+
model = SSLDualLightningModule(config)
|
89 |
+
else:
|
90 |
+
raise NotImplementedError()
|
91 |
+
|
92 |
+
datamodule = DataModule(config)
|
93 |
+
trainer.fit(model, datamodule=datamodule)
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
|
98 |
+
args = get_arg()
|
99 |
+
config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
|
100 |
+
|
101 |
+
output_path = pathlib.Path(config["general"]["output_path"]) / args.run_name
|
102 |
+
os.makedirs(output_path, exist_ok=True)
|
103 |
+
|
104 |
+
config, args = configure_args(config, args)
|
105 |
+
|
106 |
+
train(args=args, config=config, output_path=output_path)
|
utils.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import librosa.display
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import json
|
4 |
+
import torch
|
5 |
+
import torchaudio
|
6 |
+
import hifigan
|
7 |
+
|
8 |
+
|
9 |
+
def manual_logging(logger, item, idx, tag, global_step, data_type, config):
|
10 |
+
|
11 |
+
if data_type == "audio":
|
12 |
+
audio = item[idx, ...].detach().cpu().numpy()
|
13 |
+
logger.add_audio(
|
14 |
+
tag,
|
15 |
+
audio,
|
16 |
+
global_step,
|
17 |
+
sample_rate=config["preprocess"]["sampling_rate"],
|
18 |
+
)
|
19 |
+
elif data_type == "image":
|
20 |
+
image = item[idx, ...].detach().cpu().numpy()
|
21 |
+
fig, ax = plt.subplots()
|
22 |
+
_ = librosa.display.specshow(
|
23 |
+
image,
|
24 |
+
x_axis="time",
|
25 |
+
y_axis="linear",
|
26 |
+
sr=config["preprocess"]["sampling_rate"],
|
27 |
+
hop_length=config["preprocess"]["frame_shift"],
|
28 |
+
fmax=config["preprocess"]["sampling_rate"] // 2,
|
29 |
+
ax=ax,
|
30 |
+
)
|
31 |
+
logger.add_figure(tag, fig, global_step)
|
32 |
+
else:
|
33 |
+
raise NotImplementedError(
|
34 |
+
"Data type given to logger should be [audio] or [image]"
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
def load_vocoder(config):
|
39 |
+
with open(
|
40 |
+
"hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
|
41 |
+
) as f:
|
42 |
+
config_hifigan = hifigan.AttrDict(json.load(f))
|
43 |
+
vocoder = hifigan.Generator(config_hifigan)
|
44 |
+
vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
|
45 |
+
vocoder.remove_weight_norm()
|
46 |
+
for param in vocoder.parameters():
|
47 |
+
param.requires_grad = False
|
48 |
+
return vocoder
|
49 |
+
|
50 |
+
|
51 |
+
def get_conv_padding(kernel_size, dilation=1):
|
52 |
+
return int((kernel_size * dilation - dilation) / 2)
|
53 |
+
|
54 |
+
|
55 |
+
def plot_and_save_mels(wav, save_path, config):
|
56 |
+
spec_module = torchaudio.transforms.MelSpectrogram(
|
57 |
+
sample_rate=config["preprocess"]["sampling_rate"],
|
58 |
+
n_fft=config["preprocess"]["fft_length"],
|
59 |
+
win_length=config["preprocess"]["frame_length"],
|
60 |
+
hop_length=config["preprocess"]["frame_shift"],
|
61 |
+
f_min=config["preprocess"]["fmin"],
|
62 |
+
f_max=config["preprocess"]["fmax"],
|
63 |
+
n_mels=config["preprocess"]["n_mels"],
|
64 |
+
power=1,
|
65 |
+
center=True,
|
66 |
+
norm="slaney",
|
67 |
+
mel_scale="slaney",
|
68 |
+
)
|
69 |
+
spec = spec_module(wav.unsqueeze(0))
|
70 |
+
log_spec = torch.log(
|
71 |
+
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
|
72 |
+
* config["preprocess"]["comp_factor"]
|
73 |
+
)
|
74 |
+
fig, ax = plt.subplots()
|
75 |
+
_ = librosa.display.specshow(
|
76 |
+
log_spec.squeeze(0).numpy(),
|
77 |
+
x_axis="time",
|
78 |
+
y_axis="linear",
|
79 |
+
sr=config["preprocess"]["sampling_rate"],
|
80 |
+
hop_length=config["preprocess"]["frame_shift"],
|
81 |
+
fmax=config["preprocess"]["sampling_rate"] // 2,
|
82 |
+
ax=ax,
|
83 |
+
cmap="viridis",
|
84 |
+
)
|
85 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
86 |
+
|
87 |
+
|
88 |
+
def plot_and_save_mels_all(wavs, keys, save_path, config):
|
89 |
+
spec_module = torchaudio.transforms.MelSpectrogram(
|
90 |
+
sample_rate=config["preprocess"]["sampling_rate"],
|
91 |
+
n_fft=config["preprocess"]["fft_length"],
|
92 |
+
win_length=config["preprocess"]["frame_length"],
|
93 |
+
hop_length=config["preprocess"]["frame_shift"],
|
94 |
+
f_min=config["preprocess"]["fmin"],
|
95 |
+
f_max=config["preprocess"]["fmax"],
|
96 |
+
n_mels=config["preprocess"]["n_mels"],
|
97 |
+
power=1,
|
98 |
+
center=True,
|
99 |
+
norm="slaney",
|
100 |
+
mel_scale="slaney",
|
101 |
+
)
|
102 |
+
fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
|
103 |
+
for i, key in enumerate(keys):
|
104 |
+
wav = wavs[key][0, ...].cpu()
|
105 |
+
spec = spec_module(wav.unsqueeze(0))
|
106 |
+
log_spec = torch.log(
|
107 |
+
torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
|
108 |
+
* config["preprocess"]["comp_factor"]
|
109 |
+
)
|
110 |
+
ax[i // 3, i % 3].set(title=key)
|
111 |
+
_ = librosa.display.specshow(
|
112 |
+
log_spec.squeeze(0).numpy(),
|
113 |
+
x_axis="time",
|
114 |
+
y_axis="linear",
|
115 |
+
sr=config["preprocess"]["sampling_rate"],
|
116 |
+
hop_length=config["preprocess"]["frame_shift"],
|
117 |
+
fmax=config["preprocess"]["sampling_rate"] // 2,
|
118 |
+
ax=ax[i // 3, i % 3],
|
119 |
+
cmap="viridis",
|
120 |
+
)
|
121 |
+
fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
|
122 |
+
|
123 |
+
|
124 |
+
def configure_args(config, args):
|
125 |
+
for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
|
126 |
+
if getattr(args, key) != None:
|
127 |
+
config["general"][key] = str(getattr(args, key))
|
128 |
+
|
129 |
+
for key in ["n_train", "n_val", "n_test"]:
|
130 |
+
if getattr(args, key) != None:
|
131 |
+
config["preprocess"][key] = getattr(args, key)
|
132 |
+
|
133 |
+
for key in ["alpha", "beta", "learning_rate", "epoch"]:
|
134 |
+
if getattr(args, key) != None:
|
135 |
+
config["train"][key] = getattr(args, key)
|
136 |
+
|
137 |
+
for key in ["load_pretrained", "early_stopping"]:
|
138 |
+
config["train"][key] = getattr(args, key)
|
139 |
+
|
140 |
+
if args.feature_loss_type != None:
|
141 |
+
config["train"]["feature_loss"]["type"] = args.feature_loss_type
|
142 |
+
|
143 |
+
for key in ["pretrained_path"]:
|
144 |
+
if getattr(args, key) != None:
|
145 |
+
config["train"][key] = str(getattr(args, key))
|
146 |
+
|
147 |
+
return config, args
|