saeki commited on
Commit
7b918f7
·
1 Parent(s): e6364e9
Files changed (8) hide show
  1. aet.py +368 -0
  2. dataset.py +344 -0
  3. eval.py +67 -0
  4. lightning_module.py +875 -0
  5. model.py +854 -0
  6. preprocess.py +152 -0
  7. train.py +106 -0
  8. utils.py +147 -0
aet.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import pathlib
3
+ import yaml
4
+ import torch
5
+ import torchaudio
6
+ from torch.utils.data import DataLoader
7
+ import numpy as np
8
+ import random
9
+ import librosa
10
+ from dataset import Dataset
11
+ import pickle
12
+ from lightning_module import (
13
+ SSLStepLightningModule,
14
+ SSLDualLightningModule,
15
+ )
16
+ from utils import plot_and_save_mels
17
+ import os
18
+ import tqdm
19
+
20
+
21
+ class AETDataset(Dataset):
22
+ def __init__(self, filetxt, src_config, tar_config):
23
+ self.config = src_config
24
+
25
+ self.preprocessed_dir_src = pathlib.Path(
26
+ src_config["general"]["preprocessed_path"]
27
+ )
28
+ self.preprocessed_dir_tar = pathlib.Path(
29
+ tar_config["general"]["preprocessed_path"]
30
+ )
31
+ for item in [
32
+ "sampling_rate",
33
+ "fft_length",
34
+ "frame_length",
35
+ "frame_shift",
36
+ "fmin",
37
+ "fmax",
38
+ "n_mels",
39
+ ]:
40
+ assert src_config["preprocess"][item] == tar_config["preprocess"][item]
41
+
42
+ self.spec_module = torchaudio.transforms.MelSpectrogram(
43
+ sample_rate=src_config["preprocess"]["sampling_rate"],
44
+ n_fft=src_config["preprocess"]["fft_length"],
45
+ win_length=src_config["preprocess"]["frame_length"],
46
+ hop_length=src_config["preprocess"]["frame_shift"],
47
+ f_min=src_config["preprocess"]["fmin"],
48
+ f_max=src_config["preprocess"]["fmax"],
49
+ n_mels=src_config["preprocess"]["n_mels"],
50
+ power=1,
51
+ center=True,
52
+ norm="slaney",
53
+ mel_scale="slaney",
54
+ )
55
+
56
+ with open(self.preprocessed_dir_src / filetxt, "r") as fr:
57
+ self.filelist_src = [pathlib.Path(path.strip("\n")) for path in fr]
58
+ with open(self.preprocessed_dir_tar / filetxt, "r") as fr:
59
+ self.filelist_tar = [pathlib.Path(path.strip("\n")) for path in fr]
60
+
61
+ self.d_out = {"src": {}, "tar": {}}
62
+ for item in ["wavs", "wavsaux"]:
63
+ self.d_out["src"][item] = []
64
+ self.d_out["tar"][item] = []
65
+
66
+ for swp in self.filelist_src:
67
+ if src_config["general"]["corpus_type"] == "single":
68
+ basename = str(swp.stem)
69
+ else:
70
+ basename = str(swp.parent.name) + "-" + str(swp.stem)
71
+ with open(
72
+ self.preprocessed_dir_src / "{}.pickle".format(basename), "rb"
73
+ ) as fw:
74
+ d_preprocessed = pickle.load(fw)
75
+ for item in ["wavs", "wavsaux"]:
76
+ try:
77
+ self.d_out["src"][item].extend(d_preprocessed[item])
78
+ except:
79
+ pass
80
+
81
+ for twp in self.filelist_tar:
82
+ if tar_config["general"]["corpus_type"] == "single":
83
+ basename = str(twp.stem)
84
+ else:
85
+ basename = str(twp.parent.name) + "-" + str(twp.stem)
86
+ with open(
87
+ self.preprocessed_dir_tar / "{}.pickle".format(basename), "rb"
88
+ ) as fw:
89
+ d_preprocessed = pickle.load(fw)
90
+ for item in ["wavs", "wavsaux"]:
91
+ try:
92
+ self.d_out["tar"][item].extend(d_preprocessed[item])
93
+ except:
94
+ pass
95
+
96
+ min_len = min(len(self.d_out["src"]["wavs"]), len(self.d_out["tar"]["wavs"]))
97
+ for spk in ["src", "tar"]:
98
+ for item in ["wavs", "wavsaux"]:
99
+ if self.d_out[spk][item] != None:
100
+ self.d_out[spk][item] = np.asarray(self.d_out[spk][item][:min_len])
101
+
102
+ def __len__(self):
103
+ return len(self.d_out["src"]["wavs"])
104
+
105
+ def __getitem__(self, idx):
106
+ d_batch = {}
107
+
108
+ for spk in ["src", "tar"]:
109
+ for item in ["wavs", "wavsaux"]:
110
+ if self.d_out[spk][item].size > 0:
111
+ d_batch["{}_{}".format(item, spk)] = torch.from_numpy(
112
+ self.d_out[spk][item][idx]
113
+ )
114
+ d_batch["{}_{}".format(item, spk)] = self.normalize_waveform(
115
+ d_batch["{}_{}".format(item, spk)], db=-3
116
+ )
117
+
118
+ d_batch["melspecs_src"] = self.calc_spectrogram(d_batch["wavs_src"])
119
+ return d_batch
120
+
121
+
122
+ class AETModule(torch.nn.Module):
123
+ """
124
+ src: Dataset from which we extract the channel features
125
+ tar: Dataset to which the src channel features are added
126
+ """
127
+
128
+ def __init__(self, args, chmatch_config, src_config, tar_config):
129
+ super().__init__()
130
+ if args.stage == "ssl-step":
131
+ LModule = SSLStepLightningModule
132
+ elif args.stage == "ssl-dual":
133
+ LModule = SSLDualLightningModule
134
+ else:
135
+ raise NotImplementedError()
136
+
137
+ src_model = LModule(src_config).load_from_checkpoint(
138
+ checkpoint_path=chmatch_config["general"]["source"]["ckpt_path"],
139
+ config=src_config,
140
+ )
141
+ self.src_config = src_config
142
+
143
+ self.encoder_src = src_model.encoder
144
+ if src_config["general"]["use_gst"]:
145
+ self.gst_src = src_model.gst
146
+ else:
147
+ self.channelfeats_src = src_model.channelfeats
148
+ self.channel_src = src_model.channel
149
+
150
+ def forward(self, melspecs_src, wavsaux_tar):
151
+ if self.src_config["general"]["use_gst"]:
152
+ chfeats_src = self.gst_src(melspecs_src.transpose(1, 2))
153
+ else:
154
+ _, enc_hidden_src = self.encoder_src(
155
+ melspecs_src.unsqueeze(1).transpose(2, 3)
156
+ )
157
+ chfeats_src = self.channelfeats_src(enc_hidden_src)
158
+ wavschmatch_tar = self.channel_src(wavsaux_tar, chfeats_src)
159
+ return wavschmatch_tar
160
+
161
+
162
+ def get_arg():
163
+ parser = argparse.ArgumentParser()
164
+ parser.add_argument("--stage", required=True, type=str)
165
+ parser.add_argument("--config_path", required=True, type=pathlib.Path)
166
+ parser.add_argument("--exist_src_aux", action="store_true")
167
+ parser.add_argument("--run_name", required=True, type=str)
168
+ return parser.parse_args()
169
+
170
+
171
+ def main(args, chmatch_config, device):
172
+ src_config = yaml.load(
173
+ open(chmatch_config["general"]["source"]["config_path"], "r"),
174
+ Loader=yaml.FullLoader,
175
+ )
176
+ tar_config = yaml.load(
177
+ open(chmatch_config["general"]["target"]["config_path"], "r"),
178
+ Loader=yaml.FullLoader,
179
+ )
180
+ output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
181
+ dataset = AETDataset("test.txt", src_config, tar_config)
182
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
183
+ chmatch_module = AETModule(args, chmatch_config, src_config, tar_config).to(device)
184
+
185
+ if args.exist_src_aux:
186
+ char_vector = calc_deg_charactaristics(chmatch_config)
187
+
188
+ for idx, batch in enumerate(tqdm.tqdm(loader)):
189
+ melspecs_src = batch["melspecs_src"].to(device)
190
+ wavsdeg_src = batch["wavs_src"].to(device)
191
+ wavsaux_tar = batch["wavsaux_tar"].to(device)
192
+ if args.exist_src_aux:
193
+ wavsdegbaseline_tar = calc_deg_baseline(
194
+ batch["wavsaux_tar"], char_vector, tar_config
195
+ )
196
+ wavsdegbaseline_tar = normalize_waveform(wavsdegbaseline_tar, tar_config)
197
+ wavsdeg_tar = batch["wavs_tar"].to(device)
198
+ wavsmatch_tar = normalize_waveform(
199
+ chmatch_module(melspecs_src, wavsaux_tar).cpu().detach(), tar_config
200
+ )
201
+ torchaudio.save(
202
+ output_path / "test_wavs" / "{}-src_wavsdeg.wav".format(idx),
203
+ wavsdeg_src.cpu(),
204
+ src_config["preprocess"]["sampling_rate"],
205
+ )
206
+ torchaudio.save(
207
+ output_path / "test_wavs" / "{}-tar_wavsaux.wav".format(idx),
208
+ wavsaux_tar.cpu(),
209
+ tar_config["preprocess"]["sampling_rate"],
210
+ )
211
+ if args.exist_src_aux:
212
+ torchaudio.save(
213
+ output_path / "test_wavs" / "{}-tar_wavsdegbaseline.wav".format(idx),
214
+ wavsdegbaseline_tar.cpu(),
215
+ tar_config["preprocess"]["sampling_rate"],
216
+ )
217
+ torchaudio.save(
218
+ output_path / "test_wavs" / "{}-tar_wavsdeg.wav".format(idx),
219
+ wavsdeg_tar.cpu(),
220
+ tar_config["preprocess"]["sampling_rate"],
221
+ )
222
+ torchaudio.save(
223
+ output_path / "test_wavs" / "{}-tar_wavsmatch.wav".format(idx),
224
+ wavsmatch_tar.cpu(),
225
+ tar_config["preprocess"]["sampling_rate"],
226
+ )
227
+ plot_and_save_mels(
228
+ wavsdeg_src[0, ...].cpu().detach(),
229
+ output_path / "test_mels" / "{}-src_melsdeg.png".format(idx),
230
+ src_config,
231
+ )
232
+ plot_and_save_mels(
233
+ wavsaux_tar[0, ...].cpu().detach(),
234
+ output_path / "test_mels" / "{}-tar_melsaux.png".format(idx),
235
+ tar_config,
236
+ )
237
+ if args.exist_src_aux:
238
+ plot_and_save_mels(
239
+ wavsdegbaseline_tar[0, ...].cpu().detach(),
240
+ output_path / "test_mels" / "{}-tar_melsdegbaseline.png".format(idx),
241
+ tar_config,
242
+ )
243
+ plot_and_save_mels(
244
+ wavsdeg_tar[0, ...].cpu().detach(),
245
+ output_path / "test_mels" / "{}-tar_melsdeg.png".format(idx),
246
+ tar_config,
247
+ )
248
+ plot_and_save_mels(
249
+ wavsmatch_tar[0, ...].cpu().detach(),
250
+ output_path / "test_mels" / "{}-tar_melsmatch.png".format(idx),
251
+ tar_config,
252
+ )
253
+
254
+
255
+ def calc_deg_baseline(wav, char_vector, tar_config):
256
+ wav = wav[0, ...].cpu().detach().numpy()
257
+ spec = librosa.stft(
258
+ wav,
259
+ n_fft=tar_config["preprocess"]["fft_length"],
260
+ hop_length=tar_config["preprocess"]["frame_shift"],
261
+ win_length=tar_config["preprocess"]["frame_length"],
262
+ )
263
+ spec_converted = spec * char_vector.reshape(-1, 1)
264
+ wav_converted = librosa.istft(
265
+ spec_converted,
266
+ hop_length=tar_config["preprocess"]["frame_shift"],
267
+ win_length=tar_config["preprocess"]["frame_length"],
268
+ )
269
+ wav_converted = torch.from_numpy(wav_converted).to(torch.float32).unsqueeze(0)
270
+ return wav_converted
271
+
272
+
273
+ def calc_deg_charactaristics(chmatch_config):
274
+ src_config = yaml.load(
275
+ open(chmatch_config["general"]["source"]["config_path"], "r"),
276
+ Loader=yaml.FullLoader,
277
+ )
278
+ tar_config = yaml.load(
279
+ open(chmatch_config["general"]["target"]["config_path"], "r"),
280
+ Loader=yaml.FullLoader,
281
+ )
282
+ # configs
283
+ preprocessed_dir = pathlib.Path(src_config["general"]["preprocessed_path"])
284
+ n_train = src_config["preprocess"]["n_train"]
285
+ SR = src_config["preprocess"]["sampling_rate"]
286
+
287
+ os.makedirs(preprocessed_dir, exist_ok=True)
288
+
289
+ sourcepath = pathlib.Path(src_config["general"]["source_path"])
290
+
291
+ if src_config["general"]["corpus_type"] == "single":
292
+ fulllist = list(sourcepath.glob("*.wav"))
293
+ random.seed(0)
294
+ random.shuffle(fulllist)
295
+ train_filelist = fulllist[:n_train]
296
+ elif src_config["general"]["corpus_type"] == "multi-seen":
297
+ fulllist = list(sourcepath.glob("*/*.wav"))
298
+ random.seed(0)
299
+ random.shuffle(fulllist)
300
+ train_filelist = fulllist[:n_train]
301
+ elif src_config["general"]["corpus_type"] == "multi-unseen":
302
+ spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")]))
303
+ train_filelist = []
304
+ random.seed(0)
305
+ random.shuffle(spk_list)
306
+ for i, spk in enumerate(spk_list):
307
+ sourcespkpath = sourcepath / spk
308
+ if i < n_train:
309
+ train_filelist.extend(list(sourcespkpath.glob("*.wav")))
310
+ else:
311
+ raise NotImplementedError(
312
+ "corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}"
313
+ )
314
+
315
+ specs_all = np.zeros((tar_config["preprocess"]["fft_length"] // 2 + 1, 1))
316
+
317
+ for wp in tqdm.tqdm(train_filelist):
318
+ wav, _ = librosa.load(wp, sr=SR)
319
+ spec = np.abs(
320
+ librosa.stft(
321
+ wav,
322
+ n_fft=src_config["preprocess"]["fft_length"],
323
+ hop_length=src_config["preprocess"]["frame_shift"],
324
+ win_length=src_config["preprocess"]["frame_length"],
325
+ )
326
+ )
327
+
328
+ auxpath = pathlib.Path(src_config["general"]["aux_path"])
329
+ if src_config["general"]["corpus_type"] == "single":
330
+ wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR)
331
+ else:
332
+ wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR)
333
+ spec_aux = np.abs(
334
+ librosa.stft(
335
+ wav_aux,
336
+ n_fft=src_config["preprocess"]["fft_length"],
337
+ hop_length=src_config["preprocess"]["frame_shift"],
338
+ win_length=src_config["preprocess"]["frame_length"],
339
+ )
340
+ )
341
+ min_len = min(spec.shape[1], spec_aux.shape[1])
342
+ spec_diff = spec[:, :min_len] / (spec_aux[:, :min_len] + 1e-10)
343
+ specs_all = np.hstack([specs_all, np.mean(spec_diff, axis=1).reshape(-1, 1)])
344
+
345
+ char_vector = np.mean(specs_all, axis=1)
346
+ char_vector = char_vector / (np.sum(char_vector) + 1e-10)
347
+ return char_vector
348
+
349
+
350
+ def normalize_waveform(wav, tar_config, db=-3):
351
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
352
+ wav,
353
+ tar_config["preprocess"]["sampling_rate"],
354
+ [["norm", "{}".format(db)]],
355
+ )
356
+ return wav
357
+
358
+
359
+ if __name__ == "__main__":
360
+ args = get_arg()
361
+ chmatch_config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
362
+ output_path = pathlib.Path(chmatch_config["general"]["output_path"]) / args.run_name
363
+ os.makedirs(output_path, exist_ok=True)
364
+ os.makedirs(output_path / "test_wavs", exist_ok=True)
365
+ os.makedirs(output_path / "test_mels", exist_ok=True)
366
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
367
+
368
+ main(args, chmatch_config, device)
dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import pathlib
3
+ import torch
4
+ from torch.utils.data.dataloader import DataLoader
5
+ import pytorch_lightning as pl
6
+ import numpy as np
7
+ import yaml
8
+ import torchaudio
9
+ import pyworld
10
+ import pysptk
11
+ import random
12
+
13
+
14
+ class DataModule(pl.LightningDataModule):
15
+ def __init__(self, config):
16
+ super().__init__()
17
+ self.config = config
18
+ self.batchsize = config["train"]["batchsize"]
19
+ self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
20
+
21
+ def setup(self, stage):
22
+
23
+ if not self.preprocessed_dir.exists():
24
+ raise RuntimeError("Preprocessed directory was not be found")
25
+
26
+ if "dual" in self.config:
27
+ if self.config["dual"]["enable"]:
28
+ task_config = yaml.load(
29
+ open(self.config["dual"]["config_path"], "r"),
30
+ Loader=yaml.FullLoader,
31
+ )
32
+ task_preprocessed_dir = (
33
+ self.preprocessed_dir.parent
34
+ / pathlib.Path(task_config["general"]["preprocessed_path"]).name
35
+ )
36
+ if not task_preprocessed_dir.exists():
37
+ raise RuntimeError(
38
+ "Preprocessed directory for multi-task learning was not be found"
39
+ )
40
+
41
+ self.flnames = {
42
+ "train": "train.txt",
43
+ "val": "val.txt",
44
+ "test": "test.txt",
45
+ }
46
+
47
+ def get_ds(self, phase):
48
+ ds = Dataset(self.flnames[phase], self.config)
49
+ return ds
50
+
51
+ def get_loader(self, phase):
52
+ ds = self.get_ds(phase)
53
+ dl = DataLoader(
54
+ ds,
55
+ self.batchsize,
56
+ shuffle=True if phase == "train" else False,
57
+ num_workers=self.config["train"]["num_workers"],
58
+ drop_last=True,
59
+ )
60
+ return dl
61
+
62
+ def train_dataloader(self):
63
+ return self.get_loader(phase="train")
64
+
65
+ def val_dataloader(self):
66
+ return self.get_loader(phase="val")
67
+
68
+ def test_dataloader(self):
69
+ return self.get_loader(phase="test")
70
+
71
+
72
+ class Dataset(torch.utils.data.Dataset):
73
+ def __init__(self, filetxt, config):
74
+
75
+ self.preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
76
+ self.config = config
77
+ self.spec_module = torchaudio.transforms.MelSpectrogram(
78
+ sample_rate=config["preprocess"]["sampling_rate"],
79
+ n_fft=config["preprocess"]["fft_length"],
80
+ win_length=config["preprocess"]["frame_length"],
81
+ hop_length=config["preprocess"]["frame_shift"],
82
+ f_min=config["preprocess"]["fmin"],
83
+ f_max=config["preprocess"]["fmax"],
84
+ n_mels=config["preprocess"]["n_mels"],
85
+ power=1,
86
+ center=True,
87
+ norm="slaney",
88
+ mel_scale="slaney",
89
+ )
90
+ self.resample_candidate = [8000, 11025, 12000, 16000]
91
+ self.quantization_candidate = range(2 ** 6, 2 ** 10 + 2, 2)
92
+ self.segment_length = config["preprocess"]["segment_length"]
93
+
94
+ with open(self.preprocessed_dir / filetxt, "r") as fr:
95
+ self.filelist = [pathlib.Path(path.strip("\n")) for path in fr]
96
+
97
+ self.d_out = dict()
98
+ for item in ["wavs", "wavsaux"]:
99
+ self.d_out[item] = []
100
+
101
+ for wp in self.filelist:
102
+
103
+ if config["general"]["corpus_type"] == "single":
104
+ basename = str(wp.stem)
105
+ else:
106
+ basename = str(wp.parent.name) + "-" + str(wp.stem)
107
+
108
+ with open(self.preprocessed_dir / "{}.pickle".format(basename), "rb") as fw:
109
+ d_preprocessed = pickle.load(fw)
110
+
111
+ for item in ["wavs", "wavsaux"]:
112
+ try:
113
+ self.d_out[item].extend(d_preprocessed[item])
114
+ except:
115
+ pass
116
+
117
+ for item in ["wavs", "wavsaux"]:
118
+ if self.d_out[item] != None:
119
+ self.d_out[item] = np.asarray(self.d_out[item])
120
+
121
+ if "dual" in self.config:
122
+ if self.config["dual"]["enable"]:
123
+ task_config = yaml.load(
124
+ open(config["dual"]["config_path"], "r"),
125
+ Loader=yaml.FullLoader,
126
+ )
127
+ task_preprocessed_dir = (
128
+ self.preprocessed_dir.parent
129
+ / pathlib.Path(task_config["general"]["preprocessed_path"]).name
130
+ )
131
+ with open(task_preprocessed_dir / filetxt, "r") as fr:
132
+ task_filelist = [pathlib.Path(path.strip("\n")) for path in fr]
133
+ self.d_out["wavstask"] = []
134
+ for wp in task_filelist:
135
+ if task_config["general"]["corpus_type"] == "single":
136
+ basename = str(wp.stem)
137
+ else:
138
+ basename = str(wp.parent.name) + "-" + str(wp.stem)
139
+ with open(
140
+ task_preprocessed_dir / "{}.pickle".format(basename), "rb"
141
+ ) as fw:
142
+ d_preprocessed = pickle.load(fw)
143
+ self.d_out["wavstask"].extend(d_preprocessed["wavs"])
144
+ self.d_out["wavstask"] = np.asarray(self.d_out["wavstask"])
145
+
146
+ def __len__(self):
147
+ return len(self.d_out["wavs"])
148
+
149
+ def __getitem__(self, idx):
150
+
151
+ d_batch = {}
152
+
153
+ if self.d_out["wavs"].size > 0:
154
+ d_batch["wavs"] = torch.from_numpy(self.d_out["wavs"][idx])
155
+ if self.segment_length > 0:
156
+ d_batch["wavs"] = self.get_segment(d_batch["wavs"], self.segment_length)
157
+
158
+ if self.d_out["wavsaux"].size > 0:
159
+ d_batch["wavsaux"] = torch.from_numpy(self.d_out["wavsaux"][idx])
160
+ if self.segment_length > 0:
161
+ d_batch["wavsaux"] = self.get_segment(
162
+ d_batch["wavsaux"], self.segment_length
163
+ )
164
+
165
+ if self.config["general"]["stage"] == "pretrain":
166
+ if self.config["train"]["augment"]:
167
+ d_batch["wavs"] = self.augmentation(d_batch["wavsaux"])
168
+ d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
169
+ d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
170
+ if len(d_batch["wavs"]) != len(d_batch["wavsaux"]):
171
+ min_seq_len = min(len(d_batch["wavs"]), len(d_batch["wavsaux"]))
172
+ d_batch["wavs"] = d_batch["wavs"][:min_seq_len]
173
+ d_batch["wavsaux"] = d_batch["wavsaux"][:min_seq_len]
174
+ d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
175
+ if self.config["general"]["feature_type"] == "melspec":
176
+ d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
177
+ elif self.config["general"]["feature_type"] == "vocfeats":
178
+ d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
179
+ d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
180
+ d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
181
+ else:
182
+ raise NotImplementedError()
183
+
184
+ elif self.config["general"]["stage"].startswith("ssl"):
185
+ d_batch["wavs"] = self.normalize_waveform(d_batch["wavs"], db=-3)
186
+ d_batch["melspecs"] = self.calc_spectrogram(d_batch["wavs"])
187
+ if self.config["general"]["feature_type"] == "vocfeats":
188
+ d_batch["f0s"] = self.calc_f0(d_batch["wavs"])
189
+ d_batch["melcepssrc"] = self.calc_melcep(d_batch["wavs"])
190
+ if self.d_out["wavsaux"].size > 0:
191
+ d_batch["wavsaux"] = self.normalize_waveform(d_batch["wavsaux"], db=-3)
192
+ if self.config["general"]["feature_type"] == "melspec":
193
+ d_batch["melspecsaux"] = self.calc_spectrogram(d_batch["wavsaux"])
194
+ elif self.config["general"]["feature_type"] == "vocfeats":
195
+ d_batch["melceps"] = self.calc_melcep(d_batch["wavsaux"])
196
+ if "dual" in self.config:
197
+ if self.config["dual"]["enable"]:
198
+ d_batch["wavstask"] = torch.from_numpy(self.d_out["wavstask"][idx])
199
+ d_batch["wavstask"] = self.get_segment(
200
+ d_batch["wavstask"], self.segment_length
201
+ )
202
+ d_batch["wavstask"] = self.normalize_waveform(
203
+ d_batch["wavstask"], db=-3
204
+ )
205
+ if self.config["general"]["feature_type"] == "melspec":
206
+ d_batch["melspecstask"] = self.calc_spectrogram(
207
+ d_batch["wavstask"]
208
+ )
209
+ elif self.config["general"]["feature_type"] == "vocfeats":
210
+ d_batch["melcepstask"] = self.calc_melcep(d_batch["wavstask"])
211
+ else:
212
+ raise NotImplementedError()
213
+ else:
214
+ raise NotImplementedError()
215
+
216
+ return d_batch
217
+
218
+ def calc_spectrogram(self, wav):
219
+ specs = self.spec_module(wav)
220
+ log_spec = torch.log(
221
+ torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
222
+ * self.config["preprocess"]["comp_factor"]
223
+ ).to(torch.float32)
224
+ return log_spec
225
+
226
+ def calc_melcep(self, wav):
227
+ wav = wav.numpy()
228
+ _, sp, _ = pyworld.wav2world(
229
+ wav.astype(np.float64),
230
+ self.config["preprocess"]["sampling_rate"],
231
+ fft_size=self.config["preprocess"]["fft_length"],
232
+ frame_period=(
233
+ self.config["preprocess"]["frame_shift"]
234
+ / self.config["preprocess"]["sampling_rate"]
235
+ * 1000
236
+ ),
237
+ )
238
+ melcep = pysptk.sp2mc(
239
+ sp,
240
+ order=self.config["preprocess"]["cep_order"],
241
+ alpha=pysptk.util.mcepalpha(self.config["preprocess"]["sampling_rate"]),
242
+ ).transpose(1, 0)
243
+ melcep = torch.from_numpy(melcep).to(torch.float32)
244
+ return melcep
245
+
246
+ def calc_f0(self, wav):
247
+ if self.config["preprocess"]["f0_extractor"] == "dio":
248
+ return self.calc_f0_dio(wav)
249
+ elif self.config["preprocess"]["f0_extractor"] == "harvest":
250
+ return self.calc_f0_harvest(wav)
251
+ elif self.config["preprocess"]["f0_extractor"] == "swipe":
252
+ return self.calc_f0_swipe(wav)
253
+ else:
254
+ raise NotImplementedError()
255
+
256
+ def calc_f0_dio(self, wav):
257
+ wav = wav.numpy()
258
+ _f0, _t = pyworld.dio(
259
+ wav.astype(np.float64),
260
+ self.config["preprocess"]["sampling_rate"],
261
+ frame_period=(
262
+ self.config["preprocess"]["frame_shift"]
263
+ / self.config["preprocess"]["sampling_rate"]
264
+ * 1000
265
+ ),
266
+ )
267
+ f0 = pyworld.stonemask(
268
+ wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
269
+ )
270
+ f0 = torch.from_numpy(f0).to(torch.float32)
271
+ return f0
272
+
273
+ def calc_f0_harvest(self, wav):
274
+ wav = wav.numpy()
275
+ _f0, _t = pyworld.harvest(
276
+ wav.astype(np.float64),
277
+ self.config["preprocess"]["sampling_rate"],
278
+ frame_period=(
279
+ self.config["preprocess"]["frame_shift"]
280
+ / self.config["preprocess"]["sampling_rate"]
281
+ * 1000
282
+ ),
283
+ )
284
+ f0 = pyworld.stonemask(
285
+ wav.astype(np.float64), _f0, _t, self.config["preprocess"]["sampling_rate"]
286
+ )
287
+ f0 = torch.from_numpy(f0).to(torch.float32)
288
+ return f0
289
+
290
+ def calc_f0_swipe(self, wav):
291
+ wav = wav.numpy()
292
+ f0 = pysptk.sptk.swipe(
293
+ wav.astype(np.float64),
294
+ fs=self.config["preprocess"]["sampling_rate"],
295
+ min=71,
296
+ max=800,
297
+ hopsize=self.config["preprocess"]["frame_shift"],
298
+ otype="f0",
299
+ )
300
+ f0 = torch.from_numpy(f0).to(torch.float32)
301
+ return f0
302
+
303
+ def augmentation(self, wav):
304
+ wav /= torch.max(torch.abs(wav))
305
+ new_freq = random.choice(self.resample_candidate)
306
+ new_quantization = random.choice(self.quantization_candidate)
307
+ mulaw_encoder = torchaudio.transforms.MuLawEncoding(
308
+ quantization_channels=new_quantization
309
+ )
310
+ wav_quantized = mulaw_encoder(wav) / new_quantization * 2.0 - 1.0
311
+ downsampler = torchaudio.transforms.Resample(
312
+ orig_freq=self.config["preprocess"]["sampling_rate"],
313
+ new_freq=new_freq,
314
+ resampling_method="sinc_interpolation",
315
+ lowpass_filter_width=6,
316
+ dtype=torch.float32,
317
+ )
318
+ upsampler = torchaudio.transforms.Resample(
319
+ orig_freq=new_freq,
320
+ new_freq=self.config["preprocess"]["sampling_rate"],
321
+ resampling_method="sinc_interpolation",
322
+ lowpass_filter_width=6,
323
+ dtype=torch.float32,
324
+ )
325
+ wav_processed = upsampler(downsampler(wav_quantized))
326
+ return wav_processed
327
+
328
+ def normalize_waveform(self, wav, db=-3):
329
+ wav, _ = torchaudio.sox_effects.apply_effects_tensor(
330
+ wav.unsqueeze(0),
331
+ self.config["preprocess"]["sampling_rate"],
332
+ [["norm", "{}".format(db)]],
333
+ )
334
+ return wav.squeeze(0)
335
+
336
+ def get_segment(self, wav, segment_length):
337
+ seg_size = self.config["preprocess"]["sampling_rate"] * segment_length
338
+ if len(wav) >= seg_size:
339
+ max_wav_start = len(wav) - seg_size
340
+ wav_start = random.randint(0, max_wav_start)
341
+ wav = wav[wav_start : wav_start + seg_size]
342
+ else:
343
+ wav = torch.nn.functional.pad(wav, (0, seg_size - len(wav)), "constant")
344
+ return wav
eval.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import yaml
5
+
6
+ from pytorch_lightning import Trainer
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers.csv_logs import CSVLogger
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+
11
+ from dataset import DataModule
12
+ from lightning_module import (
13
+ PretrainLightningModule,
14
+ SSLStepLightningModule,
15
+ SSLDualLightningModule,
16
+ )
17
+
18
+
19
+ def get_arg():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config_path", required=True, type=pathlib.Path)
22
+ parser.add_argument("--ckpt_path", required=True, type=pathlib.Path)
23
+ parser.add_argument(
24
+ "--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"]
25
+ )
26
+ parser.add_argument("--run_name", required=True, type=str)
27
+ return parser.parse_args()
28
+
29
+
30
+ def eval(args, config, output_path):
31
+
32
+ csvlogger = CSVLogger(save_dir=output_path, name="test_log")
33
+ trainer = Trainer(
34
+ gpus=-1,
35
+ deterministic=False,
36
+ auto_select_gpus=True,
37
+ benchmark=True,
38
+ logger=[csvlogger],
39
+ default_root_dir=os.getcwd(),
40
+ )
41
+
42
+ if config["general"]["stage"] == "pretrain":
43
+ model = PretrainLightningModule(config).load_from_checkpoint(
44
+ checkpoint_path=args.ckpt_path, config=config
45
+ )
46
+ elif config["general"]["stage"] == "ssl-step":
47
+ model = SSLStepLightningModule(config).load_from_checkpoint(
48
+ checkpoint_path=args.ckpt_path, config=config
49
+ )
50
+ elif config["general"]["stage"] == "ssl-dual":
51
+ model = SSLDualLightningModule(config).load_from_checkpoint(
52
+ checkpoint_path=args.ckpt_path, config=config
53
+ )
54
+ else:
55
+ raise NotImplementedError()
56
+
57
+ datamodule = DataModule(config)
58
+ trainer.test(model=model, verbose=True, datamodule=datamodule)
59
+
60
+
61
+ if __name__ == "__main__":
62
+ args = get_arg()
63
+ config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
64
+ output_path = str(pathlib.Path(config["general"]["output_path"]) / args.run_name)
65
+ config["general"]["stage"] = str(getattr(args, "stage"))
66
+
67
+ eval(args, config, output_path)
lightning_module.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pytorch_lightning as pl
3
+ import torchaudio
4
+ import os
5
+ import pathlib
6
+ import tqdm
7
+ from model import (
8
+ EncoderModule,
9
+ ChannelFeatureModule,
10
+ ChannelModule,
11
+ MultiScaleSpectralLoss,
12
+ GSTModule,
13
+ )
14
+ from utils import (
15
+ manual_logging,
16
+ load_vocoder,
17
+ plot_and_save_mels,
18
+ plot_and_save_mels_all,
19
+ )
20
+
21
+
22
+ class PretrainLightningModule(pl.LightningModule):
23
+ def __init__(self, config):
24
+ super().__init__()
25
+ self.save_hyperparameters()
26
+ self.config = config
27
+ if config["general"]["use_gst"]:
28
+ self.encoder = EncoderModule(config)
29
+ self.gst = GSTModule(config)
30
+ else:
31
+ self.encoder = EncoderModule(config, use_channel=True)
32
+ self.channelfeats = ChannelFeatureModule(config)
33
+
34
+ self.channel = ChannelModule(config)
35
+ self.vocoder = load_vocoder(config)
36
+
37
+ self.criteria_a = MultiScaleSpectralLoss(config)
38
+ if "feature_loss" in config["train"]:
39
+ if config["train"]["feature_loss"]["type"] == "mae":
40
+ self.criteria_b = torch.nn.L1Loss()
41
+ else:
42
+ self.criteria_b = torch.nn.MSELoss()
43
+ else:
44
+ self.criteria = torch.nn.L1Loss()
45
+ self.alpha = config["train"]["alpha"]
46
+
47
+ def forward(self, melspecs, wavsaux):
48
+ if self.config["general"]["use_gst"]:
49
+ enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
50
+ chfeats = self.gst(melspecs.transpose(1, 2))
51
+ else:
52
+ enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
53
+ chfeats = self.channelfeats(enc_hidden)
54
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
55
+ wavsdeg = self.channel(wavsaux, chfeats)
56
+ return enc_out, wavsdeg
57
+
58
+ def training_step(self, batch, batch_idx):
59
+ if self.config["general"]["use_gst"]:
60
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
61
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
62
+ else:
63
+ enc_out, enc_hidden = self.encoder(
64
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
65
+ )
66
+ chfeats = self.channelfeats(enc_hidden)
67
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
68
+ wavsdeg = self.channel(batch["wavsaux"], chfeats)
69
+ loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
70
+ if self.config["general"]["feature_type"] == "melspec":
71
+ loss_encoder = self.criteria_b(enc_out, batch["melspecsaux"])
72
+ elif self.config["general"]["feature_type"] == "vocfeats":
73
+ loss_encoder = self.criteria_b(enc_out, batch["melceps"])
74
+ loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
75
+ self.log(
76
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
77
+ )
78
+ self.log(
79
+ "train_loss_recons",
80
+ loss_recons,
81
+ on_step=True,
82
+ on_epoch=True,
83
+ prog_bar=True,
84
+ logger=True,
85
+ )
86
+ self.log(
87
+ "train_loss_encoder",
88
+ loss_encoder,
89
+ on_step=True,
90
+ on_epoch=True,
91
+ prog_bar=True,
92
+ logger=True,
93
+ )
94
+ return loss
95
+
96
+ def validation_step(self, batch, batch_idx):
97
+ if self.config["general"]["use_gst"]:
98
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
99
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
100
+ else:
101
+ enc_out, enc_hidden = self.encoder(
102
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
103
+ )
104
+ chfeats = self.channelfeats(enc_hidden)
105
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
106
+ wavsdeg = self.channel(batch["wavsaux"], chfeats)
107
+ loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
108
+ if self.config["general"]["feature_type"] == "melspec":
109
+ val_aux_feats = batch["melspecsaux"]
110
+ feats_name = "melspec"
111
+ loss_encoder = self.criteria_b(enc_out, val_aux_feats)
112
+ elif self.config["general"]["feature_type"] == "vocfeats":
113
+ val_aux_feats = batch["melceps"]
114
+ feats_name = "melcep"
115
+ loss_encoder = self.criteria_b(enc_out, val_aux_feats)
116
+ loss = self.alpha * loss_recons + (1.0 - self.alpha) * loss_encoder
117
+ logger_img_dict = {
118
+ "val_src_melspec": batch["melspecs"],
119
+ "val_pred_{}".format(feats_name): enc_out,
120
+ "val_aux_{}".format(feats_name): val_aux_feats,
121
+ }
122
+ logger_wav_dict = {
123
+ "val_src_wav": batch["wavs"],
124
+ "val_pred_wav": wavsdeg,
125
+ "val_aux_wav": batch["wavsaux"],
126
+ }
127
+ return {
128
+ "val_loss": loss,
129
+ "val_loss_recons": loss_recons,
130
+ "val_loss_encoder": loss_encoder,
131
+ "logger_dict": [logger_img_dict, logger_wav_dict],
132
+ }
133
+
134
+ def validation_epoch_end(self, outputs):
135
+ val_loss = torch.stack([out["val_loss"] for out in outputs]).mean().item()
136
+ val_loss_recons = (
137
+ torch.stack([out["val_loss_recons"] for out in outputs]).mean().item()
138
+ )
139
+ val_loss_encoder = (
140
+ torch.stack([out["val_loss_encoder"] for out in outputs]).mean().item()
141
+ )
142
+ self.log("val_loss", val_loss, on_epoch=True, prog_bar=True, logger=True)
143
+ self.log(
144
+ "val_loss_recons",
145
+ val_loss_recons,
146
+ on_epoch=True,
147
+ prog_bar=True,
148
+ logger=True,
149
+ )
150
+ self.log(
151
+ "val_loss_encoder",
152
+ val_loss_encoder,
153
+ on_epoch=True,
154
+ prog_bar=True,
155
+ logger=True,
156
+ )
157
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
158
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
159
+
160
+ def test_step(self, batch, batch_idx):
161
+ if self.config["general"]["use_gst"]:
162
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
163
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
164
+ else:
165
+ enc_out, enc_hidden = self.encoder(
166
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
167
+ )
168
+ chfeats = self.channelfeats(enc_hidden)
169
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
170
+ wavsdeg = self.channel(batch["wavsaux"], chfeats)
171
+ if self.config["general"]["feature_type"] == "melspec":
172
+ enc_feats = enc_out
173
+ enc_feats_aux = batch["melspecsaux"]
174
+ elif self.config["general"]["feature_type"] == "vocfeats":
175
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
176
+ enc_feats_aux = torch.cat(
177
+ (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
178
+ )
179
+ recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
180
+ remas = self.vocoder(enc_feats).squeeze(1)
181
+ if self.config["general"]["feature_type"] == "melspec":
182
+ enc_feats_input = batch["melspecs"]
183
+ elif self.config["general"]["feature_type"] == "vocfeats":
184
+ enc_feats_input = torch.cat(
185
+ (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
186
+ )
187
+ input_recons = self.vocoder(enc_feats_input).squeeze(1)
188
+ if "wavsaux" in batch:
189
+ gt_wav = batch["wavsaux"]
190
+ else:
191
+ gt_wav = None
192
+ return {
193
+ "reconstructed": recons_wav,
194
+ "remastered": remas,
195
+ "channeled": wavsdeg,
196
+ "groundtruth": gt_wav,
197
+ "input": batch["wavs"],
198
+ "input_recons": input_recons,
199
+ }
200
+
201
+ def test_epoch_end(self, outputs):
202
+ wav_dir = (
203
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
204
+ )
205
+ os.makedirs(wav_dir, exist_ok=True)
206
+ mel_dir = (
207
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
208
+ )
209
+ os.makedirs(mel_dir, exist_ok=True)
210
+ print("Saving mel spectrogram plots ...")
211
+ for idx, out in enumerate(tqdm.tqdm(outputs)):
212
+ for key in [
213
+ "reconstructed",
214
+ "remastered",
215
+ "channeled",
216
+ "input",
217
+ "input_recons",
218
+ "groundtruth",
219
+ ]:
220
+ if out[key] != None:
221
+ torchaudio.save(
222
+ wav_dir / "{}-{}.wav".format(idx, key),
223
+ out[key][0, ...].unsqueeze(0).cpu(),
224
+ sample_rate=self.config["preprocess"]["sampling_rate"],
225
+ channels_first=True,
226
+ )
227
+ plot_and_save_mels(
228
+ out[key][0, ...].cpu(),
229
+ mel_dir / "{}-{}.png".format(idx, key),
230
+ self.config,
231
+ )
232
+ plot_and_save_mels_all(
233
+ out,
234
+ [
235
+ "reconstructed",
236
+ "remastered",
237
+ "channeled",
238
+ "input",
239
+ "input_recons",
240
+ "groundtruth",
241
+ ],
242
+ mel_dir / "{}-all.png".format(idx),
243
+ self.config,
244
+ )
245
+
246
+ def configure_optimizers(self):
247
+ optimizer = torch.optim.Adam(
248
+ self.parameters(), lr=self.config["train"]["learning_rate"]
249
+ )
250
+ lr_scheduler_config = {
251
+ "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
252
+ optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
253
+ ),
254
+ "interval": "epoch",
255
+ "frequency": 3,
256
+ "monitor": "val_loss",
257
+ }
258
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
259
+
260
+ def tflogger(self, logger_dict, data_type):
261
+ for lg in self.logger.experiment:
262
+ if type(lg).__name__ == "SummaryWriter":
263
+ tensorboard = lg
264
+ for key in logger_dict.keys():
265
+ manual_logging(
266
+ logger=tensorboard,
267
+ item=logger_dict[key],
268
+ idx=0,
269
+ tag=key,
270
+ global_step=self.global_step,
271
+ data_type=data_type,
272
+ config=self.config,
273
+ )
274
+
275
+
276
+ class SSLBaseModule(pl.LightningModule):
277
+ def __init__(self, config):
278
+ super().__init__()
279
+ self.save_hyperparameters()
280
+ self.config = config
281
+ if config["general"]["use_gst"]:
282
+ self.encoder = EncoderModule(config)
283
+ self.gst = GSTModule(config)
284
+ else:
285
+ self.encoder = EncoderModule(config, use_channel=True)
286
+ self.channelfeats = ChannelFeatureModule(config)
287
+ self.channel = ChannelModule(config)
288
+
289
+ if config["train"]["load_pretrained"]:
290
+ pre_model = PretrainLightningModule.load_from_checkpoint(
291
+ checkpoint_path=config["train"]["pretrained_path"]
292
+ )
293
+ self.encoder.load_state_dict(pre_model.encoder.state_dict(), strict=False)
294
+ self.channel.load_state_dict(pre_model.channel.state_dict(), strict=False)
295
+ if config["general"]["use_gst"]:
296
+ self.gst.load_state_dict(pre_model.gst.state_dict(), strict=False)
297
+ else:
298
+ self.channelfeats.load_state_dict(
299
+ pre_model.channelfeats.state_dict(), strict=False
300
+ )
301
+
302
+ self.vocoder = load_vocoder(config)
303
+ self.criteria = self.get_loss_function(config)
304
+
305
+ def training_step(self, batch, batch_idx):
306
+ raise NotImplementedError()
307
+
308
+ def validation_step(self, batch, batch_idx):
309
+ raise NotImplementedError()
310
+
311
+ def validation_epoch_end(self, outputs):
312
+ raise NotImplementedError()
313
+
314
+ def configure_optimizers(self):
315
+ raise NotImplementedError()
316
+
317
+ def get_loss_function(self, config):
318
+ raise NotImplementedError()
319
+
320
+ def forward(self, melspecs, f0s=None):
321
+ if self.config["general"]["use_gst"]:
322
+ enc_out = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
323
+ chfeats = self.gst(melspecs.transpose(1, 2))
324
+ else:
325
+ enc_out, enc_hidden = self.encoder(melspecs.unsqueeze(1).transpose(2, 3))
326
+ chfeats = self.channelfeats(enc_hidden)
327
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
328
+ if self.config["general"]["feature_type"] == "melspec":
329
+ enc_feats = enc_out
330
+ elif self.config["general"]["feature_type"] == "vocfeats":
331
+ enc_feats = torch.cat((f0s.unsqueeze(1), enc_out), dim=1)
332
+ remas = self.vocoder(enc_feats).squeeze(1)
333
+ wavsdeg = self.channel(remas, chfeats)
334
+ return remas, wavsdeg
335
+
336
+ def test_step(self, batch, batch_idx):
337
+ if self.config["general"]["use_gst"]:
338
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
339
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
340
+ else:
341
+ enc_out, enc_hidden = self.encoder(
342
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
343
+ )
344
+ chfeats = self.channelfeats(enc_hidden)
345
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
346
+ if self.config["general"]["feature_type"] == "melspec":
347
+ enc_feats = enc_out
348
+ elif self.config["general"]["feature_type"] == "vocfeats":
349
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
350
+ remas = self.vocoder(enc_feats).squeeze(1)
351
+ wavsdeg = self.channel(remas, chfeats)
352
+ if self.config["general"]["feature_type"] == "melspec":
353
+ enc_feats_input = batch["melspecs"]
354
+ elif self.config["general"]["feature_type"] == "vocfeats":
355
+ enc_feats_input = torch.cat(
356
+ (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
357
+ )
358
+ input_recons = self.vocoder(enc_feats_input).squeeze(1)
359
+ if "wavsaux" in batch:
360
+ gt_wav = batch["wavsaux"]
361
+ if self.config["general"]["feature_type"] == "melspec":
362
+ enc_feats_aux = batch["melspecsaux"]
363
+ elif self.config["general"]["feature_type"] == "vocfeats":
364
+ enc_feats_aux = torch.cat(
365
+ (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
366
+ )
367
+ recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
368
+ else:
369
+ gt_wav = None
370
+ recons_wav = None
371
+ return {
372
+ "reconstructed": recons_wav,
373
+ "remastered": remas,
374
+ "channeled": wavsdeg,
375
+ "input": batch["wavs"],
376
+ "input_recons": input_recons,
377
+ "groundtruth": gt_wav,
378
+ }
379
+
380
+ def test_epoch_end(self, outputs):
381
+ wav_dir = (
382
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
383
+ )
384
+ os.makedirs(wav_dir, exist_ok=True)
385
+ mel_dir = (
386
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
387
+ )
388
+ os.makedirs(mel_dir, exist_ok=True)
389
+ print("Saving mel spectrogram plots ...")
390
+ for idx, out in enumerate(tqdm.tqdm(outputs)):
391
+ plot_keys = []
392
+ for key in [
393
+ "reconstructed",
394
+ "remastered",
395
+ "channeled",
396
+ "input",
397
+ "input_recons",
398
+ "groundtruth",
399
+ ]:
400
+ if out[key] != None:
401
+ plot_keys.append(key)
402
+ torchaudio.save(
403
+ wav_dir / "{}-{}.wav".format(idx, key),
404
+ out[key][0, ...].unsqueeze(0).cpu(),
405
+ sample_rate=self.config["preprocess"]["sampling_rate"],
406
+ channels_first=True,
407
+ )
408
+ plot_and_save_mels(
409
+ out[key][0, ...].cpu(),
410
+ mel_dir / "{}-{}.png".format(idx, key),
411
+ self.config,
412
+ )
413
+ plot_and_save_mels_all(
414
+ out,
415
+ plot_keys,
416
+ mel_dir / "{}-all.png".format(idx),
417
+ self.config,
418
+ )
419
+
420
+ def tflogger(self, logger_dict, data_type):
421
+ for lg in self.logger.experiment:
422
+ if type(lg).__name__ == "SummaryWriter":
423
+ tensorboard = lg
424
+ for key in logger_dict.keys():
425
+ manual_logging(
426
+ logger=tensorboard,
427
+ item=logger_dict[key],
428
+ idx=0,
429
+ tag=key,
430
+ global_step=self.global_step,
431
+ data_type=data_type,
432
+ config=self.config,
433
+ )
434
+
435
+
436
+ class SSLStepLightningModule(SSLBaseModule):
437
+ def __init__(self, config):
438
+ super().__init__(config)
439
+ if config["train"]["fix_channel"]:
440
+ for param in self.channel.parameters():
441
+ param.requires_grad = False
442
+
443
+ def training_step(self, batch, batch_idx, optimizer_idx):
444
+ if self.config["general"]["use_gst"]:
445
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
446
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
447
+ else:
448
+ enc_out, enc_hidden = self.encoder(
449
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
450
+ )
451
+ chfeats = self.channelfeats(enc_hidden)
452
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
453
+ if self.config["general"]["feature_type"] == "melspec":
454
+ enc_feats = enc_out
455
+ elif self.config["general"]["feature_type"] == "vocfeats":
456
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
457
+ remas = self.vocoder(enc_feats).squeeze(1)
458
+ wavsdeg = self.channel(remas, chfeats)
459
+ loss = self.criteria(wavsdeg, batch["wavs"])
460
+ self.log(
461
+ "train_loss",
462
+ loss,
463
+ on_step=True,
464
+ on_epoch=True,
465
+ prog_bar=True,
466
+ logger=True,
467
+ )
468
+ return loss
469
+
470
+ def validation_step(self, batch, batch_idx):
471
+ if self.config["general"]["use_gst"]:
472
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
473
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
474
+ else:
475
+ enc_out, enc_hidden = self.encoder(
476
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
477
+ )
478
+ chfeats = self.channelfeats(enc_hidden)
479
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
480
+ if self.config["general"]["feature_type"] == "melspec":
481
+ enc_feats = enc_out
482
+ feats_name = "melspec"
483
+ elif self.config["general"]["feature_type"] == "vocfeats":
484
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
485
+ feats_name = "melcep"
486
+ remas = self.vocoder(enc_feats).squeeze(1)
487
+ wavsdeg = self.channel(remas, chfeats)
488
+ loss = self.criteria(wavsdeg, batch["wavs"])
489
+ logger_img_dict = {
490
+ "val_src_melspec": batch["melspecs"],
491
+ "val_pred_{}".format(feats_name): enc_out,
492
+ }
493
+ for auxfeats in ["melceps", "melspecsaux"]:
494
+ if auxfeats in batch:
495
+ logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
496
+ logger_wav_dict = {
497
+ "val_src_wav": batch["wavs"],
498
+ "val_remastered_wav": remas,
499
+ "val_pred_wav": wavsdeg,
500
+ }
501
+ if "wavsaux" in batch:
502
+ logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
503
+ d_out = {"val_loss": loss, "logger_dict": [logger_img_dict, logger_wav_dict]}
504
+ return d_out
505
+
506
+ def validation_epoch_end(self, outputs):
507
+ self.log(
508
+ "val_loss",
509
+ torch.stack([out["val_loss"] for out in outputs]).mean().item(),
510
+ on_epoch=True,
511
+ prog_bar=True,
512
+ logger=True,
513
+ )
514
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
515
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
516
+
517
+ def optimizer_step(
518
+ self,
519
+ epoch,
520
+ batch_idx,
521
+ optimizer,
522
+ optimizer_idx,
523
+ optimizer_closure,
524
+ on_tpu=False,
525
+ using_native_amp=False,
526
+ using_lbfgs=False,
527
+ ):
528
+ if epoch < self.config["train"]["epoch_channel"]:
529
+ if optimizer_idx == 0:
530
+ optimizer.step(closure=optimizer_closure)
531
+ elif optimizer_idx == 1:
532
+ optimizer_closure()
533
+ else:
534
+ if optimizer_idx == 0:
535
+ optimizer_closure()
536
+ elif optimizer_idx == 1:
537
+ optimizer.step(closure=optimizer_closure)
538
+
539
+ def configure_optimizers(self):
540
+ if self.config["train"]["fix_channel"]:
541
+ if self.config["general"]["use_gst"]:
542
+ optimizer_channel = torch.optim.Adam(
543
+ self.gst.parameters(), lr=self.config["train"]["learning_rate"]
544
+ )
545
+ else:
546
+ optimizer_channel = torch.optim.Adam(
547
+ self.channelfeats.parameters(),
548
+ lr=self.config["train"]["learning_rate"],
549
+ )
550
+ optimizer_encoder = torch.optim.Adam(
551
+ self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
552
+ )
553
+ else:
554
+ if self.config["general"]["use_gst"]:
555
+ optimizer_channel = torch.optim.Adam(
556
+ [
557
+ {"params": self.channel.parameters()},
558
+ {"params": self.gst.parameters()},
559
+ ],
560
+ lr=self.config["train"]["learning_rate"],
561
+ )
562
+ else:
563
+ optimizer_channel = torch.optim.Adam(
564
+ [
565
+ {"params": self.channel.parameters()},
566
+ {"params": self.channelfeats.parameters()},
567
+ ],
568
+ lr=self.config["train"]["learning_rate"],
569
+ )
570
+ optimizer_encoder = torch.optim.Adam(
571
+ self.encoder.parameters(), lr=self.config["train"]["learning_rate"]
572
+ )
573
+ optimizers = [optimizer_channel, optimizer_encoder]
574
+ schedulers = [
575
+ {
576
+ "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
577
+ optimizers[0], mode="min", factor=0.5, min_lr=1e-5, verbose=True
578
+ ),
579
+ "interval": "epoch",
580
+ "frequency": 3,
581
+ "monitor": "val_loss",
582
+ },
583
+ {
584
+ "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
585
+ optimizers[1], mode="min", factor=0.5, min_lr=1e-5, verbose=True
586
+ ),
587
+ "interval": "epoch",
588
+ "frequency": 3,
589
+ "monitor": "val_loss",
590
+ },
591
+ ]
592
+ return optimizers, schedulers
593
+
594
+ def get_loss_function(self, config):
595
+ return MultiScaleSpectralLoss(config)
596
+
597
+
598
+ class SSLDualLightningModule(SSLBaseModule):
599
+ def __init__(self, config):
600
+ super().__init__(config)
601
+ if config["train"]["fix_channel"]:
602
+ for param in self.channel.parameters():
603
+ param.requires_grad = False
604
+ self.spec_module = torchaudio.transforms.MelSpectrogram(
605
+ sample_rate=config["preprocess"]["sampling_rate"],
606
+ n_fft=config["preprocess"]["fft_length"],
607
+ win_length=config["preprocess"]["frame_length"],
608
+ hop_length=config["preprocess"]["frame_shift"],
609
+ f_min=config["preprocess"]["fmin"],
610
+ f_max=config["preprocess"]["fmax"],
611
+ n_mels=config["preprocess"]["n_mels"],
612
+ power=1,
613
+ center=True,
614
+ norm="slaney",
615
+ mel_scale="slaney",
616
+ )
617
+ self.beta = config["train"]["beta"]
618
+ self.criteria_a, self.criteria_b = self.get_loss_function(config)
619
+
620
+ def training_step(self, batch, batch_idx):
621
+ if self.config["general"]["use_gst"]:
622
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
623
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
624
+ else:
625
+ enc_out, enc_hidden = self.encoder(
626
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
627
+ )
628
+ chfeats = self.channelfeats(enc_hidden)
629
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
630
+ if self.config["general"]["feature_type"] == "melspec":
631
+ enc_feats = enc_out
632
+ elif self.config["general"]["feature_type"] == "vocfeats":
633
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
634
+ remas = self.vocoder(enc_feats).squeeze(1)
635
+ wavsdeg = self.channel(remas, chfeats)
636
+ loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
637
+
638
+ with torch.no_grad():
639
+ wavsdegtask = self.channel(batch["wavstask"], chfeats)
640
+ melspecstask = self.calc_spectrogram(wavsdegtask)
641
+ if self.config["general"]["use_gst"]:
642
+ enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
643
+ else:
644
+ enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
645
+ enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
646
+ if self.config["general"]["feature_type"] == "melspec":
647
+ loss_task = self.criteria_b(enc_out_task, batch["melspecstask"])
648
+ elif self.config["general"]["feature_type"] == "vocfeats":
649
+ loss_task = self.criteria_b(enc_out_task, batch["melcepstask"])
650
+ loss = self.beta * loss_recons + (1 - self.beta) * loss_task
651
+
652
+ self.log(
653
+ "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
654
+ )
655
+ self.log(
656
+ "train_loss_recons",
657
+ loss_recons,
658
+ on_step=True,
659
+ on_epoch=True,
660
+ prog_bar=True,
661
+ logger=True,
662
+ )
663
+ self.log(
664
+ "train_loss_task",
665
+ loss_task,
666
+ on_step=True,
667
+ on_epoch=True,
668
+ prog_bar=True,
669
+ logger=True,
670
+ )
671
+ return loss
672
+
673
+ def validation_step(self, batch, batch_idx):
674
+ if self.config["general"]["use_gst"]:
675
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
676
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
677
+ else:
678
+ enc_out, enc_hidden = self.encoder(
679
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
680
+ )
681
+ chfeats = self.channelfeats(enc_hidden)
682
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
683
+ if self.config["general"]["feature_type"] == "melspec":
684
+ enc_feats = enc_out
685
+ feats_name = "melspec"
686
+ elif self.config["general"]["feature_type"] == "vocfeats":
687
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
688
+ feats_name = "melcep"
689
+ remas = self.vocoder(enc_feats).squeeze(1)
690
+ wavsdeg = self.channel(remas, chfeats)
691
+ loss_recons = self.criteria_a(wavsdeg, batch["wavs"])
692
+
693
+ wavsdegtask = self.channel(batch["wavstask"], chfeats)
694
+ melspecstask = self.calc_spectrogram(wavsdegtask)
695
+ if self.config["general"]["use_gst"]:
696
+ enc_out_task = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
697
+ else:
698
+ enc_out_task, _ = self.encoder(melspecstask.unsqueeze(1).transpose(2, 3))
699
+ enc_out_task = enc_out_task.squeeze(1).transpose(1, 2)
700
+ if self.config["general"]["feature_type"] == "melspec":
701
+ enc_out_task_truth = batch["melspecstask"]
702
+ loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
703
+ elif self.config["general"]["feature_type"] == "vocfeats":
704
+ enc_out_task_truth = batch["melcepstask"]
705
+ loss_task = self.criteria_b(enc_out_task, enc_out_task_truth)
706
+ loss = self.beta * loss_recons + (1 - self.beta) * loss_task
707
+
708
+ logger_img_dict = {
709
+ "val_src_melspec": batch["melspecs"],
710
+ "val_pred_{}".format(feats_name): enc_out,
711
+ "val_truth_{}_task".format(feats_name): enc_out_task_truth,
712
+ "val_pred_{}_task".format(feats_name): enc_out_task,
713
+ }
714
+ for auxfeats in ["melceps", "melspecsaux"]:
715
+ if auxfeats in batch:
716
+ logger_img_dict["val_aux_{}".format(auxfeats)] = batch[auxfeats]
717
+ logger_wav_dict = {
718
+ "val_src_wav": batch["wavs"],
719
+ "val_remastered_wav": remas,
720
+ "val_pred_wav": wavsdeg,
721
+ "val_truth_wavtask": batch["wavstask"],
722
+ "val_deg_wavtask": wavsdegtask,
723
+ }
724
+ if "wavsaux" in batch:
725
+ logger_wav_dict["val_aux_wav"] = batch["wavsaux"]
726
+
727
+ d_out = {
728
+ "val_loss": loss,
729
+ "val_loss_recons": loss_recons,
730
+ "val_loss_task": loss_task,
731
+ "logger_dict": [logger_img_dict, logger_wav_dict],
732
+ }
733
+ return d_out
734
+
735
+ def validation_epoch_end(self, outputs):
736
+ self.log(
737
+ "val_loss",
738
+ torch.stack([out["val_loss"] for out in outputs]).mean().item(),
739
+ on_epoch=True,
740
+ prog_bar=True,
741
+ logger=True,
742
+ )
743
+ self.log(
744
+ "val_loss_recons",
745
+ torch.stack([out["val_loss_recons"] for out in outputs]).mean().item(),
746
+ on_epoch=True,
747
+ prog_bar=True,
748
+ logger=True,
749
+ )
750
+ self.log(
751
+ "val_loss_task",
752
+ torch.stack([out["val_loss_task"] for out in outputs]).mean().item(),
753
+ on_epoch=True,
754
+ prog_bar=True,
755
+ logger=True,
756
+ )
757
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][0], data_type="image")
758
+ self.tflogger(logger_dict=outputs[-1]["logger_dict"][1], data_type="audio")
759
+
760
+ def test_step(self, batch, batch_idx):
761
+ if self.config["general"]["use_gst"]:
762
+ enc_out = self.encoder(batch["melspecs"].unsqueeze(1).transpose(2, 3))
763
+ chfeats = self.gst(batch["melspecs"].transpose(1, 2))
764
+ else:
765
+ enc_out, enc_hidden = self.encoder(
766
+ batch["melspecs"].unsqueeze(1).transpose(2, 3)
767
+ )
768
+ chfeats = self.channelfeats(enc_hidden)
769
+ enc_out = enc_out.squeeze(1).transpose(1, 2)
770
+ if self.config["general"]["feature_type"] == "melspec":
771
+ enc_feats = enc_out
772
+ elif self.config["general"]["feature_type"] == "vocfeats":
773
+ enc_feats = torch.cat((batch["f0s"].unsqueeze(1), enc_out), dim=1)
774
+ remas = self.vocoder(enc_feats).squeeze(1)
775
+ wavsdeg = self.channel(remas, chfeats)
776
+ if self.config["general"]["feature_type"] == "melspec":
777
+ enc_feats_input = batch["melspecs"]
778
+ elif self.config["general"]["feature_type"] == "vocfeats":
779
+ enc_feats_input = torch.cat(
780
+ (batch["f0s"].unsqueeze(1), batch["melcepssrc"]), dim=1
781
+ )
782
+ input_recons = self.vocoder(enc_feats_input).squeeze(1)
783
+
784
+ wavsdegtask = self.channel(batch["wavstask"], chfeats)
785
+ if "wavsaux" in batch:
786
+ gt_wav = batch["wavsaux"]
787
+ if self.config["general"]["feature_type"] == "melspec":
788
+ enc_feats_aux = batch["melspecsaux"]
789
+ elif self.config["general"]["feature_type"] == "vocfeats":
790
+ enc_feats_aux = torch.cat(
791
+ (batch["f0s"].unsqueeze(1), batch["melceps"]), dim=1
792
+ )
793
+ recons_wav = self.vocoder(enc_feats_aux).squeeze(1)
794
+ else:
795
+ gt_wav = None
796
+ recons_wav = None
797
+ return {
798
+ "reconstructed": recons_wav,
799
+ "remastered": remas,
800
+ "channeled": wavsdeg,
801
+ "channeled_task": wavsdegtask,
802
+ "input": batch["wavs"],
803
+ "input_recons": input_recons,
804
+ "groundtruth": gt_wav,
805
+ }
806
+
807
+ def test_epoch_end(self, outputs):
808
+ wav_dir = (
809
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_wavs"
810
+ )
811
+ os.makedirs(wav_dir, exist_ok=True)
812
+ mel_dir = (
813
+ pathlib.Path(self.logger.experiment[0].log_dir).parent.parent / "test_mels"
814
+ )
815
+ os.makedirs(mel_dir, exist_ok=True)
816
+ print("Saving mel spectrogram plots ...")
817
+ for idx, out in enumerate(tqdm.tqdm(outputs)):
818
+ plot_keys = []
819
+ for key in [
820
+ "reconstructed",
821
+ "remastered",
822
+ "channeled",
823
+ "channeled_task",
824
+ "input",
825
+ "input_recons",
826
+ "groundtruth",
827
+ ]:
828
+ if out[key] != None:
829
+ plot_keys.append(key)
830
+ torchaudio.save(
831
+ wav_dir / "{}-{}.wav".format(idx, key),
832
+ out[key][0, ...].unsqueeze(0).cpu(),
833
+ sample_rate=self.config["preprocess"]["sampling_rate"],
834
+ channels_first=True,
835
+ )
836
+ plot_and_save_mels(
837
+ out[key][0, ...].cpu(),
838
+ mel_dir / "{}-{}.png".format(idx, key),
839
+ self.config,
840
+ )
841
+ plot_and_save_mels_all(
842
+ out,
843
+ plot_keys,
844
+ mel_dir / "{}-all.png".format(idx),
845
+ self.config,
846
+ )
847
+
848
+ def configure_optimizers(self):
849
+ optimizer = torch.optim.Adam(
850
+ self.parameters(), lr=self.config["train"]["learning_rate"]
851
+ )
852
+ lr_scheduler_config = {
853
+ "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(
854
+ optimizer, mode="min", factor=0.5, min_lr=1e-5, verbose=True
855
+ ),
856
+ "interval": "epoch",
857
+ "frequency": 3,
858
+ "monitor": "val_loss",
859
+ }
860
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
861
+
862
+ def calc_spectrogram(self, wav):
863
+ specs = self.spec_module(wav)
864
+ log_spec = torch.log(
865
+ torch.clamp_min(specs, self.config["preprocess"]["min_magnitude"])
866
+ * self.config["preprocess"]["comp_factor"]
867
+ ).to(torch.float32)
868
+ return log_spec
869
+
870
+ def get_loss_function(self, config):
871
+ if config["train"]["feature_loss"]["type"] == "mae":
872
+ feature_loss = torch.nn.L1Loss()
873
+ else:
874
+ feature_loss = torch.nn.MSELoss()
875
+ return MultiScaleSpectralLoss(config), feature_loss
model.py ADDED
@@ -0,0 +1,854 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torchaudio
4
+ import torch.nn.functional as F
5
+ import torch.nn.init as init
6
+ import numpy as np
7
+
8
+
9
+ class EncoderModule(nn.Module):
10
+ """
11
+ Analysis module based on 2D conv U-Net
12
+ Inspired by https://github.com/haoheliu/voicefixer
13
+
14
+ Args:
15
+ config (dict): config
16
+ use_channel (bool): output channel feature or not
17
+ """
18
+ def __init__(self, config, use_channel=False):
19
+ super().__init__()
20
+
21
+ self.channels = 1
22
+ self.use_channel = use_channel
23
+ self.downsample_ratio = 2 ** 4
24
+
25
+ self.down_block1 = DownBlockRes2D(
26
+ in_channels=self.channels,
27
+ out_channels=32,
28
+ downsample=(2, 2),
29
+ activation="relu",
30
+ momentum=0.01,
31
+ )
32
+ self.down_block2 = DownBlockRes2D(
33
+ in_channels=32,
34
+ out_channels=64,
35
+ downsample=(2, 2),
36
+ activation="relu",
37
+ momentum=0.01,
38
+ )
39
+ self.down_block3 = DownBlockRes2D(
40
+ in_channels=64,
41
+ out_channels=128,
42
+ downsample=(2, 2),
43
+ activation="relu",
44
+ momentum=0.01,
45
+ )
46
+ self.down_block4 = DownBlockRes2D(
47
+ in_channels=128,
48
+ out_channels=256,
49
+ downsample=(2, 2),
50
+ activation="relu",
51
+ momentum=0.01,
52
+ )
53
+ self.conv_block5 = ConvBlockRes2D(
54
+ in_channels=256,
55
+ out_channels=256,
56
+ size=3,
57
+ activation="relu",
58
+ momentum=0.01,
59
+ )
60
+ self.up_block1 = UpBlockRes2D(
61
+ in_channels=256,
62
+ out_channels=256,
63
+ stride=(2, 2),
64
+ activation="relu",
65
+ momentum=0.01,
66
+ )
67
+ self.up_block2 = UpBlockRes2D(
68
+ in_channels=256,
69
+ out_channels=128,
70
+ stride=(2, 2),
71
+ activation="relu",
72
+ momentum=0.01,
73
+ )
74
+ self.up_block3 = UpBlockRes2D(
75
+ in_channels=128,
76
+ out_channels=64,
77
+ stride=(2, 2),
78
+ activation="relu",
79
+ momentum=0.01,
80
+ )
81
+ self.up_block4 = UpBlockRes2D(
82
+ in_channels=64,
83
+ out_channels=32,
84
+ stride=(2, 2),
85
+ activation="relu",
86
+ momentum=0.01,
87
+ )
88
+
89
+ self.after_conv_block1 = ConvBlockRes2D(
90
+ in_channels=32,
91
+ out_channels=32,
92
+ size=3,
93
+ activation="relu",
94
+ momentum=0.01,
95
+ )
96
+
97
+ self.after_conv2 = nn.Conv2d(
98
+ in_channels=32,
99
+ out_channels=1,
100
+ kernel_size=(1, 1),
101
+ stride=(1, 1),
102
+ padding=(0, 0),
103
+ bias=True,
104
+ )
105
+
106
+ if config["general"]["feature_type"] == "melspec":
107
+ out_dim = config["preprocess"]["n_mels"]
108
+ elif config["general"]["feature_type"] == "vocfeats":
109
+ out_dim = config["preprocess"]["cep_order"] + 1
110
+ else:
111
+ raise NotImplementedError()
112
+
113
+ self.after_linear = nn.Linear(
114
+ in_features=80,
115
+ out_features=out_dim,
116
+ bias=True,
117
+ )
118
+
119
+ if self.use_channel:
120
+ self.conv_channel = ConvBlockRes2D(
121
+ in_channels=256,
122
+ out_channels=256,
123
+ size=3,
124
+ activation="relu",
125
+ momentum=0.01,
126
+ )
127
+
128
+ def forward(self, x):
129
+ """
130
+ Forward
131
+
132
+ Args:
133
+ mel spectrogram: (batch, 1, time, freq)
134
+
135
+ Return:
136
+ speech feature (mel spectrogram or mel cepstrum): (batch, 1, time, freq)
137
+ input of channel feature module (batch, 256, time, freq)
138
+ """
139
+
140
+ origin_len = x.shape[2]
141
+ pad_len = (
142
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
143
+ - origin_len
144
+ )
145
+ x = F.pad(x, pad=(0, 0, 0, pad_len))
146
+ x = x[..., 0 : x.shape[-1] - 1]
147
+
148
+ (x1_pool, x1) = self.down_block1(x)
149
+ (x2_pool, x2) = self.down_block2(x1_pool)
150
+ (x3_pool, x3) = self.down_block3(x2_pool)
151
+ (x4_pool, x4) = self.down_block4(x3_pool)
152
+ x_center = self.conv_block5(x4_pool)
153
+ x5 = self.up_block1(x_center, x4)
154
+ x6 = self.up_block2(x5, x3)
155
+ x7 = self.up_block3(x6, x2)
156
+ x8 = self.up_block4(x7, x1)
157
+ x = self.after_conv_block1(x8)
158
+ x = self.after_conv2(x)
159
+
160
+ x = F.pad(x, pad=(0, 1))
161
+ x = x[:, :, 0:origin_len, :]
162
+
163
+ x = self.after_linear(x)
164
+
165
+ if self.use_channel:
166
+ x_channel = self.conv_channel(x4_pool)
167
+ return x, x_channel
168
+ else:
169
+ return x
170
+
171
+
172
+ class ChannelModule(nn.Module):
173
+ """
174
+ Channel module based on 1D conv U-Net
175
+
176
+ Args:
177
+ config (dict): config
178
+ """
179
+ def __init__(self, config):
180
+ super().__init__()
181
+
182
+ self.channels = 1
183
+ self.downsample_ratio = 2 ** 6 # This number equals 2^{#encoder_blcoks}
184
+
185
+ self.down_block1 = DownBlockRes1D(
186
+ in_channels=self.channels,
187
+ out_channels=32,
188
+ downsample=2,
189
+ activation="relu",
190
+ momentum=0.01,
191
+ )
192
+ self.down_block2 = DownBlockRes1D(
193
+ in_channels=32,
194
+ out_channels=64,
195
+ downsample=2,
196
+ activation="relu",
197
+ momentum=0.01,
198
+ )
199
+ self.down_block3 = DownBlockRes1D(
200
+ in_channels=64,
201
+ out_channels=128,
202
+ downsample=2,
203
+ activation="relu",
204
+ momentum=0.01,
205
+ )
206
+ self.down_block4 = DownBlockRes1D(
207
+ in_channels=128,
208
+ out_channels=256,
209
+ downsample=2,
210
+ activation="relu",
211
+ momentum=0.01,
212
+ )
213
+ self.down_block5 = DownBlockRes1D(
214
+ in_channels=256,
215
+ out_channels=512,
216
+ downsample=2,
217
+ activation="relu",
218
+ momentum=0.01,
219
+ )
220
+ self.conv_block6 = ConvBlockRes1D(
221
+ in_channels=512,
222
+ out_channels=384,
223
+ size=3,
224
+ activation="relu",
225
+ momentum=0.01,
226
+ )
227
+ self.up_block1 = UpBlockRes1D(
228
+ in_channels=512,
229
+ out_channels=512,
230
+ stride=2,
231
+ activation="relu",
232
+ momentum=0.01,
233
+ )
234
+ self.up_block2 = UpBlockRes1D(
235
+ in_channels=512,
236
+ out_channels=256,
237
+ stride=2,
238
+ activation="relu",
239
+ momentum=0.01,
240
+ )
241
+ self.up_block3 = UpBlockRes1D(
242
+ in_channels=256,
243
+ out_channels=128,
244
+ stride=2,
245
+ activation="relu",
246
+ momentum=0.01,
247
+ )
248
+ self.up_block4 = UpBlockRes1D(
249
+ in_channels=128,
250
+ out_channels=64,
251
+ stride=2,
252
+ activation="relu",
253
+ momentum=0.01,
254
+ )
255
+ self.up_block5 = UpBlockRes1D(
256
+ in_channels=64,
257
+ out_channels=32,
258
+ stride=2,
259
+ activation="relu",
260
+ momentum=0.01,
261
+ )
262
+
263
+ self.after_conv_block1 = ConvBlockRes1D(
264
+ in_channels=32,
265
+ out_channels=32,
266
+ size=3,
267
+ activation="relu",
268
+ momentum=0.01,
269
+ )
270
+
271
+ self.after_conv2 = nn.Conv1d(
272
+ in_channels=32,
273
+ out_channels=1,
274
+ kernel_size=1,
275
+ stride=1,
276
+ padding=0,
277
+ bias=True,
278
+ )
279
+
280
+ def forward(self, x, h):
281
+ """
282
+ Forward
283
+
284
+ Args:
285
+ clean waveform: (batch, n_channel (1), time)
286
+ channel feature: (batch, feature_dim)
287
+ Outputs:
288
+ degraded waveform: (batch, n_channel (1), time)
289
+ """
290
+ x = x.unsqueeze(1)
291
+
292
+ origin_len = x.shape[2]
293
+ pad_len = (
294
+ int(np.ceil(x.shape[2] / self.downsample_ratio)) * self.downsample_ratio
295
+ - origin_len
296
+ )
297
+ x = F.pad(x, pad=(0, pad_len))
298
+ x = x[..., 0 : x.shape[-1] - 1]
299
+
300
+ (x1_pool, x1) = self.down_block1(x)
301
+ (x2_pool, x2) = self.down_block2(x1_pool)
302
+ (x3_pool, x3) = self.down_block3(x2_pool)
303
+ (x4_pool, x4) = self.down_block4(x3_pool)
304
+ (x5_pool, x5) = self.down_block5(x4_pool)
305
+ x_center = self.conv_block6(x5_pool)
306
+ x_concat = torch.cat(
307
+ (x_center, h.unsqueeze(2).expand(-1, -1, x_center.size(2))), dim=1
308
+ )
309
+ x6 = self.up_block1(x_concat, x5)
310
+ x7 = self.up_block2(x6, x4)
311
+ x8 = self.up_block3(x7, x3)
312
+ x9 = self.up_block4(x8, x2)
313
+ x10 = self.up_block5(x9, x1)
314
+ x = self.after_conv_block1(x10)
315
+ x = self.after_conv2(x)
316
+
317
+ x = F.pad(x, pad=(0, 1))
318
+ x = x[..., 0:origin_len]
319
+
320
+ return x.squeeze(1)
321
+
322
+
323
+ class ChannelFeatureModule(nn.Module):
324
+ """
325
+ Channel feature module based on 2D convolution layers
326
+
327
+ Args:
328
+ config (dict): config
329
+ """
330
+ def __init__(self, config):
331
+ super().__init__()
332
+ self.conv_blocks_in = ConvBlockRes2D(
333
+ in_channels=256,
334
+ out_channels=512,
335
+ size=3,
336
+ activation="relu",
337
+ momentum=0.01,
338
+ )
339
+ self.down_block1 = DownBlockRes2D(
340
+ in_channels=512,
341
+ out_channels=256,
342
+ downsample=(2, 2),
343
+ activation="relu",
344
+ momentum=0.01,
345
+ )
346
+ self.down_block2 = DownBlockRes2D(
347
+ in_channels=256,
348
+ out_channels=256,
349
+ downsample=(2, 2),
350
+ activation="relu",
351
+ momentum=0.01,
352
+ )
353
+ self.conv_block_out = ConvBlockRes2D(
354
+ in_channels=256,
355
+ out_channels=128,
356
+ size=3,
357
+ activation="relu",
358
+ momentum=0.01,
359
+ )
360
+ self.avgpool2d = torch.nn.AdaptiveAvgPool2d(1)
361
+
362
+ def forward(self, x):
363
+ """
364
+ Forward
365
+
366
+ Args:
367
+ output of analysis module: (batch, 256, time, freq)
368
+
369
+ Return:
370
+ channel feature: (batch, feature_dim)
371
+ """
372
+ x = self.conv_blocks_in(x)
373
+ x, _ = self.down_block1(x)
374
+ x, _ = self.down_block2(x)
375
+ x = self.conv_block_out(x)
376
+ x = self.avgpool2d(x)
377
+ x = x.squeeze(3).squeeze(2)
378
+ return x
379
+
380
+
381
+ class ConvBlockRes2D(nn.Module):
382
+ def __init__(self, in_channels, out_channels, size, activation, momentum):
383
+ super().__init__()
384
+
385
+ self.activation = activation
386
+ if type(size) == type((3, 4)):
387
+ pad = size[0] // 2
388
+ size = size[0]
389
+ else:
390
+ pad = size // 2
391
+ size = size
392
+
393
+ self.conv1 = nn.Conv2d(
394
+ in_channels=in_channels,
395
+ out_channels=out_channels,
396
+ kernel_size=(size, size),
397
+ stride=(1, 1),
398
+ dilation=(1, 1),
399
+ padding=(pad, pad),
400
+ bias=False,
401
+ )
402
+
403
+ self.bn1 = nn.BatchNorm2d(in_channels, momentum=momentum)
404
+
405
+ self.conv2 = nn.Conv2d(
406
+ in_channels=out_channels,
407
+ out_channels=out_channels,
408
+ kernel_size=(size, size),
409
+ stride=(1, 1),
410
+ dilation=(1, 1),
411
+ padding=(pad, pad),
412
+ bias=False,
413
+ )
414
+
415
+ self.bn2 = nn.BatchNorm2d(out_channels, momentum=momentum)
416
+
417
+ if in_channels != out_channels:
418
+ self.shortcut = nn.Conv2d(
419
+ in_channels=in_channels,
420
+ out_channels=out_channels,
421
+ kernel_size=(1, 1),
422
+ stride=(1, 1),
423
+ padding=(0, 0),
424
+ )
425
+ self.is_shortcut = True
426
+ else:
427
+ self.is_shortcut = False
428
+
429
+ def forward(self, x):
430
+ origin = x
431
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
432
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
433
+
434
+ if self.is_shortcut:
435
+ return self.shortcut(origin) + x
436
+ else:
437
+ return origin + x
438
+
439
+
440
+ class ConvBlockRes1D(nn.Module):
441
+ def __init__(self, in_channels, out_channels, size, activation, momentum):
442
+ super().__init__()
443
+
444
+ self.activation = activation
445
+ pad = size // 2
446
+
447
+ self.conv1 = nn.Conv1d(
448
+ in_channels=in_channels,
449
+ out_channels=out_channels,
450
+ kernel_size=size,
451
+ stride=1,
452
+ dilation=1,
453
+ padding=pad,
454
+ bias=False,
455
+ )
456
+
457
+ self.bn1 = nn.BatchNorm1d(in_channels, momentum=momentum)
458
+
459
+ self.conv2 = nn.Conv1d(
460
+ in_channels=out_channels,
461
+ out_channels=out_channels,
462
+ kernel_size=size,
463
+ stride=1,
464
+ dilation=1,
465
+ padding=pad,
466
+ bias=False,
467
+ )
468
+
469
+ self.bn2 = nn.BatchNorm1d(out_channels, momentum=momentum)
470
+
471
+ if in_channels != out_channels:
472
+ self.shortcut = nn.Conv1d(
473
+ in_channels=in_channels,
474
+ out_channels=out_channels,
475
+ kernel_size=1,
476
+ stride=1,
477
+ padding=0,
478
+ )
479
+ self.is_shortcut = True
480
+ else:
481
+ self.is_shortcut = False
482
+
483
+ def forward(self, x):
484
+ origin = x
485
+ x = self.conv1(F.leaky_relu_(self.bn1(x), negative_slope=0.01))
486
+ x = self.conv2(F.leaky_relu_(self.bn2(x), negative_slope=0.01))
487
+
488
+ if self.is_shortcut:
489
+ return self.shortcut(origin) + x
490
+ else:
491
+ return origin + x
492
+
493
+
494
+ class DownBlockRes2D(nn.Module):
495
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum):
496
+ super().__init__()
497
+ size = 3
498
+
499
+ self.conv_block1 = ConvBlockRes2D(
500
+ in_channels, out_channels, size, activation, momentum
501
+ )
502
+ self.conv_block2 = ConvBlockRes2D(
503
+ out_channels, out_channels, size, activation, momentum
504
+ )
505
+ self.conv_block3 = ConvBlockRes2D(
506
+ out_channels, out_channels, size, activation, momentum
507
+ )
508
+ self.conv_block4 = ConvBlockRes2D(
509
+ out_channels, out_channels, size, activation, momentum
510
+ )
511
+ self.avg_pool2d = torch.nn.AvgPool2d(downsample)
512
+
513
+ def forward(self, x):
514
+ encoder = self.conv_block1(x)
515
+ encoder = self.conv_block2(encoder)
516
+ encoder = self.conv_block3(encoder)
517
+ encoder = self.conv_block4(encoder)
518
+ encoder_pool = self.avg_pool2d(encoder)
519
+ return encoder_pool, encoder
520
+
521
+
522
+ class DownBlockRes1D(nn.Module):
523
+ def __init__(self, in_channels, out_channels, downsample, activation, momentum):
524
+ super().__init__()
525
+ size = 3
526
+
527
+ self.conv_block1 = ConvBlockRes1D(
528
+ in_channels, out_channels, size, activation, momentum
529
+ )
530
+ self.conv_block2 = ConvBlockRes1D(
531
+ out_channels, out_channels, size, activation, momentum
532
+ )
533
+ self.conv_block3 = ConvBlockRes1D(
534
+ out_channels, out_channels, size, activation, momentum
535
+ )
536
+ self.conv_block4 = ConvBlockRes1D(
537
+ out_channels, out_channels, size, activation, momentum
538
+ )
539
+ self.avg_pool1d = torch.nn.AvgPool1d(downsample)
540
+
541
+ def forward(self, x):
542
+ encoder = self.conv_block1(x)
543
+ encoder = self.conv_block2(encoder)
544
+ encoder = self.conv_block3(encoder)
545
+ encoder = self.conv_block4(encoder)
546
+ encoder_pool = self.avg_pool1d(encoder)
547
+ return encoder_pool, encoder
548
+
549
+
550
+ class UpBlockRes2D(nn.Module):
551
+ def __init__(self, in_channels, out_channels, stride, activation, momentum):
552
+ super().__init__()
553
+ size = 3
554
+ self.activation = activation
555
+
556
+ self.conv1 = torch.nn.ConvTranspose2d(
557
+ in_channels=in_channels,
558
+ out_channels=out_channels,
559
+ kernel_size=(size, size),
560
+ stride=stride,
561
+ padding=(0, 0),
562
+ output_padding=(0, 0),
563
+ bias=False,
564
+ dilation=(1, 1),
565
+ )
566
+
567
+ self.bn1 = nn.BatchNorm2d(in_channels)
568
+ self.conv_block2 = ConvBlockRes2D(
569
+ out_channels * 2, out_channels, size, activation, momentum
570
+ )
571
+ self.conv_block3 = ConvBlockRes2D(
572
+ out_channels, out_channels, size, activation, momentum
573
+ )
574
+ self.conv_block4 = ConvBlockRes2D(
575
+ out_channels, out_channels, size, activation, momentum
576
+ )
577
+ self.conv_block5 = ConvBlockRes2D(
578
+ out_channels, out_channels, size, activation, momentum
579
+ )
580
+
581
+ def prune(self, x, both=False):
582
+ """Prune the shape of x after transpose convolution."""
583
+ if both:
584
+ x = x[:, :, 0:-1, 0:-1]
585
+ else:
586
+ x = x[:, :, 0:-1, :]
587
+ return x
588
+
589
+ def forward(self, input_tensor, concat_tensor, both=False):
590
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
591
+ x = self.prune(x, both=both)
592
+ x = torch.cat((x, concat_tensor), dim=1)
593
+ x = self.conv_block2(x)
594
+ x = self.conv_block3(x)
595
+ x = self.conv_block4(x)
596
+ x = self.conv_block5(x)
597
+ return x
598
+
599
+
600
+ class UpBlockRes1D(nn.Module):
601
+ def __init__(self, in_channels, out_channels, stride, activation, momentum):
602
+ super().__init__()
603
+ size = 3
604
+ self.activation = activation
605
+
606
+ self.conv1 = torch.nn.ConvTranspose1d(
607
+ in_channels=in_channels,
608
+ out_channels=out_channels,
609
+ kernel_size=size,
610
+ stride=stride,
611
+ padding=0,
612
+ output_padding=0,
613
+ bias=False,
614
+ dilation=1,
615
+ )
616
+
617
+ self.bn1 = nn.BatchNorm1d(in_channels)
618
+ self.conv_block2 = ConvBlockRes1D(
619
+ out_channels * 2, out_channels, size, activation, momentum
620
+ )
621
+ self.conv_block3 = ConvBlockRes1D(
622
+ out_channels, out_channels, size, activation, momentum
623
+ )
624
+ self.conv_block4 = ConvBlockRes1D(
625
+ out_channels, out_channels, size, activation, momentum
626
+ )
627
+ self.conv_block5 = ConvBlockRes1D(
628
+ out_channels, out_channels, size, activation, momentum
629
+ )
630
+
631
+ def prune(self, x):
632
+ """Prune the shape of x after transpose convolution."""
633
+ print(x.shape)
634
+ x = x[:, 0:-1, :]
635
+ print(x.shape)
636
+ return x
637
+
638
+ def forward(self, input_tensor, concat_tensor):
639
+ x = self.conv1(F.relu_(self.bn1(input_tensor)))
640
+ # x = self.prune(x)
641
+ x = torch.cat((x, concat_tensor), dim=1)
642
+ x = self.conv_block2(x)
643
+ x = self.conv_block3(x)
644
+ x = self.conv_block4(x)
645
+ x = self.conv_block5(x)
646
+ return x
647
+
648
+
649
+ class MultiScaleSpectralLoss(nn.Module):
650
+ """
651
+ Multi scale spectral loss
652
+ https://openreview.net/forum?id=B1x1ma4tDr
653
+
654
+ Args:
655
+ config (dict): config
656
+ """
657
+ def __init__(self, config):
658
+ super().__init__()
659
+ try:
660
+ self.use_linear = config["train"]["multi_scale_loss"]["use_linear"]
661
+ self.gamma = config["train"]["multi_scale_loss"]["gamma"]
662
+ except KeyError:
663
+ self.use_linear = False
664
+
665
+ self.fft_sizes = [2048, 512, 256, 128, 64]
666
+ self.spectrograms = []
667
+ for fftsize in self.fft_sizes:
668
+ self.spectrograms.append(
669
+ torchaudio.transforms.Spectrogram(
670
+ n_fft=fftsize, hop_length=fftsize // 4, power=2
671
+ )
672
+ )
673
+ self.spectrograms = nn.ModuleList(self.spectrograms)
674
+ self.criteria = nn.L1Loss()
675
+ self.eps = 1e-10
676
+
677
+ def forward(self, wav_out, wav_target):
678
+ """
679
+ Forward
680
+
681
+ Args:
682
+ wav_out: output of channel module (batch, time)
683
+ wav_target: input degraded waveform (batch, time)
684
+
685
+ Return:
686
+ loss
687
+ """
688
+ loss = 0.0
689
+ length = min(wav_out.size(1), wav_target.size(1))
690
+ for spectrogram in self.spectrograms:
691
+ S_out = spectrogram(wav_out[..., :length])
692
+ S_target = spectrogram(wav_target[..., :length])
693
+ log_S_out = torch.log(S_out + self.eps)
694
+ log_S_target = torch.log(S_target + self.eps)
695
+ if self.use_linear:
696
+ loss += self.criteria(S_out, S_target) + self.gamma * self.criteria(
697
+ log_S_out, log_S_target
698
+ )
699
+ else:
700
+ loss += self.criteria(log_S_out, log_S_target)
701
+ return loss
702
+
703
+
704
+ class ReferenceEncoder(nn.Module):
705
+ def __init__(
706
+ self, idim=80, ref_enc_filters=[32, 32, 64, 64, 128, 128], ref_dim=128
707
+ ):
708
+ super().__init__()
709
+ K = len(ref_enc_filters)
710
+ filters = [1] + ref_enc_filters
711
+
712
+ convs = [
713
+ nn.Conv2d(
714
+ in_channels=filters[i],
715
+ out_channels=filters[i + 1],
716
+ kernel_size=(3, 3),
717
+ stride=(2, 2),
718
+ padding=(1, 1),
719
+ )
720
+ for i in range(K)
721
+ ]
722
+ self.convs = nn.ModuleList(convs)
723
+ self.bns = nn.ModuleList(
724
+ [nn.BatchNorm2d(num_features=ref_enc_filters[i]) for i in range(K)]
725
+ )
726
+
727
+ out_channels = self.calculate_channels(idim, 3, 2, 1, K)
728
+
729
+ self.gru = nn.GRU(
730
+ input_size=ref_enc_filters[-1] * out_channels,
731
+ hidden_size=ref_dim,
732
+ batch_first=True,
733
+ )
734
+ self.n_mel_channels = idim
735
+
736
+ def forward(self, inputs):
737
+
738
+ out = inputs.view(inputs.size(0), 1, -1, self.n_mel_channels)
739
+ for conv, bn in zip(self.convs, self.bns):
740
+ out = conv(out)
741
+ out = bn(out)
742
+ out = F.relu(out)
743
+
744
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
745
+ N, T = out.size(0), out.size(1)
746
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
747
+
748
+ self.gru.flatten_parameters()
749
+
750
+ _, out = self.gru(out)
751
+
752
+ return out.squeeze(0)
753
+
754
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
755
+ for _ in range(n_convs):
756
+ L = (L - kernel_size + 2 * pad) // stride + 1
757
+ return L
758
+
759
+
760
+ class STL(nn.Module):
761
+ def __init__(self, ref_dim=128, num_heads=4, token_num=10, token_dim=128):
762
+ super().__init__()
763
+ self.embed = nn.Parameter(torch.FloatTensor(token_num, token_dim // num_heads))
764
+ d_q = ref_dim
765
+ d_k = token_dim // num_heads
766
+ self.attention = MultiHeadAttention(
767
+ query_dim=d_q, key_dim=d_k, num_units=token_dim, num_heads=num_heads
768
+ )
769
+ init.normal_(self.embed, mean=0, std=0.5)
770
+
771
+ def forward(self, inputs):
772
+ N = inputs.size(0)
773
+ query = inputs.unsqueeze(1)
774
+ keys = (
775
+ torch.tanh(self.embed).unsqueeze(0).expand(N, -1, -1)
776
+ ) # [N, token_num, token_embedding_size // num_heads]
777
+ style_embed = self.attention(query, keys)
778
+ return style_embed
779
+
780
+
781
+ class MultiHeadAttention(nn.Module):
782
+ """
783
+ Multi head attention
784
+ https://github.com/KinglittleQ/GST-Tacotron
785
+
786
+ """
787
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
788
+ super().__init__()
789
+ self.num_units = num_units
790
+ self.num_heads = num_heads
791
+ self.key_dim = key_dim
792
+
793
+ self.W_query = nn.Linear(
794
+ in_features=query_dim, out_features=num_units, bias=False
795
+ )
796
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
797
+ self.W_value = nn.Linear(
798
+ in_features=key_dim, out_features=num_units, bias=False
799
+ )
800
+
801
+ def forward(self, query, key):
802
+ """
803
+ Forward
804
+
805
+ Args:
806
+ query: (batch, T_q, query_dim)
807
+ key: (batch, T_k, key_dim)
808
+
809
+ Return:
810
+ out: (N, T_q, num_units)
811
+ """
812
+ querys = self.W_query(query) # [N, T_q, num_units]
813
+
814
+ keys = self.W_key(key) # [N, T_k, num_units]
815
+ values = self.W_value(key)
816
+
817
+ split_size = self.num_units // self.num_heads
818
+ querys = torch.stack(
819
+ torch.split(querys, split_size, dim=2), dim=0
820
+ ) # [h, N, T_q, num_units/h]
821
+ keys = torch.stack(
822
+ torch.split(keys, split_size, dim=2), dim=0
823
+ ) # [h, N, T_k, num_units/h]
824
+ values = torch.stack(
825
+ torch.split(values, split_size, dim=2), dim=0
826
+ ) # [h, N, T_k, num_units/h]
827
+
828
+ # score = softmax(QK^T / (d_k ** 0.5))
829
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
830
+ scores = scores / (self.key_dim ** 0.5)
831
+ scores = F.softmax(scores, dim=3)
832
+
833
+ # out = score * V
834
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
835
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
836
+ 0
837
+ ) # [N, T_q, num_units]
838
+
839
+ return out
840
+
841
+
842
+ class GSTModule(nn.Module):
843
+ def __init__(self, config):
844
+ super().__init__()
845
+ self.encoder_post = ReferenceEncoder(
846
+ idim=config["preprocess"]["n_mels"],
847
+ ref_dim=256,
848
+ )
849
+ self.stl = STL(ref_dim=256, num_heads=8, token_num=10, token_dim=128)
850
+
851
+ def forward(self, inputs):
852
+ acoustic_embed = self.encoder_post(inputs)
853
+ style_embed = self.stl(acoustic_embed)
854
+ return style_embed.squeeze(1)
preprocess.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import librosa
4
+ import tqdm
5
+ import pickle
6
+ import random
7
+ import argparse
8
+ import yaml
9
+ import pathlib
10
+
11
+
12
+ def get_arg():
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--config_path", required=True, type=pathlib.Path)
15
+ parser.add_argument("--corpus_type", default=None, type=str)
16
+ parser.add_argument("--source_path", default=None, type=pathlib.Path)
17
+ parser.add_argument("--source_path_task", default=None, type=pathlib.Path)
18
+ parser.add_argument("--aux_path", default=None, type=pathlib.Path)
19
+ parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path)
20
+ parser.add_argument("--n_train", default=None, type=int)
21
+ parser.add_argument("--n_val", default=None, type=int)
22
+ parser.add_argument("--n_test", default=None, type=int)
23
+ return parser.parse_args()
24
+
25
+
26
+ def preprocess(config):
27
+
28
+ # configs
29
+ preprocessed_dir = pathlib.Path(config["general"]["preprocessed_path"])
30
+ n_train = config["preprocess"]["n_train"]
31
+ n_val = config["preprocess"]["n_val"]
32
+ n_test = config["preprocess"]["n_test"]
33
+ SR = config["preprocess"]["sampling_rate"]
34
+
35
+ os.makedirs(preprocessed_dir, exist_ok=True)
36
+
37
+ sourcepath = pathlib.Path(config["general"]["source_path"])
38
+
39
+ if config["general"]["corpus_type"] == "single":
40
+ fulllist = list(sourcepath.glob("*.wav"))
41
+ random.seed(0)
42
+ random.shuffle(fulllist)
43
+ train_filelist = fulllist[:n_train]
44
+ val_filelist = fulllist[n_train : n_train + n_val]
45
+ test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test]
46
+ filelist = train_filelist + val_filelist + test_filelist
47
+ elif config["general"]["corpus_type"] == "multi-seen":
48
+ fulllist = list(sourcepath.glob("*/*.wav"))
49
+ random.seed(0)
50
+ random.shuffle(fulllist)
51
+ train_filelist = fulllist[:n_train]
52
+ val_filelist = fulllist[n_train : n_train + n_val]
53
+ test_filelist = fulllist[n_train + n_val : n_train + n_val + n_test]
54
+ filelist = train_filelist + val_filelist + test_filelist
55
+ elif config["general"]["corpus_type"] == "multi-unseen":
56
+ spk_list = list(set([x.parent for x in sourcepath.glob("*/*.wav")]))
57
+ train_filelist = []
58
+ val_filelist = []
59
+ test_filelist = []
60
+ random.seed(0)
61
+ random.shuffle(spk_list)
62
+ for i, spk in enumerate(spk_list):
63
+ sourcespkpath = sourcepath / spk
64
+ if i < n_train:
65
+ train_filelist.extend(list(sourcespkpath.glob("*.wav")))
66
+ elif i < n_train + n_val:
67
+ val_filelist.extend(list(sourcespkpath.glob("*.wav")))
68
+ elif i < n_train + n_val + n_test:
69
+ test_filelist.extend(list(sourcespkpath.glob("*.wav")))
70
+ filelist = train_filelist + val_filelist + test_filelist
71
+ else:
72
+ raise NotImplementedError(
73
+ "corpus_type specified in config.yaml should be {single, multi-seen, multi-unseen}"
74
+ )
75
+
76
+ with open(preprocessed_dir / "train.txt", "w", encoding="utf-8") as f:
77
+ for m in train_filelist:
78
+ f.write(str(m) + "\n")
79
+ with open(preprocessed_dir / "val.txt", "w", encoding="utf-8") as f:
80
+ for m in val_filelist:
81
+ f.write(str(m) + "\n")
82
+ with open(preprocessed_dir / "test.txt", "w", encoding="utf-8") as f:
83
+ for m in test_filelist:
84
+ f.write(str(m) + "\n")
85
+
86
+ for wp in tqdm.tqdm(filelist):
87
+
88
+ if config["general"]["corpus_type"] == "single":
89
+ basename = str(wp.stem)
90
+ else:
91
+ basename = str(wp.parent.name) + "-" + str(wp.stem)
92
+
93
+ wav, _ = librosa.load(wp, sr=SR)
94
+ wavsegs = []
95
+
96
+ if config["general"]["aux_path"] != None:
97
+ auxpath = pathlib.Path(config["general"]["aux_path"])
98
+ if config["general"]["corpus_type"] == "single":
99
+ wav_aux, _ = librosa.load(auxpath / wp.name, sr=SR)
100
+ else:
101
+ wav_aux, _ = librosa.load(auxpath / wp.parent.name / wp.name, sr=SR)
102
+ wavauxsegs = []
103
+
104
+ if config["general"]["aux_path"] == None:
105
+ wavsegs.append(wav)
106
+ else:
107
+ min_seq_len = min(len(wav), len(wav_aux))
108
+ wav = wav[:min_seq_len]
109
+ wav_aux = wav_aux[:min_seq_len]
110
+ wavsegs.append(wav)
111
+ wavauxsegs.append(wav_aux)
112
+
113
+ wavsegs = np.asarray(wavsegs).astype(np.float32)
114
+ if config["general"]["aux_path"] != None:
115
+ wavauxsegs = np.asarray(wavauxsegs).astype(np.float32)
116
+ else:
117
+ wavauxsegs = None
118
+
119
+ d_preprocessed = {"wavs": wavsegs, "wavsaux": wavauxsegs}
120
+
121
+ with open(preprocessed_dir / "{}.pickle".format(basename), "wb") as fw:
122
+ pickle.dump(d_preprocessed, fw)
123
+
124
+
125
+ if __name__ == "__main__":
126
+ args = get_arg()
127
+
128
+ config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
129
+ for key in ["corpus_type", "source_path", "aux_path", "preprocessed_path"]:
130
+ if getattr(args, key) != None:
131
+ config["general"][key] = str(getattr(args, key))
132
+ for key in ["n_train", "n_val", "n_test"]:
133
+ if getattr(args, key) != None:
134
+ config["preprocess"][key] = getattr(args, key)
135
+
136
+ print("Performing preprocessing ...")
137
+ preprocess(config)
138
+
139
+ if "dual" in config:
140
+ if config["dual"]["enable"]:
141
+ task_config = yaml.load(
142
+ open(config["dual"]["config_path"], "r"), Loader=yaml.FullLoader
143
+ )
144
+ task_preprocessed_dir = (
145
+ pathlib.Path(config["general"]["preprocessed_path"]).parent
146
+ / pathlib.Path(task_config["general"]["preprocessed_path"]).name
147
+ )
148
+ task_config["general"]["preprocessed_path"] = task_preprocessed_dir
149
+ if args.source_path_task != None:
150
+ task_config["general"]["source_path"] = args.source_path_task
151
+ print("Performing preprocessing for multi-task learning ...")
152
+ preprocess(task_config)
train.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import pathlib
4
+ import yaml
5
+ from dataset import DataModule
6
+ from pytorch_lightning import Trainer
7
+ from pytorch_lightning.callbacks import ModelCheckpoint
8
+ from pytorch_lightning.loggers.csv_logs import CSVLogger
9
+ from pytorch_lightning.loggers import TensorBoardLogger
10
+ from pytorch_lightning.callbacks.early_stopping import EarlyStopping
11
+ from lightning_module import (
12
+ PretrainLightningModule,
13
+ SSLStepLightningModule,
14
+ SSLDualLightningModule,
15
+ )
16
+ from utils import configure_args
17
+
18
+
19
+ def get_arg():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config_path", required=True, type=pathlib.Path)
22
+ parser.add_argument(
23
+ "--stage", required=True, type=str, choices=["pretrain", "ssl-step", "ssl-dual"]
24
+ )
25
+ parser.add_argument("--run_name", required=True, type=str)
26
+ parser.add_argument("--corpus_type", default=None, type=str)
27
+ parser.add_argument("--source_path", default=None, type=pathlib.Path)
28
+ parser.add_argument("--aux_path", default=None, type=pathlib.Path)
29
+ parser.add_argument("--preprocessed_path", default=None, type=pathlib.Path)
30
+ parser.add_argument("--n_train", default=None, type=int)
31
+ parser.add_argument("--n_val", default=None, type=int)
32
+ parser.add_argument("--n_test", default=None, type=int)
33
+ parser.add_argument("--epoch", default=None, type=int)
34
+ parser.add_argument("--load_pretrained", action="store_true")
35
+ parser.add_argument("--pretrained_path", default=None, type=pathlib.Path)
36
+ parser.add_argument("--early_stopping", action="store_true")
37
+ parser.add_argument("--alpha", default=None, type=float)
38
+ parser.add_argument("--beta", default=None, type=float)
39
+ parser.add_argument("--learning_rate", default=None, type=float)
40
+ parser.add_argument(
41
+ "--feature_loss_type", default=None, type=str, choices=["mae", "mse"]
42
+ )
43
+ parser.add_argument("--debug", action="store_true")
44
+ return parser.parse_args()
45
+
46
+
47
+ def train(args, config, output_path):
48
+ debug = args.debug
49
+
50
+ csvlogger = CSVLogger(save_dir=str(output_path), name="train_log")
51
+ tblogger = TensorBoardLogger(save_dir=str(output_path), name="tf_log")
52
+
53
+ checkpoint_callback = ModelCheckpoint(
54
+ dirpath=str(output_path),
55
+ save_weights_only=True,
56
+ save_top_k=-1,
57
+ every_n_epochs=1,
58
+ monitor="val_loss",
59
+ )
60
+ callbacks = [checkpoint_callback]
61
+ if config["train"]["early_stopping"]:
62
+ earlystop_callback = EarlyStopping(
63
+ monitor="val_loss", min_delta=0.0, patience=15, mode="min"
64
+ )
65
+ callbacks.append(earlystop_callback)
66
+
67
+ trainer = Trainer(
68
+ max_epochs=1 if debug else config["train"]["epoch"],
69
+ gpus=-1,
70
+ deterministic=False,
71
+ auto_select_gpus=True,
72
+ benchmark=True,
73
+ default_root_dir=os.getcwd(),
74
+ limit_train_batches=0.01 if debug else 1.0,
75
+ limit_val_batches=0.5 if debug else 1.0,
76
+ callbacks=callbacks,
77
+ logger=[csvlogger, tblogger],
78
+ gradient_clip_val=config["train"]["grad_clip_thresh"],
79
+ flush_logs_every_n_steps=config["train"]["logger_step"],
80
+ val_check_interval=0.5,
81
+ )
82
+
83
+ if config["general"]["stage"] == "pretrain":
84
+ model = PretrainLightningModule(config)
85
+ elif config["general"]["stage"] == "ssl-step":
86
+ model = SSLStepLightningModule(config)
87
+ elif config["general"]["stage"] == "ssl-dual":
88
+ model = SSLDualLightningModule(config)
89
+ else:
90
+ raise NotImplementedError()
91
+
92
+ datamodule = DataModule(config)
93
+ trainer.fit(model, datamodule=datamodule)
94
+
95
+
96
+ if __name__ == "__main__":
97
+
98
+ args = get_arg()
99
+ config = yaml.load(open(args.config_path, "r"), Loader=yaml.FullLoader)
100
+
101
+ output_path = pathlib.Path(config["general"]["output_path"]) / args.run_name
102
+ os.makedirs(output_path, exist_ok=True)
103
+
104
+ config, args = configure_args(config, args)
105
+
106
+ train(args=args, config=config, output_path=output_path)
utils.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa.display
2
+ import matplotlib.pyplot as plt
3
+ import json
4
+ import torch
5
+ import torchaudio
6
+ import hifigan
7
+
8
+
9
+ def manual_logging(logger, item, idx, tag, global_step, data_type, config):
10
+
11
+ if data_type == "audio":
12
+ audio = item[idx, ...].detach().cpu().numpy()
13
+ logger.add_audio(
14
+ tag,
15
+ audio,
16
+ global_step,
17
+ sample_rate=config["preprocess"]["sampling_rate"],
18
+ )
19
+ elif data_type == "image":
20
+ image = item[idx, ...].detach().cpu().numpy()
21
+ fig, ax = plt.subplots()
22
+ _ = librosa.display.specshow(
23
+ image,
24
+ x_axis="time",
25
+ y_axis="linear",
26
+ sr=config["preprocess"]["sampling_rate"],
27
+ hop_length=config["preprocess"]["frame_shift"],
28
+ fmax=config["preprocess"]["sampling_rate"] // 2,
29
+ ax=ax,
30
+ )
31
+ logger.add_figure(tag, fig, global_step)
32
+ else:
33
+ raise NotImplementedError(
34
+ "Data type given to logger should be [audio] or [image]"
35
+ )
36
+
37
+
38
+ def load_vocoder(config):
39
+ with open(
40
+ "hifigan/config_{}.json".format(config["general"]["feature_type"]), "r"
41
+ ) as f:
42
+ config_hifigan = hifigan.AttrDict(json.load(f))
43
+ vocoder = hifigan.Generator(config_hifigan)
44
+ vocoder.load_state_dict(torch.load(config["general"]["hifigan_path"])["generator"])
45
+ vocoder.remove_weight_norm()
46
+ for param in vocoder.parameters():
47
+ param.requires_grad = False
48
+ return vocoder
49
+
50
+
51
+ def get_conv_padding(kernel_size, dilation=1):
52
+ return int((kernel_size * dilation - dilation) / 2)
53
+
54
+
55
+ def plot_and_save_mels(wav, save_path, config):
56
+ spec_module = torchaudio.transforms.MelSpectrogram(
57
+ sample_rate=config["preprocess"]["sampling_rate"],
58
+ n_fft=config["preprocess"]["fft_length"],
59
+ win_length=config["preprocess"]["frame_length"],
60
+ hop_length=config["preprocess"]["frame_shift"],
61
+ f_min=config["preprocess"]["fmin"],
62
+ f_max=config["preprocess"]["fmax"],
63
+ n_mels=config["preprocess"]["n_mels"],
64
+ power=1,
65
+ center=True,
66
+ norm="slaney",
67
+ mel_scale="slaney",
68
+ )
69
+ spec = spec_module(wav.unsqueeze(0))
70
+ log_spec = torch.log(
71
+ torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
72
+ * config["preprocess"]["comp_factor"]
73
+ )
74
+ fig, ax = plt.subplots()
75
+ _ = librosa.display.specshow(
76
+ log_spec.squeeze(0).numpy(),
77
+ x_axis="time",
78
+ y_axis="linear",
79
+ sr=config["preprocess"]["sampling_rate"],
80
+ hop_length=config["preprocess"]["frame_shift"],
81
+ fmax=config["preprocess"]["sampling_rate"] // 2,
82
+ ax=ax,
83
+ cmap="viridis",
84
+ )
85
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
86
+
87
+
88
+ def plot_and_save_mels_all(wavs, keys, save_path, config):
89
+ spec_module = torchaudio.transforms.MelSpectrogram(
90
+ sample_rate=config["preprocess"]["sampling_rate"],
91
+ n_fft=config["preprocess"]["fft_length"],
92
+ win_length=config["preprocess"]["frame_length"],
93
+ hop_length=config["preprocess"]["frame_shift"],
94
+ f_min=config["preprocess"]["fmin"],
95
+ f_max=config["preprocess"]["fmax"],
96
+ n_mels=config["preprocess"]["n_mels"],
97
+ power=1,
98
+ center=True,
99
+ norm="slaney",
100
+ mel_scale="slaney",
101
+ )
102
+ fig, ax = plt.subplots(nrows=3, ncols=3, figsize=(18, 18))
103
+ for i, key in enumerate(keys):
104
+ wav = wavs[key][0, ...].cpu()
105
+ spec = spec_module(wav.unsqueeze(0))
106
+ log_spec = torch.log(
107
+ torch.clamp_min(spec, config["preprocess"]["min_magnitude"])
108
+ * config["preprocess"]["comp_factor"]
109
+ )
110
+ ax[i // 3, i % 3].set(title=key)
111
+ _ = librosa.display.specshow(
112
+ log_spec.squeeze(0).numpy(),
113
+ x_axis="time",
114
+ y_axis="linear",
115
+ sr=config["preprocess"]["sampling_rate"],
116
+ hop_length=config["preprocess"]["frame_shift"],
117
+ fmax=config["preprocess"]["sampling_rate"] // 2,
118
+ ax=ax[i // 3, i % 3],
119
+ cmap="viridis",
120
+ )
121
+ fig.savefig(save_path, bbox_inches="tight", pad_inches=0)
122
+
123
+
124
+ def configure_args(config, args):
125
+ for key in ["stage", "corpus_type", "source_path", "aux_path", "preprocessed_path"]:
126
+ if getattr(args, key) != None:
127
+ config["general"][key] = str(getattr(args, key))
128
+
129
+ for key in ["n_train", "n_val", "n_test"]:
130
+ if getattr(args, key) != None:
131
+ config["preprocess"][key] = getattr(args, key)
132
+
133
+ for key in ["alpha", "beta", "learning_rate", "epoch"]:
134
+ if getattr(args, key) != None:
135
+ config["train"][key] = getattr(args, key)
136
+
137
+ for key in ["load_pretrained", "early_stopping"]:
138
+ config["train"][key] = getattr(args, key)
139
+
140
+ if args.feature_loss_type != None:
141
+ config["train"]["feature_loss"]["type"] = args.feature_loss_type
142
+
143
+ for key in ["pretrained_path"]:
144
+ if getattr(args, key) != None:
145
+ config["train"][key] = str(getattr(args, key))
146
+
147
+ return config, args