yuancwang commited on
Commit
dce1ab4
1 Parent(s): 5548515

add models

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +18 -0
  2. models/__init__.py +0 -0
  3. models/base/__init__.py +7 -0
  4. models/base/base_dataset.py +344 -0
  5. models/base/base_inference.py +220 -0
  6. models/base/base_sampler.py +136 -0
  7. models/base/base_trainer.py +348 -0
  8. models/base/new_dataset.py +50 -0
  9. models/base/new_inference.py +249 -0
  10. models/base/new_trainer.py +722 -0
  11. models/svc/__init__.py +0 -0
  12. models/svc/base/__init__.py +7 -0
  13. models/svc/base/svc_dataset.py +437 -0
  14. models/svc/base/svc_inference.py +15 -0
  15. models/svc/base/svc_trainer.py +111 -0
  16. models/svc/comosvc/__init__.py +4 -0
  17. models/svc/comosvc/comosvc.py +377 -0
  18. models/svc/comosvc/comosvc_inference.py +39 -0
  19. models/svc/comosvc/comosvc_trainer.py +295 -0
  20. models/svc/comosvc/utils.py +31 -0
  21. models/svc/diffusion/__init__.py +0 -0
  22. models/svc/diffusion/diffusion_inference.py +63 -0
  23. models/svc/diffusion/diffusion_inference_pipeline.py +47 -0
  24. models/svc/diffusion/diffusion_trainer.py +88 -0
  25. models/svc/diffusion/diffusion_wrapper.py +73 -0
  26. models/svc/transformer/__init__.py +0 -0
  27. models/svc/transformer/conformer.py +405 -0
  28. models/svc/transformer/transformer.py +82 -0
  29. models/svc/transformer/transformer_inference.py +45 -0
  30. models/svc/transformer/transformer_trainer.py +52 -0
  31. models/svc/vits/__init__.py +0 -0
  32. models/svc/vits/vits.py +271 -0
  33. models/svc/vits/vits_inference.py +84 -0
  34. models/svc/vits/vits_trainer.py +483 -0
  35. models/tta/autoencoder/__init__.py +0 -0
  36. models/tta/autoencoder/autoencoder.py +405 -0
  37. models/tta/autoencoder/autoencoder_dataset.py +114 -0
  38. models/tta/autoencoder/autoencoder_loss.py +305 -0
  39. models/tta/autoencoder/autoencoder_trainer.py +187 -0
  40. models/tta/ldm/__init__.py +0 -0
  41. models/tta/ldm/attention.py +329 -0
  42. models/tta/ldm/audioldm.py +926 -0
  43. models/tta/ldm/audioldm_dataset.py +153 -0
  44. models/tta/ldm/audioldm_inference.py +193 -0
  45. models/tta/ldm/audioldm_trainer.py +251 -0
  46. models/tta/ldm/inference_utils/utils.py +62 -0
  47. models/tta/ldm/inference_utils/vocoder.py +408 -0
  48. models/tts/naturalspeech2/ns2_dataset.py +0 -2
  49. models/vocoders/autoregressive/autoregressive_vocoder_dataset.py +0 -0
  50. models/vocoders/autoregressive/autoregressive_vocoder_inference.py +0 -0
app.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import numpy as np
4
+ import torch
5
+ from tqdm import tqdm
6
+ import torch.nn as nn
7
+ from collections import OrderedDict
8
+ import json
9
+
10
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
11
+ from models.tta.ldm.inference_utils.vocoder import Generator
12
+ from models.tta.ldm.audioldm import AudioLDM
13
+ from transformers import T5EncoderModel, AutoTokenizer
14
+ from diffusers import PNDMScheduler
15
+
16
+ import matplotlib.pyplot as plt
17
+ from scipy.io.wavfile import write
18
+
models/__init__.py ADDED
File without changes
models/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .new_trainer import BaseTrainer
7
+ from .new_inference import BaseInference
models/base/base_dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from utils.data_utils import *
11
+ from processors.acoustic_extractor import cal_normalized_mel
12
+ from text import text_to_sequence
13
+ from text.text_token_collation import phoneIDCollation
14
+
15
+
16
+ class BaseDataset(torch.utils.data.Dataset):
17
+ def __init__(self, cfg, dataset, is_valid=False):
18
+ """
19
+ Args:
20
+ cfg: config
21
+ dataset: dataset name
22
+ is_valid: whether to use train or valid dataset
23
+ """
24
+
25
+ assert isinstance(dataset, str)
26
+
27
+ # self.data_root = processed_data_dir
28
+ self.cfg = cfg
29
+
30
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
31
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
32
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
33
+ self.metadata = self.get_metadata()
34
+
35
+ """
36
+ load spk2id and utt2spk from json file
37
+ spk2id: {spk1: 0, spk2: 1, ...}
38
+ utt2spk: {dataset_uid: spk1, ...}
39
+ """
40
+ if cfg.preprocess.use_spkid:
41
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
42
+ with open(spk2id_path, "r") as f:
43
+ self.spk2id = json.load(f)
44
+
45
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
46
+ self.utt2spk = dict()
47
+ with open(utt2spk_path, "r") as f:
48
+ for line in f.readlines():
49
+ utt, spk = line.strip().split("\t")
50
+ self.utt2spk[utt] = spk
51
+
52
+ if cfg.preprocess.use_uv:
53
+ self.utt2uv_path = {}
54
+ for utt_info in self.metadata:
55
+ dataset = utt_info["Dataset"]
56
+ uid = utt_info["Uid"]
57
+ utt = "{}_{}".format(dataset, uid)
58
+ self.utt2uv_path[utt] = os.path.join(
59
+ cfg.preprocess.processed_dir,
60
+ dataset,
61
+ cfg.preprocess.uv_dir,
62
+ uid + ".npy",
63
+ )
64
+
65
+ if cfg.preprocess.use_frame_pitch:
66
+ self.utt2frame_pitch_path = {}
67
+ for utt_info in self.metadata:
68
+ dataset = utt_info["Dataset"]
69
+ uid = utt_info["Uid"]
70
+ utt = "{}_{}".format(dataset, uid)
71
+
72
+ self.utt2frame_pitch_path[utt] = os.path.join(
73
+ cfg.preprocess.processed_dir,
74
+ dataset,
75
+ cfg.preprocess.pitch_dir,
76
+ uid + ".npy",
77
+ )
78
+
79
+ if cfg.preprocess.use_frame_energy:
80
+ self.utt2frame_energy_path = {}
81
+ for utt_info in self.metadata:
82
+ dataset = utt_info["Dataset"]
83
+ uid = utt_info["Uid"]
84
+ utt = "{}_{}".format(dataset, uid)
85
+
86
+ self.utt2frame_energy_path[utt] = os.path.join(
87
+ cfg.preprocess.processed_dir,
88
+ dataset,
89
+ cfg.preprocess.energy_dir,
90
+ uid + ".npy",
91
+ )
92
+
93
+ if cfg.preprocess.use_mel:
94
+ self.utt2mel_path = {}
95
+ for utt_info in self.metadata:
96
+ dataset = utt_info["Dataset"]
97
+ uid = utt_info["Uid"]
98
+ utt = "{}_{}".format(dataset, uid)
99
+
100
+ self.utt2mel_path[utt] = os.path.join(
101
+ cfg.preprocess.processed_dir,
102
+ dataset,
103
+ cfg.preprocess.mel_dir,
104
+ uid + ".npy",
105
+ )
106
+
107
+ if cfg.preprocess.use_linear:
108
+ self.utt2linear_path = {}
109
+ for utt_info in self.metadata:
110
+ dataset = utt_info["Dataset"]
111
+ uid = utt_info["Uid"]
112
+ utt = "{}_{}".format(dataset, uid)
113
+
114
+ self.utt2linear_path[utt] = os.path.join(
115
+ cfg.preprocess.processed_dir,
116
+ dataset,
117
+ cfg.preprocess.linear_dir,
118
+ uid + ".npy",
119
+ )
120
+
121
+ if cfg.preprocess.use_audio:
122
+ self.utt2audio_path = {}
123
+ for utt_info in self.metadata:
124
+ dataset = utt_info["Dataset"]
125
+ uid = utt_info["Uid"]
126
+ utt = "{}_{}".format(dataset, uid)
127
+
128
+ self.utt2audio_path[utt] = os.path.join(
129
+ cfg.preprocess.processed_dir,
130
+ dataset,
131
+ cfg.preprocess.audio_dir,
132
+ uid + ".npy",
133
+ )
134
+ elif cfg.preprocess.use_label:
135
+ self.utt2label_path = {}
136
+ for utt_info in self.metadata:
137
+ dataset = utt_info["Dataset"]
138
+ uid = utt_info["Uid"]
139
+ utt = "{}_{}".format(dataset, uid)
140
+
141
+ self.utt2label_path[utt] = os.path.join(
142
+ cfg.preprocess.processed_dir,
143
+ dataset,
144
+ cfg.preprocess.label_dir,
145
+ uid + ".npy",
146
+ )
147
+ elif cfg.preprocess.use_one_hot:
148
+ self.utt2one_hot_path = {}
149
+ for utt_info in self.metadata:
150
+ dataset = utt_info["Dataset"]
151
+ uid = utt_info["Uid"]
152
+ utt = "{}_{}".format(dataset, uid)
153
+
154
+ self.utt2one_hot_path[utt] = os.path.join(
155
+ cfg.preprocess.processed_dir,
156
+ dataset,
157
+ cfg.preprocess.one_hot_dir,
158
+ uid + ".npy",
159
+ )
160
+
161
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
162
+ self.utt2seq = {}
163
+ for utt_info in self.metadata:
164
+ dataset = utt_info["Dataset"]
165
+ uid = utt_info["Uid"]
166
+ utt = "{}_{}".format(dataset, uid)
167
+
168
+ if cfg.preprocess.use_text:
169
+ text = utt_info["Text"]
170
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
171
+ elif cfg.preprocess.use_phone:
172
+ # load phoneme squence from phone file
173
+ phone_path = os.path.join(
174
+ processed_data_dir, cfg.preprocess.phone_dir, uid + ".phone"
175
+ )
176
+ with open(phone_path, "r") as fin:
177
+ phones = fin.readlines()
178
+ assert len(phones) == 1
179
+ phones = phones[0].strip()
180
+ phones_seq = phones.split(" ")
181
+
182
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
183
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
184
+
185
+ self.utt2seq[utt] = sequence
186
+
187
+ def get_metadata(self):
188
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
189
+ metadata = json.load(f)
190
+
191
+ return metadata
192
+
193
+ def get_dataset_name(self):
194
+ return self.metadata[0]["Dataset"]
195
+
196
+ def __getitem__(self, index):
197
+ utt_info = self.metadata[index]
198
+
199
+ dataset = utt_info["Dataset"]
200
+ uid = utt_info["Uid"]
201
+ utt = "{}_{}".format(dataset, uid)
202
+
203
+ single_feature = dict()
204
+
205
+ if self.cfg.preprocess.use_spkid:
206
+ single_feature["spk_id"] = np.array(
207
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
208
+ )
209
+
210
+ if self.cfg.preprocess.use_mel:
211
+ mel = np.load(self.utt2mel_path[utt])
212
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
213
+ if self.cfg.preprocess.use_min_max_norm_mel:
214
+ # do mel norm
215
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
216
+
217
+ if "target_len" not in single_feature.keys():
218
+ single_feature["target_len"] = mel.shape[1]
219
+ single_feature["mel"] = mel.T # [T, n_mels]
220
+
221
+ if self.cfg.preprocess.use_linear:
222
+ linear = np.load(self.utt2linear_path[utt])
223
+ if "target_len" not in single_feature.keys():
224
+ single_feature["target_len"] = linear.shape[1]
225
+ single_feature["linear"] = linear.T # [T, n_linear]
226
+
227
+ if self.cfg.preprocess.use_frame_pitch:
228
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
229
+ frame_pitch = np.load(frame_pitch_path)
230
+ if "target_len" not in single_feature.keys():
231
+ single_feature["target_len"] = len(frame_pitch)
232
+ aligned_frame_pitch = align_length(
233
+ frame_pitch, single_feature["target_len"]
234
+ )
235
+ single_feature["frame_pitch"] = aligned_frame_pitch
236
+
237
+ if self.cfg.preprocess.use_uv:
238
+ frame_uv_path = self.utt2uv_path[utt]
239
+ frame_uv = np.load(frame_uv_path)
240
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
241
+ aligned_frame_uv = [
242
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
243
+ ]
244
+ aligned_frame_uv = np.array(aligned_frame_uv)
245
+ single_feature["frame_uv"] = aligned_frame_uv
246
+
247
+ if self.cfg.preprocess.use_frame_energy:
248
+ frame_energy_path = self.utt2frame_energy_path[utt]
249
+ frame_energy = np.load(frame_energy_path)
250
+ if "target_len" not in single_feature.keys():
251
+ single_feature["target_len"] = len(frame_energy)
252
+ aligned_frame_energy = align_length(
253
+ frame_energy, single_feature["target_len"]
254
+ )
255
+ single_feature["frame_energy"] = aligned_frame_energy
256
+
257
+ if self.cfg.preprocess.use_audio:
258
+ audio = np.load(self.utt2audio_path[utt])
259
+ single_feature["audio"] = audio
260
+ single_feature["audio_len"] = audio.shape[0]
261
+
262
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
263
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
264
+ single_feature["phone_len"] = len(self.utt2seq[utt])
265
+
266
+ return single_feature
267
+
268
+ def __len__(self):
269
+ return len(self.metadata)
270
+
271
+
272
+ class BaseCollator(object):
273
+ """Zero-pads model inputs and targets based on number of frames per step"""
274
+
275
+ def __init__(self, cfg):
276
+ self.cfg = cfg
277
+
278
+ def __call__(self, batch):
279
+ packed_batch_features = dict()
280
+
281
+ # mel: [b, T, n_mels]
282
+ # frame_pitch, frame_energy: [1, T]
283
+ # target_len: [1]
284
+ # spk_id: [b, 1]
285
+ # mask: [b, T, 1]
286
+
287
+ for key in batch[0].keys():
288
+ if key == "target_len":
289
+ packed_batch_features["target_len"] = torch.LongTensor(
290
+ [b["target_len"] for b in batch]
291
+ )
292
+ masks = [
293
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
294
+ ]
295
+ packed_batch_features["mask"] = pad_sequence(
296
+ masks, batch_first=True, padding_value=0
297
+ )
298
+ elif key == "phone_len":
299
+ packed_batch_features["phone_len"] = torch.LongTensor(
300
+ [b["phone_len"] for b in batch]
301
+ )
302
+ masks = [
303
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
304
+ ]
305
+ packed_batch_features["phn_mask"] = pad_sequence(
306
+ masks, batch_first=True, padding_value=0
307
+ )
308
+ elif key == "audio_len":
309
+ packed_batch_features["audio_len"] = torch.LongTensor(
310
+ [b["audio_len"] for b in batch]
311
+ )
312
+ masks = [
313
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
314
+ ]
315
+ else:
316
+ values = [torch.from_numpy(b[key]) for b in batch]
317
+ packed_batch_features[key] = pad_sequence(
318
+ values, batch_first=True, padding_value=0
319
+ )
320
+ return packed_batch_features
321
+
322
+
323
+ class BaseTestDataset(torch.utils.data.Dataset):
324
+ def __init__(self, cfg, args):
325
+ raise NotImplementedError
326
+
327
+ def get_metadata(self):
328
+ raise NotImplementedError
329
+
330
+ def __getitem__(self, index):
331
+ raise NotImplementedError
332
+
333
+ def __len__(self):
334
+ return len(self.metadata)
335
+
336
+
337
+ class BaseTestCollator(object):
338
+ """Zero-pads model inputs and targets based on number of frames per step"""
339
+
340
+ def __init__(self, cfg):
341
+ raise NotImplementedError
342
+
343
+ def __call__(self, batch):
344
+ raise NotImplementedError
models/base/base_inference.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import re
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from models.vocoders.vocoder_inference import synthesis
17
+ from torch.utils.data import DataLoader
18
+ from utils.util import set_all_random_seed
19
+ from utils.util import load_config
20
+
21
+
22
+ def parse_vocoder(vocoder_dir):
23
+ r"""Parse vocoder config"""
24
+ vocoder_dir = os.path.abspath(vocoder_dir)
25
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
26
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
27
+ ckpt_path = str(ckpt_list[0])
28
+ vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
29
+ vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
30
+ return vocoder_cfg, ckpt_path
31
+
32
+
33
+ class BaseInference(object):
34
+ def __init__(self, cfg, args):
35
+ self.cfg = cfg
36
+ self.args = args
37
+ self.model_type = cfg.model_type
38
+ self.avg_rtf = list()
39
+ set_all_random_seed(10086)
40
+ os.makedirs(args.output_dir, exist_ok=True)
41
+
42
+ if torch.cuda.is_available():
43
+ self.device = torch.device("cuda")
44
+ else:
45
+ self.device = torch.device("cpu")
46
+ torch.set_num_threads(10) # inference on 1 core cpu.
47
+
48
+ # Load acoustic model
49
+ self.model = self.create_model().to(self.device)
50
+ state_dict = self.load_state_dict()
51
+ self.load_model(state_dict)
52
+ self.model.eval()
53
+
54
+ # Load vocoder model if necessary
55
+ if self.args.checkpoint_dir_vocoder is not None:
56
+ self.get_vocoder_info()
57
+
58
+ def create_model(self):
59
+ raise NotImplementedError
60
+
61
+ def load_state_dict(self):
62
+ self.checkpoint_file = self.args.checkpoint_file
63
+ if self.checkpoint_file is None:
64
+ assert self.args.checkpoint_dir is not None
65
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
66
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
67
+ self.checkpoint_file = os.path.join(
68
+ self.args.checkpoint_dir, checkpoint_filename
69
+ )
70
+
71
+ self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
72
+
73
+ print("Restore acoustic model from {}".format(self.checkpoint_file))
74
+ raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
75
+ self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
76
+
77
+ return raw_state_dict
78
+
79
+ def load_model(self, model):
80
+ raise NotImplementedError
81
+
82
+ def get_vocoder_info(self):
83
+ self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
84
+ self.vocoder_cfg = os.path.join(
85
+ os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
86
+ )
87
+ self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
88
+ self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
89
+ self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
90
+
91
+ def build_test_utt_data(self):
92
+ raise NotImplementedError
93
+
94
+ def build_testdata_loader(self, args, target_speaker=None):
95
+ datasets, collate = self.build_test_dataset()
96
+ self.test_dataset = datasets(self.cfg, args, target_speaker)
97
+ self.test_collate = collate(self.cfg)
98
+ self.test_batch_size = min(
99
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
100
+ )
101
+ test_loader = DataLoader(
102
+ self.test_dataset,
103
+ collate_fn=self.test_collate,
104
+ num_workers=self.args.num_workers,
105
+ batch_size=self.test_batch_size,
106
+ shuffle=False,
107
+ )
108
+ return test_loader
109
+
110
+ def inference_each_batch(self, batch_data):
111
+ raise NotImplementedError
112
+
113
+ def inference_for_batches(self, args, target_speaker=None):
114
+ ###### Construct test_batch ######
115
+ loader = self.build_testdata_loader(args, target_speaker)
116
+
117
+ n_batch = len(loader)
118
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
119
+ print(
120
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
121
+ now, self.test_batch_size, n_batch
122
+ )
123
+ )
124
+ self.model.eval()
125
+
126
+ ###### Inference for each batch ######
127
+ pred_res = []
128
+ with torch.no_grad():
129
+ for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
130
+ # Put the data to device
131
+ for k, v in batch_data.items():
132
+ batch_data[k] = batch_data[k].to(self.device)
133
+
134
+ y_pred, stats = self.inference_each_batch(batch_data)
135
+
136
+ pred_res += y_pred
137
+
138
+ return pred_res
139
+
140
+ def inference(self, feature):
141
+ raise NotImplementedError
142
+
143
+ def synthesis_by_vocoder(self, pred):
144
+ audios_pred = synthesis(
145
+ self.vocoder_cfg,
146
+ self.checkpoint_dir_vocoder,
147
+ len(pred),
148
+ pred,
149
+ )
150
+ return audios_pred
151
+
152
+ def __call__(self, utt):
153
+ feature = self.build_test_utt_data(utt)
154
+ start_time = time.time()
155
+ with torch.no_grad():
156
+ outputs = self.inference(feature)[0]
157
+ time_used = time.time() - start_time
158
+ rtf = time_used / (
159
+ outputs.shape[1]
160
+ * self.cfg.preprocess.hop_size
161
+ / self.cfg.preprocess.sample_rate
162
+ )
163
+ print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
164
+ self.avg_rtf.append(rtf)
165
+ audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
166
+ return audios
167
+
168
+
169
+ def base_parser():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--config", default="config.json", help="json files for configurations."
173
+ )
174
+ parser.add_argument("--use_ddp_inference", default=False)
175
+ parser.add_argument("--n_workers", default=1, type=int)
176
+ parser.add_argument("--local_rank", default=-1, type=int)
177
+ parser.add_argument(
178
+ "--batch_size", default=1, type=int, help="Batch size for inference"
179
+ )
180
+ parser.add_argument(
181
+ "--num_workers",
182
+ default=1,
183
+ type=int,
184
+ help="Worker number for inference dataloader",
185
+ )
186
+ parser.add_argument(
187
+ "--checkpoint_dir",
188
+ type=str,
189
+ default=None,
190
+ help="Checkpoint dir including model file and configuration",
191
+ )
192
+ parser.add_argument(
193
+ "--checkpoint_file", help="checkpoint file", type=str, default=None
194
+ )
195
+ parser.add_argument(
196
+ "--test_list", help="test utterance list for testing", type=str, default=None
197
+ )
198
+ parser.add_argument(
199
+ "--checkpoint_dir_vocoder",
200
+ help="Vocoder's checkpoint dir including model file and configuration",
201
+ type=str,
202
+ default=None,
203
+ )
204
+ parser.add_argument(
205
+ "--output_dir",
206
+ type=str,
207
+ default=None,
208
+ help="Output dir for saving generated results",
209
+ )
210
+ return parser
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = base_parser()
215
+ args = parser.parse_args()
216
+ cfg = load_config(args.config)
217
+
218
+ # Build inference
219
+ inference = BaseInference(cfg, args)
220
+ inference()
models/base/base_sampler.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ concat_dataset,
36
+ batch_size,
37
+ holistic_shuffle,
38
+ logger=None,
39
+ loader_type="train",
40
+ ):
41
+ if not isinstance(concat_dataset, ConcatDataset):
42
+ raise ValueError(
43
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
44
+ type(concat_dataset)
45
+ )
46
+ )
47
+ if not isinstance(batch_size, int):
48
+ raise ValueError(
49
+ "batch_size must be an integer, but got {}".format(type(batch_size))
50
+ )
51
+ if not isinstance(holistic_shuffle, bool):
52
+ raise ValueError(
53
+ "holistic_shuffle must be a boolean, but got {}".format(
54
+ type(holistic_shuffle)
55
+ )
56
+ )
57
+
58
+ self.concat_dataset = concat_dataset
59
+ self.batch_size = batch_size
60
+ self.holistic_shuffle = holistic_shuffle
61
+
62
+ affected_dataset_name = []
63
+ affected_dataset_len = []
64
+ for dataset in concat_dataset.datasets:
65
+ dataset_len = len(dataset)
66
+ dataset_name = dataset.get_dataset_name()
67
+ if dataset_len < batch_size:
68
+ affected_dataset_name.append(dataset_name)
69
+ affected_dataset_len.append(dataset_len)
70
+
71
+ self.type = loader_type
72
+ for dataset_name, dataset_len in zip(
73
+ affected_dataset_name, affected_dataset_len
74
+ ):
75
+ if not loader_type == "valid":
76
+ logger.warning(
77
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
78
+ loader_type, dataset_name, dataset_len, batch_size
79
+ )
80
+ )
81
+
82
+ def __len__(self):
83
+ # the number of batches with drop last
84
+ num_of_batches = sum(
85
+ [
86
+ math.floor(len(dataset) / self.batch_size)
87
+ for dataset in self.concat_dataset.datasets
88
+ ]
89
+ )
90
+ # if samples are not enough for one batch, we don't drop last
91
+ if self.type == "valid" and num_of_batches < 1:
92
+ return len(self.concat_dataset)
93
+ return num_of_batches * self.batch_size
94
+
95
+ def __iter__(self):
96
+ iters = []
97
+ for dataset in self.concat_dataset.datasets:
98
+ iters.append(
99
+ SequentialSampler(dataset).__iter__()
100
+ if not self.holistic_shuffle
101
+ else RandomSampler(dataset).__iter__()
102
+ )
103
+ # e.g. [0, 200, 400]
104
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
105
+ output_batches = []
106
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
107
+ cur_batch = []
108
+ for idx in iters[dataset_idx]:
109
+ cur_batch.append(idx + init_indices[dataset_idx])
110
+ if len(cur_batch) == self.batch_size:
111
+ output_batches.append(cur_batch)
112
+ cur_batch = []
113
+ # if loader_type is valid, we don't need to drop last
114
+ if self.type == "valid" and len(cur_batch) > 0:
115
+ output_batches.append(cur_batch)
116
+
117
+ # force drop last in training
118
+ random.shuffle(output_batches)
119
+ output_indices = [item for sublist in output_batches for item in sublist]
120
+ return iter(output_indices)
121
+
122
+
123
+ def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
124
+ sampler = ScheduledSampler(
125
+ concat_dataset,
126
+ cfg.train.batch_size,
127
+ cfg.train.sampler.holistic_shuffle,
128
+ logger,
129
+ loader_type,
130
+ )
131
+ batch_sampler = BatchSampler(
132
+ sampler,
133
+ cfg.train.batch_size,
134
+ cfg.train.sampler.drop_last if not loader_type == "valid" else False,
135
+ )
136
+ return sampler, batch_sampler
models/base/base_trainer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from models.base.base_sampler import BatchSampler
19
+ from utils.util import (
20
+ Logger,
21
+ remove_older_ckpt,
22
+ save_config,
23
+ set_all_random_seed,
24
+ ValueWindow,
25
+ )
26
+
27
+
28
+ class BaseTrainer(object):
29
+ def __init__(self, args, cfg):
30
+ self.args = args
31
+ self.log_dir = args.log_dir
32
+ self.cfg = cfg
33
+
34
+ self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
35
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
36
+ if not cfg.train.ddp or args.local_rank == 0:
37
+ self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
38
+ self.logger = self.build_logger()
39
+ self.time_window = ValueWindow(50)
40
+
41
+ self.step = 0
42
+ self.epoch = -1
43
+ self.max_epochs = self.cfg.train.epochs
44
+ self.max_steps = self.cfg.train.max_steps
45
+
46
+ # set random seed & init distributed training
47
+ set_all_random_seed(self.cfg.train.random_seed)
48
+ if cfg.train.ddp:
49
+ dist.init_process_group(backend="nccl")
50
+
51
+ if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
52
+ self.singers = self.build_singers_lut()
53
+
54
+ # setup data_loader
55
+ self.data_loader = self.build_data_loader()
56
+
57
+ # setup model & enable distributed training
58
+ self.model = self.build_model()
59
+ print(self.model)
60
+
61
+ if isinstance(self.model, dict):
62
+ for key, value in self.model.items():
63
+ value.cuda(self.args.local_rank)
64
+ if key == "PQMF":
65
+ continue
66
+ if cfg.train.ddp:
67
+ self.model[key] = DistributedDataParallel(
68
+ value, device_ids=[self.args.local_rank]
69
+ )
70
+ else:
71
+ self.model.cuda(self.args.local_rank)
72
+ if cfg.train.ddp:
73
+ self.model = DistributedDataParallel(
74
+ self.model, device_ids=[self.args.local_rank]
75
+ )
76
+
77
+ # create criterion
78
+ self.criterion = self.build_criterion()
79
+ if isinstance(self.criterion, dict):
80
+ for key, value in self.criterion.items():
81
+ self.criterion[key].cuda(args.local_rank)
82
+ else:
83
+ self.criterion.cuda(self.args.local_rank)
84
+
85
+ # optimizer
86
+ self.optimizer = self.build_optimizer()
87
+ self.scheduler = self.build_scheduler()
88
+
89
+ # save config file
90
+ self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
91
+
92
+ def build_logger(self):
93
+ log_file = os.path.join(self.checkpoint_dir, "train.log")
94
+ logger = Logger(log_file, level=self.args.log_level).logger
95
+
96
+ return logger
97
+
98
+ def build_dataset(self):
99
+ raise NotImplementedError
100
+
101
+ def build_data_loader(self):
102
+ Dataset, Collator = self.build_dataset()
103
+ # build dataset instance for each dataset and combine them by ConcatDataset
104
+ datasets_list = []
105
+ for dataset in self.cfg.dataset:
106
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
107
+ datasets_list.append(subdataset)
108
+ train_dataset = ConcatDataset(datasets_list)
109
+
110
+ train_collate = Collator(self.cfg)
111
+ # TODO: multi-GPU training
112
+ if self.cfg.train.ddp:
113
+ raise NotImplementedError("DDP is not supported yet.")
114
+
115
+ # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
116
+ batch_sampler = BatchSampler(
117
+ cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
118
+ )
119
+
120
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
121
+ train_loader = DataLoader(
122
+ train_dataset,
123
+ collate_fn=train_collate,
124
+ num_workers=self.args.num_workers,
125
+ batch_sampler=batch_sampler,
126
+ pin_memory=False,
127
+ )
128
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
129
+ datasets_list = []
130
+ for dataset in self.cfg.dataset:
131
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
132
+ datasets_list.append(subdataset)
133
+ valid_dataset = ConcatDataset(datasets_list)
134
+ valid_collate = Collator(self.cfg)
135
+ batch_sampler = BatchSampler(
136
+ cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
137
+ )
138
+ valid_loader = DataLoader(
139
+ valid_dataset,
140
+ collate_fn=valid_collate,
141
+ num_workers=1,
142
+ batch_sampler=batch_sampler,
143
+ )
144
+ else:
145
+ raise NotImplementedError("DDP is not supported yet.")
146
+ # valid_loader = None
147
+ data_loader = {"train": train_loader, "valid": valid_loader}
148
+ return data_loader
149
+
150
+ def build_singers_lut(self):
151
+ # combine singers
152
+ if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
153
+ singers = collections.OrderedDict()
154
+ else:
155
+ with open(
156
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
157
+ ) as singer_file:
158
+ singers = json.load(singer_file)
159
+ singer_count = len(singers)
160
+ for dataset in self.cfg.dataset:
161
+ singer_lut_path = os.path.join(
162
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
163
+ )
164
+ with open(singer_lut_path, "r") as singer_lut_path:
165
+ singer_lut = json.load(singer_lut_path)
166
+ for singer in singer_lut.keys():
167
+ if singer not in singers:
168
+ singers[singer] = singer_count
169
+ singer_count += 1
170
+ with open(
171
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
172
+ ) as singer_file:
173
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
174
+ print(
175
+ "singers have been dumped to {}".format(
176
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
177
+ )
178
+ )
179
+ return singers
180
+
181
+ def build_model(self):
182
+ raise NotImplementedError()
183
+
184
+ def build_optimizer(self):
185
+ raise NotImplementedError
186
+
187
+ def build_scheduler(self):
188
+ raise NotImplementedError()
189
+
190
+ def build_criterion(self):
191
+ raise NotImplementedError
192
+
193
+ def get_state_dict(self):
194
+ raise NotImplementedError
195
+
196
+ def save_config_file(self):
197
+ save_config(self.config_save_path, self.cfg)
198
+
199
+ # TODO, save without module.
200
+ def save_checkpoint(self, state_dict, saved_model_path):
201
+ torch.save(state_dict, saved_model_path)
202
+
203
+ def load_checkpoint(self):
204
+ checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
205
+ assert os.path.exists(checkpoint_path)
206
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
207
+ model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
208
+ assert os.path.exists(model_path)
209
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
210
+ self.logger.info(f"Re(store) from {model_path}")
211
+ checkpoint = torch.load(model_path, map_location="cpu")
212
+ return checkpoint
213
+
214
+ def load_model(self, checkpoint):
215
+ raise NotImplementedError
216
+
217
+ def restore(self):
218
+ checkpoint = self.load_checkpoint()
219
+ self.load_model(checkpoint)
220
+
221
+ def train_step(self, data):
222
+ raise NotImplementedError(
223
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
224
+ f"your sub-class of {self.__class__.__name__}. "
225
+ )
226
+
227
+ @torch.no_grad()
228
+ def eval_step(self):
229
+ raise NotImplementedError(
230
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
231
+ f"your sub-class of {self.__class__.__name__}. "
232
+ )
233
+
234
+ def write_summary(self, losses, stats):
235
+ raise NotImplementedError(
236
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
237
+ f"your sub-class of {self.__class__.__name__}. "
238
+ )
239
+
240
+ def write_valid_summary(self, losses, stats):
241
+ raise NotImplementedError(
242
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
243
+ f"your sub-class of {self.__class__.__name__}. "
244
+ )
245
+
246
+ def echo_log(self, losses, mode="Training"):
247
+ message = [
248
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
249
+ mode, self.epoch + 1, self.step, self.time_window.average
250
+ )
251
+ ]
252
+
253
+ for key in sorted(losses.keys()):
254
+ if isinstance(losses[key], dict):
255
+ for k, v in losses[key].items():
256
+ message.append(
257
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
258
+ )
259
+ else:
260
+ message.append(
261
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
262
+ )
263
+ self.logger.info(", ".join(message))
264
+
265
+ def eval_epoch(self):
266
+ self.logger.info("Validation...")
267
+ valid_losses = {}
268
+ for i, batch_data in enumerate(self.data_loader["valid"]):
269
+ for k, v in batch_data.items():
270
+ if isinstance(v, torch.Tensor):
271
+ batch_data[k] = v.cuda()
272
+ valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
273
+ for key in valid_loss:
274
+ if key not in valid_losses:
275
+ valid_losses[key] = 0
276
+ valid_losses[key] += valid_loss[key]
277
+
278
+ # Add mel and audio to the Tensorboard
279
+ # Average loss
280
+ for key in valid_losses:
281
+ valid_losses[key] /= i + 1
282
+ self.echo_log(valid_losses, "Valid")
283
+ return valid_losses, valid_stats
284
+
285
+ def train_epoch(self):
286
+ for i, batch_data in enumerate(self.data_loader["train"]):
287
+ start_time = time.time()
288
+ # Put the data to cuda device
289
+ for k, v in batch_data.items():
290
+ if isinstance(v, torch.Tensor):
291
+ batch_data[k] = v.cuda(self.args.local_rank)
292
+
293
+ # Training step
294
+ train_losses, train_stats, total_loss = self.train_step(batch_data)
295
+ self.time_window.append(time.time() - start_time)
296
+
297
+ if self.args.local_rank == 0 or not self.cfg.train.ddp:
298
+ if self.step % self.args.stdout_interval == 0:
299
+ self.echo_log(train_losses, "Training")
300
+
301
+ if self.step % self.cfg.train.save_summary_steps == 0:
302
+ self.logger.info(f"Save summary as step {self.step}")
303
+ self.write_summary(train_losses, train_stats)
304
+
305
+ if (
306
+ self.step % self.cfg.train.save_checkpoints_steps == 0
307
+ and self.step != 0
308
+ ):
309
+ saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
310
+ self.step, total_loss
311
+ )
312
+ saved_model_path = os.path.join(
313
+ self.checkpoint_dir, saved_model_name
314
+ )
315
+ saved_state_dict = self.get_state_dict()
316
+ self.save_checkpoint(saved_state_dict, saved_model_path)
317
+ self.save_config_file()
318
+ # keep max n models
319
+ remove_older_ckpt(
320
+ saved_model_name,
321
+ self.checkpoint_dir,
322
+ max_to_keep=self.cfg.train.keep_checkpoint_max,
323
+ )
324
+
325
+ if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
326
+ if isinstance(self.model, dict):
327
+ for key in self.model.keys():
328
+ self.model[key].eval()
329
+ else:
330
+ self.model.eval()
331
+ # Evaluate one epoch and get average loss
332
+ valid_losses, valid_stats = self.eval_epoch()
333
+ if isinstance(self.model, dict):
334
+ for key in self.model.keys():
335
+ self.model[key].train()
336
+ else:
337
+ self.model.train()
338
+ # Write validation losses to summary.
339
+ self.write_valid_summary(valid_losses, valid_stats)
340
+ self.step += 1
341
+
342
+ def train(self):
343
+ for epoch in range(max(0, self.epoch), self.max_epochs):
344
+ self.train_epoch()
345
+ self.epoch += 1
346
+ if self.step > self.max_steps:
347
+ self.logger.info("Training finished!")
348
+ break
models/base/new_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ from abc import abstractmethod
9
+ from pathlib import Path
10
+
11
+ import json5
12
+ import torch
13
+ import yaml
14
+
15
+
16
+ # TODO: for training and validating
17
+ class BaseDataset(torch.utils.data.Dataset):
18
+ r"""Base dataset for training and validating."""
19
+
20
+ def __init__(self, args, cfg, is_valid=False):
21
+ pass
22
+
23
+
24
+ class BaseTestDataset(torch.utils.data.Dataset):
25
+ r"""Test dataset for inference."""
26
+
27
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
28
+ assert infer_type in ["from_dataset", "from_file"]
29
+
30
+ self.args = args
31
+ self.cfg = cfg
32
+ self.infer_type = infer_type
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, index):
36
+ pass
37
+
38
+ def __len__(self):
39
+ return len(self.metadata)
40
+
41
+ def get_metadata(self):
42
+ path = Path(self.args.source)
43
+ if path.suffix == ".json" or path.suffix == ".jsonc":
44
+ metadata = json5.load(open(self.args.source, "r"))
45
+ elif path.suffix == ".yaml" or path.suffix == ".yml":
46
+ metadata = yaml.full_load(open(self.args.source, "r"))
47
+ else:
48
+ raise ValueError(f"Unsupported file type: {path.suffix}")
49
+
50
+ return metadata
models/base/new_inference.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from abc import abstractmethod
11
+ from pathlib import Path
12
+
13
+ import accelerate
14
+ import json5
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.logging import get_logger
18
+ from torch.utils.data import DataLoader
19
+
20
+ from models.vocoders.vocoder_inference import synthesis
21
+ from utils.io import save_audio
22
+ from utils.util import load_config
23
+ from utils.audio_slicer import is_silence
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class BaseInference(object):
29
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
30
+ super().__init__()
31
+
32
+ start = time.monotonic_ns()
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ assert infer_type in ["from_dataset", "from_file"]
37
+ self.infer_type = infer_type
38
+
39
+ # init with accelerate
40
+ self.accelerator = accelerate.Accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ # Use accelerate logger for distributed inference
44
+ with self.accelerator.main_process_first():
45
+ self.logger = get_logger("inference", log_level=args.log_level)
46
+
47
+ # Log some info
48
+ self.logger.info("=" * 56)
49
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
50
+ self.logger.info("=" * 56)
51
+ self.logger.info("\n")
52
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
53
+
54
+ self.acoustics_dir = args.acoustics_dir
55
+ self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
56
+ self.vocoder_dir = args.vocoder_dir
57
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
58
+ # should be in svc inferencer
59
+ # self.target_singer = args.target_singer
60
+ # self.logger.info(f"Target singers: {args.target_singer}")
61
+ # self.trans_key = args.trans_key
62
+ # self.logger.info(f"Trans key: {args.trans_key}")
63
+
64
+ os.makedirs(args.output_dir, exist_ok=True)
65
+
66
+ # set random seed
67
+ with self.accelerator.main_process_first():
68
+ start = time.monotonic_ns()
69
+ self._set_random_seed(self.cfg.train.random_seed)
70
+ end = time.monotonic_ns()
71
+ self.logger.debug(
72
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
73
+ )
74
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
75
+
76
+ # setup data_loader
77
+ with self.accelerator.main_process_first():
78
+ self.logger.info("Building dataset...")
79
+ start = time.monotonic_ns()
80
+ self.test_dataloader = self._build_dataloader()
81
+ end = time.monotonic_ns()
82
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
83
+
84
+ # setup model
85
+ with self.accelerator.main_process_first():
86
+ self.logger.info("Building model...")
87
+ start = time.monotonic_ns()
88
+ self.model = self._build_model()
89
+ end = time.monotonic_ns()
90
+ # self.logger.debug(self.model)
91
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
92
+
93
+ # init with accelerate
94
+ self.logger.info("Initializing accelerate...")
95
+ start = time.monotonic_ns()
96
+ self.accelerator = accelerate.Accelerator()
97
+ self.model = self.accelerator.prepare(self.model)
98
+ end = time.monotonic_ns()
99
+ self.accelerator.wait_for_everyone()
100
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
101
+
102
+ with self.accelerator.main_process_first():
103
+ self.logger.info("Loading checkpoint...")
104
+ start = time.monotonic_ns()
105
+ # TODO: Also, suppose only use latest one yet
106
+ self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
107
+ end = time.monotonic_ns()
108
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
109
+
110
+ self.model.eval()
111
+ self.accelerator.wait_for_everyone()
112
+
113
+ ### Abstract methods ###
114
+ @abstractmethod
115
+ def _build_test_dataset(self):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _build_model(self):
120
+ pass
121
+
122
+ @abstractmethod
123
+ @torch.inference_mode()
124
+ def _inference_each_batch(self, batch_data):
125
+ pass
126
+
127
+ ### Abstract methods end ###
128
+
129
+ @torch.inference_mode()
130
+ def inference(self):
131
+ for i, batch in enumerate(self.test_dataloader):
132
+ y_pred = self._inference_each_batch(batch).cpu()
133
+ mel_min, mel_max = self.test_dataset.target_mel_extrema
134
+ y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
135
+ y_ls = y_pred.chunk(self.test_batch_size)
136
+ tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
137
+ j = 0
138
+ for it, l in zip(y_ls, tgt_ls):
139
+ l = l.item()
140
+ it = it.squeeze(0)[:l]
141
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
142
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
143
+ j += 1
144
+
145
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
146
+
147
+ res = synthesis(
148
+ cfg=vocoder_cfg,
149
+ vocoder_weight_file=vocoder_ckpt,
150
+ n_samples=None,
151
+ pred=[
152
+ torch.load(
153
+ os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
154
+ ).numpy(force=True)
155
+ for i in self.test_dataset.metadata
156
+ ],
157
+ )
158
+
159
+ output_audio_files = []
160
+ for it, wav in zip(self.test_dataset.metadata, res):
161
+ uid = it["Uid"]
162
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
163
+ output_audio_files.append(file)
164
+
165
+ wav = wav.numpy(force=True)
166
+ save_audio(
167
+ file,
168
+ wav,
169
+ self.cfg.preprocess.sample_rate,
170
+ add_silence=False,
171
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
172
+ )
173
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
174
+
175
+ return sorted(output_audio_files)
176
+
177
+ # TODO: LEGACY CODE
178
+ def _build_dataloader(self):
179
+ datasets, collate = self._build_test_dataset()
180
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
181
+ self.test_collate = collate(self.cfg)
182
+ self.test_batch_size = min(
183
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
184
+ )
185
+ test_dataloader = DataLoader(
186
+ self.test_dataset,
187
+ collate_fn=self.test_collate,
188
+ num_workers=1,
189
+ batch_size=self.test_batch_size,
190
+ shuffle=False,
191
+ )
192
+ return test_dataloader
193
+
194
+ def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
195
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
196
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
197
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
198
+ method after** ``accelerator.prepare()``.
199
+ """
200
+ if checkpoint_path is None:
201
+ ls = []
202
+ for i in Path(checkpoint_dir).iterdir():
203
+ if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
204
+ ls.append(i)
205
+ ls.sort(
206
+ key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
207
+ )
208
+ checkpoint_path = ls[0]
209
+ else:
210
+ checkpoint_path = Path(checkpoint_path)
211
+ self.accelerator.load_state(str(checkpoint_path))
212
+ # set epoch and step
213
+ self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
214
+ self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
215
+ return str(checkpoint_path)
216
+
217
+ @staticmethod
218
+ def _set_random_seed(seed):
219
+ r"""Set random seed for all possible random modules."""
220
+ random.seed(seed)
221
+ np.random.seed(seed)
222
+ torch.random.manual_seed(seed)
223
+
224
+ @staticmethod
225
+ def _parse_vocoder(vocoder_dir):
226
+ r"""Parse vocoder config"""
227
+ vocoder_dir = os.path.abspath(vocoder_dir)
228
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
229
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
230
+ ckpt_path = str(ckpt_list[0])
231
+ vocoder_cfg = load_config(
232
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
233
+ )
234
+ return vocoder_cfg, ckpt_path
235
+
236
+ @staticmethod
237
+ def __count_parameters(model):
238
+ return sum(p.numel() for p in model.parameters())
239
+
240
+ def __dump_cfg(self, path):
241
+ os.makedirs(os.path.dirname(path), exist_ok=True)
242
+ json5.dump(
243
+ self.cfg,
244
+ open(path, "w"),
245
+ indent=4,
246
+ sort_keys=True,
247
+ ensure_ascii=False,
248
+ quote_keys=True,
249
+ )
models/base/new_trainer.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from abc import abstractmethod
12
+ from pathlib import Path
13
+
14
+ import accelerate
15
+ import json5
16
+ import numpy as np
17
+ import torch
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from models.base.base_sampler import build_samplers
24
+ from optimizer.optimizers import NoamLR
25
+
26
+
27
+ class BaseTrainer(object):
28
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
29
+
30
+ def __init__(self, args=None, cfg=None):
31
+ super().__init__()
32
+
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ cfg.exp_name = args.exp_name
37
+
38
+ # init with accelerate
39
+ self._init_accelerator()
40
+ self.accelerator.wait_for_everyone()
41
+
42
+ # Use accelerate logger for distributed training
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # setup model
104
+ with self.accelerator.main_process_first():
105
+ self.logger.info("Building model...")
106
+ start = time.monotonic_ns()
107
+ self.model = self._build_model()
108
+ end = time.monotonic_ns()
109
+ self.logger.debug(self.model)
110
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
111
+ self.logger.info(
112
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
113
+ )
114
+ # optimizer & scheduler
115
+ with self.accelerator.main_process_first():
116
+ self.logger.info("Building optimizer and scheduler...")
117
+ start = time.monotonic_ns()
118
+ self.optimizer = self.__build_optimizer()
119
+ self.scheduler = self.__build_scheduler()
120
+ end = time.monotonic_ns()
121
+ self.logger.info(
122
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
123
+ )
124
+
125
+ # accelerate prepare
126
+ self.logger.info("Initializing accelerate...")
127
+ start = time.monotonic_ns()
128
+ (
129
+ self.train_dataloader,
130
+ self.valid_dataloader,
131
+ self.model,
132
+ self.optimizer,
133
+ self.scheduler,
134
+ ) = self.accelerator.prepare(
135
+ self.train_dataloader,
136
+ self.valid_dataloader,
137
+ self.model,
138
+ self.optimizer,
139
+ self.scheduler,
140
+ )
141
+ end = time.monotonic_ns()
142
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
143
+
144
+ # create criterion
145
+ with self.accelerator.main_process_first():
146
+ self.logger.info("Building criterion...")
147
+ start = time.monotonic_ns()
148
+ self.criterion = self._build_criterion()
149
+ end = time.monotonic_ns()
150
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
151
+
152
+ # Resume or Finetune
153
+ with self.accelerator.main_process_first():
154
+ if args.resume:
155
+ ## Automatically resume according to the current exprimental name
156
+ self.logger.info("Resuming from {}...".format(self.checkpoint_dir))
157
+ start = time.monotonic_ns()
158
+ ckpt_path = self.__load_model(
159
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
160
+ )
161
+ end = time.monotonic_ns()
162
+ self.logger.info(
163
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
164
+ )
165
+ self.checkpoints_path = json.load(
166
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
167
+ )
168
+ elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "":
169
+ ## Resume from the given checkpoint path
170
+ if not os.path.exists(args.resume_from_ckpt_path):
171
+ raise ValueError(
172
+ "[Error] The resumed checkpoint path {} don't exist.".format(
173
+ args.resume_from_ckpt_path
174
+ )
175
+ )
176
+
177
+ self.logger.info(
178
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
179
+ )
180
+ start = time.monotonic_ns()
181
+ ckpt_path = self.__load_model(
182
+ checkpoint_path=args.resume_from_ckpt_path,
183
+ resume_type=args.resume_type,
184
+ )
185
+ end = time.monotonic_ns()
186
+ self.logger.info(
187
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
188
+ )
189
+
190
+ # save config file path
191
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
192
+
193
+ ### Following are abstract methods that should be implemented in child classes ###
194
+ @abstractmethod
195
+ def _build_dataset(self):
196
+ r"""Build dataset for model training/validating/evaluating."""
197
+ pass
198
+
199
+ @staticmethod
200
+ @abstractmethod
201
+ def _build_criterion():
202
+ r"""Build criterion function for model loss calculation."""
203
+ pass
204
+
205
+ @abstractmethod
206
+ def _build_model(self):
207
+ r"""Build model for training/validating/evaluating."""
208
+ pass
209
+
210
+ @abstractmethod
211
+ def _forward_step(self, batch):
212
+ r"""One forward step of the neural network. This abstract method is trying to
213
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
214
+ However, for special case that using different forward step pattern for
215
+ training and validating, you could just override this method with ``pass`` and
216
+ implement ``_train_step`` and ``_valid_step`` separately.
217
+ """
218
+ pass
219
+
220
+ @abstractmethod
221
+ def _save_auxiliary_states(self):
222
+ r"""To save some auxiliary states when saving model's ckpt"""
223
+ pass
224
+
225
+ ### Abstract methods end ###
226
+
227
+ ### THIS IS MAIN ENTRY ###
228
+ def train_loop(self):
229
+ r"""Training loop. The public entry of training process."""
230
+ # Wait everyone to prepare before we move on
231
+ self.accelerator.wait_for_everyone()
232
+ # dump config file
233
+ if self.accelerator.is_main_process:
234
+ self.__dump_cfg(self.config_save_path)
235
+ self.model.train()
236
+ self.optimizer.zero_grad()
237
+ # Wait to ensure good to go
238
+ self.accelerator.wait_for_everyone()
239
+ while self.epoch < self.max_epoch:
240
+ self.logger.info("\n")
241
+ self.logger.info("-" * 32)
242
+ self.logger.info("Epoch {}: ".format(self.epoch))
243
+
244
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
245
+ ### It's inconvenient for the model with multiple losses
246
+ # Do training & validating epoch
247
+ train_loss = self._train_epoch()
248
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
249
+ valid_loss = self._valid_epoch()
250
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
251
+ self.accelerator.log(
252
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
253
+ step=self.epoch,
254
+ )
255
+
256
+ self.accelerator.wait_for_everyone()
257
+ # TODO: what is scheduler?
258
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
259
+
260
+ # Check if hit save_checkpoint_stride and run_eval
261
+ run_eval = False
262
+ if self.accelerator.is_main_process:
263
+ save_checkpoint = False
264
+ hit_dix = []
265
+ for i, num in enumerate(self.save_checkpoint_stride):
266
+ if self.epoch % num == 0:
267
+ save_checkpoint = True
268
+ hit_dix.append(i)
269
+ run_eval |= self.run_eval[i]
270
+
271
+ self.accelerator.wait_for_everyone()
272
+ if self.accelerator.is_main_process and save_checkpoint:
273
+ path = os.path.join(
274
+ self.checkpoint_dir,
275
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
276
+ self.epoch, self.step, train_loss
277
+ ),
278
+ )
279
+ self.tmp_checkpoint_save_path = path
280
+ self.accelerator.save_state(path)
281
+ print(f"save checkpoint in {path}")
282
+ json.dump(
283
+ self.checkpoints_path,
284
+ open(os.path.join(path, "ckpts.json"), "w"),
285
+ ensure_ascii=False,
286
+ indent=4,
287
+ )
288
+ self._save_auxiliary_states()
289
+
290
+ # Remove old checkpoints
291
+ to_remove = []
292
+ for idx in hit_dix:
293
+ self.checkpoints_path[idx].append(path)
294
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
295
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
296
+
297
+ # Search conflicts
298
+ total = set()
299
+ for i in self.checkpoints_path:
300
+ total |= set(i)
301
+ do_remove = set()
302
+ for idx, path in to_remove[::-1]:
303
+ if path in total:
304
+ self.checkpoints_path[idx].insert(0, path)
305
+ else:
306
+ do_remove.add(path)
307
+
308
+ # Remove old checkpoints
309
+ for path in do_remove:
310
+ shutil.rmtree(path, ignore_errors=True)
311
+ self.logger.debug(f"Remove old checkpoint: {path}")
312
+
313
+ self.accelerator.wait_for_everyone()
314
+ if run_eval:
315
+ # TODO: run evaluation
316
+ pass
317
+
318
+ # Update info for each epoch
319
+ self.epoch += 1
320
+
321
+ # Finish training and save final checkpoint
322
+ self.accelerator.wait_for_everyone()
323
+ if self.accelerator.is_main_process:
324
+ self.accelerator.save_state(
325
+ os.path.join(
326
+ self.checkpoint_dir,
327
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
328
+ self.epoch, self.step, valid_loss
329
+ ),
330
+ )
331
+ )
332
+ self._save_auxiliary_states()
333
+
334
+ self.accelerator.end_training()
335
+
336
+ ### Following are methods that can be used directly in child classes ###
337
+ def _train_epoch(self):
338
+ r"""Training epoch. Should return average loss of a batch (sample) over
339
+ one epoch. See ``train_loop`` for usage.
340
+ """
341
+ self.model.train()
342
+ epoch_sum_loss: float = 0.0
343
+ epoch_step: int = 0
344
+ for batch in tqdm(
345
+ self.train_dataloader,
346
+ desc=f"Training Epoch {self.epoch}",
347
+ unit="batch",
348
+ colour="GREEN",
349
+ leave=False,
350
+ dynamic_ncols=True,
351
+ smoothing=0.04,
352
+ disable=not self.accelerator.is_main_process,
353
+ ):
354
+ # Do training step and BP
355
+ with self.accelerator.accumulate(self.model):
356
+ loss = self._train_step(batch)
357
+ self.accelerator.backward(loss)
358
+ self.optimizer.step()
359
+ self.optimizer.zero_grad()
360
+ self.batch_count += 1
361
+
362
+ # Update info for each step
363
+ # TODO: step means BP counts or batch counts?
364
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
365
+ epoch_sum_loss += loss
366
+ self.accelerator.log(
367
+ {
368
+ "Step/Train Loss": loss,
369
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
370
+ },
371
+ step=self.step,
372
+ )
373
+ self.step += 1
374
+ epoch_step += 1
375
+
376
+ self.accelerator.wait_for_everyone()
377
+ return (
378
+ epoch_sum_loss
379
+ / len(self.train_dataloader)
380
+ * self.cfg.train.gradient_accumulation_step
381
+ )
382
+
383
+ @torch.inference_mode()
384
+ def _valid_epoch(self):
385
+ r"""Testing epoch. Should return average loss of a batch (sample) over
386
+ one epoch. See ``train_loop`` for usage.
387
+ """
388
+ self.model.eval()
389
+ epoch_sum_loss = 0.0
390
+ for batch in tqdm(
391
+ self.valid_dataloader,
392
+ desc=f"Validating Epoch {self.epoch}",
393
+ unit="batch",
394
+ colour="GREEN",
395
+ leave=False,
396
+ dynamic_ncols=True,
397
+ smoothing=0.04,
398
+ disable=not self.accelerator.is_main_process,
399
+ ):
400
+ batch_loss = self._valid_step(batch)
401
+ epoch_sum_loss += batch_loss.item()
402
+
403
+ self.accelerator.wait_for_everyone()
404
+ return epoch_sum_loss / len(self.valid_dataloader)
405
+
406
+ def _train_step(self, batch):
407
+ r"""Training forward step. Should return average loss of a sample over
408
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
409
+ See ``_train_epoch`` for usage.
410
+ """
411
+ return self._forward_step(batch)
412
+
413
+ @torch.inference_mode()
414
+ def _valid_step(self, batch):
415
+ r"""Testing forward step. Should return average loss of a sample over
416
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
417
+ See ``_test_epoch`` for usage.
418
+ """
419
+ return self._forward_step(batch)
420
+
421
+ def __load_model(
422
+ self,
423
+ checkpoint_dir: str = None,
424
+ checkpoint_path: str = None,
425
+ resume_type: str = "",
426
+ ):
427
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
428
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
429
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
430
+ method after** ``accelerator.prepare()``.
431
+ """
432
+ if checkpoint_path is None:
433
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
434
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
435
+ checkpoint_path = ls[0]
436
+ self.logger.info("Resume from {}...".format(checkpoint_path))
437
+
438
+ if resume_type in ["resume", ""]:
439
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
440
+ self.accelerator.load_state(input_dir=checkpoint_path)
441
+
442
+ # set epoch and step
443
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
444
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
445
+
446
+ elif resume_type == "finetune":
447
+ # Load only the model weights
448
+ accelerate.load_checkpoint_and_dispatch(
449
+ self.accelerator.unwrap_model(self.model),
450
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
451
+ )
452
+ self.logger.info("Load model weights for finetune...")
453
+
454
+ else:
455
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
456
+
457
+ return checkpoint_path
458
+
459
+ # TODO: LEGACY CODE
460
+ def _build_dataloader(self):
461
+ Dataset, Collator = self._build_dataset()
462
+
463
+ # build dataset instance for each dataset and combine them by ConcatDataset
464
+ datasets_list = []
465
+ for dataset in self.cfg.dataset:
466
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
467
+ datasets_list.append(subdataset)
468
+ train_dataset = ConcatDataset(datasets_list)
469
+ train_collate = Collator(self.cfg)
470
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
471
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
472
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
473
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
474
+ train_loader = DataLoader(
475
+ train_dataset,
476
+ collate_fn=train_collate,
477
+ batch_sampler=batch_sampler,
478
+ num_workers=self.cfg.train.dataloader.num_worker,
479
+ pin_memory=self.cfg.train.dataloader.pin_memory,
480
+ )
481
+
482
+ # Build valid dataloader
483
+ datasets_list = []
484
+ for dataset in self.cfg.dataset:
485
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
486
+ datasets_list.append(subdataset)
487
+ valid_dataset = ConcatDataset(datasets_list)
488
+ valid_collate = Collator(self.cfg)
489
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
490
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
491
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
492
+ valid_loader = DataLoader(
493
+ valid_dataset,
494
+ collate_fn=valid_collate,
495
+ batch_sampler=batch_sampler,
496
+ num_workers=self.cfg.train.dataloader.num_worker,
497
+ pin_memory=self.cfg.train.dataloader.pin_memory,
498
+ )
499
+ return train_loader, valid_loader
500
+
501
+ @staticmethod
502
+ def _set_random_seed(seed):
503
+ r"""Set random seed for all possible random modules."""
504
+ random.seed(seed)
505
+ np.random.seed(seed)
506
+ torch.random.manual_seed(seed)
507
+
508
+ def _check_nan(self, loss, y_pred, y_gt):
509
+ if torch.any(torch.isnan(loss)):
510
+ self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
511
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
512
+ if torch.any(torch.isnan(y_pred)):
513
+ self.logger.error(
514
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
515
+ )
516
+ else:
517
+ self.logger.debug(
518
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
519
+ )
520
+ if torch.any(torch.isnan(y_gt)):
521
+ self.logger.error(
522
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
523
+ )
524
+ else:
525
+ self.logger.debug(
526
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
527
+ )
528
+ if torch.any(torch.isnan(y_pred)):
529
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
530
+ else:
531
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
532
+ if torch.any(torch.isnan(y_gt)):
533
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
534
+ else:
535
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
536
+
537
+ # TODO: still OK to save tracking?
538
+ self.accelerator.end_training()
539
+ raise RuntimeError("Loss has Nan! See log for more info.")
540
+
541
+ ### Protected methods end ###
542
+
543
+ ## Following are private methods ##
544
+ ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
545
+ def __build_optimizer(self):
546
+ r"""Build optimizer for model."""
547
+ # Make case-insensitive matching
548
+ if self.cfg.train.optimizer.lower() == "adadelta":
549
+ optimizer = torch.optim.Adadelta(
550
+ self.model.parameters(), **self.cfg.train.adadelta
551
+ )
552
+ self.logger.info("Using Adadelta optimizer.")
553
+ elif self.cfg.train.optimizer.lower() == "adagrad":
554
+ optimizer = torch.optim.Adagrad(
555
+ self.model.parameters(), **self.cfg.train.adagrad
556
+ )
557
+ self.logger.info("Using Adagrad optimizer.")
558
+ elif self.cfg.train.optimizer.lower() == "adam":
559
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
560
+ self.logger.info("Using Adam optimizer.")
561
+ elif self.cfg.train.optimizer.lower() == "adamw":
562
+ optimizer = torch.optim.AdamW(
563
+ self.model.parameters(), **self.cfg.train.adamw
564
+ )
565
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
566
+ optimizer = torch.optim.SparseAdam(
567
+ self.model.parameters(), **self.cfg.train.sparseadam
568
+ )
569
+ elif self.cfg.train.optimizer.lower() == "adamax":
570
+ optimizer = torch.optim.Adamax(
571
+ self.model.parameters(), **self.cfg.train.adamax
572
+ )
573
+ elif self.cfg.train.optimizer.lower() == "asgd":
574
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
575
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
576
+ optimizer = torch.optim.LBFGS(
577
+ self.model.parameters(), **self.cfg.train.lbfgs
578
+ )
579
+ elif self.cfg.train.optimizer.lower() == "nadam":
580
+ optimizer = torch.optim.NAdam(
581
+ self.model.parameters(), **self.cfg.train.nadam
582
+ )
583
+ elif self.cfg.train.optimizer.lower() == "radam":
584
+ optimizer = torch.optim.RAdam(
585
+ self.model.parameters(), **self.cfg.train.radam
586
+ )
587
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
588
+ optimizer = torch.optim.RMSprop(
589
+ self.model.parameters(), **self.cfg.train.rmsprop
590
+ )
591
+ elif self.cfg.train.optimizer.lower() == "rprop":
592
+ optimizer = torch.optim.Rprop(
593
+ self.model.parameters(), **self.cfg.train.rprop
594
+ )
595
+ elif self.cfg.train.optimizer.lower() == "sgd":
596
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
597
+ else:
598
+ raise NotImplementedError(
599
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
600
+ )
601
+ return optimizer
602
+
603
+ def __build_scheduler(self):
604
+ r"""Build scheduler for optimizer."""
605
+ # Make case-insensitive matching
606
+ if self.cfg.train.scheduler.lower() == "lambdalr":
607
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
608
+ self.optimizer, **self.cfg.train.lambdalr
609
+ )
610
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
611
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
612
+ self.optimizer, **self.cfg.train.multiplicativelr
613
+ )
614
+ elif self.cfg.train.scheduler.lower() == "steplr":
615
+ scheduler = torch.optim.lr_scheduler.StepLR(
616
+ self.optimizer, **self.cfg.train.steplr
617
+ )
618
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
619
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
620
+ self.optimizer, **self.cfg.train.multisteplr
621
+ )
622
+ elif self.cfg.train.scheduler.lower() == "constantlr":
623
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
624
+ self.optimizer, **self.cfg.train.constantlr
625
+ )
626
+ elif self.cfg.train.scheduler.lower() == "linearlr":
627
+ scheduler = torch.optim.lr_scheduler.LinearLR(
628
+ self.optimizer, **self.cfg.train.linearlr
629
+ )
630
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
631
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
632
+ self.optimizer, **self.cfg.train.exponentiallr
633
+ )
634
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
635
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
636
+ self.optimizer, **self.cfg.train.polynomiallr
637
+ )
638
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
639
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
640
+ self.optimizer, **self.cfg.train.cosineannealinglr
641
+ )
642
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
643
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
644
+ self.optimizer, **self.cfg.train.sequentiallr
645
+ )
646
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
647
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
648
+ self.optimizer, **self.cfg.train.reducelronplateau
649
+ )
650
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
651
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
652
+ self.optimizer, **self.cfg.train.cycliclr
653
+ )
654
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
655
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
656
+ self.optimizer, **self.cfg.train.onecyclelr
657
+ )
658
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
659
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
660
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
661
+ )
662
+ elif self.cfg.train.scheduler.lower() == "noamlr":
663
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
664
+ else:
665
+ raise NotImplementedError(
666
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
667
+ )
668
+ return scheduler
669
+
670
+ def _init_accelerator(self):
671
+ self.exp_dir = os.path.join(
672
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
673
+ )
674
+ project_config = ProjectConfiguration(
675
+ project_dir=self.exp_dir,
676
+ logging_dir=os.path.join(self.exp_dir, "log"),
677
+ )
678
+ self.accelerator = accelerate.Accelerator(
679
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
680
+ log_with=self.cfg.train.tracker,
681
+ project_config=project_config,
682
+ )
683
+ if self.accelerator.is_main_process:
684
+ os.makedirs(project_config.project_dir, exist_ok=True)
685
+ os.makedirs(project_config.logging_dir, exist_ok=True)
686
+ with self.accelerator.main_process_first():
687
+ self.accelerator.init_trackers(self.args.exp_name)
688
+
689
+ def __check_basic_configs(self):
690
+ if self.cfg.train.gradient_accumulation_step <= 0:
691
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
692
+ self.logger.error(
693
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
694
+ )
695
+ self.accelerator.end_training()
696
+ raise ValueError(
697
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
698
+ )
699
+ # TODO: check other values
700
+
701
+ @staticmethod
702
+ def __count_parameters(model):
703
+ model_param = 0.0
704
+ if isinstance(model, dict):
705
+ for key, value in model.items():
706
+ model_param += sum(p.numel() for p in model[key].parameters())
707
+ else:
708
+ model_param = sum(p.numel() for p in model.parameters())
709
+ return model_param
710
+
711
+ def __dump_cfg(self, path):
712
+ os.makedirs(os.path.dirname(path), exist_ok=True)
713
+ json5.dump(
714
+ self.cfg,
715
+ open(path, "w"),
716
+ indent=4,
717
+ sort_keys=True,
718
+ ensure_ascii=False,
719
+ quote_keys=True,
720
+ )
721
+
722
+ ### Private methods end ###
models/svc/__init__.py ADDED
File without changes
models/svc/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .svc_inference import SVCInference
7
+ from .svc_trainer import SVCTrainer
models/svc/base/svc_dataset.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ import json
10
+ import os
11
+ import numpy as np
12
+ from utils.data_utils import *
13
+ from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
14
+ from processors.content_extractor import (
15
+ ContentvecExtractor,
16
+ WhisperExtractor,
17
+ WenetExtractor,
18
+ )
19
+ from models.base.base_dataset import (
20
+ BaseCollator,
21
+ BaseDataset,
22
+ )
23
+ from models.base.new_dataset import BaseTestDataset
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class SVCDataset(BaseDataset):
29
+ def __init__(self, cfg, dataset, is_valid=False):
30
+ BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
31
+
32
+ cfg = self.cfg
33
+
34
+ if cfg.model.condition_encoder.use_whisper:
35
+ self.whisper_aligner = WhisperExtractor(self.cfg)
36
+ self.utt2whisper_path = load_content_feature_path(
37
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
38
+ )
39
+
40
+ if cfg.model.condition_encoder.use_contentvec:
41
+ self.contentvec_aligner = ContentvecExtractor(self.cfg)
42
+ self.utt2contentVec_path = load_content_feature_path(
43
+ self.metadata,
44
+ cfg.preprocess.processed_dir,
45
+ cfg.preprocess.contentvec_dir,
46
+ )
47
+
48
+ if cfg.model.condition_encoder.use_mert:
49
+ self.utt2mert_path = load_content_feature_path(
50
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
51
+ )
52
+ if cfg.model.condition_encoder.use_wenet:
53
+ self.wenet_aligner = WenetExtractor(self.cfg)
54
+ self.utt2wenet_path = load_content_feature_path(
55
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
56
+ )
57
+
58
+ def __getitem__(self, index):
59
+ single_feature = BaseDataset.__getitem__(self, index)
60
+
61
+ utt_info = self.metadata[index]
62
+ dataset = utt_info["Dataset"]
63
+ uid = utt_info["Uid"]
64
+ utt = "{}_{}".format(dataset, uid)
65
+
66
+ if self.cfg.model.condition_encoder.use_whisper:
67
+ assert "target_len" in single_feature.keys()
68
+ aligned_whisper_feat = self.whisper_aligner.offline_align(
69
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
70
+ )
71
+ single_feature["whisper_feat"] = aligned_whisper_feat
72
+
73
+ if self.cfg.model.condition_encoder.use_contentvec:
74
+ assert "target_len" in single_feature.keys()
75
+ aligned_contentvec = self.contentvec_aligner.offline_align(
76
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
77
+ )
78
+ single_feature["contentvec_feat"] = aligned_contentvec
79
+
80
+ if self.cfg.model.condition_encoder.use_mert:
81
+ assert "target_len" in single_feature.keys()
82
+ aligned_mert_feat = align_content_feature_length(
83
+ np.load(self.utt2mert_path[utt]),
84
+ single_feature["target_len"],
85
+ source_hop=self.cfg.preprocess.mert_hop_size,
86
+ )
87
+ single_feature["mert_feat"] = aligned_mert_feat
88
+
89
+ if self.cfg.model.condition_encoder.use_wenet:
90
+ assert "target_len" in single_feature.keys()
91
+ aligned_wenet_feat = self.wenet_aligner.offline_align(
92
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
93
+ )
94
+ single_feature["wenet_feat"] = aligned_wenet_feat
95
+
96
+ # print(single_feature.keys())
97
+ # for k, v in single_feature.items():
98
+ # if type(v) in [torch.Tensor, np.ndarray]:
99
+ # print(k, v.shape)
100
+ # else:
101
+ # print(k, v)
102
+ # exit()
103
+
104
+ return self.clip_if_too_long(single_feature)
105
+
106
+ def __len__(self):
107
+ return len(self.metadata)
108
+
109
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
110
+ """
111
+ ending_ts: to avoid invalid whisper features for over 30s audios
112
+ 2812 = 30 * 24000 // 256
113
+ """
114
+ ts = max(feature_seq_len - max_seq_len, 0)
115
+ ts = min(ts, ending_ts - max_seq_len)
116
+
117
+ start = random.randint(0, ts)
118
+ end = start + max_seq_len
119
+ return start, end
120
+
121
+ def clip_if_too_long(self, sample, max_seq_len=512):
122
+ """
123
+ sample :
124
+ {
125
+ 'spk_id': (1,),
126
+ 'target_len': int
127
+ 'mel': (seq_len, dim),
128
+ 'frame_pitch': (seq_len,)
129
+ 'frame_energy': (seq_len,)
130
+ 'content_vector_feat': (seq_len, dim)
131
+ }
132
+ """
133
+
134
+ if sample["target_len"] <= max_seq_len:
135
+ return sample
136
+
137
+ start, end = self.random_select(sample["target_len"], max_seq_len)
138
+ sample["target_len"] = end - start
139
+
140
+ for k in sample.keys():
141
+ if k == "audio":
142
+ # audio should be clipped in hop_size scale
143
+ sample[k] = sample[k][
144
+ start
145
+ * self.cfg.preprocess.hop_size : end
146
+ * self.cfg.preprocess.hop_size
147
+ ]
148
+ elif k == "audio_len":
149
+ sample[k] = (end - start) * self.cfg.preprocess.hop_size
150
+ elif k not in ["spk_id", "target_len"]:
151
+ sample[k] = sample[k][start:end]
152
+
153
+ return sample
154
+
155
+
156
+ class SVCCollator(BaseCollator):
157
+ """Zero-pads model inputs and targets based on number of frames per step"""
158
+
159
+ def __init__(self, cfg):
160
+ BaseCollator.__init__(self, cfg)
161
+
162
+ def __call__(self, batch):
163
+ parsed_batch_features = BaseCollator.__call__(self, batch)
164
+ return parsed_batch_features
165
+
166
+
167
+ class SVCTestDataset(BaseTestDataset):
168
+ def __init__(self, args, cfg, infer_type):
169
+ BaseTestDataset.__init__(self, args, cfg, infer_type)
170
+ self.metadata = self.get_metadata()
171
+
172
+ target_singer = args.target_singer
173
+ self.cfg = cfg
174
+ self.trans_key = args.trans_key
175
+ assert type(target_singer) == str
176
+
177
+ self.target_singer = target_singer.split("_")[-1]
178
+ self.target_dataset = target_singer.replace(
179
+ "_{}".format(self.target_singer), ""
180
+ )
181
+ if cfg.preprocess.mel_min_max_norm:
182
+ self.target_mel_extrema = load_mel_extrema(
183
+ cfg.preprocess, self.target_dataset
184
+ )
185
+ self.target_mel_extrema = torch.as_tensor(
186
+ self.target_mel_extrema[0]
187
+ ), torch.as_tensor(self.target_mel_extrema[1])
188
+
189
+ ######### Load source acoustic features #########
190
+ if cfg.preprocess.use_spkid:
191
+ spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
192
+ # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
193
+
194
+ with open(spk2id_path, "r") as f:
195
+ self.spk2id = json.load(f)
196
+ # print("self.spk2id", self.spk2id)
197
+
198
+ if cfg.preprocess.use_uv:
199
+ self.utt2uv_path = {
200
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
201
+ cfg.preprocess.processed_dir,
202
+ utt_info["Dataset"],
203
+ cfg.preprocess.uv_dir,
204
+ utt_info["Uid"] + ".npy",
205
+ )
206
+ for utt_info in self.metadata
207
+ }
208
+
209
+ if cfg.preprocess.use_frame_pitch:
210
+ self.utt2frame_pitch_path = {
211
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
212
+ cfg.preprocess.processed_dir,
213
+ utt_info["Dataset"],
214
+ cfg.preprocess.pitch_dir,
215
+ utt_info["Uid"] + ".npy",
216
+ )
217
+ for utt_info in self.metadata
218
+ }
219
+
220
+ # Target F0 median
221
+ target_f0_statistics_path = os.path.join(
222
+ cfg.preprocess.processed_dir,
223
+ self.target_dataset,
224
+ cfg.preprocess.pitch_dir,
225
+ "statistics.json",
226
+ )
227
+ self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[
228
+ f"{self.target_dataset}_{self.target_singer}"
229
+ ]["voiced_positions"]["median"]
230
+
231
+ # Source F0 median (if infer from file)
232
+ if infer_type == "from_file":
233
+ source_audio_name = cfg.inference.source_audio_name
234
+ source_f0_statistics_path = os.path.join(
235
+ cfg.preprocess.processed_dir,
236
+ source_audio_name,
237
+ cfg.preprocess.pitch_dir,
238
+ "statistics.json",
239
+ )
240
+ self.source_pitch_median = json.load(
241
+ open(source_f0_statistics_path, "r")
242
+ )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
243
+ "median"
244
+ ]
245
+ else:
246
+ self.source_pitch_median = None
247
+
248
+ if cfg.preprocess.use_frame_energy:
249
+ self.utt2frame_energy_path = {
250
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
251
+ cfg.preprocess.processed_dir,
252
+ utt_info["Dataset"],
253
+ cfg.preprocess.energy_dir,
254
+ utt_info["Uid"] + ".npy",
255
+ )
256
+ for utt_info in self.metadata
257
+ }
258
+
259
+ if cfg.preprocess.use_mel:
260
+ self.utt2mel_path = {
261
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
262
+ cfg.preprocess.processed_dir,
263
+ utt_info["Dataset"],
264
+ cfg.preprocess.mel_dir,
265
+ utt_info["Uid"] + ".npy",
266
+ )
267
+ for utt_info in self.metadata
268
+ }
269
+
270
+ ######### Load source content features' path #########
271
+ if cfg.model.condition_encoder.use_whisper:
272
+ self.whisper_aligner = WhisperExtractor(cfg)
273
+ self.utt2whisper_path = load_content_feature_path(
274
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
275
+ )
276
+
277
+ if cfg.model.condition_encoder.use_contentvec:
278
+ self.contentvec_aligner = ContentvecExtractor(cfg)
279
+ self.utt2contentVec_path = load_content_feature_path(
280
+ self.metadata,
281
+ cfg.preprocess.processed_dir,
282
+ cfg.preprocess.contentvec_dir,
283
+ )
284
+
285
+ if cfg.model.condition_encoder.use_mert:
286
+ self.utt2mert_path = load_content_feature_path(
287
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
288
+ )
289
+ if cfg.model.condition_encoder.use_wenet:
290
+ self.wenet_aligner = WenetExtractor(cfg)
291
+ self.utt2wenet_path = load_content_feature_path(
292
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
293
+ )
294
+
295
+ def __getitem__(self, index):
296
+ single_feature = {}
297
+
298
+ utt_info = self.metadata[index]
299
+ dataset = utt_info["Dataset"]
300
+ uid = utt_info["Uid"]
301
+ utt = "{}_{}".format(dataset, uid)
302
+
303
+ source_dataset = self.metadata[index]["Dataset"]
304
+
305
+ if self.cfg.preprocess.use_spkid:
306
+ single_feature["spk_id"] = np.array(
307
+ [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
308
+ dtype=np.int32,
309
+ )
310
+
311
+ ######### Get Acoustic Features Item #########
312
+ if self.cfg.preprocess.use_mel:
313
+ mel = np.load(self.utt2mel_path[utt])
314
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
315
+ if self.cfg.preprocess.use_min_max_norm_mel:
316
+ # mel norm
317
+ mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
318
+
319
+ if "target_len" not in single_feature.keys():
320
+ single_feature["target_len"] = mel.shape[1]
321
+ single_feature["mel"] = mel.T # [T, n_mels]
322
+
323
+ if self.cfg.preprocess.use_frame_pitch:
324
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
325
+ frame_pitch = np.load(frame_pitch_path)
326
+
327
+ if self.trans_key:
328
+ try:
329
+ self.trans_key = int(self.trans_key)
330
+ except:
331
+ pass
332
+ if type(self.trans_key) == int:
333
+ frame_pitch = transpose_key(frame_pitch, self.trans_key)
334
+ elif self.trans_key:
335
+ assert self.target_singer
336
+
337
+ frame_pitch = pitch_shift_to_target(
338
+ frame_pitch, self.target_pitch_median, self.source_pitch_median
339
+ )
340
+
341
+ if "target_len" not in single_feature.keys():
342
+ single_feature["target_len"] = len(frame_pitch)
343
+ aligned_frame_pitch = align_length(
344
+ frame_pitch, single_feature["target_len"]
345
+ )
346
+ single_feature["frame_pitch"] = aligned_frame_pitch
347
+
348
+ if self.cfg.preprocess.use_uv:
349
+ frame_uv_path = self.utt2uv_path[utt]
350
+ frame_uv = np.load(frame_uv_path)
351
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
352
+ aligned_frame_uv = [
353
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
354
+ ]
355
+ aligned_frame_uv = np.array(aligned_frame_uv)
356
+ single_feature["frame_uv"] = aligned_frame_uv
357
+
358
+ if self.cfg.preprocess.use_frame_energy:
359
+ frame_energy_path = self.utt2frame_energy_path[utt]
360
+ frame_energy = np.load(frame_energy_path)
361
+ if "target_len" not in single_feature.keys():
362
+ single_feature["target_len"] = len(frame_energy)
363
+ aligned_frame_energy = align_length(
364
+ frame_energy, single_feature["target_len"]
365
+ )
366
+ single_feature["frame_energy"] = aligned_frame_energy
367
+
368
+ ######### Get Content Features Item #########
369
+ if self.cfg.model.condition_encoder.use_whisper:
370
+ assert "target_len" in single_feature.keys()
371
+ aligned_whisper_feat = self.whisper_aligner.offline_align(
372
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
373
+ )
374
+ single_feature["whisper_feat"] = aligned_whisper_feat
375
+
376
+ if self.cfg.model.condition_encoder.use_contentvec:
377
+ assert "target_len" in single_feature.keys()
378
+ aligned_contentvec = self.contentvec_aligner.offline_align(
379
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
380
+ )
381
+ single_feature["contentvec_feat"] = aligned_contentvec
382
+
383
+ if self.cfg.model.condition_encoder.use_mert:
384
+ assert "target_len" in single_feature.keys()
385
+ aligned_mert_feat = align_content_feature_length(
386
+ np.load(self.utt2mert_path[utt]),
387
+ single_feature["target_len"],
388
+ source_hop=self.cfg.preprocess.mert_hop_size,
389
+ )
390
+ single_feature["mert_feat"] = aligned_mert_feat
391
+
392
+ if self.cfg.model.condition_encoder.use_wenet:
393
+ assert "target_len" in single_feature.keys()
394
+ aligned_wenet_feat = self.wenet_aligner.offline_align(
395
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
396
+ )
397
+ single_feature["wenet_feat"] = aligned_wenet_feat
398
+
399
+ return single_feature
400
+
401
+ def __len__(self):
402
+ return len(self.metadata)
403
+
404
+
405
+ class SVCTestCollator:
406
+ """Zero-pads model inputs and targets based on number of frames per step"""
407
+
408
+ def __init__(self, cfg):
409
+ self.cfg = cfg
410
+
411
+ def __call__(self, batch):
412
+ packed_batch_features = dict()
413
+
414
+ # mel: [b, T, n_mels]
415
+ # frame_pitch, frame_energy: [1, T]
416
+ # target_len: [1]
417
+ # spk_id: [b, 1]
418
+ # mask: [b, T, 1]
419
+
420
+ for key in batch[0].keys():
421
+ if key == "target_len":
422
+ packed_batch_features["target_len"] = torch.LongTensor(
423
+ [b["target_len"] for b in batch]
424
+ )
425
+ masks = [
426
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
427
+ ]
428
+ packed_batch_features["mask"] = pad_sequence(
429
+ masks, batch_first=True, padding_value=0
430
+ )
431
+ else:
432
+ values = [torch.from_numpy(b[key]) for b in batch]
433
+ packed_batch_features[key] = pad_sequence(
434
+ values, batch_first=True, padding_value=0
435
+ )
436
+
437
+ return packed_batch_features
models/svc/base/svc_inference.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.base.new_inference import BaseInference
7
+ from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
8
+
9
+
10
+ class SVCInference(BaseInference):
11
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
12
+ BaseInference.__init__(self, args, cfg, infer_type)
13
+
14
+ def _build_test_dataset(self):
15
+ return SVCTestDataset, SVCTestCollator
models/svc/base/svc_trainer.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from models.base.new_trainer import BaseTrainer
13
+ from models.svc.base.svc_dataset import SVCCollator, SVCDataset
14
+
15
+
16
+ class SVCTrainer(BaseTrainer):
17
+ r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
18
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
19
+ class, and implement ``_build_model``, ``_forward_step``.
20
+ """
21
+
22
+ def __init__(self, args=None, cfg=None):
23
+ self.args = args
24
+ self.cfg = cfg
25
+
26
+ self._init_accelerator()
27
+
28
+ # Only for SVC tasks
29
+ with self.accelerator.main_process_first():
30
+ self.singers = self._build_singer_lut()
31
+
32
+ # Super init
33
+ BaseTrainer.__init__(self, args, cfg)
34
+
35
+ # Only for SVC tasks
36
+ self.task_type = "SVC"
37
+ self.logger.info("Task type: {}".format(self.task_type))
38
+
39
+ ### Following are methods only for SVC tasks ###
40
+ # TODO: LEGACY CODE, NEED TO BE REFACTORED
41
+ def _build_dataset(self):
42
+ return SVCDataset, SVCCollator
43
+
44
+ @staticmethod
45
+ def _build_criterion():
46
+ criterion = nn.MSELoss(reduction="none")
47
+ return criterion
48
+
49
+ @staticmethod
50
+ def _compute_loss(criterion, y_pred, y_gt, loss_mask):
51
+ """
52
+ Args:
53
+ criterion: MSELoss(reduction='none')
54
+ y_pred, y_gt: (bs, seq_len, D)
55
+ loss_mask: (bs, seq_len, 1)
56
+ Returns:
57
+ loss: Tensor of shape []
58
+ """
59
+
60
+ # (bs, seq_len, D)
61
+ loss = criterion(y_pred, y_gt)
62
+ # expand loss_mask to (bs, seq_len, D)
63
+ loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
64
+
65
+ loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
66
+ return loss
67
+
68
+ def _save_auxiliary_states(self):
69
+ """
70
+ To save the singer's look-up table in the checkpoint saving path
71
+ """
72
+ with open(
73
+ os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
74
+ ) as f:
75
+ json.dump(self.singers, f, indent=4, ensure_ascii=False)
76
+
77
+ def _build_singer_lut(self):
78
+ resumed_singer_path = None
79
+ if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
80
+ resumed_singer_path = os.path.join(
81
+ self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
82
+ )
83
+ if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
84
+ resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
85
+
86
+ if resumed_singer_path:
87
+ with open(resumed_singer_path, "r") as f:
88
+ singers = json.load(f)
89
+ else:
90
+ singers = dict()
91
+
92
+ for dataset in self.cfg.dataset:
93
+ singer_lut_path = os.path.join(
94
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
95
+ )
96
+ with open(singer_lut_path, "r") as singer_lut_path:
97
+ singer_lut = json.load(singer_lut_path)
98
+ for singer in singer_lut.keys():
99
+ if singer not in singers:
100
+ singers[singer] = len(singers)
101
+
102
+ with open(
103
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
104
+ ) as singer_file:
105
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
106
+ print(
107
+ "singers have been dumped to {}".format(
108
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
109
+ )
110
+ )
111
+ return singers
models/svc/comosvc/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
models/svc/comosvc/comosvc.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Adapted from https://github.com/zhenye234/CoMoSpeech"""
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import copy
11
+ import numpy as np
12
+ import math
13
+ from tqdm.auto import tqdm
14
+
15
+ from utils.ssim import SSIM
16
+
17
+ from models.svc.transformer.conformer import Conformer, BaseModule
18
+ from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
19
+ from models.svc.comosvc.utils import slice_segments, rand_ids_segments
20
+
21
+
22
+ class Consistency(nn.Module):
23
+ def __init__(self, cfg, distill=False):
24
+ super().__init__()
25
+ self.cfg = cfg
26
+ # self.denoise_fn = GradLogPEstimator2d(96)
27
+ self.denoise_fn = DiffusionWrapper(self.cfg)
28
+ self.cfg = cfg.model.comosvc
29
+ self.teacher = not distill
30
+ self.P_mean = self.cfg.P_mean
31
+ self.P_std = self.cfg.P_std
32
+ self.sigma_data = self.cfg.sigma_data
33
+ self.sigma_min = self.cfg.sigma_min
34
+ self.sigma_max = self.cfg.sigma_max
35
+ self.rho = self.cfg.rho
36
+ self.N = self.cfg.n_timesteps
37
+ self.ssim_loss = SSIM()
38
+
39
+ # Time step discretization
40
+ step_indices = torch.arange(self.N)
41
+ # karras boundaries formula
42
+ t_steps = (
43
+ self.sigma_min ** (1 / self.rho)
44
+ + step_indices
45
+ / (self.N - 1)
46
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
47
+ ) ** self.rho
48
+ self.t_steps = torch.cat(
49
+ [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
50
+ )
51
+
52
+ def init_consistency_training(self):
53
+ self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
54
+ self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
55
+
56
+ def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
57
+ """
58
+ karras diffusion reverse process
59
+
60
+ Args:
61
+ x: noisy mel-spectrogram [B x n_mel x L]
62
+ sigma: noise level [B x 1 x 1]
63
+ cond: output of conformer encoder [B x n_mel x L]
64
+ denoise_fn: denoiser neural network e.g. DilatedCNN
65
+ mask: mask of padded frames [B x n_mel x L]
66
+
67
+ Returns:
68
+ denoised mel-spectrogram [B x n_mel x L]
69
+ """
70
+ sigma = sigma.reshape(-1, 1, 1)
71
+
72
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
73
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
74
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
75
+ c_noise = sigma.log() / 4
76
+
77
+ x_in = c_in * x
78
+ x_in = x_in.transpose(1, 2)
79
+ x = x.transpose(1, 2)
80
+ cond = cond.transpose(1, 2)
81
+ F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
82
+ # F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten())
83
+ D_x = c_skip * x + c_out * (F_x)
84
+ D_x = D_x.transpose(1, 2)
85
+ return D_x
86
+
87
+ def EDMLoss(self, x_start, cond, mask):
88
+ """
89
+ compute loss for EDM model
90
+
91
+ Args:
92
+ x_start: ground truth mel-spectrogram [B x n_mel x L]
93
+ cond: output of conformer encoder [B x n_mel x L]
94
+ mask: mask of padded frames [B x n_mel x L]
95
+ """
96
+ rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
97
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
98
+ weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
99
+
100
+ # follow Grad-TTS, start from Gaussian noise with mean cond and std I
101
+ noise = (torch.randn_like(x_start) + cond) * sigma
102
+ D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
103
+ loss = weight * ((D_yn - x_start) ** 2)
104
+ loss = torch.sum(loss * mask) / torch.sum(mask)
105
+ return loss
106
+
107
+ def round_sigma(self, sigma):
108
+ return torch.as_tensor(sigma)
109
+
110
+ def edm_sampler(
111
+ self,
112
+ latents,
113
+ cond,
114
+ nonpadding,
115
+ num_steps=50,
116
+ sigma_min=0.002,
117
+ sigma_max=80,
118
+ rho=7,
119
+ S_churn=0,
120
+ S_min=0,
121
+ S_max=float("inf"),
122
+ S_noise=1,
123
+ # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
124
+ # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
125
+ # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
126
+ # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
127
+ ):
128
+ """
129
+ karras diffusion sampler
130
+
131
+ Args:
132
+ latents: noisy mel-spectrogram [B x n_mel x L]
133
+ cond: output of conformer encoder [B x n_mel x L]
134
+ nonpadding: mask of padded frames [B x n_mel x L]
135
+ num_steps: number of steps for diffusion inference
136
+
137
+ Returns:
138
+ denoised mel-spectrogram [B x n_mel x L]
139
+ """
140
+ # Time step discretization.
141
+ step_indices = torch.arange(num_steps, device=latents.device)
142
+
143
+ num_steps = num_steps + 1
144
+ t_steps = (
145
+ sigma_max ** (1 / rho)
146
+ + step_indices
147
+ / (num_steps - 1)
148
+ * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
149
+ ) ** rho
150
+ t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
151
+
152
+ # Main sampling loop.
153
+ x_next = latents * t_steps[0]
154
+ # wrap in tqdm for progress bar
155
+ bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
156
+ for i, (t_cur, t_next) in bar:
157
+ x_cur = x_next
158
+ # Increase noise temporarily.
159
+ gamma = (
160
+ min(S_churn / num_steps, np.sqrt(2) - 1)
161
+ if S_min <= t_cur <= S_max
162
+ else 0
163
+ )
164
+ t_hat = self.round_sigma(t_cur + gamma * t_cur)
165
+ t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
166
+ t[:, 0, 0] = t_hat
167
+ t_hat = t
168
+ x_hat = x_cur + (
169
+ t_hat**2 - t_cur**2
170
+ ).sqrt() * S_noise * torch.randn_like(x_cur)
171
+ # Euler step.
172
+ denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
173
+ d_cur = (x_hat - denoised) / t_hat
174
+ x_next = x_hat + (t_next - t_hat) * d_cur
175
+
176
+ return x_next
177
+
178
+ def CTLoss_D(self, y, cond, mask):
179
+ """
180
+ compute loss for consistency distillation
181
+
182
+ Args:
183
+ y: ground truth mel-spectrogram [B x n_mel x L]
184
+ cond: output of conformer encoder [B x n_mel x L]
185
+ mask: mask of padded frames [B x n_mel x L]
186
+ """
187
+ with torch.no_grad():
188
+ mu = 0.95
189
+ for p, ema_p in zip(
190
+ self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
191
+ ):
192
+ ema_p.mul_(mu).add_(p, alpha=1 - mu)
193
+
194
+ n = torch.randint(1, self.N, (y.shape[0],))
195
+ z = torch.randn_like(y) + cond
196
+
197
+ tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
198
+ f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
199
+
200
+ with torch.no_grad():
201
+ tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
202
+
203
+ # euler step
204
+ x_hat = y + tn_1 * z
205
+ denoised = self.EDMPrecond(
206
+ x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
207
+ )
208
+ d_cur = (x_hat - denoised) / tn_1
209
+ y_tn = x_hat + (tn - tn_1) * d_cur
210
+
211
+ f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
212
+
213
+ # loss = (f_theta - f_theta_ema.detach()) ** 2
214
+ # loss = torch.sum(loss * mask) / torch.sum(mask)
215
+ loss = self.ssim_loss(f_theta, f_theta_ema.detach())
216
+ loss = torch.sum(loss * mask) / torch.sum(mask)
217
+
218
+ return loss
219
+
220
+ def get_t_steps(self, N):
221
+ N = N + 1
222
+ step_indices = torch.arange(N) # , device=latents.device)
223
+ t_steps = (
224
+ self.sigma_min ** (1 / self.rho)
225
+ + step_indices
226
+ / (N - 1)
227
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
228
+ ) ** self.rho
229
+
230
+ return t_steps.flip(0)
231
+
232
+ def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
233
+ """
234
+ consistency distillation sampler
235
+
236
+ Args:
237
+ latents: noisy mel-spectrogram [B x n_mel x L]
238
+ cond: output of conformer encoder [B x n_mel x L]
239
+ nonpadding: mask of padded frames [B x n_mel x L]
240
+ t_steps: number of steps for diffusion inference
241
+
242
+ Returns:
243
+ denoised mel-spectrogram [B x n_mel x L]
244
+ """
245
+ # one-step
246
+ if t_steps == 1:
247
+ t_steps = [80]
248
+ # multi-step
249
+ else:
250
+ t_steps = self.get_t_steps(t_steps)
251
+
252
+ t_steps = torch.as_tensor(t_steps).to(latents.device)
253
+ latents = latents * t_steps[0]
254
+ _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
255
+ _t[:, 0, 0] = t_steps
256
+ x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
257
+
258
+ for t in t_steps[1:-1]:
259
+ z = torch.randn_like(x) + cond
260
+ x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
261
+ _t = torch.zeros((x.shape[0], 1, 1), device=x.device)
262
+ _t[:, 0, 0] = t
263
+ t = _t
264
+ print(t)
265
+ x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
266
+ return x
267
+
268
+ def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
269
+ """
270
+ calculate loss or sample mel-spectrogram
271
+
272
+ Args:
273
+ x:
274
+ training: ground truth mel-spectrogram [B x n_mel x L]
275
+ inference: output of encoder [B x n_mel x L]
276
+ """
277
+ if self.teacher: # teacher model -- karras diffusion
278
+ if not infer:
279
+ loss = self.EDMLoss(x, cond, nonpadding)
280
+ return loss
281
+ else:
282
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
283
+ x = torch.randn(shape, device=x.device) + cond
284
+ x = self.edm_sampler(x, cond, nonpadding, t_steps)
285
+
286
+ return x
287
+ else: # Consistency distillation
288
+ if not infer:
289
+ loss = self.CTLoss_D(x, cond, nonpadding)
290
+ return loss
291
+
292
+ else:
293
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
294
+ x = torch.randn(shape, device=x.device) + cond
295
+ x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
296
+
297
+ return x
298
+
299
+
300
+ class ComoSVC(BaseModule):
301
+ def __init__(self, cfg):
302
+ super().__init__()
303
+ self.cfg = cfg
304
+ self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
305
+ self.distill = self.cfg.model.comosvc.distill
306
+ self.encoder = Conformer(self.cfg.model.comosvc)
307
+ self.decoder = Consistency(self.cfg, distill=self.distill)
308
+ self.ssim_loss = SSIM()
309
+
310
+ @torch.no_grad()
311
+ def forward(self, x_mask, x, n_timesteps, temperature=1.0):
312
+ """
313
+ Generates mel-spectrogram from pitch, content vector, energy. Returns:
314
+ 1. encoder outputs (from conformer)
315
+ 2. decoder outputs (from diffusion-based decoder)
316
+
317
+ Args:
318
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
319
+ x : output of encoder framework. [B x L x d_condition]
320
+ n_timesteps : number of steps to use for reverse diffusion in decoder.
321
+ temperature : controls variance of terminal distribution.
322
+ """
323
+
324
+ # Get encoder_outputs `mu_x`
325
+ mu_x = self.encoder(x, x_mask)
326
+ encoder_outputs = mu_x
327
+
328
+ mu_x = mu_x.transpose(1, 2)
329
+ x_mask = x_mask.transpose(1, 2)
330
+
331
+ # Generate sample by performing reverse dynamics
332
+ decoder_outputs = self.decoder(
333
+ mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
334
+ )
335
+ decoder_outputs = decoder_outputs.transpose(1, 2)
336
+ return encoder_outputs, decoder_outputs
337
+
338
+ def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
339
+ """
340
+ Computes 2 losses:
341
+ 1. prior loss: loss between mel-spectrogram and encoder outputs.
342
+ 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
343
+
344
+ Args:
345
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
346
+ x : output of encoder framework. [B x L x d_condition]
347
+ mel : ground truth mel-spectrogram. [B x L x n_mel]
348
+ """
349
+
350
+ mu_x = self.encoder(x, x_mask)
351
+ # prior loss
352
+ prior_loss = torch.sum(
353
+ 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
354
+ )
355
+ prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
356
+ # ssim loss
357
+ ssim_loss = self.ssim_loss(mu_x, mel)
358
+ ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
359
+
360
+ x_mask = x_mask.transpose(1, 2)
361
+ mu_x = mu_x.transpose(1, 2)
362
+ mel = mel.transpose(1, 2)
363
+ if not self.distill and skip_diff:
364
+ diff_loss = prior_loss.clone()
365
+ diff_loss.fill_(0)
366
+
367
+ # Cut a small segment of mel-spectrogram in order to increase batch size
368
+ else:
369
+ if self.distill:
370
+ mu_y = mu_x.detach()
371
+ else:
372
+ mu_y = mu_x
373
+ mask_y = x_mask
374
+
375
+ diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
376
+
377
+ return ssim_loss, prior_loss, diff_loss
models/svc/comosvc/comosvc_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from models.svc.base import SVCInference
9
+ from modules.encoder.condition_encoder import ConditionEncoder
10
+ from models.svc.comosvc.comosvc import ComoSVC
11
+
12
+
13
+ class ComoSVCInference(SVCInference):
14
+ def __init__(self, args, cfg, infer_type="from_dataset"):
15
+ SVCInference.__init__(self, args, cfg, infer_type)
16
+
17
+ def _build_model(self):
18
+ # TODO: sort out the config
19
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
20
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
21
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
22
+ self.acoustic_mapper = ComoSVC(self.cfg)
23
+ if self.cfg.model.comosvc.distill:
24
+ self.acoustic_mapper.decoder.init_consistency_training()
25
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
26
+ return model
27
+
28
+ def _inference_each_batch(self, batch_data):
29
+ device = self.accelerator.device
30
+ for k, v in batch_data.items():
31
+ batch_data[k] = v.to(device)
32
+
33
+ cond = self.condition_encoder(batch_data)
34
+ mask = batch_data["mask"]
35
+ encoder_pred, decoder_pred = self.acoustic_mapper(
36
+ mask, cond, self.cfg.inference.comosvc.inference_steps
37
+ )
38
+
39
+ return decoder_pred
models/svc/comosvc/comosvc_trainer.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import os
8
+ import json5
9
+ from collections import OrderedDict
10
+ from tqdm import tqdm
11
+ import json
12
+ import shutil
13
+
14
+ from models.svc.base import SVCTrainer
15
+ from modules.encoder.condition_encoder import ConditionEncoder
16
+ from models.svc.comosvc.comosvc import ComoSVC
17
+
18
+
19
+ class ComoSVCTrainer(SVCTrainer):
20
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
21
+ implements ``_build_model`` and ``_forward_step`` methods.
22
+ """
23
+
24
+ def __init__(self, args=None, cfg=None):
25
+ SVCTrainer.__init__(self, args, cfg)
26
+ self.distill = cfg.model.comosvc.distill
27
+ self.skip_diff = True
28
+ if self.distill: # and args.resume is None:
29
+ self.teacher_model_path = cfg.model.teacher_model_path
30
+ self.teacher_state_dict = self._load_teacher_state_dict()
31
+ self._load_teacher_model(self.teacher_state_dict)
32
+ self.acoustic_mapper.decoder.init_consistency_training()
33
+
34
+ ### Following are methods only for comoSVC models ###
35
+ def _load_teacher_state_dict(self):
36
+ self.checkpoint_file = self.teacher_model_path
37
+ print("Load teacher acoustic model from {}".format(self.checkpoint_file))
38
+ raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device)
39
+ return raw_state_dict
40
+
41
+ def _load_teacher_model(self, state_dict):
42
+ raw_dict = state_dict
43
+ clean_dict = OrderedDict()
44
+ for k, v in raw_dict.items():
45
+ if k.startswith("module."):
46
+ clean_dict[k[7:]] = v
47
+ else:
48
+ clean_dict[k] = v
49
+ self.model.load_state_dict(clean_dict)
50
+
51
+ def _build_model(self):
52
+ r"""Build the model for training. This function is called in ``__init__`` function."""
53
+
54
+ # TODO: sort out the config
55
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
56
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
57
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
58
+ self.acoustic_mapper = ComoSVC(self.cfg)
59
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
60
+ return model
61
+
62
+ def _forward_step(self, batch):
63
+ r"""Forward step for training and inference. This function is called
64
+ in ``_train_step`` & ``_test_step`` function.
65
+ """
66
+ loss = {}
67
+ mask = batch["mask"]
68
+ mel_input = batch["mel"]
69
+ cond = self.condition_encoder(batch)
70
+ if self.distill:
71
+ cond = cond.detach()
72
+ self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
73
+ ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
74
+ mask, cond, mel_input, skip_diff=self.skip_diff
75
+ )
76
+ if self.distill:
77
+ loss["distil_loss"] = diff_loss
78
+ else:
79
+ loss["ssim_loss_encoder"] = ssim_loss
80
+ loss["prior_loss_encoder"] = prior_loss
81
+ loss["diffusion_loss_decoder"] = diff_loss
82
+
83
+ return loss
84
+
85
+ def _train_epoch(self):
86
+ r"""Training epoch. Should return average loss of a batch (sample) over
87
+ one epoch. See ``train_loop`` for usage.
88
+ """
89
+ self.model.train()
90
+ epoch_sum_loss: float = 0.0
91
+ epoch_step: int = 0
92
+ for batch in tqdm(
93
+ self.train_dataloader,
94
+ desc=f"Training Epoch {self.epoch}",
95
+ unit="batch",
96
+ colour="GREEN",
97
+ leave=False,
98
+ dynamic_ncols=True,
99
+ smoothing=0.04,
100
+ disable=not self.accelerator.is_main_process,
101
+ ):
102
+ # Do training step and BP
103
+ with self.accelerator.accumulate(self.model):
104
+ loss = self._train_step(batch)
105
+ total_loss = 0
106
+ for k, v in loss.items():
107
+ total_loss += v
108
+ self.accelerator.backward(total_loss)
109
+ enc_grad_norm = torch.nn.utils.clip_grad_norm_(
110
+ self.acoustic_mapper.encoder.parameters(), max_norm=1
111
+ )
112
+ dec_grad_norm = torch.nn.utils.clip_grad_norm_(
113
+ self.acoustic_mapper.decoder.parameters(), max_norm=1
114
+ )
115
+ self.optimizer.step()
116
+ self.optimizer.zero_grad()
117
+ self.batch_count += 1
118
+
119
+ # Update info for each step
120
+ # TODO: step means BP counts or batch counts?
121
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
122
+ epoch_sum_loss += total_loss
123
+ log_info = {}
124
+ for k, v in loss.items():
125
+ key = "Step/Train Loss/{}".format(k)
126
+ log_info[key] = v
127
+ log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
128
+ self.accelerator.log(
129
+ log_info,
130
+ step=self.step,
131
+ )
132
+ self.step += 1
133
+ epoch_step += 1
134
+
135
+ self.accelerator.wait_for_everyone()
136
+ return (
137
+ epoch_sum_loss
138
+ / len(self.train_dataloader)
139
+ * self.cfg.train.gradient_accumulation_step,
140
+ loss,
141
+ )
142
+
143
+ def train_loop(self):
144
+ r"""Training loop. The public entry of training process."""
145
+ # Wait everyone to prepare before we move on
146
+ self.accelerator.wait_for_everyone()
147
+ # dump config file
148
+ if self.accelerator.is_main_process:
149
+ self.__dump_cfg(self.config_save_path)
150
+ self.model.train()
151
+ self.optimizer.zero_grad()
152
+ # Wait to ensure good to go
153
+ self.accelerator.wait_for_everyone()
154
+ while self.epoch < self.max_epoch:
155
+ self.logger.info("\n")
156
+ self.logger.info("-" * 32)
157
+ self.logger.info("Epoch {}: ".format(self.epoch))
158
+
159
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
160
+ ### It's inconvenient for the model with multiple losses
161
+ # Do training & validating epoch
162
+ train_loss, loss = self._train_epoch()
163
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
164
+ for k, v in loss.items():
165
+ self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
166
+ valid_loss = self._valid_epoch()
167
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
168
+ self.accelerator.log(
169
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
170
+ step=self.epoch,
171
+ )
172
+
173
+ self.accelerator.wait_for_everyone()
174
+ # TODO: what is scheduler?
175
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
176
+
177
+ # Check if hit save_checkpoint_stride and run_eval
178
+ run_eval = False
179
+ if self.accelerator.is_main_process:
180
+ save_checkpoint = False
181
+ hit_dix = []
182
+ for i, num in enumerate(self.save_checkpoint_stride):
183
+ if self.epoch % num == 0:
184
+ save_checkpoint = True
185
+ hit_dix.append(i)
186
+ run_eval |= self.run_eval[i]
187
+
188
+ self.accelerator.wait_for_everyone()
189
+ if (
190
+ self.accelerator.is_main_process
191
+ and save_checkpoint
192
+ and (self.distill or not self.skip_diff)
193
+ ):
194
+ path = os.path.join(
195
+ self.checkpoint_dir,
196
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
197
+ self.epoch, self.step, train_loss
198
+ ),
199
+ )
200
+ self.accelerator.save_state(path)
201
+ json.dump(
202
+ self.checkpoints_path,
203
+ open(os.path.join(path, "ckpts.json"), "w"),
204
+ ensure_ascii=False,
205
+ indent=4,
206
+ )
207
+
208
+ # Remove old checkpoints
209
+ to_remove = []
210
+ for idx in hit_dix:
211
+ self.checkpoints_path[idx].append(path)
212
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
213
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
214
+
215
+ # Search conflicts
216
+ total = set()
217
+ for i in self.checkpoints_path:
218
+ total |= set(i)
219
+ do_remove = set()
220
+ for idx, path in to_remove[::-1]:
221
+ if path in total:
222
+ self.checkpoints_path[idx].insert(0, path)
223
+ else:
224
+ do_remove.add(path)
225
+
226
+ # Remove old checkpoints
227
+ for path in do_remove:
228
+ shutil.rmtree(path, ignore_errors=True)
229
+ self.logger.debug(f"Remove old checkpoint: {path}")
230
+
231
+ self.accelerator.wait_for_everyone()
232
+ if run_eval:
233
+ # TODO: run evaluation
234
+ pass
235
+
236
+ # Update info for each epoch
237
+ self.epoch += 1
238
+
239
+ # Finish training and save final checkpoint
240
+ self.accelerator.wait_for_everyone()
241
+ if self.accelerator.is_main_process:
242
+ self.accelerator.save_state(
243
+ os.path.join(
244
+ self.checkpoint_dir,
245
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
246
+ self.epoch, self.step, valid_loss
247
+ ),
248
+ )
249
+ )
250
+ self.accelerator.end_training()
251
+
252
+ @torch.inference_mode()
253
+ def _valid_epoch(self):
254
+ r"""Testing epoch. Should return average loss of a batch (sample) over
255
+ one epoch. See ``train_loop`` for usage.
256
+ """
257
+ self.model.eval()
258
+ epoch_sum_loss = 0.0
259
+ for batch in tqdm(
260
+ self.valid_dataloader,
261
+ desc=f"Validating Epoch {self.epoch}",
262
+ unit="batch",
263
+ colour="GREEN",
264
+ leave=False,
265
+ dynamic_ncols=True,
266
+ smoothing=0.04,
267
+ disable=not self.accelerator.is_main_process,
268
+ ):
269
+ batch_loss = self._valid_step(batch)
270
+ for k, v in batch_loss.items():
271
+ epoch_sum_loss += v
272
+
273
+ self.accelerator.wait_for_everyone()
274
+ return epoch_sum_loss / len(self.valid_dataloader)
275
+
276
+ @staticmethod
277
+ def __count_parameters(model):
278
+ model_param = 0.0
279
+ if isinstance(model, dict):
280
+ for key, value in model.items():
281
+ model_param += sum(p.numel() for p in model[key].parameters())
282
+ else:
283
+ model_param = sum(p.numel() for p in model.parameters())
284
+ return model_param
285
+
286
+ def __dump_cfg(self, path):
287
+ os.makedirs(os.path.dirname(path), exist_ok=True)
288
+ json5.dump(
289
+ self.cfg,
290
+ open(path, "w"),
291
+ indent=4,
292
+ sort_keys=True,
293
+ ensure_ascii=False,
294
+ quote_keys=True,
295
+ )
models/svc/comosvc/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def slice_segments(x, ids_str, segment_size=200):
10
+ ret = torch.zeros_like(x[:, :, :segment_size])
11
+ for i in range(x.size(0)):
12
+ idx_str = ids_str[i]
13
+ idx_end = idx_str + segment_size
14
+ ret[i] = x[i, :, idx_str:idx_end]
15
+ return ret
16
+
17
+
18
+ def rand_ids_segments(lengths, segment_size=200):
19
+ b = lengths.shape[0]
20
+ ids_str_max = lengths - segment_size
21
+ ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
22
+ dtype=torch.long
23
+ )
24
+ return ids_str
25
+
26
+
27
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
28
+ while True:
29
+ if length % (2**num_downsamplings_in_unet) == 0:
30
+ return length
31
+ length += 1
models/svc/diffusion/__init__.py ADDED
File without changes
models/svc/diffusion/diffusion_inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
8
+
9
+ from models.svc.base import SVCInference
10
+ from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
11
+ from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
12
+ from modules.encoder.condition_encoder import ConditionEncoder
13
+
14
+
15
+ class DiffusionInference(SVCInference):
16
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
17
+ SVCInference.__init__(self, args, cfg, infer_type)
18
+
19
+ settings = {
20
+ **cfg.model.diffusion.scheduler_settings,
21
+ **cfg.inference.diffusion.scheduler_settings,
22
+ }
23
+ settings.pop("num_inference_timesteps")
24
+
25
+ if cfg.inference.diffusion.scheduler.lower() == "ddpm":
26
+ self.scheduler = DDPMScheduler(**settings)
27
+ self.logger.info("Using DDPM scheduler.")
28
+ elif cfg.inference.diffusion.scheduler.lower() == "ddim":
29
+ self.scheduler = DDIMScheduler(**settings)
30
+ self.logger.info("Using DDIM scheduler.")
31
+ elif cfg.inference.diffusion.scheduler.lower() == "pndm":
32
+ self.scheduler = PNDMScheduler(**settings)
33
+ self.logger.info("Using PNDM scheduler.")
34
+ else:
35
+ raise NotImplementedError(
36
+ "Unsupported scheduler type: {}".format(
37
+ cfg.inference.diffusion.scheduler.lower()
38
+ )
39
+ )
40
+
41
+ self.pipeline = DiffusionInferencePipeline(
42
+ self.model[1],
43
+ self.scheduler,
44
+ cfg.inference.diffusion.scheduler_settings.num_inference_timesteps,
45
+ )
46
+
47
+ def _build_model(self):
48
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
49
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
50
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
51
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
52
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
53
+ return model
54
+
55
+ def _inference_each_batch(self, batch_data):
56
+ device = self.accelerator.device
57
+ for k, v in batch_data.items():
58
+ batch_data[k] = v.to(device)
59
+
60
+ conditioner = self.model[0](batch_data)
61
+ noise = torch.randn_like(batch_data["mel"], device=device)
62
+ y_pred = self.pipeline(noise, conditioner)
63
+ return y_pred
models/svc/diffusion/diffusion_inference_pipeline.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+
9
+
10
+ class DiffusionInferencePipeline(DiffusionPipeline):
11
+ def __init__(self, network, scheduler, num_inference_timesteps=1000):
12
+ super().__init__()
13
+
14
+ self.register_modules(network=network, scheduler=scheduler)
15
+ self.num_inference_timesteps = num_inference_timesteps
16
+
17
+ @torch.inference_mode()
18
+ def __call__(
19
+ self,
20
+ initial_noise: torch.Tensor,
21
+ conditioner: torch.Tensor = None,
22
+ ):
23
+ r"""
24
+ Args:
25
+ initial_noise: The initial noise to be denoised.
26
+ conditioner:The conditioner.
27
+ n_inference_steps: The number of denoising steps. More denoising steps
28
+ usually lead to a higher quality at the expense of slower inference.
29
+ """
30
+
31
+ mel = initial_noise
32
+ batch_size = mel.size(0)
33
+ self.scheduler.set_timesteps(self.num_inference_timesteps)
34
+
35
+ for t in self.progress_bar(self.scheduler.timesteps):
36
+ timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
37
+
38
+ # 1. predict noise model_output
39
+ model_output = self.network(mel, timestep, conditioner)
40
+
41
+ # 2. denoise, compute previous step: x_t -> x_t-1
42
+ mel = self.scheduler.step(model_output, t, mel).prev_sample
43
+
44
+ # 3. clamp
45
+ mel = mel.clamp(-1.0, 1.0)
46
+
47
+ return mel
models/svc/diffusion/diffusion_trainer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DDPMScheduler
8
+
9
+ from models.svc.base import SVCTrainer
10
+ from modules.encoder.condition_encoder import ConditionEncoder
11
+ from .diffusion_wrapper import DiffusionWrapper
12
+
13
+
14
+ class DiffusionTrainer(SVCTrainer):
15
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
16
+ implements ``_build_model`` and ``_forward_step`` methods.
17
+ """
18
+
19
+ def __init__(self, args=None, cfg=None):
20
+ SVCTrainer.__init__(self, args, cfg)
21
+
22
+ # Only for SVC tasks using diffusion
23
+ self.noise_scheduler = DDPMScheduler(
24
+ **self.cfg.model.diffusion.scheduler_settings,
25
+ )
26
+ self.diffusion_timesteps = (
27
+ self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
28
+ )
29
+
30
+ ### Following are methods only for diffusion models ###
31
+ def _build_model(self):
32
+ r"""Build the model for training. This function is called in ``__init__`` function."""
33
+
34
+ # TODO: sort out the config
35
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
36
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
37
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
38
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
39
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
40
+
41
+ num_of_params_encoder = self.count_parameters(self.condition_encoder)
42
+ num_of_params_am = self.count_parameters(self.acoustic_mapper)
43
+ num_of_params = num_of_params_encoder + num_of_params_am
44
+ log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
45
+ num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
46
+ )
47
+ self.logger.info(log)
48
+
49
+ return model
50
+
51
+ def count_parameters(self, model):
52
+ model_param = 0.0
53
+ if isinstance(model, dict):
54
+ for key, value in model.items():
55
+ model_param += sum(p.numel() for p in model[key].parameters())
56
+ else:
57
+ model_param = sum(p.numel() for p in model.parameters())
58
+ return model_param
59
+
60
+ def _forward_step(self, batch):
61
+ r"""Forward step for training and inference. This function is called
62
+ in ``_train_step`` & ``_test_step`` function.
63
+ """
64
+
65
+ device = self.accelerator.device
66
+
67
+ mel_input = batch["mel"]
68
+ noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
69
+ batch_size = mel_input.size(0)
70
+ timesteps = torch.randint(
71
+ 0,
72
+ self.diffusion_timesteps,
73
+ (batch_size,),
74
+ device=device,
75
+ dtype=torch.long,
76
+ )
77
+
78
+ noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
79
+ conditioner = self.condition_encoder(batch)
80
+
81
+ y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
82
+
83
+ # TODO: Predict noise or gt should be configurable
84
+ loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
85
+ self._check_nan(loss, y_pred, noise)
86
+
87
+ # FIXME: Clarify that we should not divide it with batch size here
88
+ return loss
models/svc/diffusion/diffusion_wrapper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+
8
+ from modules.diffusion import BiDilConv
9
+ from modules.encoder.position_encoder import PositionEncoder
10
+
11
+
12
+ class DiffusionWrapper(nn.Module):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+
16
+ self.cfg = cfg
17
+ self.diff_cfg = cfg.model.diffusion
18
+
19
+ self.diff_encoder = PositionEncoder(
20
+ d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
21
+ d_out=self.diff_cfg.bidilconv.base_channel,
22
+ d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
23
+ activation_function=self.diff_cfg.step_encoder.activation,
24
+ n_layer=self.diff_cfg.step_encoder.num_layer,
25
+ max_period=self.diff_cfg.step_encoder.max_period,
26
+ )
27
+
28
+ # FIXME: Only support BiDilConv now for debug
29
+ if self.diff_cfg.model_type.lower() == "bidilconv":
30
+ self.neural_network = BiDilConv(
31
+ input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
32
+ )
33
+ else:
34
+ raise ValueError(
35
+ f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
36
+ )
37
+
38
+ def forward(self, x, t, c):
39
+ """
40
+ Args:
41
+ x: [N, T, mel_band] of mel spectrogram
42
+ t: Diffusion time step with shape of [N]
43
+ c: [N, T, conditioner_size] of conditioner
44
+
45
+ Returns:
46
+ [N, T, mel_band] of mel spectrogram
47
+ """
48
+
49
+ assert (
50
+ x.size()[:-1] == c.size()[:-1]
51
+ ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
52
+ assert x.size(0) == t.size(
53
+ 0
54
+ ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
55
+ assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
56
+
57
+ N, T, mel_band = x.size()
58
+
59
+ x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
60
+ c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
61
+ t = self.diff_encoder(t).contiguous() # [N, base_channel]
62
+
63
+ h = self.neural_network(x, t, c)
64
+ h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
65
+
66
+ assert h.size() == (
67
+ N,
68
+ T,
69
+ mel_band,
70
+ ), "h mismatch with input x, got \n h: {} \n x: {}".format(
71
+ h.size(), (N, T, mel_band)
72
+ )
73
+ return h
models/svc/transformer/__init__.py ADDED
File without changes
models/svc/transformer/conformer.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import numpy as np
9
+ import torch.nn as nn
10
+ from utils.util import convert_pad_shape
11
+
12
+
13
+ class BaseModule(torch.nn.Module):
14
+ def __init__(self):
15
+ super(BaseModule, self).__init__()
16
+
17
+ @property
18
+ def nparams(self):
19
+ """
20
+ Returns number of trainable parameters of the module.
21
+ """
22
+ num_params = 0
23
+ for name, param in self.named_parameters():
24
+ if param.requires_grad:
25
+ num_params += np.prod(param.detach().cpu().numpy().shape)
26
+ return num_params
27
+
28
+ def relocate_input(self, x: list):
29
+ """
30
+ Relocates provided tensors to the same device set for the module.
31
+ """
32
+ device = next(self.parameters()).device
33
+ for i in range(len(x)):
34
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
35
+ x[i] = x[i].to(device)
36
+ return x
37
+
38
+
39
+ class LayerNorm(BaseModule):
40
+ def __init__(self, channels, eps=1e-4):
41
+ super(LayerNorm, self).__init__()
42
+ self.channels = channels
43
+ self.eps = eps
44
+
45
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
46
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
47
+
48
+ def forward(self, x):
49
+ n_dims = len(x.shape)
50
+ mean = torch.mean(x, 1, keepdim=True)
51
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
52
+
53
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
54
+
55
+ shape = [1, -1] + [1] * (n_dims - 2)
56
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
57
+ return x
58
+
59
+
60
+ class ConvReluNorm(BaseModule):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ hidden_channels,
65
+ out_channels,
66
+ kernel_size,
67
+ n_layers,
68
+ p_dropout,
69
+ eps=1e-5,
70
+ ):
71
+ super(ConvReluNorm, self).__init__()
72
+ self.in_channels = in_channels
73
+ self.hidden_channels = hidden_channels
74
+ self.out_channels = out_channels
75
+ self.kernel_size = kernel_size
76
+ self.n_layers = n_layers
77
+ self.p_dropout = p_dropout
78
+ self.eps = eps
79
+
80
+ self.conv_layers = torch.nn.ModuleList()
81
+ self.conv_layers.append(
82
+ torch.nn.Conv1d(
83
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
84
+ )
85
+ )
86
+ self.relu_drop = torch.nn.Sequential(
87
+ torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
88
+ )
89
+ for _ in range(n_layers - 1):
90
+ self.conv_layers.append(
91
+ torch.nn.Conv1d(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ kernel_size,
95
+ padding=kernel_size // 2,
96
+ )
97
+ )
98
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
99
+ self.proj.weight.data.zero_()
100
+ self.proj.bias.data.zero_()
101
+
102
+ def forward(self, x, x_mask):
103
+ for i in range(self.n_layers):
104
+ x = self.conv_layers[i](x * x_mask)
105
+ x = self.instance_norm(x, x_mask)
106
+ x = self.relu_drop(x)
107
+ x = self.proj(x)
108
+ return x * x_mask
109
+
110
+ def instance_norm(self, x, mask, return_mean_std=False):
111
+ mean, std = self.calc_mean_std(x, mask)
112
+ x = (x - mean) / std
113
+ if return_mean_std:
114
+ return x, mean, std
115
+ else:
116
+ return x
117
+
118
+ def calc_mean_std(self, x, mask=None):
119
+ x = x * mask
120
+ B, C = x.shape[:2]
121
+ mn = x.view(B, C, -1).mean(-1)
122
+ sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
123
+ mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
124
+ sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
125
+ return mn, sd
126
+
127
+
128
+ class MultiHeadAttention(BaseModule):
129
+ def __init__(
130
+ self,
131
+ channels,
132
+ out_channels,
133
+ n_heads,
134
+ window_size=None,
135
+ heads_share=True,
136
+ p_dropout=0.0,
137
+ proximal_bias=False,
138
+ proximal_init=False,
139
+ ):
140
+ super(MultiHeadAttention, self).__init__()
141
+ assert channels % n_heads == 0
142
+
143
+ self.channels = channels
144
+ self.out_channels = out_channels
145
+ self.n_heads = n_heads
146
+ self.window_size = window_size
147
+ self.heads_share = heads_share
148
+ self.proximal_bias = proximal_bias
149
+ self.p_dropout = p_dropout
150
+ self.attn = None
151
+
152
+ self.k_channels = channels // n_heads
153
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
154
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
155
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
156
+ if window_size is not None:
157
+ n_heads_rel = 1 if heads_share else n_heads
158
+ rel_stddev = self.k_channels**-0.5
159
+ self.emb_rel_k = torch.nn.Parameter(
160
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
161
+ * rel_stddev
162
+ )
163
+ self.emb_rel_v = torch.nn.Parameter(
164
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
165
+ * rel_stddev
166
+ )
167
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
168
+ self.drop = torch.nn.Dropout(p_dropout)
169
+
170
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
171
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
172
+ if proximal_init:
173
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
174
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
175
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
176
+
177
+ def forward(self, x, c, attn_mask=None):
178
+ q = self.conv_q(x)
179
+ k = self.conv_k(c)
180
+ v = self.conv_v(c)
181
+
182
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
+
184
+ x = self.conv_o(x)
185
+ return x
186
+
187
+ def attention(self, query, key, value, mask=None):
188
+ b, d, t_s, t_t = (*key.size(), query.size(2))
189
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
190
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
191
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
+
193
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
194
+ if self.window_size is not None:
195
+ assert (
196
+ t_s == t_t
197
+ ), "Relative attention is only available for self-attention."
198
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
199
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
200
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
201
+ scores_local = rel_logits / math.sqrt(self.k_channels)
202
+ scores = scores + scores_local
203
+ if self.proximal_bias:
204
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
205
+ scores = scores + self._attention_bias_proximal(t_s).to(
206
+ device=scores.device, dtype=scores.dtype
207
+ )
208
+ if mask is not None:
209
+ scores = scores.masked_fill(mask == 0, -1e4)
210
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
211
+ p_attn = self.drop(p_attn)
212
+ output = torch.matmul(p_attn, value)
213
+ if self.window_size is not None:
214
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
215
+ value_relative_embeddings = self._get_relative_embeddings(
216
+ self.emb_rel_v, t_s
217
+ )
218
+ output = output + self._matmul_with_relative_values(
219
+ relative_weights, value_relative_embeddings
220
+ )
221
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
222
+ return output, p_attn
223
+
224
+ def _matmul_with_relative_values(self, x, y):
225
+ ret = torch.matmul(x, y.unsqueeze(0))
226
+ return ret
227
+
228
+ def _matmul_with_relative_keys(self, x, y):
229
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
230
+ return ret
231
+
232
+ def _get_relative_embeddings(self, relative_embeddings, length):
233
+ pad_length = max(length - (self.window_size + 1), 0)
234
+ slice_start_position = max((self.window_size + 1) - length, 0)
235
+ slice_end_position = slice_start_position + 2 * length - 1
236
+ if pad_length > 0:
237
+ padded_relative_embeddings = torch.nn.functional.pad(
238
+ relative_embeddings,
239
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
240
+ )
241
+ else:
242
+ padded_relative_embeddings = relative_embeddings
243
+ used_relative_embeddings = padded_relative_embeddings[
244
+ :, slice_start_position:slice_end_position
245
+ ]
246
+ return used_relative_embeddings
247
+
248
+ def _relative_position_to_absolute_position(self, x):
249
+ batch, heads, length, _ = x.size()
250
+ x = torch.nn.functional.pad(
251
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
252
+ )
253
+ x_flat = x.view([batch, heads, length * 2 * length])
254
+ x_flat = torch.nn.functional.pad(
255
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
256
+ )
257
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
258
+ :, :, :length, length - 1 :
259
+ ]
260
+ return x_final
261
+
262
+ def _absolute_position_to_relative_position(self, x):
263
+ batch, heads, length, _ = x.size()
264
+ x = torch.nn.functional.pad(
265
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
266
+ )
267
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
268
+ x_flat = torch.nn.functional.pad(
269
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
270
+ )
271
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
272
+ return x_final
273
+
274
+ def _attention_bias_proximal(self, length):
275
+ r = torch.arange(length, dtype=torch.float32)
276
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
277
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
278
+
279
+
280
+ class FFN(BaseModule):
281
+ def __init__(
282
+ self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
283
+ ):
284
+ super(FFN, self).__init__()
285
+ self.in_channels = in_channels
286
+ self.out_channels = out_channels
287
+ self.filter_channels = filter_channels
288
+ self.kernel_size = kernel_size
289
+ self.p_dropout = p_dropout
290
+
291
+ self.conv_1 = torch.nn.Conv1d(
292
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
293
+ )
294
+ self.conv_2 = torch.nn.Conv1d(
295
+ filter_channels, out_channels, kernel_size, padding=kernel_size // 2
296
+ )
297
+ self.drop = torch.nn.Dropout(p_dropout)
298
+
299
+ def forward(self, x, x_mask):
300
+ x = self.conv_1(x * x_mask)
301
+ x = torch.relu(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ return x * x_mask
305
+
306
+
307
+ class Encoder(BaseModule):
308
+ def __init__(
309
+ self,
310
+ hidden_channels,
311
+ filter_channels,
312
+ n_heads=2,
313
+ n_layers=6,
314
+ kernel_size=3,
315
+ p_dropout=0.1,
316
+ window_size=4,
317
+ **kwargs
318
+ ):
319
+ super(Encoder, self).__init__()
320
+ self.hidden_channels = hidden_channels
321
+ self.filter_channels = filter_channels
322
+ self.n_heads = n_heads
323
+ self.n_layers = n_layers
324
+ self.kernel_size = kernel_size
325
+ self.p_dropout = p_dropout
326
+ self.window_size = window_size
327
+
328
+ self.drop = torch.nn.Dropout(p_dropout)
329
+ self.attn_layers = torch.nn.ModuleList()
330
+ self.norm_layers_1 = torch.nn.ModuleList()
331
+ self.ffn_layers = torch.nn.ModuleList()
332
+ self.norm_layers_2 = torch.nn.ModuleList()
333
+ for _ in range(self.n_layers):
334
+ self.attn_layers.append(
335
+ MultiHeadAttention(
336
+ hidden_channels,
337
+ hidden_channels,
338
+ n_heads,
339
+ window_size=window_size,
340
+ p_dropout=p_dropout,
341
+ )
342
+ )
343
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
344
+ self.ffn_layers.append(
345
+ FFN(
346
+ hidden_channels,
347
+ hidden_channels,
348
+ filter_channels,
349
+ kernel_size,
350
+ p_dropout=p_dropout,
351
+ )
352
+ )
353
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
354
+
355
+ def forward(self, x, x_mask):
356
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
357
+ for i in range(self.n_layers):
358
+ x = x * x_mask
359
+ y = self.attn_layers[i](x, x, attn_mask)
360
+ y = self.drop(y)
361
+ x = self.norm_layers_1[i](x + y)
362
+ y = self.ffn_layers[i](x, x_mask)
363
+ y = self.drop(y)
364
+ x = self.norm_layers_2[i](x + y)
365
+ x = x * x_mask
366
+ return x
367
+
368
+
369
+ class Conformer(BaseModule):
370
+ def __init__(self, cfg):
371
+ super().__init__()
372
+ self.cfg = cfg
373
+ self.n_heads = self.cfg.n_heads
374
+ self.n_layers = self.cfg.n_layers
375
+ self.hidden_channels = self.cfg.input_dim
376
+ self.filter_channels = self.cfg.filter_channels
377
+ self.output_dim = self.cfg.output_dim
378
+ self.dropout = self.cfg.dropout
379
+
380
+ self.conformer_encoder = Encoder(
381
+ self.hidden_channels,
382
+ self.filter_channels,
383
+ n_heads=self.n_heads,
384
+ n_layers=self.n_layers,
385
+ kernel_size=3,
386
+ p_dropout=self.dropout,
387
+ window_size=4,
388
+ )
389
+ self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
390
+
391
+ def forward(self, x, x_mask):
392
+ """
393
+ Args:
394
+ x: (N, seq_len, input_dim)
395
+ Returns:
396
+ output: (N, seq_len, output_dim)
397
+ """
398
+ # (N, seq_len, d_model)
399
+ x = x.transpose(1, 2)
400
+ x_mask = x_mask.transpose(1, 2)
401
+ output = self.conformer_encoder(x, x_mask)
402
+ # (N, seq_len, output_dim)
403
+ output = self.projection(output)
404
+ output = output.transpose(1, 2)
405
+ return output
models/svc/transformer/transformer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
10
+
11
+
12
+ class Transformer(nn.Module):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+
17
+ dropout = self.cfg.dropout
18
+ nhead = self.cfg.n_heads
19
+ nlayers = self.cfg.n_layers
20
+ input_dim = self.cfg.input_dim
21
+ output_dim = self.cfg.output_dim
22
+
23
+ d_model = input_dim
24
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
25
+ encoder_layers = TransformerEncoderLayer(
26
+ d_model, nhead, dropout=dropout, batch_first=True
27
+ )
28
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
29
+
30
+ self.output_mlp = nn.Linear(d_model, output_dim)
31
+
32
+ def forward(self, x, mask=None):
33
+ """
34
+ Args:
35
+ x: (N, seq_len, input_dim)
36
+ Returns:
37
+ output: (N, seq_len, output_dim)
38
+ """
39
+ # (N, seq_len, d_model)
40
+ src = self.pos_encoder(x)
41
+ # model_stats["pos_embedding"] = x
42
+ # (N, seq_len, d_model)
43
+ output = self.transformer_encoder(src)
44
+ # (N, seq_len, output_dim)
45
+ output = self.output_mlp(output)
46
+ return output
47
+
48
+
49
+ class PositionalEncoding(nn.Module):
50
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
51
+ super().__init__()
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ position = torch.arange(max_len).unsqueeze(1)
55
+ div_term = torch.exp(
56
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
57
+ )
58
+
59
+ # Assume that x is (seq_len, N, d)
60
+ # pe = torch.zeros(max_len, 1, d_model)
61
+ # pe[:, 0, 0::2] = torch.sin(position * div_term)
62
+ # pe[:, 0, 1::2] = torch.cos(position * div_term)
63
+
64
+ # Assume that x in (N, seq_len, d)
65
+ pe = torch.zeros(1, max_len, d_model)
66
+ pe[0, :, 0::2] = torch.sin(position * div_term)
67
+ pe[0, :, 1::2] = torch.cos(position * div_term)
68
+
69
+ self.register_buffer("pe", pe)
70
+
71
+ def forward(self, x):
72
+ """
73
+ Args:
74
+ x: Tensor, shape [N, seq_len, d]
75
+ """
76
+ # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
77
+ # x = x + self.pe[: x.size(0)]
78
+
79
+ # Now: self.pe is (1, max_len, d)
80
+ x = x + self.pe[:, : x.size(1), :]
81
+
82
+ return self.dropout(x)
models/svc/transformer/transformer_inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import torch.nn as nn
12
+ from collections import OrderedDict
13
+
14
+ from models.svc.base import SVCInference
15
+ from modules.encoder.condition_encoder import ConditionEncoder
16
+ from models.svc.transformer.transformer import Transformer
17
+ from models.svc.transformer.conformer import Conformer
18
+
19
+
20
+ class TransformerInference(SVCInference):
21
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
22
+ SVCInference.__init__(self, args, cfg, infer_type)
23
+
24
+ def _build_model(self):
25
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
26
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
27
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
28
+ if self.cfg.model.transformer.type == "transformer":
29
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
30
+ elif self.cfg.model.transformer.type == "conformer":
31
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
32
+ else:
33
+ raise NotImplementedError
34
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
35
+ return model
36
+
37
+ def _inference_each_batch(self, batch_data):
38
+ device = self.accelerator.device
39
+ for k, v in batch_data.items():
40
+ batch_data[k] = v.to(device)
41
+
42
+ condition = self.condition_encoder(batch_data)
43
+ y_pred = self.acoustic_mapper(condition, batch_data["mask"])
44
+
45
+ return y_pred
models/svc/transformer/transformer_trainer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from models.svc.base import SVCTrainer
9
+ from modules.encoder.condition_encoder import ConditionEncoder
10
+ from models.svc.transformer.transformer import Transformer
11
+ from models.svc.transformer.conformer import Conformer
12
+ from utils.ssim import SSIM
13
+
14
+
15
+ class TransformerTrainer(SVCTrainer):
16
+ def __init__(self, args, cfg):
17
+ SVCTrainer.__init__(self, args, cfg)
18
+ self.ssim_loss = SSIM()
19
+
20
+ def _build_model(self):
21
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
22
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
23
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
24
+ if self.cfg.model.transformer.type == "transformer":
25
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
26
+ elif self.cfg.model.transformer.type == "conformer":
27
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
28
+ else:
29
+ raise NotImplementedError
30
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
31
+ return model
32
+
33
+ def _forward_step(self, batch):
34
+ total_loss = 0
35
+ device = self.accelerator.device
36
+ mel = batch["mel"]
37
+ mask = batch["mask"]
38
+
39
+ condition = self.condition_encoder(batch)
40
+ mel_pred = self.acoustic_mapper(condition, mask)
41
+
42
+ l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
43
+ batch["mask"]
44
+ )
45
+ self._check_nan(l1_loss, mel_pred, mel)
46
+ total_loss += l1_loss
47
+ ssim_loss = self.ssim_loss(mel_pred, mel)
48
+ ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
49
+ self._check_nan(ssim_loss, mel_pred, mel)
50
+ total_loss += ssim_loss
51
+
52
+ return total_loss
models/svc/vits/__init__.py ADDED
File without changes
models/svc/vits/vits.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ # This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.1-Stable/models.py
7
+ import copy
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+
12
+ from utils.util import *
13
+ from utils.f0 import f0_to_coarse
14
+
15
+ from modules.transformer.attentions import Encoder
16
+ from models.tts.vits.vits import ResidualCouplingBlock, PosteriorEncoder
17
+ from models.vocoders.gan.generator.bigvgan import BigVGAN
18
+ from models.vocoders.gan.generator.hifigan import HiFiGAN
19
+ from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN
20
+ from models.vocoders.gan.generator.melgan import MelGAN
21
+ from models.vocoders.gan.generator.apnet import APNet
22
+ from modules.encoder.condition_encoder import ConditionEncoder
23
+
24
+
25
+ def slice_pitch_segments(x, ids_str, segment_size=4):
26
+ ret = torch.zeros_like(x[:, :segment_size])
27
+ for i in range(x.size(0)):
28
+ idx_str = ids_str[i]
29
+ idx_end = idx_str + segment_size
30
+ ret[i] = x[i, idx_str:idx_end]
31
+ return ret
32
+
33
+
34
+ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
35
+ b, d, t = x.size()
36
+ if x_lengths is None:
37
+ x_lengths = t
38
+ ids_str_max = x_lengths - segment_size + 1
39
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
40
+ ret = slice_segments(x, ids_str, segment_size)
41
+ ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
42
+ return ret, ret_pitch, ids_str
43
+
44
+
45
+ class ContentEncoder(nn.Module):
46
+ def __init__(
47
+ self,
48
+ out_channels,
49
+ hidden_channels,
50
+ kernel_size,
51
+ n_layers,
52
+ gin_channels=0,
53
+ filter_channels=None,
54
+ n_heads=None,
55
+ p_dropout=None,
56
+ ):
57
+ super().__init__()
58
+ self.out_channels = out_channels
59
+ self.hidden_channels = hidden_channels
60
+ self.kernel_size = kernel_size
61
+ self.n_layers = n_layers
62
+ self.gin_channels = gin_channels
63
+
64
+ self.f0_emb = nn.Embedding(256, hidden_channels)
65
+
66
+ self.enc_ = Encoder(
67
+ hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
68
+ )
69
+
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
71
+
72
+ # condition_encoder ver.
73
+ def forward(self, x, x_mask, noice_scale=1):
74
+ x = self.enc_(x * x_mask, x_mask)
75
+ stats = self.proj(x) * x_mask
76
+ m, logs = torch.split(stats, self.out_channels, dim=1)
77
+ z = (m + torch.randn_like(m) * torch.exp(logs) * noice_scale) * x_mask
78
+
79
+ return z, m, logs, x_mask
80
+
81
+
82
+ class SynthesizerTrn(nn.Module):
83
+ """
84
+ Synthesizer for Training
85
+ """
86
+
87
+ def __init__(self, spec_channels, segment_size, cfg):
88
+ super().__init__()
89
+ self.spec_channels = spec_channels
90
+ self.segment_size = segment_size
91
+ self.cfg = cfg
92
+ self.inter_channels = cfg.model.vits.inter_channels
93
+ self.hidden_channels = cfg.model.vits.hidden_channels
94
+ self.filter_channels = cfg.model.vits.filter_channels
95
+ self.n_heads = cfg.model.vits.n_heads
96
+ self.n_layers = cfg.model.vits.n_layers
97
+ self.kernel_size = cfg.model.vits.kernel_size
98
+ self.p_dropout = cfg.model.vits.p_dropout
99
+ self.ssl_dim = cfg.model.vits.ssl_dim
100
+ self.n_flow_layer = cfg.model.vits.n_flow_layer
101
+ self.gin_channels = cfg.model.vits.gin_channels
102
+ self.n_speakers = cfg.model.vits.n_speakers
103
+
104
+ # f0
105
+ self.n_bins = cfg.preprocess.pitch_bin
106
+ self.f0_min = cfg.preprocess.f0_min
107
+ self.f0_max = cfg.preprocess.f0_max
108
+
109
+ # TODO: sort out the config
110
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
111
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
112
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
113
+
114
+ self.emb_g = nn.Embedding(self.n_speakers, self.gin_channels)
115
+
116
+ self.enc_p = ContentEncoder(
117
+ self.inter_channels,
118
+ self.hidden_channels,
119
+ filter_channels=self.filter_channels,
120
+ n_heads=self.n_heads,
121
+ n_layers=self.n_layers,
122
+ kernel_size=self.kernel_size,
123
+ p_dropout=self.p_dropout,
124
+ )
125
+
126
+ assert cfg.model.generator in [
127
+ "bigvgan",
128
+ "hifigan",
129
+ "melgan",
130
+ "nsfhifigan",
131
+ "apnet",
132
+ ]
133
+ self.dec_name = cfg.model.generator
134
+ temp_cfg = copy.deepcopy(cfg)
135
+ temp_cfg.preprocess.n_mel = self.inter_channels
136
+ if cfg.model.generator == "bigvgan":
137
+ temp_cfg.model.bigvgan = cfg.model.generator_config.bigvgan
138
+ self.dec = BigVGAN(temp_cfg)
139
+ elif cfg.model.generator == "hifigan":
140
+ temp_cfg.model.hifigan = cfg.model.generator_config.hifigan
141
+ self.dec = HiFiGAN(temp_cfg)
142
+ elif cfg.model.generator == "melgan":
143
+ temp_cfg.model.melgan = cfg.model.generator_config.melgan
144
+ self.dec = MelGAN(temp_cfg)
145
+ elif cfg.model.generator == "nsfhifigan":
146
+ temp_cfg.model.nsfhifigan = cfg.model.generator_config.nsfhifigan
147
+ self.dec = NSFHiFiGAN(temp_cfg) # TODO: nsf need f0
148
+ elif cfg.model.generator == "apnet":
149
+ temp_cfg.model.apnet = cfg.model.generator_config.apnet
150
+ self.dec = APNet(temp_cfg)
151
+
152
+ self.enc_q = PosteriorEncoder(
153
+ self.spec_channels,
154
+ self.inter_channels,
155
+ self.hidden_channels,
156
+ 5,
157
+ 1,
158
+ 16,
159
+ gin_channels=self.gin_channels,
160
+ )
161
+
162
+ self.flow = ResidualCouplingBlock(
163
+ self.inter_channels,
164
+ self.hidden_channels,
165
+ 5,
166
+ 1,
167
+ self.n_flow_layer,
168
+ gin_channels=self.gin_channels,
169
+ )
170
+
171
+ def forward(self, data):
172
+ """VitsSVC forward function.
173
+
174
+ Args:
175
+ data (dict): condition data & audio data, including:
176
+ B: batch size, T: target length
177
+ {
178
+ "spk_id": [B, singer_table_size]
179
+ "target_len": [B]
180
+ "mask": [B, T, 1]
181
+ "mel": [B, T, n_mel]
182
+ "linear": [B, T, n_fft // 2 + 1]
183
+ "frame_pitch": [B, T]
184
+ "frame_uv": [B, T]
185
+ "audio": [B, audio_len]
186
+ "audio_len": [B]
187
+ "contentvec_feat": [B, T, contentvec_dim]
188
+ "whisper_feat": [B, T, whisper_dim]
189
+ ...
190
+ }
191
+ """
192
+
193
+ # TODO: elegantly handle the dimensions
194
+ c = data["contentvec_feat"].transpose(1, 2)
195
+ spec = data["linear"].transpose(1, 2)
196
+
197
+ g = data["spk_id"]
198
+ g = self.emb_g(g).transpose(1, 2)
199
+
200
+ c_lengths = data["target_len"]
201
+ spec_lengths = data["target_len"]
202
+ f0 = data["frame_pitch"]
203
+
204
+ x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
205
+ # condition_encoder ver.
206
+ x = self.condition_encoder(data).transpose(1, 2)
207
+
208
+ # prior encoder
209
+ z_ptemp, m_p, logs_p, _ = self.enc_p(x, x_mask)
210
+ # posterior encoder
211
+ z, m_q, logs_q, spec_mask = self.enc_q(spec, spec_lengths, g=g)
212
+
213
+ # flow
214
+ z_p = self.flow(z, spec_mask, g=g)
215
+ z_slice, pitch_slice, ids_slice = rand_slice_segments_with_pitch(
216
+ z, f0, spec_lengths, self.segment_size
217
+ )
218
+
219
+ if self.dec_name == "nsfhifigan":
220
+ o = self.dec(z_slice, f0=f0.float())
221
+ elif self.dec_name == "apnet":
222
+ _, _, _, _, o = self.dec(z_slice)
223
+ else:
224
+ o = self.dec(z_slice)
225
+
226
+ outputs = {
227
+ "y_hat": o,
228
+ "ids_slice": ids_slice,
229
+ "x_mask": x_mask,
230
+ "z_mask": data["mask"].transpose(1, 2),
231
+ "z": z,
232
+ "z_p": z_p,
233
+ "m_p": m_p,
234
+ "logs_p": logs_p,
235
+ "m_q": m_q,
236
+ "logs_q": logs_q,
237
+ }
238
+ return outputs
239
+
240
+ @torch.no_grad()
241
+ def infer(self, data, noise_scale=0.35, seed=52468):
242
+ # c, f0, uv, g
243
+ c = data["contentvec_feat"].transpose(1, 2)
244
+ f0 = data["frame_pitch"]
245
+ g = data["spk_id"]
246
+
247
+ if c.device == torch.device("cuda"):
248
+ torch.cuda.manual_seed_all(seed)
249
+ else:
250
+ torch.manual_seed(seed)
251
+
252
+ c_lengths = (torch.ones(c.size(0)) * c.size(-1)).to(c.device)
253
+
254
+ if g.dim() == 1:
255
+ g = g.unsqueeze(0)
256
+ g = self.emb_g(g).transpose(1, 2)
257
+
258
+ x_mask = torch.unsqueeze(sequence_mask(c_lengths, c.size(2)), 1).to(c.dtype)
259
+ # condition_encoder ver.
260
+ x = self.condition_encoder(data).transpose(1, 2)
261
+
262
+ z_p, m_p, logs_p, c_mask = self.enc_p(x, x_mask, noice_scale=noise_scale)
263
+ z = self.flow(z_p, c_mask, g=g, reverse=True)
264
+
265
+ if self.dec_name == "nsfhifigan":
266
+ o = self.dec(z * c_mask, f0=f0)
267
+ elif self.dec_name == "apnet":
268
+ _, _, _, _, o = self.dec(z * c_mask)
269
+ else:
270
+ o = self.dec(z * c_mask)
271
+ return o, f0
models/svc/vits/vits_inference.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import time
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+ import torch
12
+
13
+ from models.svc.base import SVCInference
14
+ from models.svc.vits.vits import SynthesizerTrn
15
+
16
+ from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator
17
+ from utils.io import save_audio
18
+ from utils.audio_slicer import is_silence
19
+
20
+
21
+ class VitsInference(SVCInference):
22
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
23
+ SVCInference.__init__(self, args, cfg)
24
+
25
+ def _build_model(self):
26
+ net_g = SynthesizerTrn(
27
+ self.cfg.preprocess.n_fft // 2 + 1,
28
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
29
+ self.cfg,
30
+ )
31
+ self.model = net_g
32
+ return net_g
33
+
34
+ def build_save_dir(self, dataset, speaker):
35
+ save_dir = os.path.join(
36
+ self.args.output_dir,
37
+ "svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode),
38
+ )
39
+ if dataset is not None:
40
+ save_dir = os.path.join(save_dir, "data_{}".format(dataset))
41
+ if speaker != -1:
42
+ save_dir = os.path.join(
43
+ save_dir,
44
+ "spk_{}".format(speaker),
45
+ )
46
+ os.makedirs(save_dir, exist_ok=True)
47
+ print("Saving to ", save_dir)
48
+ return save_dir
49
+
50
+ @torch.inference_mode()
51
+ def inference(self):
52
+ res = []
53
+ for i, batch in enumerate(self.test_dataloader):
54
+ pred_audio_list = self._inference_each_batch(batch)
55
+ for it, wav in zip(self.test_dataset.metadata, pred_audio_list):
56
+ uid = it["Uid"]
57
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
58
+
59
+ wav = wav.numpy(force=True)
60
+ save_audio(
61
+ file,
62
+ wav,
63
+ self.cfg.preprocess.sample_rate,
64
+ add_silence=False,
65
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
66
+ )
67
+ res.append(file)
68
+ return res
69
+
70
+ def _inference_each_batch(self, batch_data, noise_scale=0.667):
71
+ device = self.accelerator.device
72
+ pred_res = []
73
+ self.model.eval()
74
+ with torch.no_grad():
75
+ # Put the data to device
76
+ # device = self.accelerator.device
77
+ for k, v in batch_data.items():
78
+ batch_data[k] = v.to(device)
79
+
80
+ audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale)
81
+
82
+ pred_res.extend(audios)
83
+
84
+ return pred_res
models/svc/vits/vits_trainer.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from torch.optim.lr_scheduler import ExponentialLR
8
+ from tqdm import tqdm
9
+
10
+ # from models.svc.base import SVCTrainer
11
+ from models.svc.base.svc_dataset import SVCCollator, SVCDataset
12
+ from models.svc.vits.vits import *
13
+ from models.tts.base import TTSTrainer
14
+
15
+ from utils.mel import mel_spectrogram_torch
16
+ import json
17
+
18
+ from models.vocoders.gan.discriminator.mpd import (
19
+ MultiPeriodDiscriminator_vits as MultiPeriodDiscriminator,
20
+ )
21
+
22
+
23
+ class VitsSVCTrainer(TTSTrainer):
24
+ def __init__(self, args, cfg):
25
+ self.args = args
26
+ self.cfg = cfg
27
+ self._init_accelerator()
28
+ # Only for SVC tasks
29
+ with self.accelerator.main_process_first():
30
+ self.singers = self._build_singer_lut()
31
+ TTSTrainer.__init__(self, args, cfg)
32
+
33
+ def _build_model(self):
34
+ net_g = SynthesizerTrn(
35
+ self.cfg.preprocess.n_fft // 2 + 1,
36
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
37
+ # directly use cfg
38
+ self.cfg,
39
+ )
40
+ net_d = MultiPeriodDiscriminator(self.cfg.model.vits.use_spectral_norm)
41
+ model = {"generator": net_g, "discriminator": net_d}
42
+
43
+ return model
44
+
45
+ def _build_dataset(self):
46
+ return SVCDataset, SVCCollator
47
+
48
+ def _build_optimizer(self):
49
+ optimizer_g = torch.optim.AdamW(
50
+ self.model["generator"].parameters(),
51
+ self.cfg.train.learning_rate,
52
+ betas=self.cfg.train.AdamW.betas,
53
+ eps=self.cfg.train.AdamW.eps,
54
+ )
55
+ optimizer_d = torch.optim.AdamW(
56
+ self.model["discriminator"].parameters(),
57
+ self.cfg.train.learning_rate,
58
+ betas=self.cfg.train.AdamW.betas,
59
+ eps=self.cfg.train.AdamW.eps,
60
+ )
61
+ optimizer = {"optimizer_g": optimizer_g, "optimizer_d": optimizer_d}
62
+
63
+ return optimizer
64
+
65
+ def _build_scheduler(self):
66
+ scheduler_g = ExponentialLR(
67
+ self.optimizer["optimizer_g"],
68
+ gamma=self.cfg.train.lr_decay,
69
+ last_epoch=self.epoch - 1,
70
+ )
71
+ scheduler_d = ExponentialLR(
72
+ self.optimizer["optimizer_d"],
73
+ gamma=self.cfg.train.lr_decay,
74
+ last_epoch=self.epoch - 1,
75
+ )
76
+
77
+ scheduler = {"scheduler_g": scheduler_g, "scheduler_d": scheduler_d}
78
+ return scheduler
79
+
80
+ def _build_criterion(self):
81
+ class GeneratorLoss(nn.Module):
82
+ def __init__(self, cfg):
83
+ super(GeneratorLoss, self).__init__()
84
+ self.cfg = cfg
85
+ self.l1_loss = nn.L1Loss()
86
+
87
+ def generator_loss(self, disc_outputs):
88
+ loss = 0
89
+ gen_losses = []
90
+ for dg in disc_outputs:
91
+ dg = dg.float()
92
+ l = torch.mean((1 - dg) ** 2)
93
+ gen_losses.append(l)
94
+ loss += l
95
+
96
+ return loss, gen_losses
97
+
98
+ def feature_loss(self, fmap_r, fmap_g):
99
+ loss = 0
100
+ for dr, dg in zip(fmap_r, fmap_g):
101
+ for rl, gl in zip(dr, dg):
102
+ rl = rl.float().detach()
103
+ gl = gl.float()
104
+ loss += torch.mean(torch.abs(rl - gl))
105
+
106
+ return loss * 2
107
+
108
+ def kl_loss(self, z_p, logs_q, m_p, logs_p, z_mask):
109
+ """
110
+ z_p, logs_q: [b, h, t_t]
111
+ m_p, logs_p: [b, h, t_t]
112
+ """
113
+ z_p = z_p.float()
114
+ logs_q = logs_q.float()
115
+ m_p = m_p.float()
116
+ logs_p = logs_p.float()
117
+ z_mask = z_mask.float()
118
+
119
+ kl = logs_p - logs_q - 0.5
120
+ kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
121
+ kl = torch.sum(kl * z_mask)
122
+ l = kl / torch.sum(z_mask)
123
+ return l
124
+
125
+ def forward(
126
+ self,
127
+ outputs_g,
128
+ outputs_d,
129
+ y_mel,
130
+ y_hat_mel,
131
+ ):
132
+ loss_g = {}
133
+
134
+ # mel loss
135
+ loss_mel = self.l1_loss(y_mel, y_hat_mel) * self.cfg.train.c_mel
136
+ loss_g["loss_mel"] = loss_mel
137
+
138
+ # kl loss
139
+ loss_kl = (
140
+ self.kl_loss(
141
+ outputs_g["z_p"],
142
+ outputs_g["logs_q"],
143
+ outputs_g["m_p"],
144
+ outputs_g["logs_p"],
145
+ outputs_g["z_mask"],
146
+ )
147
+ * self.cfg.train.c_kl
148
+ )
149
+ loss_g["loss_kl"] = loss_kl
150
+
151
+ # feature loss
152
+ loss_fm = self.feature_loss(outputs_d["fmap_rs"], outputs_d["fmap_gs"])
153
+ loss_g["loss_fm"] = loss_fm
154
+
155
+ # gan loss
156
+ loss_gen, losses_gen = self.generator_loss(outputs_d["y_d_hat_g"])
157
+ loss_g["loss_gen"] = loss_gen
158
+ loss_g["loss_gen_all"] = loss_mel + loss_kl + loss_fm + loss_gen
159
+
160
+ return loss_g
161
+
162
+ class DiscriminatorLoss(nn.Module):
163
+ def __init__(self, cfg):
164
+ super(DiscriminatorLoss, self).__init__()
165
+ self.cfg = cfg
166
+ self.l1Loss = torch.nn.L1Loss(reduction="mean")
167
+
168
+ def __call__(self, disc_real_outputs, disc_generated_outputs):
169
+ loss_d = {}
170
+
171
+ loss = 0
172
+ r_losses = []
173
+ g_losses = []
174
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
175
+ dr = dr.float()
176
+ dg = dg.float()
177
+ r_loss = torch.mean((1 - dr) ** 2)
178
+ g_loss = torch.mean(dg**2)
179
+ loss += r_loss + g_loss
180
+ r_losses.append(r_loss.item())
181
+ g_losses.append(g_loss.item())
182
+
183
+ loss_d["loss_disc_all"] = loss
184
+
185
+ return loss_d
186
+
187
+ criterion = {
188
+ "generator": GeneratorLoss(self.cfg),
189
+ "discriminator": DiscriminatorLoss(self.cfg),
190
+ }
191
+ return criterion
192
+
193
+ # Keep legacy unchanged
194
+ def write_summary(
195
+ self,
196
+ losses,
197
+ stats,
198
+ images={},
199
+ audios={},
200
+ audio_sampling_rate=24000,
201
+ tag="train",
202
+ ):
203
+ for key, value in losses.items():
204
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
205
+ self.sw.add_scalar(
206
+ "learning_rate",
207
+ self.optimizer["optimizer_g"].param_groups[0]["lr"],
208
+ self.step,
209
+ )
210
+
211
+ if len(images) != 0:
212
+ for key, value in images.items():
213
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
214
+ if len(audios) != 0:
215
+ for key, value in audios.items():
216
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
217
+
218
+ def write_valid_summary(
219
+ self, losses, stats, images={}, audios={}, audio_sampling_rate=24000, tag="val"
220
+ ):
221
+ for key, value in losses.items():
222
+ self.sw.add_scalar(tag + "/" + key, value, self.step)
223
+
224
+ if len(images) != 0:
225
+ for key, value in images.items():
226
+ self.sw.add_image(key, value, self.global_step, batchformats="HWC")
227
+ if len(audios) != 0:
228
+ for key, value in audios.items():
229
+ self.sw.add_audio(key, value, self.global_step, audio_sampling_rate)
230
+
231
+ def _get_state_dict(self):
232
+ state_dict = {
233
+ "generator": self.model["generator"].state_dict(),
234
+ "discriminator": self.model["discriminator"].state_dict(),
235
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
236
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
237
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
238
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
239
+ "step": self.step,
240
+ "epoch": self.epoch,
241
+ "batch_size": self.cfg.train.batch_size,
242
+ }
243
+ return state_dict
244
+
245
+ def get_state_dict(self):
246
+ state_dict = {
247
+ "generator": self.model["generator"].state_dict(),
248
+ "discriminator": self.model["discriminator"].state_dict(),
249
+ "optimizer_g": self.optimizer["optimizer_g"].state_dict(),
250
+ "optimizer_d": self.optimizer["optimizer_d"].state_dict(),
251
+ "scheduler_g": self.scheduler["scheduler_g"].state_dict(),
252
+ "scheduler_d": self.scheduler["scheduler_d"].state_dict(),
253
+ "step": self.step,
254
+ "epoch": self.epoch,
255
+ "batch_size": self.cfg.train.batch_size,
256
+ }
257
+ return state_dict
258
+
259
+ def load_model(self, checkpoint):
260
+ self.step = checkpoint["step"]
261
+ self.epoch = checkpoint["epoch"]
262
+ self.model["generator"].load_state_dict(checkpoint["generator"])
263
+ self.model["discriminator"].load_state_dict(checkpoint["discriminator"])
264
+ self.optimizer["optimizer_g"].load_state_dict(checkpoint["optimizer_g"])
265
+ self.optimizer["optimizer_d"].load_state_dict(checkpoint["optimizer_d"])
266
+ self.scheduler["scheduler_g"].load_state_dict(checkpoint["scheduler_g"])
267
+ self.scheduler["scheduler_d"].load_state_dict(checkpoint["scheduler_d"])
268
+
269
+ @torch.inference_mode()
270
+ def _valid_step(self, batch):
271
+ r"""Testing forward step. Should return average loss of a sample over
272
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
273
+ See ``_test_epoch`` for usage.
274
+ """
275
+
276
+ valid_losses = {}
277
+ total_loss = 0
278
+ valid_stats = {}
279
+
280
+ # Discriminator
281
+ # Generator output
282
+ outputs_g = self.model["generator"](batch)
283
+
284
+ y_mel = slice_segments(
285
+ batch["mel"].transpose(1, 2),
286
+ outputs_g["ids_slice"],
287
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
288
+ )
289
+ y_hat_mel = mel_spectrogram_torch(
290
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
291
+ )
292
+ y = slice_segments(
293
+ batch["audio"].unsqueeze(1),
294
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
295
+ self.cfg.preprocess.segment_size,
296
+ )
297
+
298
+ # Discriminator output
299
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
300
+ ## Discriminator loss
301
+ loss_d = self.criterion["discriminator"](
302
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
303
+ )
304
+ valid_losses.update(loss_d)
305
+
306
+ ## Generator
307
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
308
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
309
+ valid_losses.update(loss_g)
310
+
311
+ for item in valid_losses:
312
+ valid_losses[item] = valid_losses[item].item()
313
+
314
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
315
+
316
+ return (
317
+ total_loss.item(),
318
+ valid_losses,
319
+ valid_stats,
320
+ )
321
+
322
+ def _train_step(self, batch):
323
+ r"""Forward step for training and inference. This function is called
324
+ in ``_train_step`` & ``_test_step`` function.
325
+ """
326
+
327
+ train_losses = {}
328
+ total_loss = 0
329
+ training_stats = {}
330
+
331
+ ## Train Discriminator
332
+ # Generator output
333
+ outputs_g = self.model["generator"](batch)
334
+
335
+ y_mel = slice_segments(
336
+ batch["mel"].transpose(1, 2),
337
+ outputs_g["ids_slice"],
338
+ self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size,
339
+ )
340
+ y_hat_mel = mel_spectrogram_torch(
341
+ outputs_g["y_hat"].squeeze(1), self.cfg.preprocess
342
+ )
343
+
344
+ y = slice_segments(
345
+ # [1, 168418] -> [1, 1, 168418]
346
+ batch["audio"].unsqueeze(1),
347
+ outputs_g["ids_slice"] * self.cfg.preprocess.hop_size,
348
+ self.cfg.preprocess.segment_size,
349
+ )
350
+
351
+ # Discriminator output
352
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"].detach())
353
+ # Discriminator loss
354
+ loss_d = self.criterion["discriminator"](
355
+ outputs_d["y_d_hat_r"], outputs_d["y_d_hat_g"]
356
+ )
357
+ train_losses.update(loss_d)
358
+
359
+ # BP and Grad Updated
360
+ self.optimizer["optimizer_d"].zero_grad()
361
+ self.accelerator.backward(loss_d["loss_disc_all"])
362
+ self.optimizer["optimizer_d"].step()
363
+
364
+ ## Train Generator
365
+ outputs_d = self.model["discriminator"](y, outputs_g["y_hat"])
366
+ loss_g = self.criterion["generator"](outputs_g, outputs_d, y_mel, y_hat_mel)
367
+ train_losses.update(loss_g)
368
+
369
+ # BP and Grad Updated
370
+ self.optimizer["optimizer_g"].zero_grad()
371
+ self.accelerator.backward(loss_g["loss_gen_all"])
372
+ self.optimizer["optimizer_g"].step()
373
+
374
+ for item in train_losses:
375
+ train_losses[item] = train_losses[item].item()
376
+
377
+ total_loss = loss_g["loss_gen_all"] + loss_d["loss_disc_all"]
378
+
379
+ return (
380
+ total_loss.item(),
381
+ train_losses,
382
+ training_stats,
383
+ )
384
+
385
+ def _train_epoch(self):
386
+ r"""Training epoch. Should return average loss of a batch (sample) over
387
+ one epoch. See ``train_loop`` for usage.
388
+ """
389
+ epoch_sum_loss: float = 0.0
390
+ epoch_losses: dict = {}
391
+ epoch_step: int = 0
392
+ for batch in tqdm(
393
+ self.train_dataloader,
394
+ desc=f"Training Epoch {self.epoch}",
395
+ unit="batch",
396
+ colour="GREEN",
397
+ leave=False,
398
+ dynamic_ncols=True,
399
+ smoothing=0.04,
400
+ disable=not self.accelerator.is_main_process,
401
+ ):
402
+ # Do training step and BP
403
+ with self.accelerator.accumulate(self.model):
404
+ total_loss, train_losses, training_stats = self._train_step(batch)
405
+ self.batch_count += 1
406
+
407
+ # Update info for each step
408
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
409
+ epoch_sum_loss += total_loss
410
+ for key, value in train_losses.items():
411
+ if key not in epoch_losses.keys():
412
+ epoch_losses[key] = value
413
+ else:
414
+ epoch_losses[key] += value
415
+
416
+ self.accelerator.log(
417
+ {
418
+ "Step/Generator Loss": train_losses["loss_gen_all"],
419
+ "Step/Discriminator Loss": train_losses["loss_disc_all"],
420
+ "Step/Generator Learning Rate": self.optimizer[
421
+ "optimizer_d"
422
+ ].param_groups[0]["lr"],
423
+ "Step/Discriminator Learning Rate": self.optimizer[
424
+ "optimizer_g"
425
+ ].param_groups[0]["lr"],
426
+ },
427
+ step=self.step,
428
+ )
429
+ self.step += 1
430
+ epoch_step += 1
431
+
432
+ self.accelerator.wait_for_everyone()
433
+
434
+ epoch_sum_loss = (
435
+ epoch_sum_loss
436
+ / len(self.train_dataloader)
437
+ * self.cfg.train.gradient_accumulation_step
438
+ )
439
+
440
+ for key in epoch_losses.keys():
441
+ epoch_losses[key] = (
442
+ epoch_losses[key]
443
+ / len(self.train_dataloader)
444
+ * self.cfg.train.gradient_accumulation_step
445
+ )
446
+
447
+ return epoch_sum_loss, epoch_losses
448
+
449
+ def _build_singer_lut(self):
450
+ resumed_singer_path = None
451
+ if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
452
+ resumed_singer_path = os.path.join(
453
+ self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
454
+ )
455
+ if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
456
+ resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
457
+
458
+ if resumed_singer_path:
459
+ with open(resumed_singer_path, "r") as f:
460
+ singers = json.load(f)
461
+ else:
462
+ singers = dict()
463
+
464
+ for dataset in self.cfg.dataset:
465
+ singer_lut_path = os.path.join(
466
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
467
+ )
468
+ with open(singer_lut_path, "r") as singer_lut_path:
469
+ singer_lut = json.load(singer_lut_path)
470
+ for singer in singer_lut.keys():
471
+ if singer not in singers:
472
+ singers[singer] = len(singers)
473
+
474
+ with open(
475
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
476
+ ) as singer_file:
477
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
478
+ print(
479
+ "singers have been dumped to {}".format(
480
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
481
+ )
482
+ )
483
+ return singers
models/tta/autoencoder/__init__.py ADDED
File without changes
models/tta/autoencoder/autoencoder.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from modules.distributions.distributions import DiagonalGaussianDistribution
11
+
12
+
13
+ def nonlinearity(x):
14
+ # swish
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels):
19
+ return torch.nn.GroupNorm(
20
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
21
+ )
22
+
23
+
24
+ class Upsample2d(nn.Module):
25
+ def __init__(self, in_channels, with_conv):
26
+ super().__init__()
27
+ self.with_conv = with_conv
28
+ if self.with_conv:
29
+ self.conv = torch.nn.Conv2d(
30
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
35
+ if self.with_conv:
36
+ x = self.conv(x)
37
+ return x
38
+
39
+
40
+ class Upsample1d(Upsample2d):
41
+ def __init__(self, in_channels, with_conv):
42
+ super().__init__(in_channels, with_conv)
43
+ if self.with_conv:
44
+ self.conv = torch.nn.Conv1d(
45
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
46
+ )
47
+
48
+
49
+ class Downsample2d(nn.Module):
50
+ def __init__(self, in_channels, with_conv):
51
+ super().__init__()
52
+ self.with_conv = with_conv
53
+ if self.with_conv:
54
+ # no asymmetric padding in torch conv, must do it ourselves
55
+ self.conv = torch.nn.Conv2d(
56
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
57
+ )
58
+ self.pad = (0, 1, 0, 1)
59
+ else:
60
+ self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
61
+
62
+ def forward(self, x):
63
+ if self.with_conv: # bp: check self.avgpool and self.pad
64
+ x = torch.nn.functional.pad(x, self.pad, mode="constant", value=0)
65
+ x = self.conv(x)
66
+ else:
67
+ x = self.avg_pool(x)
68
+ return x
69
+
70
+
71
+ class Downsample1d(Downsample2d):
72
+ def __init__(self, in_channels, with_conv):
73
+ super().__init__(in_channels, with_conv)
74
+ if self.with_conv:
75
+ # no asymmetric padding in torch conv, must do it ourselves
76
+ # TODO: can we replace it just with conv2d with padding 1?
77
+ self.conv = torch.nn.Conv1d(
78
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
79
+ )
80
+ self.pad = (1, 1)
81
+ else:
82
+ self.avg_pool = nn.AvgPool1d(kernel_size=2, stride=2)
83
+
84
+
85
+ class ResnetBlock(nn.Module):
86
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
87
+ super().__init__()
88
+ self.in_channels = in_channels
89
+ out_channels = in_channels if out_channels is None else out_channels
90
+ self.out_channels = out_channels
91
+ self.use_conv_shortcut = conv_shortcut
92
+
93
+ self.norm1 = Normalize(in_channels)
94
+ self.conv1 = torch.nn.Conv2d(
95
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
96
+ )
97
+
98
+ self.norm2 = Normalize(out_channels)
99
+ self.dropout = torch.nn.Dropout(dropout)
100
+ self.conv2 = torch.nn.Conv2d(
101
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
102
+ )
103
+ if self.in_channels != self.out_channels:
104
+ if self.use_conv_shortcut:
105
+ self.conv_shortcut = torch.nn.Conv2d(
106
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
107
+ )
108
+ else:
109
+ self.nin_shortcut = torch.nn.Conv2d(
110
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
111
+ )
112
+
113
+ def forward(self, x):
114
+ h = x
115
+ h = self.norm1(h)
116
+ h = nonlinearity(h)
117
+ h = self.conv1(h)
118
+
119
+ h = self.norm2(h)
120
+ h = nonlinearity(h)
121
+ h = self.dropout(h)
122
+ h = self.conv2(h)
123
+
124
+ if self.in_channels != self.out_channels:
125
+ if self.use_conv_shortcut:
126
+ x = self.conv_shortcut(x)
127
+ else:
128
+ x = self.nin_shortcut(x)
129
+
130
+ return x + h
131
+
132
+
133
+ class ResnetBlock1d(ResnetBlock):
134
+ def __init__(
135
+ self,
136
+ *,
137
+ in_channels,
138
+ out_channels=None,
139
+ conv_shortcut=False,
140
+ dropout,
141
+ temb_channels=512
142
+ ):
143
+ super().__init__(
144
+ in_channels=in_channels,
145
+ out_channels=out_channels,
146
+ conv_shortcut=conv_shortcut,
147
+ dropout=dropout,
148
+ )
149
+
150
+ self.conv1 = torch.nn.Conv1d(
151
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
152
+ )
153
+ self.conv2 = torch.nn.Conv1d(
154
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
155
+ )
156
+ if self.in_channels != self.out_channels:
157
+ if self.use_conv_shortcut:
158
+ self.conv_shortcut = torch.nn.Conv1d(
159
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
160
+ )
161
+ else:
162
+ self.nin_shortcut = torch.nn.Conv1d(
163
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
164
+ )
165
+
166
+
167
+ class Encoder2d(nn.Module):
168
+ def __init__(
169
+ self,
170
+ *,
171
+ ch,
172
+ ch_mult=(1, 2, 4, 8),
173
+ num_res_blocks,
174
+ dropout=0.0,
175
+ resamp_with_conv=True,
176
+ in_channels,
177
+ z_channels,
178
+ double_z=True,
179
+ **ignore_kwargs
180
+ ):
181
+ super().__init__()
182
+ self.ch = ch
183
+ self.num_resolutions = len(ch_mult)
184
+ self.num_res_blocks = num_res_blocks
185
+ self.in_channels = in_channels
186
+
187
+ # downsampling
188
+ self.conv_in = torch.nn.Conv2d(
189
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
190
+ )
191
+
192
+ in_ch_mult = (1,) + tuple(ch_mult)
193
+ self.down = nn.ModuleList()
194
+ for i_level in range(self.num_resolutions):
195
+ block = nn.ModuleList()
196
+ block_in = ch * in_ch_mult[i_level]
197
+ block_out = ch * ch_mult[i_level]
198
+ for i_block in range(self.num_res_blocks):
199
+ block.append(
200
+ ResnetBlock(
201
+ in_channels=block_in, out_channels=block_out, dropout=dropout
202
+ )
203
+ )
204
+ block_in = block_out
205
+ down = nn.Module()
206
+ down.block = block
207
+ if i_level != self.num_resolutions - 1:
208
+ down.downsample = Downsample2d(block_in, resamp_with_conv)
209
+ self.down.append(down)
210
+
211
+ # middle
212
+ self.mid = nn.Module()
213
+ self.mid.block_1 = ResnetBlock(
214
+ in_channels=block_in, out_channels=block_in, dropout=dropout
215
+ )
216
+ self.mid.block_2 = ResnetBlock(
217
+ in_channels=block_in, out_channels=block_in, dropout=dropout
218
+ )
219
+
220
+ # end
221
+ self.norm_out = Normalize(block_in)
222
+ self.conv_out = torch.nn.Conv2d(
223
+ block_in,
224
+ 2 * z_channels if double_z else z_channels,
225
+ kernel_size=3,
226
+ stride=1,
227
+ padding=1,
228
+ )
229
+
230
+ def forward(self, x):
231
+ # downsampling
232
+ hs = [self.conv_in(x)]
233
+ for i_level in range(self.num_resolutions):
234
+ for i_block in range(self.num_res_blocks):
235
+ h = self.down[i_level].block[i_block](hs[-1])
236
+ hs.append(h)
237
+ if i_level != self.num_resolutions - 1:
238
+ hs.append(self.down[i_level].downsample(hs[-1]))
239
+
240
+ # middle
241
+ h = hs[-1]
242
+ h = self.mid.block_1(h)
243
+ h = self.mid.block_2(h)
244
+
245
+ # end
246
+ h = self.norm_out(h)
247
+ h = nonlinearity(h)
248
+ h = self.conv_out(h)
249
+ return h
250
+
251
+
252
+ # TODO: Encoder1d
253
+ class Encoder1d(Encoder2d):
254
+ ...
255
+
256
+
257
+ class Decoder2d(nn.Module):
258
+ def __init__(
259
+ self,
260
+ *,
261
+ ch,
262
+ out_ch,
263
+ ch_mult=(1, 2, 4, 8),
264
+ num_res_blocks,
265
+ dropout=0.0,
266
+ resamp_with_conv=True,
267
+ in_channels,
268
+ z_channels,
269
+ give_pre_end=False,
270
+ **ignorekwargs
271
+ ):
272
+ super().__init__()
273
+ self.ch = ch
274
+ self.num_resolutions = len(ch_mult)
275
+ self.num_res_blocks = num_res_blocks
276
+ self.in_channels = in_channels
277
+ self.give_pre_end = give_pre_end
278
+
279
+ # compute in_ch_mult, block_in and curr_res at lowest res
280
+ in_ch_mult = (1,) + tuple(ch_mult)
281
+ block_in = ch * ch_mult[self.num_resolutions - 1]
282
+ # self.z_shape = (1,z_channels,curr_res,curr_res)
283
+ # print("Working with z of shape {} = {} dimensions.".format(
284
+ # self.z_shape, np.prod(self.z_shape)))
285
+
286
+ # z to block_in
287
+ self.conv_in = torch.nn.Conv2d(
288
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
289
+ )
290
+
291
+ # middle
292
+ self.mid = nn.Module()
293
+ self.mid.block_1 = ResnetBlock(
294
+ in_channels=block_in, out_channels=block_in, dropout=dropout
295
+ )
296
+ self.mid.block_2 = ResnetBlock(
297
+ in_channels=block_in, out_channels=block_in, dropout=dropout
298
+ )
299
+
300
+ # upsampling
301
+ self.up = nn.ModuleList()
302
+ for i_level in reversed(range(self.num_resolutions)):
303
+ block = nn.ModuleList()
304
+ attn = nn.ModuleList()
305
+ block_out = ch * ch_mult[i_level]
306
+ for i_block in range(self.num_res_blocks + 1):
307
+ block.append(
308
+ ResnetBlock(
309
+ in_channels=block_in, out_channels=block_out, dropout=dropout
310
+ )
311
+ )
312
+ block_in = block_out
313
+ up = nn.Module()
314
+ up.block = block
315
+ up.attn = attn
316
+ if i_level != 0:
317
+ up.upsample = Upsample2d(block_in, resamp_with_conv)
318
+ self.up.insert(0, up) # prepend to get consistent order
319
+
320
+ # end
321
+ self.norm_out = Normalize(block_in)
322
+ self.conv_out = torch.nn.Conv2d(
323
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
324
+ )
325
+
326
+ def forward(self, z):
327
+ self.last_z_shape = z.shape
328
+
329
+ # z to block_in
330
+ h = self.conv_in(z)
331
+
332
+ # middle
333
+ h = self.mid.block_1(h)
334
+ h = self.mid.block_2(h)
335
+
336
+ # upsampling
337
+ for i_level in reversed(range(self.num_resolutions)):
338
+ for i_block in range(self.num_res_blocks + 1):
339
+ h = self.up[i_level].block[i_block](h)
340
+ if i_level != 0:
341
+ h = self.up[i_level].upsample(h)
342
+
343
+ # end
344
+ if self.give_pre_end:
345
+ return h
346
+
347
+ h = self.norm_out(h)
348
+ h = nonlinearity(h)
349
+ h = self.conv_out(h)
350
+ return h
351
+
352
+
353
+ # TODO: decoder1d
354
+ class Decoder1d(Decoder2d):
355
+ ...
356
+
357
+
358
+ class AutoencoderKL(nn.Module):
359
+ def __init__(self, cfg):
360
+ super().__init__()
361
+ self.cfg = cfg
362
+ self.encoder = Encoder2d(
363
+ ch=cfg.ch,
364
+ ch_mult=cfg.ch_mult,
365
+ num_res_blocks=cfg.num_res_blocks,
366
+ in_channels=cfg.in_channels,
367
+ z_channels=cfg.z_channels,
368
+ double_z=cfg.double_z,
369
+ )
370
+ self.decoder = Decoder2d(
371
+ ch=cfg.ch,
372
+ ch_mult=cfg.ch_mult,
373
+ num_res_blocks=cfg.num_res_blocks,
374
+ out_ch=cfg.out_ch,
375
+ z_channels=cfg.z_channels,
376
+ in_channels=None,
377
+ )
378
+ assert self.cfg.double_z
379
+
380
+ self.quant_conv = torch.nn.Conv2d(2 * cfg.z_channels, 2 * cfg.z_channels, 1)
381
+ self.post_quant_conv = torch.nn.Conv2d(cfg.z_channels, cfg.z_channels, 1)
382
+ self.embed_dim = cfg.z_channels
383
+
384
+ def encode(self, x):
385
+ h = self.encoder(x)
386
+ moments = self.quant_conv(h)
387
+ posterior = DiagonalGaussianDistribution(moments)
388
+ return posterior
389
+
390
+ def decode(self, z):
391
+ z = self.post_quant_conv(z)
392
+ dec = self.decoder(z)
393
+ return dec
394
+
395
+ def forward(self, input, sample_posterior=True):
396
+ posterior = self.encode(input)
397
+ if sample_posterior:
398
+ z = posterior.sample()
399
+ else:
400
+ z = posterior.mode()
401
+ dec = self.decode(z)
402
+ return dec, posterior
403
+
404
+ def get_last_layer(self):
405
+ return self.decoder.conv_out.weight
models/tta/autoencoder/autoencoder_dataset.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+ from models.base.base_dataset import (
11
+ BaseCollator,
12
+ BaseDataset,
13
+ BaseTestDataset,
14
+ BaseTestCollator,
15
+ )
16
+ import librosa
17
+
18
+
19
+ class AutoencoderKLDataset(BaseDataset):
20
+ def __init__(self, cfg, dataset, is_valid=False):
21
+ BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
22
+
23
+ cfg = self.cfg
24
+
25
+ # utt2melspec
26
+ if cfg.preprocess.use_melspec:
27
+ self.utt2melspec_path = {}
28
+ for utt_info in self.metadata:
29
+ dataset = utt_info["Dataset"]
30
+ uid = utt_info["Uid"]
31
+ utt = "{}_{}".format(dataset, uid)
32
+
33
+ self.utt2melspec_path[utt] = os.path.join(
34
+ cfg.preprocess.processed_dir,
35
+ dataset,
36
+ cfg.preprocess.melspec_dir,
37
+ uid + ".npy",
38
+ )
39
+
40
+ # utt2wav
41
+ if cfg.preprocess.use_wav:
42
+ self.utt2wav_path = {}
43
+ for utt_info in self.metadata:
44
+ dataset = utt_info["Dataset"]
45
+ uid = utt_info["Uid"]
46
+ utt = "{}_{}".format(dataset, uid)
47
+
48
+ self.utt2wav_path[utt] = os.path.join(
49
+ cfg.preprocess.processed_dir,
50
+ dataset,
51
+ cfg.preprocess.wav_dir,
52
+ uid + ".wav",
53
+ )
54
+
55
+ def __getitem__(self, index):
56
+ # melspec: (n_mels, T)
57
+ # wav: (T,)
58
+
59
+ single_feature = BaseDataset.__getitem__(self, index)
60
+
61
+ utt_info = self.metadata[index]
62
+ dataset = utt_info["Dataset"]
63
+ uid = utt_info["Uid"]
64
+ utt = "{}_{}".format(dataset, uid)
65
+
66
+ if self.cfg.preprocess.use_melspec:
67
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
68
+
69
+ if self.cfg.preprocess.use_wav:
70
+ wav, sr = librosa.load(
71
+ self.utt2wav_path[utt], sr=16000
72
+ ) # hard coding for 16KHz...
73
+ single_feature["wav"] = wav
74
+
75
+ return single_feature
76
+
77
+ def __len__(self):
78
+ return len(self.metadata)
79
+
80
+ def __len__(self):
81
+ return len(self.metadata)
82
+
83
+
84
+ class AutoencoderKLCollator(BaseCollator):
85
+ def __init__(self, cfg):
86
+ BaseCollator.__init__(self, cfg)
87
+
88
+ def __call__(self, batch):
89
+ # mel: (B, n_mels, T)
90
+ # wav (option): (B, T)
91
+
92
+ packed_batch_features = dict()
93
+
94
+ for key in batch[0].keys():
95
+ if key == "melspec":
96
+ packed_batch_features["melspec"] = torch.from_numpy(
97
+ np.array([b["melspec"][:, :624] for b in batch])
98
+ )
99
+
100
+ if key == "wav":
101
+ values = [torch.from_numpy(b[key]) for b in batch]
102
+ packed_batch_features[key] = pad_sequence(
103
+ values, batch_first=True, padding_value=0
104
+ )
105
+
106
+ return packed_batch_features
107
+
108
+
109
+ class AutoencoderKLTestDataset(BaseTestDataset):
110
+ ...
111
+
112
+
113
+ class AutoencoderKLTestCollator(BaseTestCollator):
114
+ ...
models/tta/autoencoder/autoencoder_loss.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import functools
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def hinge_d_loss(logits_real, logits_fake):
13
+ loss_real = torch.mean(F.relu(1.0 - logits_real))
14
+ loss_fake = torch.mean(F.relu(1.0 + logits_fake))
15
+ d_loss = 0.5 * (loss_real + loss_fake)
16
+ return d_loss
17
+
18
+
19
+ def vanilla_d_loss(logits_real, logits_fake):
20
+ d_loss = 0.5 * (
21
+ torch.mean(F.softplus(-logits_real)) + torch.mean(F.softplus(logits_fake))
22
+ )
23
+ return d_loss
24
+
25
+
26
+ def adopt_weight(weight, global_step, threshold=0, value=0.0):
27
+ if global_step < threshold:
28
+ weight = value
29
+ return weight
30
+
31
+
32
+ class ActNorm(nn.Module):
33
+ def __init__(
34
+ self, num_features, logdet=False, affine=True, allow_reverse_init=False
35
+ ):
36
+ assert affine
37
+ super().__init__()
38
+ self.logdet = logdet
39
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
40
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
41
+ self.allow_reverse_init = allow_reverse_init
42
+
43
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
44
+
45
+ def initialize(self, input):
46
+ with torch.no_grad():
47
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
48
+ mean = (
49
+ flatten.mean(1)
50
+ .unsqueeze(1)
51
+ .unsqueeze(2)
52
+ .unsqueeze(3)
53
+ .permute(1, 0, 2, 3)
54
+ )
55
+ std = (
56
+ flatten.std(1)
57
+ .unsqueeze(1)
58
+ .unsqueeze(2)
59
+ .unsqueeze(3)
60
+ .permute(1, 0, 2, 3)
61
+ )
62
+
63
+ self.loc.data.copy_(-mean)
64
+ self.scale.data.copy_(1 / (std + 1e-6))
65
+
66
+ def forward(self, input, reverse=False):
67
+ if reverse:
68
+ return self.reverse(input)
69
+ if len(input.shape) == 2:
70
+ input = input[:, :, None, None]
71
+ squeeze = True
72
+ else:
73
+ squeeze = False
74
+
75
+ _, _, height, width = input.shape
76
+
77
+ if self.training and self.initialized.item() == 0:
78
+ self.initialize(input)
79
+ self.initialized.fill_(1)
80
+
81
+ h = self.scale * (input + self.loc)
82
+
83
+ if squeeze:
84
+ h = h.squeeze(-1).squeeze(-1)
85
+
86
+ if self.logdet:
87
+ log_abs = torch.log(torch.abs(self.scale))
88
+ logdet = height * width * torch.sum(log_abs)
89
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
90
+ return h, logdet
91
+
92
+ return h
93
+
94
+ def reverse(self, output):
95
+ if self.training and self.initialized.item() == 0:
96
+ if not self.allow_reverse_init:
97
+ raise RuntimeError(
98
+ "Initializing ActNorm in reverse direction is "
99
+ "disabled by default. Use allow_reverse_init=True to enable."
100
+ )
101
+ else:
102
+ self.initialize(output)
103
+ self.initialized.fill_(1)
104
+
105
+ if len(output.shape) == 2:
106
+ output = output[:, :, None, None]
107
+ squeeze = True
108
+ else:
109
+ squeeze = False
110
+
111
+ h = output / self.scale - self.loc
112
+
113
+ if squeeze:
114
+ h = h.squeeze(-1).squeeze(-1)
115
+ return h
116
+
117
+
118
+ def weights_init(m):
119
+ classname = m.__class__.__name__
120
+ if classname.find("Conv") != -1:
121
+ nn.init.normal_(m.weight.data, 0.0, 0.02)
122
+ elif classname.find("BatchNorm") != -1:
123
+ nn.init.normal_(m.weight.data, 1.0, 0.02)
124
+ nn.init.constant_(m.bias.data, 0)
125
+
126
+
127
+ class NLayerDiscriminator(nn.Module):
128
+ """Defines a PatchGAN discriminator as in Pix2Pix
129
+ --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
130
+ """
131
+
132
+ def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
133
+ """Construct a PatchGAN discriminator
134
+ Parameters:
135
+ input_nc (int) -- the number of channels in input images
136
+ ndf (int) -- the number of filters in the last conv layer
137
+ n_layers (int) -- the number of conv layers in the discriminator
138
+ norm_layer -- normalization layer
139
+ """
140
+ super(NLayerDiscriminator, self).__init__()
141
+ if not use_actnorm:
142
+ norm_layer = nn.BatchNorm2d
143
+ else:
144
+ norm_layer = ActNorm
145
+ if (
146
+ type(norm_layer) == functools.partial
147
+ ): # no need to use bias as BatchNorm2d has affine parameters
148
+ use_bias = norm_layer.func != nn.BatchNorm2d
149
+ else:
150
+ use_bias = norm_layer != nn.BatchNorm2d
151
+
152
+ kw = 4
153
+ padw = 1
154
+ sequence = [
155
+ nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
156
+ nn.LeakyReLU(0.2, True),
157
+ ]
158
+ nf_mult = 1
159
+ nf_mult_prev = 1
160
+ for n in range(1, n_layers): # gradually increase the number of filters
161
+ nf_mult_prev = nf_mult
162
+ nf_mult = min(2**n, 8)
163
+ sequence += [
164
+ nn.Conv2d(
165
+ ndf * nf_mult_prev,
166
+ ndf * nf_mult,
167
+ kernel_size=kw,
168
+ stride=2,
169
+ padding=padw,
170
+ bias=use_bias,
171
+ ),
172
+ norm_layer(ndf * nf_mult),
173
+ nn.LeakyReLU(0.2, True),
174
+ ]
175
+
176
+ nf_mult_prev = nf_mult
177
+ nf_mult = min(2**n_layers, 8)
178
+ sequence += [
179
+ nn.Conv2d(
180
+ ndf * nf_mult_prev,
181
+ ndf * nf_mult,
182
+ kernel_size=kw,
183
+ stride=1,
184
+ padding=padw,
185
+ bias=use_bias,
186
+ ),
187
+ norm_layer(ndf * nf_mult),
188
+ nn.LeakyReLU(0.2, True),
189
+ ]
190
+
191
+ sequence += [
192
+ nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
193
+ ] # output 1 channel prediction map
194
+ self.main = nn.Sequential(*sequence)
195
+
196
+ def forward(self, input):
197
+ """Standard forward."""
198
+ return self.main(input)
199
+
200
+
201
+ class AutoencoderLossWithDiscriminator(nn.Module):
202
+ def __init__(self, cfg):
203
+ super().__init__()
204
+ self.cfg = cfg
205
+ self.kl_weight = cfg.kl_weight
206
+ self.logvar = nn.Parameter(torch.ones(size=()) * cfg.logvar_init)
207
+
208
+ self.discriminator = NLayerDiscriminator(
209
+ input_nc=cfg.disc_in_channels,
210
+ n_layers=cfg.disc_num_layers,
211
+ use_actnorm=cfg.use_actnorm,
212
+ ).apply(weights_init)
213
+
214
+ self.discriminator_iter_start = cfg.disc_start
215
+ self.discriminator_weight = cfg.disc_weight
216
+ self.disc_factor = cfg.disc_factor
217
+ self.disc_loss = hinge_d_loss
218
+
219
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
220
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
221
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
222
+
223
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
224
+ d_weight = torch.clamp(
225
+ d_weight, self.cfg.min_adapt_d_weight, self.cfg.max_adapt_d_weight
226
+ ).detach()
227
+ d_weight = d_weight * self.discriminator_weight
228
+ return d_weight
229
+
230
+ def forward(
231
+ self,
232
+ inputs,
233
+ reconstructions,
234
+ posteriors,
235
+ optimizer_idx,
236
+ global_step,
237
+ last_layer,
238
+ split="train",
239
+ weights=None,
240
+ ):
241
+ rec_loss = torch.abs(
242
+ inputs.contiguous() - reconstructions.contiguous()
243
+ ) # l1 loss
244
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
245
+ weighted_nll_loss = nll_loss
246
+ if weights is not None:
247
+ weighted_nll_loss = weights * nll_loss
248
+ # weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
249
+ weighted_nll_loss = torch.mean(weighted_nll_loss)
250
+ # nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
251
+ nll_loss = torch.mean(nll_loss)
252
+ kl_loss = posteriors.kl()
253
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
254
+ # ? kl_loss = torch.mean(kl_loss)
255
+
256
+ # now the GAN part
257
+ if optimizer_idx == 0:
258
+ logits_fake = self.discriminator(reconstructions.contiguous())
259
+ g_loss = -torch.mean(logits_fake)
260
+
261
+ if self.disc_factor > 0.0:
262
+ try:
263
+ d_weight = self.calculate_adaptive_weight(
264
+ nll_loss, g_loss, last_layer=last_layer
265
+ )
266
+ except RuntimeError:
267
+ assert not self.training
268
+ d_weight = torch.tensor(0.0)
269
+ else:
270
+ d_weight = torch.tensor(0.0)
271
+
272
+ disc_factor = adopt_weight(
273
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
274
+ )
275
+
276
+ total_loss = (
277
+ weighted_nll_loss
278
+ + self.kl_weight * kl_loss
279
+ + d_weight * disc_factor * g_loss
280
+ )
281
+
282
+ return {
283
+ "loss": total_loss,
284
+ "kl_loss": kl_loss,
285
+ "rec_loss": rec_loss.mean(),
286
+ "nll_loss": nll_loss,
287
+ "g_loss": g_loss,
288
+ "d_weight": d_weight,
289
+ "disc_factor": torch.tensor(disc_factor),
290
+ }
291
+
292
+ if optimizer_idx == 1:
293
+ logits_real = self.discriminator(inputs.contiguous().detach())
294
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
295
+
296
+ disc_factor = adopt_weight(
297
+ self.disc_factor, global_step, threshold=self.discriminator_iter_start
298
+ )
299
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
300
+
301
+ return {
302
+ "d_loss": d_loss,
303
+ "logits_real": logits_real.mean(),
304
+ "logits_fake": logits_fake.mean(),
305
+ }
models/tta/autoencoder/autoencoder_trainer.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from models.base.base_trainer import BaseTrainer
8
+ from models.tta.autoencoder.autoencoder_dataset import (
9
+ AutoencoderKLDataset,
10
+ AutoencoderKLCollator,
11
+ )
12
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
13
+ from models.tta.autoencoder.autoencoder_loss import AutoencoderLossWithDiscriminator
14
+ from torch.optim import Adam, AdamW
15
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
16
+ from torch.nn import MSELoss, L1Loss
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import ConcatDataset, DataLoader
19
+
20
+
21
+ class AutoencoderKLTrainer(BaseTrainer):
22
+ def __init__(self, args, cfg):
23
+ BaseTrainer.__init__(self, args, cfg)
24
+ self.cfg = cfg
25
+ self.save_config_file()
26
+
27
+ def build_dataset(self):
28
+ return AutoencoderKLDataset, AutoencoderKLCollator
29
+
30
+ def build_optimizer(self):
31
+ opt_ae = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
32
+ opt_disc = torch.optim.AdamW(
33
+ self.criterion.discriminator.parameters(), **self.cfg.train.adam
34
+ )
35
+ optimizer = {"opt_ae": opt_ae, "opt_disc": opt_disc}
36
+ return optimizer
37
+
38
+ def build_data_loader(self):
39
+ Dataset, Collator = self.build_dataset()
40
+ # build dataset instance for each dataset and combine them by ConcatDataset
41
+ datasets_list = []
42
+ for dataset in self.cfg.dataset:
43
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
44
+ datasets_list.append(subdataset)
45
+ train_dataset = ConcatDataset(datasets_list)
46
+
47
+ train_collate = Collator(self.cfg)
48
+
49
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
50
+ train_loader = DataLoader(
51
+ train_dataset,
52
+ collate_fn=train_collate,
53
+ num_workers=self.args.num_workers,
54
+ batch_size=self.cfg.train.batch_size,
55
+ pin_memory=False,
56
+ )
57
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
58
+ datasets_list = []
59
+ for dataset in self.cfg.dataset:
60
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
61
+ datasets_list.append(subdataset)
62
+ valid_dataset = ConcatDataset(datasets_list)
63
+ valid_collate = Collator(self.cfg)
64
+
65
+ valid_loader = DataLoader(
66
+ valid_dataset,
67
+ collate_fn=valid_collate,
68
+ num_workers=1,
69
+ batch_size=self.cfg.train.batch_size,
70
+ )
71
+ else:
72
+ raise NotImplementedError("DDP is not supported yet.")
73
+ # valid_loader = None
74
+ data_loader = {"train": train_loader, "valid": valid_loader}
75
+ return data_loader
76
+
77
+ # TODO: check it...
78
+ def build_scheduler(self):
79
+ return None
80
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
81
+
82
+ def write_summary(self, losses, stats):
83
+ for key, value in losses.items():
84
+ self.sw.add_scalar(key, value, self.step)
85
+
86
+ def write_valid_summary(self, losses, stats):
87
+ for key, value in losses.items():
88
+ self.sw.add_scalar(key, value, self.step)
89
+
90
+ def build_criterion(self):
91
+ return AutoencoderLossWithDiscriminator(self.cfg.model.loss)
92
+
93
+ def get_state_dict(self):
94
+ if self.scheduler != None:
95
+ state_dict = {
96
+ "model": self.model.state_dict(),
97
+ "optimizer_ae": self.optimizer["opt_ae"].state_dict(),
98
+ "optimizer_disc": self.optimizer["opt_disc"].state_dict(),
99
+ "scheduler": self.scheduler.state_dict(),
100
+ "step": self.step,
101
+ "epoch": self.epoch,
102
+ "batch_size": self.cfg.train.batch_size,
103
+ }
104
+ else:
105
+ state_dict = {
106
+ "model": self.model.state_dict(),
107
+ "optimizer_ae": self.optimizer["opt_ae"].state_dict(),
108
+ "optimizer_disc": self.optimizer["opt_disc"].state_dict(),
109
+ "step": self.step,
110
+ "epoch": self.epoch,
111
+ "batch_size": self.cfg.train.batch_size,
112
+ }
113
+ return state_dict
114
+
115
+ def load_model(self, checkpoint):
116
+ self.step = checkpoint["step"]
117
+ self.epoch = checkpoint["epoch"]
118
+
119
+ self.model.load_state_dict(checkpoint["model"])
120
+ self.optimizer["opt_ae"].load_state_dict(checkpoint["optimizer_ae"])
121
+ self.optimizer["opt_disc"].load_state_dict(checkpoint["optimizer_disc"])
122
+ if self.scheduler != None:
123
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
124
+
125
+ def build_model(self):
126
+ self.model = AutoencoderKL(self.cfg.model.autoencoderkl)
127
+ return self.model
128
+
129
+ # TODO: train step
130
+ def train_step(self, data):
131
+ global_step = self.step
132
+ optimizer_idx = global_step % 2
133
+
134
+ train_losses = {}
135
+ total_loss = 0
136
+ train_states = {}
137
+
138
+ inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
139
+ reconstructions, posterior = self.model(inputs)
140
+ # train_stats.update(stat)
141
+
142
+ train_losses = self.criterion(
143
+ inputs=inputs,
144
+ reconstructions=reconstructions,
145
+ posteriors=posterior,
146
+ optimizer_idx=optimizer_idx,
147
+ global_step=global_step,
148
+ last_layer=self.model.get_last_layer(),
149
+ split="train",
150
+ )
151
+
152
+ if optimizer_idx == 0:
153
+ total_loss = train_losses["loss"]
154
+ self.optimizer["opt_ae"].zero_grad()
155
+ total_loss.backward()
156
+ self.optimizer["opt_ae"].step()
157
+
158
+ else:
159
+ total_loss = train_losses["d_loss"]
160
+ self.optimizer["opt_disc"].zero_grad()
161
+ total_loss.backward()
162
+ self.optimizer["opt_disc"].step()
163
+
164
+ for item in train_losses:
165
+ train_losses[item] = train_losses[item].item()
166
+
167
+ return train_losses, train_states, total_loss.item()
168
+
169
+ # TODO: eval step
170
+ @torch.no_grad()
171
+ def eval_step(self, data, index):
172
+ valid_loss = {}
173
+ total_valid_loss = 0
174
+ valid_stats = {}
175
+
176
+ inputs = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
177
+ reconstructions, posterior = self.model(inputs)
178
+
179
+ loss = F.l1_loss(inputs, reconstructions)
180
+ valid_loss["loss"] = loss
181
+
182
+ total_valid_loss += loss
183
+
184
+ for item in valid_loss:
185
+ valid_loss[item] = valid_loss[item].item()
186
+
187
+ return valid_loss, valid_stats, total_valid_loss.item()
models/tta/ldm/__init__.py ADDED
File without changes
models/tta/ldm/attention.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from inspect import isfunction
7
+ import math
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn, einsum
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ class CheckpointFunction(torch.autograd.Function):
15
+ @staticmethod
16
+ def forward(ctx, run_function, length, *args):
17
+ ctx.run_function = run_function
18
+ ctx.input_tensors = list(args[:length])
19
+ ctx.input_params = list(args[length:])
20
+
21
+ with torch.no_grad():
22
+ output_tensors = ctx.run_function(*ctx.input_tensors)
23
+ return output_tensors
24
+
25
+ @staticmethod
26
+ def backward(ctx, *output_grads):
27
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
28
+ with torch.enable_grad():
29
+ # Fixes a bug where the first op in run_function modifies the
30
+ # Tensor storage in place, which is not allowed for detach()'d
31
+ # Tensors.
32
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
33
+ output_tensors = ctx.run_function(*shallow_copies)
34
+ input_grads = torch.autograd.grad(
35
+ output_tensors,
36
+ ctx.input_tensors + ctx.input_params,
37
+ output_grads,
38
+ allow_unused=True,
39
+ )
40
+ del ctx.input_tensors
41
+ del ctx.input_params
42
+ del output_tensors
43
+ return (None, None) + input_grads
44
+
45
+
46
+ def checkpoint(func, inputs, params, flag):
47
+ """
48
+ Evaluate a function without caching intermediate activations, allowing for
49
+ reduced memory at the expense of extra compute in the backward pass.
50
+ :param func: the function to evaluate.
51
+ :param inputs: the argument sequence to pass to `func`.
52
+ :param params: a sequence of parameters `func` depends on but does not
53
+ explicitly take as arguments.
54
+ :param flag: if False, disable gradient checkpointing.
55
+ """
56
+ if flag:
57
+ args = tuple(inputs) + tuple(params)
58
+ return CheckpointFunction.apply(func, len(inputs), *args)
59
+ else:
60
+ return func(*inputs)
61
+
62
+
63
+ def exists(val):
64
+ return val is not None
65
+
66
+
67
+ def uniq(arr):
68
+ return {el: True for el in arr}.keys()
69
+
70
+
71
+ def default(val, d):
72
+ if exists(val):
73
+ return val
74
+ return d() if isfunction(d) else d
75
+
76
+
77
+ def max_neg_value(t):
78
+ return -torch.finfo(t.dtype).max
79
+
80
+
81
+ def init_(tensor):
82
+ dim = tensor.shape[-1]
83
+ std = 1 / math.sqrt(dim)
84
+ tensor.uniform_(-std, std)
85
+ return tensor
86
+
87
+
88
+ # feedforward
89
+ class GEGLU(nn.Module):
90
+ def __init__(self, dim_in, dim_out):
91
+ super().__init__()
92
+ self.proj = nn.Linear(dim_in, dim_out * 2)
93
+
94
+ def forward(self, x):
95
+ x, gate = self.proj(x).chunk(2, dim=-1)
96
+ return x * F.gelu(gate)
97
+
98
+
99
+ class FeedForward(nn.Module):
100
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
101
+ super().__init__()
102
+ inner_dim = int(dim * mult)
103
+ dim_out = default(dim_out, dim)
104
+ project_in = (
105
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
106
+ if not glu
107
+ else GEGLU(dim, inner_dim)
108
+ )
109
+
110
+ self.net = nn.Sequential(
111
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
112
+ )
113
+
114
+ def forward(self, x):
115
+ return self.net(x)
116
+
117
+
118
+ def zero_module(module):
119
+ """
120
+ Zero out the parameters of a module and return it.
121
+ """
122
+ for p in module.parameters():
123
+ p.detach().zero_()
124
+ return module
125
+
126
+
127
+ def Normalize(in_channels):
128
+ return torch.nn.GroupNorm(
129
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
130
+ )
131
+
132
+
133
+ class LinearAttention(nn.Module):
134
+ def __init__(self, dim, heads=4, dim_head=32):
135
+ super().__init__()
136
+ self.heads = heads
137
+ hidden_dim = dim_head * heads
138
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
139
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
140
+
141
+ def forward(self, x):
142
+ b, c, h, w = x.shape
143
+ qkv = self.to_qkv(x)
144
+ q, k, v = rearrange(
145
+ qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
146
+ )
147
+ k = k.softmax(dim=-1)
148
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
149
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
150
+ out = rearrange(
151
+ out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
152
+ )
153
+ return self.to_out(out)
154
+
155
+
156
+ class SpatialSelfAttention(nn.Module):
157
+ def __init__(self, in_channels):
158
+ super().__init__()
159
+ self.in_channels = in_channels
160
+
161
+ self.norm = Normalize(in_channels)
162
+ self.q = torch.nn.Conv2d(
163
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
164
+ )
165
+ self.k = torch.nn.Conv2d(
166
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
167
+ )
168
+ self.v = torch.nn.Conv2d(
169
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
170
+ )
171
+ self.proj_out = torch.nn.Conv2d(
172
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
173
+ )
174
+
175
+ def forward(self, x):
176
+ h_ = x
177
+ h_ = self.norm(h_)
178
+ q = self.q(h_)
179
+ k = self.k(h_)
180
+ v = self.v(h_)
181
+
182
+ # compute attention
183
+ b, c, h, w = q.shape
184
+ q = rearrange(q, "b c h w -> b (h w) c")
185
+ k = rearrange(k, "b c h w -> b c (h w)")
186
+ w_ = torch.einsum("bij,bjk->bik", q, k)
187
+
188
+ w_ = w_ * (int(c) ** (-0.5))
189
+ w_ = torch.nn.functional.softmax(w_, dim=2)
190
+
191
+ # attend to values
192
+ v = rearrange(v, "b c h w -> b c (h w)")
193
+ w_ = rearrange(w_, "b i j -> b j i")
194
+ h_ = torch.einsum("bij,bjk->bik", v, w_)
195
+ h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
196
+ h_ = self.proj_out(h_)
197
+
198
+ return x + h_
199
+
200
+
201
+ class CrossAttention(nn.Module):
202
+ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
203
+ super().__init__()
204
+ inner_dim = dim_head * heads
205
+ context_dim = default(context_dim, query_dim)
206
+
207
+ self.scale = dim_head**-0.5
208
+ self.heads = heads
209
+
210
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
211
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
212
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
213
+
214
+ self.to_out = nn.Sequential(
215
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
216
+ )
217
+
218
+ def forward(self, x, context=None, mask=None):
219
+ h = self.heads
220
+
221
+ q = self.to_q(x)
222
+ context = default(context, x)
223
+ k = self.to_k(context)
224
+ v = self.to_v(context)
225
+
226
+ q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
227
+
228
+ sim = einsum("b i d, b j d -> b i j", q, k) * self.scale
229
+
230
+ if exists(mask):
231
+ mask = rearrange(mask, "b ... -> b (...)")
232
+ max_neg_value = -torch.finfo(sim.dtype).max
233
+ mask = repeat(mask, "b j -> (b h) () j", h=h)
234
+ sim.masked_fill_(~mask, max_neg_value)
235
+
236
+ # attention, what we cannot get enough of
237
+ attn = sim.softmax(dim=-1)
238
+
239
+ out = einsum("b i j, b j d -> b i d", attn, v)
240
+ out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
241
+ return self.to_out(out)
242
+
243
+
244
+ class BasicTransformerBlock(nn.Module):
245
+ def __init__(
246
+ self,
247
+ dim,
248
+ n_heads,
249
+ d_head,
250
+ dropout=0.0,
251
+ context_dim=None,
252
+ gated_ff=True,
253
+ checkpoint=True,
254
+ ):
255
+ super().__init__()
256
+ self.attn1 = CrossAttention(
257
+ query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
258
+ ) # is a self-attention
259
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
260
+ self.attn2 = CrossAttention(
261
+ query_dim=dim,
262
+ context_dim=context_dim,
263
+ heads=n_heads,
264
+ dim_head=d_head,
265
+ dropout=dropout,
266
+ ) # is self-attn if context is none
267
+ self.norm1 = nn.LayerNorm(dim)
268
+ self.norm2 = nn.LayerNorm(dim)
269
+ self.norm3 = nn.LayerNorm(dim)
270
+ self.checkpoint = checkpoint
271
+
272
+ def forward(self, x, context=None):
273
+ return checkpoint(
274
+ self._forward, (x, context), self.parameters(), self.checkpoint
275
+ )
276
+
277
+ def _forward(self, x, context=None):
278
+ x = self.attn1(self.norm1(x)) + x
279
+ x = self.attn2(self.norm2(x), context=context) + x
280
+ x = self.ff(self.norm3(x)) + x
281
+ return x
282
+
283
+
284
+ class SpatialTransformer(nn.Module):
285
+ """
286
+ Transformer block for image-like data.
287
+ First, project the input (aka embedding)
288
+ and reshape to b, t, d.
289
+ Then apply standard transformer action.
290
+ Finally, reshape to image
291
+ """
292
+
293
+ def __init__(
294
+ self, in_channels, n_heads, d_head, depth=1, dropout=0.0, context_dim=None
295
+ ):
296
+ super().__init__()
297
+ self.in_channels = in_channels
298
+ inner_dim = n_heads * d_head
299
+ self.norm = Normalize(in_channels)
300
+
301
+ self.proj_in = nn.Conv2d(
302
+ in_channels, inner_dim, kernel_size=1, stride=1, padding=0
303
+ )
304
+
305
+ self.transformer_blocks = nn.ModuleList(
306
+ [
307
+ BasicTransformerBlock(
308
+ inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
309
+ )
310
+ for d in range(depth)
311
+ ]
312
+ )
313
+
314
+ self.proj_out = zero_module(
315
+ nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
316
+ )
317
+
318
+ def forward(self, x, context=None):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ b, c, h, w = x.shape
321
+ x_in = x
322
+ x = self.norm(x)
323
+ x = self.proj_in(x)
324
+ x = rearrange(x, "b c h w -> b (h w) c")
325
+ for block in self.transformer_blocks:
326
+ x = block(x, context=context)
327
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
328
+ x = self.proj_out(x)
329
+ return x + x_in
models/tta/ldm/audioldm.py ADDED
@@ -0,0 +1,926 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from abc import abstractmethod
7
+ from functools import partial
8
+ import math
9
+ from typing import Iterable
10
+
11
+ import os
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ from einops import repeat
17
+
18
+ from models.tta.ldm.attention import SpatialTransformer
19
+
20
+ # from attention import SpatialTransformer
21
+
22
+
23
+ class CheckpointFunction(torch.autograd.Function):
24
+ @staticmethod
25
+ def forward(ctx, run_function, length, *args):
26
+ ctx.run_function = run_function
27
+ ctx.input_tensors = list(args[:length])
28
+ ctx.input_params = list(args[length:])
29
+
30
+ with torch.no_grad():
31
+ output_tensors = ctx.run_function(*ctx.input_tensors)
32
+ return output_tensors
33
+
34
+ @staticmethod
35
+ def backward(ctx, *output_grads):
36
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
37
+ with torch.enable_grad():
38
+ # Fixes a bug where the first op in run_function modifies the
39
+ # Tensor storage in place, which is not allowed for detach()'d
40
+ # Tensors.
41
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
42
+ output_tensors = ctx.run_function(*shallow_copies)
43
+ input_grads = torch.autograd.grad(
44
+ output_tensors,
45
+ ctx.input_tensors + ctx.input_params,
46
+ output_grads,
47
+ allow_unused=True,
48
+ )
49
+ del ctx.input_tensors
50
+ del ctx.input_params
51
+ del output_tensors
52
+ return (None, None) + input_grads
53
+
54
+
55
+ def checkpoint(func, inputs, params, flag):
56
+ """
57
+ Evaluate a function without caching intermediate activations, allowing for
58
+ reduced memory at the expense of extra compute in the backward pass.
59
+ :param func: the function to evaluate.
60
+ :param inputs: the argument sequence to pass to `func`.
61
+ :param params: a sequence of parameters `func` depends on but does not
62
+ explicitly take as arguments.
63
+ :param flag: if False, disable gradient checkpointing.
64
+ """
65
+ if flag:
66
+ args = tuple(inputs) + tuple(params)
67
+ return CheckpointFunction.apply(func, len(inputs), *args)
68
+ else:
69
+ return func(*inputs)
70
+
71
+
72
+ def zero_module(module):
73
+ """
74
+ Zero out the parameters of a module and return it.
75
+ """
76
+ for p in module.parameters():
77
+ p.detach().zero_()
78
+ return module
79
+
80
+
81
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
82
+ """
83
+ Create sinusoidal timestep embeddings.
84
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
85
+ These may be fractional.
86
+ :param dim: the dimension of the output.
87
+ :param max_period: controls the minimum frequency of the embeddings.
88
+ :return: an [N x dim] Tensor of positional embeddings.
89
+ """
90
+ if not repeat_only:
91
+ half = dim // 2
92
+ freqs = torch.exp(
93
+ -math.log(max_period)
94
+ * torch.arange(start=0, end=half, dtype=torch.float32)
95
+ / half
96
+ ).to(device=timesteps.device)
97
+ args = timesteps[:, None].float() * freqs[None]
98
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
99
+ if dim % 2:
100
+ embedding = torch.cat(
101
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
102
+ )
103
+ else:
104
+ embedding = repeat(timesteps, "b -> b d", d=dim)
105
+ return embedding
106
+
107
+
108
+ class GroupNorm32(nn.GroupNorm):
109
+ def forward(self, x):
110
+ return super().forward(x.float()).type(x.dtype)
111
+
112
+
113
+ def normalization(channels):
114
+ """
115
+ Make a standard normalization layer.
116
+ :param channels: number of input channels.
117
+ :return: an nn.Module for normalization.
118
+ """
119
+ return GroupNorm32(32, channels)
120
+
121
+
122
+ def count_flops_attn(model, _x, y):
123
+ """
124
+ A counter for the `thop` package to count the operations in an
125
+ attention operation.
126
+ Meant to be used like:
127
+ macs, params = thop.profile(
128
+ model,
129
+ inputs=(inputs, timestamps),
130
+ custom_ops={QKVAttention: QKVAttention.count_flops},
131
+ )
132
+ """
133
+ b, c, *spatial = y[0].shape
134
+ num_spatial = int(np.prod(spatial))
135
+ # We perform two matmuls with the same number of ops.
136
+ # The first computes the weight matrix, the second computes
137
+ # the combination of the value vectors.
138
+ matmul_ops = 2 * b * (num_spatial**2) * c
139
+ model.total_ops += torch.DoubleTensor([matmul_ops])
140
+
141
+
142
+ def conv_nd(dims, *args, **kwargs):
143
+ """
144
+ Create a 1D, 2D, or 3D convolution module.
145
+ """
146
+ if dims == 1:
147
+ return nn.Conv1d(*args, **kwargs)
148
+ elif dims == 2:
149
+ return nn.Conv2d(*args, **kwargs)
150
+ elif dims == 3:
151
+ return nn.Conv3d(*args, **kwargs)
152
+ raise ValueError(f"unsupported dimensions: {dims}")
153
+
154
+
155
+ def avg_pool_nd(dims, *args, **kwargs):
156
+ """
157
+ Create a 1D, 2D, or 3D average pooling module.
158
+ """
159
+ if dims == 1:
160
+ return nn.AvgPool1d(*args, **kwargs)
161
+ elif dims == 2:
162
+ return nn.AvgPool2d(*args, **kwargs)
163
+ elif dims == 3:
164
+ return nn.AvgPool3d(*args, **kwargs)
165
+ raise ValueError(f"unsupported dimensions: {dims}")
166
+
167
+
168
+ class QKVAttention(nn.Module):
169
+ """
170
+ A module which performs QKV attention and splits in a different order.
171
+ """
172
+
173
+ def __init__(self, n_heads):
174
+ super().__init__()
175
+ self.n_heads = n_heads
176
+
177
+ def forward(self, qkv):
178
+ """
179
+ Apply QKV attention.
180
+ :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
181
+ :return: an [N x (H * C) x T] tensor after attention.
182
+ """
183
+
184
+ bs, width, length = qkv.shape
185
+ assert width % (3 * self.n_heads) == 0
186
+ ch = width // (3 * self.n_heads)
187
+ q, k, v = qkv.chunk(3, dim=1) # [N x (H * C) x T]
188
+ scale = 1 / math.sqrt(math.sqrt(ch))
189
+ weight = torch.einsum(
190
+ "bct,bcs->bts",
191
+ (q * scale).view(bs * self.n_heads, ch, length),
192
+ (k * scale).view(bs * self.n_heads, ch, length),
193
+ ) # More stable with f16 than dividing afterwards
194
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
195
+ a = torch.einsum(
196
+ "bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length)
197
+ )
198
+ return a.reshape(bs, -1, length)
199
+
200
+ @staticmethod
201
+ def count_flops(model, _x, y):
202
+ return count_flops_attn(model, _x, y)
203
+
204
+
205
+ class QKVAttentionLegacy(nn.Module):
206
+ """
207
+ A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
208
+ """
209
+
210
+ def __init__(self, n_heads):
211
+ super().__init__()
212
+ self.n_heads = n_heads
213
+
214
+ def forward(self, qkv):
215
+ """
216
+ Apply QKV attention.
217
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
218
+ :return: an [N x (H * C) x T] tensor after attention.
219
+ """
220
+ bs, width, length = qkv.shape
221
+ assert width % (3 * self.n_heads) == 0
222
+ ch = width // (3 * self.n_heads)
223
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
224
+ scale = 1 / math.sqrt(math.sqrt(ch))
225
+ weight = torch.einsum(
226
+ "bct,bcs->bts", q * scale, k * scale
227
+ ) # More stable with f16 than dividing afterwards
228
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
229
+ a = torch.einsum("bts,bcs->bct", weight, v)
230
+ return a.reshape(bs, -1, length)
231
+
232
+ @staticmethod
233
+ def count_flops(model, _x, y):
234
+ return count_flops_attn(model, _x, y)
235
+
236
+
237
+ class AttentionPool2d(nn.Module):
238
+ """
239
+ Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ spacial_dim: int,
245
+ embed_dim: int,
246
+ num_heads_channels: int,
247
+ output_dim: int = None,
248
+ ):
249
+ super().__init__()
250
+ self.positional_embedding = nn.Parameter(
251
+ torch.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5
252
+ )
253
+ self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
254
+ self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
255
+ self.num_heads = embed_dim // num_heads_channels
256
+ self.attention = QKVAttention(self.num_heads)
257
+
258
+ def forward(self, x):
259
+ b, c, *_spatial = x.shape
260
+ x = x.reshape(b, c, -1) # NC(HW)
261
+ x = torch.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
262
+ x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
263
+ x = self.qkv_proj(x)
264
+ x = self.attention(x)
265
+ x = self.c_proj(x)
266
+ return x[:, :, 0]
267
+
268
+
269
+ class TimestepBlock(nn.Module):
270
+ """
271
+ Any module where forward() takes timestep embeddings as a second argument.
272
+ """
273
+
274
+ @abstractmethod
275
+ def forward(self, x, emb):
276
+ """
277
+ Apply the module to `x` given `emb` timestep embeddings.
278
+ """
279
+
280
+
281
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
282
+ """
283
+ A sequential module that passes timestep embeddings to the children that
284
+ support it as an extra input.
285
+ """
286
+
287
+ def forward(self, x, emb, context=None):
288
+ for layer in self:
289
+ if isinstance(layer, TimestepBlock):
290
+ x = layer(x, emb)
291
+ elif isinstance(layer, SpatialTransformer):
292
+ x = layer(x, context)
293
+ else:
294
+ x = layer(x)
295
+ return x
296
+
297
+
298
+ class Upsample(nn.Module):
299
+ """
300
+ An upsampling layer with an optional convolution.
301
+ :param channels: channels in the inputs and outputs.
302
+ :param use_conv: a bool determining if a convolution is applied.
303
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
304
+ upsampling occurs in the inner-two dimensions.
305
+ """
306
+
307
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
308
+ super().__init__()
309
+ self.channels = channels
310
+ self.out_channels = out_channels or channels
311
+ self.use_conv = use_conv
312
+ self.dims = dims
313
+ if use_conv:
314
+ self.conv = conv_nd(
315
+ dims, self.channels, self.out_channels, 3, padding=padding
316
+ )
317
+
318
+ def forward(self, x):
319
+ assert x.shape[1] == self.channels
320
+ if self.dims == 3:
321
+ x = F.interpolate(
322
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
323
+ )
324
+ else:
325
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
326
+ if self.use_conv:
327
+ x = self.conv(x)
328
+ return x
329
+
330
+
331
+ class TransposedUpsample(nn.Module):
332
+ "Learned 2x upsampling without padding"
333
+
334
+ def __init__(self, channels, out_channels=None, ks=5):
335
+ super().__init__()
336
+ self.channels = channels
337
+ self.out_channels = out_channels or channels
338
+
339
+ self.up = nn.ConvTranspose2d(
340
+ self.channels, self.out_channels, kernel_size=ks, stride=2
341
+ )
342
+
343
+ def forward(self, x):
344
+ return self.up(x)
345
+
346
+
347
+ class Downsample(nn.Module):
348
+ """
349
+ A downsampling layer with an optional convolution.
350
+ :param channels: channels in the inputs and outputs.
351
+ :param use_conv: a bool determining if a convolution is applied.
352
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
353
+ downsampling occurs in the inner-two dimensions.
354
+ """
355
+
356
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
357
+ super().__init__()
358
+ self.channels = channels
359
+ self.out_channels = out_channels or channels
360
+ self.use_conv = use_conv
361
+ self.dims = dims
362
+ stride = 2 if dims != 3 else (1, 2, 2)
363
+ if use_conv:
364
+ self.op = conv_nd(
365
+ dims,
366
+ self.channels,
367
+ self.out_channels,
368
+ 3,
369
+ stride=stride,
370
+ padding=padding,
371
+ )
372
+ else:
373
+ assert self.channels == self.out_channels
374
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
375
+
376
+ def forward(self, x):
377
+ assert x.shape[1] == self.channels
378
+ return self.op(x)
379
+
380
+
381
+ class ResBlock(TimestepBlock):
382
+ """
383
+ A residual block that can optionally change the number of channels.
384
+ :param channels: the number of input channels.
385
+ :param emb_channels: the number of timestep embedding channels.
386
+ :param dropout: the rate of dropout.
387
+ :param out_channels: if specified, the number of out channels.
388
+ :param use_conv: if True and out_channels is specified, use a spatial
389
+ convolution instead of a smaller 1x1 convolution to change the
390
+ channels in the skip connection.
391
+ :param dims: determines if the signal is 1D, 2D, or 3D.
392
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
393
+ :param up: if True, use this block for upsampling.
394
+ :param down: if True, use this block for downsampling.
395
+ """
396
+
397
+ def __init__(
398
+ self,
399
+ channels,
400
+ emb_channels,
401
+ dropout,
402
+ out_channels=None,
403
+ use_conv=False,
404
+ use_scale_shift_norm=False,
405
+ dims=2,
406
+ use_checkpoint=False,
407
+ up=False,
408
+ down=False,
409
+ ):
410
+ super().__init__()
411
+ self.channels = channels
412
+ self.emb_channels = emb_channels
413
+ self.dropout = dropout
414
+ self.out_channels = out_channels or channels
415
+ self.use_conv = use_conv
416
+ self.use_checkpoint = use_checkpoint
417
+ self.use_scale_shift_norm = use_scale_shift_norm
418
+
419
+ self.in_layers = nn.Sequential(
420
+ normalization(channels),
421
+ nn.SiLU(),
422
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
423
+ )
424
+
425
+ self.updown = up or down
426
+
427
+ if up:
428
+ self.h_upd = Upsample(channels, False, dims)
429
+ self.x_upd = Upsample(channels, False, dims)
430
+ elif down:
431
+ self.h_upd = Downsample(channels, False, dims)
432
+ self.x_upd = Downsample(channels, False, dims)
433
+ else:
434
+ self.h_upd = self.x_upd = nn.Identity()
435
+
436
+ self.emb_layers = nn.Sequential(
437
+ nn.SiLU(),
438
+ nn.Linear(
439
+ emb_channels,
440
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
441
+ ),
442
+ )
443
+ self.out_layers = nn.Sequential(
444
+ normalization(self.out_channels),
445
+ nn.SiLU(),
446
+ nn.Dropout(p=dropout),
447
+ zero_module(
448
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
449
+ ),
450
+ )
451
+
452
+ if self.out_channels == channels:
453
+ self.skip_connection = nn.Identity()
454
+ elif use_conv:
455
+ self.skip_connection = conv_nd(
456
+ dims, channels, self.out_channels, 3, padding=1
457
+ )
458
+ else:
459
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
460
+
461
+ def forward(self, x, emb):
462
+ """
463
+ Apply the block to a Tensor, conditioned on a timestep embedding.
464
+ :param x: an [N x C x ...] Tensor of features.
465
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
466
+ :return: an [N x C x ...] Tensor of outputs.
467
+ """
468
+ return checkpoint(
469
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
470
+ )
471
+
472
+ def _forward(self, x, emb):
473
+ if self.updown:
474
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
475
+ h = in_rest(x)
476
+ h = self.h_upd(h)
477
+ x = self.x_upd(x)
478
+ h = in_conv(h)
479
+ else:
480
+ h = self.in_layers(x)
481
+ emb_out = self.emb_layers(emb).type(h.dtype)
482
+ while len(emb_out.shape) < len(h.shape):
483
+ emb_out = emb_out[..., None]
484
+ if self.use_scale_shift_norm:
485
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
486
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
487
+ h = out_norm(h) * (1 + scale) + shift
488
+ h = out_rest(h)
489
+ else:
490
+ h = h + emb_out
491
+ h = self.out_layers(h)
492
+ return self.skip_connection(x) + h
493
+
494
+
495
+ class AttentionBlock(nn.Module):
496
+ """
497
+ An attention block that allows spatial positions to attend to each other.
498
+ Originally ported from here, but adapted to the N-d case.
499
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
500
+ """
501
+
502
+ def __init__(
503
+ self,
504
+ channels,
505
+ num_heads=1,
506
+ num_head_channels=-1,
507
+ use_checkpoint=False,
508
+ use_new_attention_order=False,
509
+ ):
510
+ super().__init__()
511
+ self.channels = channels
512
+ if num_head_channels == -1:
513
+ self.num_heads = num_heads
514
+ else:
515
+ assert (
516
+ channels % num_head_channels == 0
517
+ ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
518
+ self.num_heads = channels // num_head_channels
519
+ self.use_checkpoint = use_checkpoint
520
+ self.norm = normalization(channels)
521
+ self.qkv = conv_nd(1, channels, channels * 3, 1)
522
+ if use_new_attention_order:
523
+ # split qkv before split heads
524
+ self.attention = QKVAttention(self.num_heads)
525
+ else:
526
+ # split heads before split qkv
527
+ self.attention = QKVAttentionLegacy(self.num_heads)
528
+
529
+ self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
530
+
531
+ def forward(self, x):
532
+ return checkpoint(
533
+ self._forward, (x,), self.parameters(), True
534
+ ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
535
+ # return pt_checkpoint(self._forward, x) # pytorch
536
+
537
+ def _forward(self, x):
538
+ b, c, *spatial = x.shape
539
+ x = x.reshape(b, c, -1)
540
+ qkv = self.qkv(self.norm(x))
541
+ h = self.attention(qkv)
542
+ h = self.proj_out(h)
543
+ return (x + h).reshape(b, c, *spatial)
544
+
545
+
546
+ class UNetModel(nn.Module):
547
+ """
548
+ The full UNet model with attention and timestep embedding.
549
+ :param in_channels: channels in the input Tensor.
550
+ :param model_channels: base channel count for the model.
551
+ :param out_channels: channels in the output Tensor.
552
+ :param num_res_blocks: number of residual blocks per downsample.
553
+ :param attention_resolutions: a collection of downsample rates at which
554
+ attention will take place. May be a set, list, or tuple.
555
+ For example, if this contains 4, then at 4x downsampling, attention
556
+ will be used.
557
+ :param dropout: the dropout probability.
558
+ :param channel_mult: channel multiplier for each level of the UNet.
559
+ :param conv_resample: if True, use learned convolutions for upsampling and
560
+ downsampling.
561
+ :param dims: determines if the signal is 1D, 2D, or 3D.
562
+ :param num_classes: if specified (as an int), then this model will be
563
+ class-conditional with `num_classes` classes.
564
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
565
+ :param num_heads: the number of attention heads in each attention layer.
566
+ :param num_heads_channels: if specified, ignore num_heads and instead use
567
+ a fixed channel width per attention head.
568
+ :param num_heads_upsample: works with num_heads to set a different number
569
+ of heads for upsampling. Deprecated.
570
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
571
+ :param resblock_updown: use residual blocks for up/downsampling.
572
+ :param use_new_attention_order: use a different attention pattern for potentially
573
+ increased efficiency.
574
+ """
575
+
576
+ def __init__(
577
+ self,
578
+ image_size,
579
+ in_channels,
580
+ model_channels,
581
+ out_channels,
582
+ num_res_blocks,
583
+ attention_resolutions,
584
+ dropout=0,
585
+ channel_mult=(1, 2, 4, 8),
586
+ conv_resample=True,
587
+ dims=2,
588
+ num_classes=None,
589
+ use_checkpoint=False,
590
+ use_fp16=False,
591
+ num_heads=-1,
592
+ num_head_channels=-1,
593
+ num_heads_upsample=-1,
594
+ use_scale_shift_norm=False,
595
+ resblock_updown=False,
596
+ use_new_attention_order=False,
597
+ use_spatial_transformer=False, # custom transformer support
598
+ transformer_depth=1, # custom transformer support
599
+ context_dim=None, # custom transformer support
600
+ n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
601
+ legacy=True,
602
+ ):
603
+ super().__init__()
604
+ if use_spatial_transformer:
605
+ assert (
606
+ context_dim is not None
607
+ ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..."
608
+
609
+ if context_dim is not None:
610
+ assert (
611
+ use_spatial_transformer
612
+ ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..."
613
+ from omegaconf.listconfig import ListConfig
614
+
615
+ if type(context_dim) == ListConfig:
616
+ context_dim = list(context_dim)
617
+
618
+ if num_heads_upsample == -1:
619
+ num_heads_upsample = num_heads
620
+
621
+ if num_heads == -1:
622
+ assert (
623
+ num_head_channels != -1
624
+ ), "Either num_heads or num_head_channels has to be set"
625
+
626
+ if num_head_channels == -1:
627
+ assert (
628
+ num_heads != -1
629
+ ), "Either num_heads or num_head_channels has to be set"
630
+
631
+ self.image_size = image_size
632
+ self.in_channels = in_channels
633
+ self.model_channels = model_channels
634
+ self.out_channels = out_channels
635
+ self.num_res_blocks = num_res_blocks
636
+ self.attention_resolutions = attention_resolutions
637
+ self.dropout = dropout
638
+ self.channel_mult = channel_mult
639
+ self.conv_resample = conv_resample
640
+ self.num_classes = num_classes
641
+ self.use_checkpoint = use_checkpoint
642
+ self.dtype = torch.float16 if use_fp16 else torch.float32
643
+ self.num_heads = num_heads
644
+ self.num_head_channels = num_head_channels
645
+ self.num_heads_upsample = num_heads_upsample
646
+ self.predict_codebook_ids = n_embed is not None
647
+
648
+ time_embed_dim = model_channels * 4
649
+ self.time_embed = nn.Sequential(
650
+ nn.Linear(model_channels, time_embed_dim),
651
+ nn.SiLU(),
652
+ nn.Linear(time_embed_dim, time_embed_dim),
653
+ )
654
+
655
+ if self.num_classes is not None:
656
+ self.label_emb = nn.Embedding(num_classes, time_embed_dim)
657
+
658
+ self.input_blocks = nn.ModuleList(
659
+ [
660
+ TimestepEmbedSequential(
661
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
662
+ )
663
+ ]
664
+ )
665
+ self._feature_size = model_channels
666
+ input_block_chans = [model_channels]
667
+ ch = model_channels
668
+ ds = 1
669
+ for level, mult in enumerate(channel_mult):
670
+ for _ in range(num_res_blocks):
671
+ layers = [
672
+ ResBlock(
673
+ ch,
674
+ time_embed_dim,
675
+ dropout,
676
+ out_channels=mult * model_channels,
677
+ dims=dims,
678
+ use_checkpoint=use_checkpoint,
679
+ use_scale_shift_norm=use_scale_shift_norm,
680
+ )
681
+ ]
682
+ ch = mult * model_channels
683
+ if ds in attention_resolutions:
684
+ if num_head_channels == -1:
685
+ dim_head = ch // num_heads
686
+ else:
687
+ num_heads = ch // num_head_channels
688
+ dim_head = num_head_channels
689
+ if legacy:
690
+ # num_heads = 1
691
+ dim_head = (
692
+ ch // num_heads
693
+ if use_spatial_transformer
694
+ else num_head_channels
695
+ )
696
+ layers.append(
697
+ AttentionBlock(
698
+ ch,
699
+ use_checkpoint=use_checkpoint,
700
+ num_heads=num_heads,
701
+ num_head_channels=dim_head,
702
+ use_new_attention_order=use_new_attention_order,
703
+ )
704
+ if not use_spatial_transformer
705
+ else SpatialTransformer(
706
+ ch,
707
+ num_heads,
708
+ dim_head,
709
+ depth=transformer_depth,
710
+ context_dim=context_dim,
711
+ )
712
+ )
713
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
714
+ self._feature_size += ch
715
+ input_block_chans.append(ch)
716
+ if level != len(channel_mult) - 1:
717
+ out_ch = ch
718
+ self.input_blocks.append(
719
+ TimestepEmbedSequential(
720
+ ResBlock(
721
+ ch,
722
+ time_embed_dim,
723
+ dropout,
724
+ out_channels=out_ch,
725
+ dims=dims,
726
+ use_checkpoint=use_checkpoint,
727
+ use_scale_shift_norm=use_scale_shift_norm,
728
+ down=True,
729
+ )
730
+ if resblock_updown
731
+ else Downsample(
732
+ ch, conv_resample, dims=dims, out_channels=out_ch
733
+ )
734
+ )
735
+ )
736
+ ch = out_ch
737
+ input_block_chans.append(ch)
738
+ ds *= 2
739
+ self._feature_size += ch
740
+
741
+ if num_head_channels == -1:
742
+ dim_head = ch // num_heads
743
+ else:
744
+ num_heads = ch // num_head_channels
745
+ dim_head = num_head_channels
746
+ if legacy:
747
+ # num_heads = 1
748
+ dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
749
+ self.middle_block = TimestepEmbedSequential(
750
+ ResBlock(
751
+ ch,
752
+ time_embed_dim,
753
+ dropout,
754
+ dims=dims,
755
+ use_checkpoint=use_checkpoint,
756
+ use_scale_shift_norm=use_scale_shift_norm,
757
+ ),
758
+ AttentionBlock(
759
+ ch,
760
+ use_checkpoint=use_checkpoint,
761
+ num_heads=num_heads,
762
+ num_head_channels=dim_head,
763
+ use_new_attention_order=use_new_attention_order,
764
+ )
765
+ if not use_spatial_transformer
766
+ else SpatialTransformer(
767
+ ch,
768
+ num_heads,
769
+ dim_head,
770
+ depth=transformer_depth,
771
+ context_dim=context_dim,
772
+ ),
773
+ ResBlock(
774
+ ch,
775
+ time_embed_dim,
776
+ dropout,
777
+ dims=dims,
778
+ use_checkpoint=use_checkpoint,
779
+ use_scale_shift_norm=use_scale_shift_norm,
780
+ ),
781
+ )
782
+ self._feature_size += ch
783
+
784
+ self.output_blocks = nn.ModuleList([])
785
+ for level, mult in list(enumerate(channel_mult))[::-1]:
786
+ for i in range(num_res_blocks + 1):
787
+ ich = input_block_chans.pop()
788
+ layers = [
789
+ ResBlock(
790
+ ch + ich,
791
+ time_embed_dim,
792
+ dropout,
793
+ out_channels=model_channels * mult,
794
+ dims=dims,
795
+ use_checkpoint=use_checkpoint,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = model_channels * mult
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+ if legacy:
807
+ # num_heads = 1
808
+ dim_head = (
809
+ ch // num_heads
810
+ if use_spatial_transformer
811
+ else num_head_channels
812
+ )
813
+ layers.append(
814
+ AttentionBlock(
815
+ ch,
816
+ use_checkpoint=use_checkpoint,
817
+ num_heads=num_heads_upsample,
818
+ num_head_channels=dim_head,
819
+ use_new_attention_order=use_new_attention_order,
820
+ )
821
+ if not use_spatial_transformer
822
+ else SpatialTransformer(
823
+ ch,
824
+ num_heads,
825
+ dim_head,
826
+ depth=transformer_depth,
827
+ context_dim=context_dim,
828
+ )
829
+ )
830
+ if level and i == num_res_blocks:
831
+ out_ch = ch
832
+ layers.append(
833
+ ResBlock(
834
+ ch,
835
+ time_embed_dim,
836
+ dropout,
837
+ out_channels=out_ch,
838
+ dims=dims,
839
+ use_checkpoint=use_checkpoint,
840
+ use_scale_shift_norm=use_scale_shift_norm,
841
+ up=True,
842
+ )
843
+ if resblock_updown
844
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
845
+ )
846
+ ds //= 2
847
+ self.output_blocks.append(TimestepEmbedSequential(*layers))
848
+ self._feature_size += ch
849
+
850
+ self.out = nn.Sequential(
851
+ normalization(ch),
852
+ nn.SiLU(),
853
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
854
+ )
855
+ if self.predict_codebook_ids:
856
+ self.id_predictor = nn.Sequential(
857
+ normalization(ch),
858
+ conv_nd(dims, model_channels, n_embed, 1),
859
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
860
+ )
861
+
862
+ def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
863
+ """
864
+ Apply the model to an input batch.
865
+ :param x: an [N x C x ...] Tensor of inputs.
866
+ :param timesteps: a 1-D batch of timesteps.
867
+ :param context: conditioning plugged in via crossattn
868
+ :param y: an [N] Tensor of labels, if class-conditional.
869
+ :return: an [N x C x ...] Tensor of outputs.
870
+ """
871
+ assert (y is not None) == (
872
+ self.num_classes is not None
873
+ ), "must specify y if and only if the model is class-conditional"
874
+ hs = []
875
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
876
+ emb = self.time_embed(t_emb)
877
+
878
+ if self.num_classes is not None:
879
+ assert y.shape == (x.shape[0],)
880
+ emb = emb + self.label_emb(y)
881
+
882
+ h = x.type(self.dtype)
883
+ for module in self.input_blocks:
884
+ h = module(h, emb, context)
885
+ hs.append(h)
886
+ h = self.middle_block(h, emb, context)
887
+ for module in self.output_blocks:
888
+ # print(h.shape, hs[-1].shape)
889
+ if h.shape != hs[-1].shape:
890
+ if h.shape[-1] > hs[-1].shape[-1]:
891
+ h = h[:, :, :, : hs[-1].shape[-1]]
892
+ if h.shape[-2] > hs[-1].shape[-2]:
893
+ h = h[:, :, : hs[-1].shape[-2], :]
894
+ h = torch.cat([h, hs.pop()], dim=1)
895
+ h = module(h, emb, context)
896
+ # print(h.shape)
897
+ h = h.type(x.dtype)
898
+ if self.predict_codebook_ids:
899
+ return self.id_predictor(h)
900
+ else:
901
+ return self.out(h)
902
+
903
+
904
+ class AudioLDM(nn.Module):
905
+ def __init__(self, cfg):
906
+ super().__init__()
907
+ self.cfg = cfg
908
+ self.unet = UNetModel(
909
+ image_size=cfg.image_size,
910
+ in_channels=cfg.in_channels,
911
+ out_channels=cfg.out_channels,
912
+ model_channels=cfg.model_channels,
913
+ attention_resolutions=cfg.attention_resolutions,
914
+ num_res_blocks=cfg.num_res_blocks,
915
+ channel_mult=cfg.channel_mult,
916
+ num_heads=cfg.num_heads,
917
+ use_spatial_transformer=cfg.use_spatial_transformer,
918
+ transformer_depth=cfg.transformer_depth,
919
+ context_dim=cfg.context_dim,
920
+ use_checkpoint=cfg.use_checkpoint,
921
+ legacy=cfg.legacy,
922
+ )
923
+
924
+ def forward(self, x, timesteps=None, context=None, y=None):
925
+ x = self.unet(x=x, timesteps=timesteps, context=context, y=y)
926
+ return x
models/tta/ldm/audioldm_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from utils.data_utils import *
10
+
11
+
12
+ from models.base.base_dataset import (
13
+ BaseCollator,
14
+ BaseDataset,
15
+ BaseTestDataset,
16
+ BaseTestCollator,
17
+ )
18
+ import librosa
19
+
20
+ from transformers import AutoTokenizer
21
+
22
+
23
+ class AudioLDMDataset(BaseDataset):
24
+ def __init__(self, cfg, dataset, is_valid=False):
25
+ BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
26
+
27
+ self.cfg = cfg
28
+
29
+ # utt2melspec
30
+ if cfg.preprocess.use_melspec:
31
+ self.utt2melspec_path = {}
32
+ for utt_info in self.metadata:
33
+ dataset = utt_info["Dataset"]
34
+ uid = utt_info["Uid"]
35
+ utt = "{}_{}".format(dataset, uid)
36
+
37
+ self.utt2melspec_path[utt] = os.path.join(
38
+ cfg.preprocess.processed_dir,
39
+ dataset,
40
+ cfg.preprocess.melspec_dir,
41
+ uid + ".npy",
42
+ )
43
+
44
+ # utt2wav
45
+ if cfg.preprocess.use_wav:
46
+ self.utt2wav_path = {}
47
+ for utt_info in self.metadata:
48
+ dataset = utt_info["Dataset"]
49
+ uid = utt_info["Uid"]
50
+ utt = "{}_{}".format(dataset, uid)
51
+
52
+ self.utt2wav_path[utt] = os.path.join(
53
+ cfg.preprocess.processed_dir,
54
+ dataset,
55
+ cfg.preprocess.wav_dir,
56
+ uid + ".wav",
57
+ )
58
+
59
+ # utt2caption
60
+ if cfg.preprocess.use_caption:
61
+ self.utt2caption = {}
62
+ for utt_info in self.metadata:
63
+ dataset = utt_info["Dataset"]
64
+ uid = utt_info["Uid"]
65
+ utt = "{}_{}".format(dataset, uid)
66
+
67
+ self.utt2caption[utt] = utt_info["Caption"]
68
+
69
+ def __getitem__(self, index):
70
+ # melspec: (n_mels, T)
71
+ # wav: (T,)
72
+
73
+ single_feature = BaseDataset.__getitem__(self, index)
74
+
75
+ utt_info = self.metadata[index]
76
+ dataset = utt_info["Dataset"]
77
+ uid = utt_info["Uid"]
78
+ utt = "{}_{}".format(dataset, uid)
79
+
80
+ if self.cfg.preprocess.use_melspec:
81
+ single_feature["melspec"] = np.load(self.utt2melspec_path[utt])
82
+
83
+ if self.cfg.preprocess.use_wav:
84
+ wav, sr = librosa.load(
85
+ self.utt2wav_path[utt], sr=16000
86
+ ) # hard coding for 16KHz...
87
+ single_feature["wav"] = wav
88
+
89
+ if self.cfg.preprocess.use_caption:
90
+ cond_mask = np.random.choice(
91
+ [1, 0],
92
+ p=[
93
+ self.cfg.preprocess.cond_mask_prob,
94
+ 1 - self.cfg.preprocess.cond_mask_prob,
95
+ ],
96
+ ) # (0.1, 0.9)
97
+ if cond_mask:
98
+ single_feature["caption"] = ""
99
+ else:
100
+ single_feature["caption"] = self.utt2caption[utt]
101
+
102
+ return single_feature
103
+
104
+ def __len__(self):
105
+ return len(self.metadata)
106
+
107
+
108
+ class AudioLDMCollator(BaseCollator):
109
+ def __init__(self, cfg):
110
+ BaseCollator.__init__(self, cfg)
111
+
112
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
113
+
114
+ def __call__(self, batch):
115
+ # mel: (B, n_mels, T)
116
+ # wav (option): (B, T)
117
+ # text_input_ids: (B, L)
118
+ # text_attention_mask: (B, L)
119
+
120
+ packed_batch_features = dict()
121
+
122
+ for key in batch[0].keys():
123
+ if key == "melspec":
124
+ packed_batch_features["melspec"] = torch.from_numpy(
125
+ np.array([b["melspec"][:, :624] for b in batch])
126
+ )
127
+
128
+ if key == "wav":
129
+ values = [torch.from_numpy(b[key]) for b in batch]
130
+ packed_batch_features[key] = pad_sequence(
131
+ values, batch_first=True, padding_value=0
132
+ )
133
+
134
+ if key == "caption":
135
+ captions = [b[key] for b in batch]
136
+ text_input = self.tokenizer(
137
+ captions, return_tensors="pt", truncation=True, padding="longest"
138
+ )
139
+ text_input_ids = text_input["input_ids"]
140
+ text_attention_mask = text_input["attention_mask"]
141
+
142
+ packed_batch_features["text_input_ids"] = text_input_ids
143
+ packed_batch_features["text_attention_mask"] = text_attention_mask
144
+
145
+ return packed_batch_features
146
+
147
+
148
+ class AudioLDMTestDataset(BaseTestDataset):
149
+ ...
150
+
151
+
152
+ class AudioLDMTestCollator(BaseTestCollator):
153
+ ...
models/tta/ldm/audioldm_inference.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import torch.nn as nn
12
+ from collections import OrderedDict
13
+ import json
14
+
15
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
16
+ from models.tta.ldm.inference_utils.vocoder import Generator
17
+ from models.tta.ldm.audioldm import AudioLDM
18
+ from transformers import T5EncoderModel, AutoTokenizer
19
+ from diffusers import PNDMScheduler
20
+
21
+ import matplotlib.pyplot as plt
22
+ from scipy.io.wavfile import write
23
+
24
+
25
+ class AttrDict(dict):
26
+ def __init__(self, *args, **kwargs):
27
+ super(AttrDict, self).__init__(*args, **kwargs)
28
+ self.__dict__ = self
29
+
30
+
31
+ class AudioLDMInference:
32
+ def __init__(self, args, cfg):
33
+ self.cfg = cfg
34
+ self.args = args
35
+
36
+ self.build_autoencoderkl()
37
+ self.build_textencoder()
38
+
39
+ self.model = self.build_model()
40
+ self.load_state_dict()
41
+
42
+ self.build_vocoder()
43
+
44
+ self.out_path = self.args.output_dir
45
+ self.out_mel_path = os.path.join(self.out_path, "mel")
46
+ self.out_wav_path = os.path.join(self.out_path, "wav")
47
+ os.makedirs(self.out_mel_path, exist_ok=True)
48
+ os.makedirs(self.out_wav_path, exist_ok=True)
49
+
50
+ def build_autoencoderkl(self):
51
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
52
+ self.autoencoder_path = self.cfg.model.autoencoder_path
53
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
54
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
55
+ self.autoencoderkl.cuda(self.args.local_rank)
56
+ self.autoencoderkl.requires_grad_(requires_grad=False)
57
+ self.autoencoderkl.eval()
58
+
59
+ def build_textencoder(self):
60
+ self.tokenizer = AutoTokenizer.from_pretrained("t5-base", model_max_length=512)
61
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
62
+ self.text_encoder.cuda(self.args.local_rank)
63
+ self.text_encoder.requires_grad_(requires_grad=False)
64
+ self.text_encoder.eval()
65
+
66
+ def build_vocoder(self):
67
+ config_file = os.path.join(self.args.vocoder_config_path)
68
+ with open(config_file) as f:
69
+ data = f.read()
70
+ json_config = json.loads(data)
71
+ h = AttrDict(json_config)
72
+ self.vocoder = Generator(h).to(self.args.local_rank)
73
+ checkpoint_dict = torch.load(
74
+ self.args.vocoder_path, map_location=self.args.local_rank
75
+ )
76
+ self.vocoder.load_state_dict(checkpoint_dict["generator"])
77
+
78
+ def build_model(self):
79
+ self.model = AudioLDM(self.cfg.model.audioldm)
80
+ return self.model
81
+
82
+ def load_state_dict(self):
83
+ self.checkpoint_path = self.args.checkpoint_path
84
+ checkpoint = torch.load(self.checkpoint_path, map_location="cpu")
85
+ self.model.load_state_dict(checkpoint["model"])
86
+ self.model.cuda(self.args.local_rank)
87
+
88
+ def get_text_embedding(self):
89
+ text = self.args.text
90
+
91
+ prompt = [text]
92
+
93
+ text_input = self.tokenizer(
94
+ prompt,
95
+ max_length=self.tokenizer.model_max_length,
96
+ truncation=True,
97
+ padding="do_not_pad",
98
+ return_tensors="pt",
99
+ )
100
+ text_embeddings = self.text_encoder(
101
+ text_input.input_ids.to(self.args.local_rank)
102
+ )[0]
103
+
104
+ max_length = text_input.input_ids.shape[-1]
105
+ uncond_input = self.tokenizer(
106
+ [""] * 1, padding="max_length", max_length=max_length, return_tensors="pt"
107
+ )
108
+ uncond_embeddings = self.text_encoder(
109
+ uncond_input.input_ids.to(self.args.local_rank)
110
+ )[0]
111
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
112
+
113
+ return text_embeddings
114
+
115
+ def inference(self):
116
+ text_embeddings = self.get_text_embedding()
117
+ print(text_embeddings.shape)
118
+
119
+ num_steps = self.args.num_steps
120
+ guidance_scale = self.args.guidance_scale
121
+
122
+ noise_scheduler = PNDMScheduler(
123
+ num_train_timesteps=1000,
124
+ beta_start=0.00085,
125
+ beta_end=0.012,
126
+ beta_schedule="scaled_linear",
127
+ skip_prk_steps=True,
128
+ set_alpha_to_one=False,
129
+ steps_offset=1,
130
+ prediction_type="epsilon",
131
+ )
132
+
133
+ noise_scheduler.set_timesteps(num_steps)
134
+
135
+ latents = torch.randn(
136
+ (
137
+ 1,
138
+ self.cfg.model.autoencoderkl.z_channels,
139
+ 80 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
140
+ 624 // (2 ** (len(self.cfg.model.autoencoderkl.ch_mult) - 1)),
141
+ )
142
+ ).to(self.args.local_rank)
143
+
144
+ self.model.eval()
145
+ for t in tqdm(noise_scheduler.timesteps):
146
+ t = t.to(self.args.local_rank)
147
+
148
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
149
+ latent_model_input = torch.cat([latents] * 2)
150
+
151
+ latent_model_input = noise_scheduler.scale_model_input(
152
+ latent_model_input, timestep=t
153
+ )
154
+ # print(latent_model_input.shape)
155
+
156
+ # predict the noise residual
157
+ with torch.no_grad():
158
+ noise_pred = self.model(
159
+ latent_model_input, torch.cat([t.unsqueeze(0)] * 2), text_embeddings
160
+ )
161
+
162
+ # perform guidance
163
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
164
+ noise_pred = noise_pred_uncond + guidance_scale * (
165
+ noise_pred_text - noise_pred_uncond
166
+ )
167
+
168
+ # compute the previous noisy sample x_t -> x_t-1
169
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
170
+ # print(latents.shape)
171
+
172
+ latents_out = latents
173
+ print(latents_out.shape)
174
+
175
+ with torch.no_grad():
176
+ mel_out = self.autoencoderkl.decode(latents_out)
177
+ print(mel_out.shape)
178
+
179
+ melspec = mel_out[0, 0].cpu().detach().numpy()
180
+ plt.imsave(os.path.join(self.out_mel_path, self.args.text + ".png"), melspec)
181
+
182
+ self.vocoder.eval()
183
+ self.vocoder.remove_weight_norm()
184
+ with torch.no_grad():
185
+ melspec = np.expand_dims(melspec, 0)
186
+ melspec = torch.FloatTensor(melspec).to(self.args.local_rank)
187
+
188
+ y = self.vocoder(melspec)
189
+ audio = y.squeeze()
190
+ audio = audio * 32768.0
191
+ audio = audio.cpu().numpy().astype("int16")
192
+
193
+ write(os.path.join(self.out_wav_path, self.args.text + ".wav"), 16000, audio)
models/tta/ldm/audioldm_trainer.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.base.base_trainer import BaseTrainer
7
+ from diffusers import DDPMScheduler
8
+ from models.tta.ldm.audioldm_dataset import AudioLDMDataset, AudioLDMCollator
9
+ from models.tta.autoencoder.autoencoder import AutoencoderKL
10
+ from models.tta.ldm.audioldm import AudioLDM, UNetModel
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import MSELoss, L1Loss
14
+ import torch.nn.functional as F
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+
17
+ from transformers import T5EncoderModel
18
+ from diffusers import DDPMScheduler
19
+
20
+
21
+ class AudioLDMTrainer(BaseTrainer):
22
+ def __init__(self, args, cfg):
23
+ BaseTrainer.__init__(self, args, cfg)
24
+ self.cfg = cfg
25
+
26
+ self.build_autoencoderkl()
27
+ self.build_textencoder()
28
+ self.nosie_scheduler = self.build_noise_scheduler()
29
+
30
+ self.save_config_file()
31
+
32
+ def build_autoencoderkl(self):
33
+ self.autoencoderkl = AutoencoderKL(self.cfg.model.autoencoderkl)
34
+ self.autoencoder_path = self.cfg.model.autoencoder_path
35
+ checkpoint = torch.load(self.autoencoder_path, map_location="cpu")
36
+ self.autoencoderkl.load_state_dict(checkpoint["model"])
37
+ self.autoencoderkl.cuda(self.args.local_rank)
38
+ self.autoencoderkl.requires_grad_(requires_grad=False)
39
+ self.autoencoderkl.eval()
40
+
41
+ def build_textencoder(self):
42
+ self.text_encoder = T5EncoderModel.from_pretrained("t5-base")
43
+ self.text_encoder.cuda(self.args.local_rank)
44
+ self.text_encoder.requires_grad_(requires_grad=False)
45
+ self.text_encoder.eval()
46
+
47
+ def build_noise_scheduler(self):
48
+ nosie_scheduler = DDPMScheduler(
49
+ num_train_timesteps=self.cfg.model.noise_scheduler.num_train_timesteps,
50
+ beta_start=self.cfg.model.noise_scheduler.beta_start,
51
+ beta_end=self.cfg.model.noise_scheduler.beta_end,
52
+ beta_schedule=self.cfg.model.noise_scheduler.beta_schedule,
53
+ clip_sample=self.cfg.model.noise_scheduler.clip_sample,
54
+ # steps_offset=self.cfg.model.noise_scheduler.steps_offset,
55
+ # set_alpha_to_one=self.cfg.model.noise_scheduler.set_alpha_to_one,
56
+ # skip_prk_steps=self.cfg.model.noise_scheduler.skip_prk_steps,
57
+ prediction_type=self.cfg.model.noise_scheduler.prediction_type,
58
+ )
59
+ return nosie_scheduler
60
+
61
+ def build_dataset(self):
62
+ return AudioLDMDataset, AudioLDMCollator
63
+
64
+ def build_data_loader(self):
65
+ Dataset, Collator = self.build_dataset()
66
+ # build dataset instance for each dataset and combine them by ConcatDataset
67
+ datasets_list = []
68
+ for dataset in self.cfg.dataset:
69
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
70
+ datasets_list.append(subdataset)
71
+ train_dataset = ConcatDataset(datasets_list)
72
+
73
+ train_collate = Collator(self.cfg)
74
+
75
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
76
+ train_loader = DataLoader(
77
+ train_dataset,
78
+ collate_fn=train_collate,
79
+ num_workers=self.args.num_workers,
80
+ batch_size=self.cfg.train.batch_size,
81
+ pin_memory=False,
82
+ )
83
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
84
+ datasets_list = []
85
+ for dataset in self.cfg.dataset:
86
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
87
+ datasets_list.append(subdataset)
88
+ valid_dataset = ConcatDataset(datasets_list)
89
+ valid_collate = Collator(self.cfg)
90
+
91
+ valid_loader = DataLoader(
92
+ valid_dataset,
93
+ collate_fn=valid_collate,
94
+ num_workers=1,
95
+ batch_size=self.cfg.train.batch_size,
96
+ )
97
+ else:
98
+ raise NotImplementedError("DDP is not supported yet.")
99
+ # valid_loader = None
100
+ data_loader = {"train": train_loader, "valid": valid_loader}
101
+ return data_loader
102
+
103
+ def build_optimizer(self):
104
+ optimizer = torch.optim.AdamW(self.model.parameters(), **self.cfg.train.adam)
105
+ return optimizer
106
+
107
+ # TODO: check it...
108
+ def build_scheduler(self):
109
+ return None
110
+ # return ReduceLROnPlateau(self.optimizer["opt_ae"], **self.cfg.train.lronPlateau)
111
+
112
+ def write_summary(self, losses, stats):
113
+ for key, value in losses.items():
114
+ self.sw.add_scalar(key, value, self.step)
115
+
116
+ def write_valid_summary(self, losses, stats):
117
+ for key, value in losses.items():
118
+ self.sw.add_scalar(key, value, self.step)
119
+
120
+ def build_criterion(self):
121
+ criterion = nn.MSELoss(reduction="mean")
122
+ return criterion
123
+
124
+ def get_state_dict(self):
125
+ if self.scheduler != None:
126
+ state_dict = {
127
+ "model": self.model.state_dict(),
128
+ "optimizer": self.optimizer.state_dict(),
129
+ "scheduler": self.scheduler.state_dict(),
130
+ "step": self.step,
131
+ "epoch": self.epoch,
132
+ "batch_size": self.cfg.train.batch_size,
133
+ }
134
+ else:
135
+ state_dict = {
136
+ "model": self.model.state_dict(),
137
+ "optimizer": self.optimizer.state_dict(),
138
+ "step": self.step,
139
+ "epoch": self.epoch,
140
+ "batch_size": self.cfg.train.batch_size,
141
+ }
142
+ return state_dict
143
+
144
+ def load_model(self, checkpoint):
145
+ self.step = checkpoint["step"]
146
+ self.epoch = checkpoint["epoch"]
147
+
148
+ self.model.load_state_dict(checkpoint["model"])
149
+ self.optimizer.load_state_dict(checkpoint["optimizer"])
150
+ if self.scheduler != None:
151
+ self.scheduler.load_state_dict(checkpoint["scheduler"])
152
+
153
+ def build_model(self):
154
+ self.model = AudioLDM(self.cfg.model.audioldm)
155
+ return self.model
156
+
157
+ @torch.no_grad()
158
+ def mel_to_latent(self, melspec):
159
+ posterior = self.autoencoderkl.encode(melspec)
160
+ latent = posterior.sample() # (B, 4, 5, 78)
161
+ return latent
162
+
163
+ @torch.no_grad()
164
+ def get_text_embedding(self, text_input_ids, text_attention_mask):
165
+ text_embedding = self.text_encoder(
166
+ input_ids=text_input_ids, attention_mask=text_attention_mask
167
+ ).last_hidden_state
168
+ return text_embedding # (B, T, 768)
169
+
170
+ def train_step(self, data):
171
+ train_losses = {}
172
+ total_loss = 0
173
+ train_stats = {}
174
+
175
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
176
+ latents = self.mel_to_latent(melspec)
177
+
178
+ text_embedding = self.get_text_embedding(
179
+ data["text_input_ids"], data["text_attention_mask"]
180
+ )
181
+
182
+ noise = torch.randn_like(latents).float()
183
+
184
+ bsz = latents.shape[0]
185
+ timesteps = torch.randint(
186
+ 0,
187
+ self.cfg.model.noise_scheduler.num_train_timesteps,
188
+ (bsz,),
189
+ device=latents.device,
190
+ )
191
+ timesteps = timesteps.long()
192
+
193
+ with torch.no_grad():
194
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
195
+
196
+ model_pred = self.model(
197
+ noisy_latents, timesteps=timesteps, context=text_embedding
198
+ )
199
+
200
+ loss = self.criterion(model_pred, noise)
201
+
202
+ train_losses["loss"] = loss
203
+ total_loss += loss
204
+
205
+ self.optimizer.zero_grad()
206
+ total_loss.backward()
207
+ self.optimizer.step()
208
+
209
+ for item in train_losses:
210
+ train_losses[item] = train_losses[item].item()
211
+
212
+ return train_losses, train_stats, total_loss.item()
213
+
214
+ # TODO: eval step
215
+ @torch.no_grad()
216
+ def eval_step(self, data, index):
217
+ valid_loss = {}
218
+ total_valid_loss = 0
219
+ valid_stats = {}
220
+
221
+ melspec = data["melspec"].unsqueeze(1) # (B, 80, T) -> (B, 1, 80, T)
222
+ latents = self.mel_to_latent(melspec)
223
+
224
+ text_embedding = self.get_text_embedding(
225
+ data["text_input_ids"], data["text_attention_mask"]
226
+ )
227
+
228
+ noise = torch.randn_like(latents).float()
229
+
230
+ bsz = latents.shape[0]
231
+ timesteps = torch.randint(
232
+ 0,
233
+ self.cfg.model.noise_scheduler.num_train_timesteps,
234
+ (bsz,),
235
+ device=latents.device,
236
+ )
237
+ timesteps = timesteps.long()
238
+
239
+ noisy_latents = self.nosie_scheduler.add_noise(latents, noise, timesteps)
240
+
241
+ model_pred = self.model(noisy_latents, timesteps, text_embedding)
242
+
243
+ loss = self.criterion(model_pred, noise)
244
+ valid_loss["loss"] = loss
245
+
246
+ total_valid_loss += loss
247
+
248
+ for item in valid_loss:
249
+ valid_loss[item] = valid_loss[item].item()
250
+
251
+ return valid_loss, valid_stats, total_valid_loss.item()
models/tta/ldm/inference_utils/utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import glob
7
+ import os
8
+ import matplotlib
9
+ import torch
10
+ from torch.nn.utils import weight_norm
11
+
12
+ matplotlib.use("Agg")
13
+ import matplotlib.pylab as plt
14
+
15
+
16
+ def plot_spectrogram(spectrogram):
17
+ fig, ax = plt.subplots(figsize=(10, 2))
18
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
19
+ plt.colorbar(im, ax=ax)
20
+
21
+ fig.canvas.draw()
22
+ plt.close()
23
+
24
+ return fig
25
+
26
+
27
+ def init_weights(m, mean=0.0, std=0.01):
28
+ classname = m.__class__.__name__
29
+ if classname.find("Conv") != -1:
30
+ m.weight.data.normal_(mean, std)
31
+
32
+
33
+ def apply_weight_norm(m):
34
+ classname = m.__class__.__name__
35
+ if classname.find("Conv") != -1:
36
+ weight_norm(m)
37
+
38
+
39
+ def get_padding(kernel_size, dilation=1):
40
+ return int((kernel_size * dilation - dilation) / 2)
41
+
42
+
43
+ def load_checkpoint(filepath, device):
44
+ assert os.path.isfile(filepath)
45
+ print("Loading '{}'".format(filepath))
46
+ checkpoint_dict = torch.load(filepath, map_location=device)
47
+ print("Complete.")
48
+ return checkpoint_dict
49
+
50
+
51
+ def save_checkpoint(filepath, obj):
52
+ print("Saving checkpoint to {}".format(filepath))
53
+ torch.save(obj, filepath)
54
+ print("Complete.")
55
+
56
+
57
+ def scan_checkpoint(cp_dir, prefix):
58
+ pattern = os.path.join(cp_dir, prefix + "????????")
59
+ cp_list = glob.glob(pattern)
60
+ if len(cp_list) == 0:
61
+ return None
62
+ return sorted(cp_list)[-1]
models/tta/ldm/inference_utils/vocoder.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.nn as nn
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
11
+ from models.tta.ldm.inference_utils.utils import get_padding, init_weights
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class ResBlock1(torch.nn.Module):
17
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
18
+ super(ResBlock1, self).__init__()
19
+ self.h = h
20
+ self.convs1 = nn.ModuleList(
21
+ [
22
+ weight_norm(
23
+ Conv1d(
24
+ channels,
25
+ channels,
26
+ kernel_size,
27
+ 1,
28
+ dilation=dilation[0],
29
+ padding=get_padding(kernel_size, dilation[0]),
30
+ )
31
+ ),
32
+ weight_norm(
33
+ Conv1d(
34
+ channels,
35
+ channels,
36
+ kernel_size,
37
+ 1,
38
+ dilation=dilation[1],
39
+ padding=get_padding(kernel_size, dilation[1]),
40
+ )
41
+ ),
42
+ weight_norm(
43
+ Conv1d(
44
+ channels,
45
+ channels,
46
+ kernel_size,
47
+ 1,
48
+ dilation=dilation[2],
49
+ padding=get_padding(kernel_size, dilation[2]),
50
+ )
51
+ ),
52
+ ]
53
+ )
54
+ self.convs1.apply(init_weights)
55
+
56
+ self.convs2 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ 1,
64
+ dilation=1,
65
+ padding=get_padding(kernel_size, 1),
66
+ )
67
+ ),
68
+ weight_norm(
69
+ Conv1d(
70
+ channels,
71
+ channels,
72
+ kernel_size,
73
+ 1,
74
+ dilation=1,
75
+ padding=get_padding(kernel_size, 1),
76
+ )
77
+ ),
78
+ weight_norm(
79
+ Conv1d(
80
+ channels,
81
+ channels,
82
+ kernel_size,
83
+ 1,
84
+ dilation=1,
85
+ padding=get_padding(kernel_size, 1),
86
+ )
87
+ ),
88
+ ]
89
+ )
90
+ self.convs2.apply(init_weights)
91
+
92
+ def forward(self, x):
93
+ for c1, c2 in zip(self.convs1, self.convs2):
94
+ xt = F.leaky_relu(x, LRELU_SLOPE)
95
+ xt = c1(xt)
96
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
97
+ xt = c2(xt)
98
+ x = xt + x
99
+ return x
100
+
101
+ def remove_weight_norm(self):
102
+ for l in self.convs1:
103
+ remove_weight_norm(l)
104
+ for l in self.convs2:
105
+ remove_weight_norm(l)
106
+
107
+
108
+ class ResBlock2(torch.nn.Module):
109
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
110
+ super(ResBlock2, self).__init__()
111
+ self.h = h
112
+ self.convs = nn.ModuleList(
113
+ [
114
+ weight_norm(
115
+ Conv1d(
116
+ channels,
117
+ channels,
118
+ kernel_size,
119
+ 1,
120
+ dilation=dilation[0],
121
+ padding=get_padding(kernel_size, dilation[0]),
122
+ )
123
+ ),
124
+ weight_norm(
125
+ Conv1d(
126
+ channels,
127
+ channels,
128
+ kernel_size,
129
+ 1,
130
+ dilation=dilation[1],
131
+ padding=get_padding(kernel_size, dilation[1]),
132
+ )
133
+ ),
134
+ ]
135
+ )
136
+ self.convs.apply(init_weights)
137
+
138
+ def forward(self, x):
139
+ for c in self.convs:
140
+ xt = F.leaky_relu(x, LRELU_SLOPE)
141
+ xt = c(xt)
142
+ x = xt + x
143
+ return x
144
+
145
+ def remove_weight_norm(self):
146
+ for l in self.convs:
147
+ remove_weight_norm(l)
148
+
149
+
150
+ class Generator(torch.nn.Module):
151
+ def __init__(self, h):
152
+ super(Generator, self).__init__()
153
+ self.h = h
154
+ self.num_kernels = len(h.resblock_kernel_sizes)
155
+ self.num_upsamples = len(h.upsample_rates)
156
+ self.conv_pre = weight_norm(
157
+ Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
158
+ )
159
+ resblock = ResBlock1 if h.resblock == "1" else ResBlock2
160
+
161
+ self.ups = nn.ModuleList()
162
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
163
+ self.ups.append(
164
+ weight_norm(
165
+ ConvTranspose1d(
166
+ h.upsample_initial_channel // (2**i),
167
+ h.upsample_initial_channel // (2 ** (i + 1)),
168
+ k,
169
+ u,
170
+ padding=(k - u) // 2,
171
+ )
172
+ )
173
+ )
174
+
175
+ self.resblocks = nn.ModuleList()
176
+ for i in range(len(self.ups)):
177
+ ch = h.upsample_initial_channel // (2 ** (i + 1))
178
+ for j, (k, d) in enumerate(
179
+ zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
180
+ ):
181
+ self.resblocks.append(resblock(h, ch, k, d))
182
+
183
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
184
+ self.ups.apply(init_weights)
185
+ self.conv_post.apply(init_weights)
186
+
187
+ def forward(self, x):
188
+ x = self.conv_pre(x)
189
+ for i in range(self.num_upsamples):
190
+ x = F.leaky_relu(x, LRELU_SLOPE)
191
+ x = self.ups[i](x)
192
+ xs = None
193
+ for j in range(self.num_kernels):
194
+ if xs is None:
195
+ xs = self.resblocks[i * self.num_kernels + j](x)
196
+ else:
197
+ xs += self.resblocks[i * self.num_kernels + j](x)
198
+ x = xs / self.num_kernels
199
+ x = F.leaky_relu(x)
200
+ x = self.conv_post(x)
201
+ x = torch.tanh(x)
202
+
203
+ return x
204
+
205
+ def remove_weight_norm(self):
206
+ print("Removing weight norm...")
207
+ for l in self.ups:
208
+ remove_weight_norm(l)
209
+ for l in self.resblocks:
210
+ l.remove_weight_norm()
211
+ remove_weight_norm(self.conv_pre)
212
+ remove_weight_norm(self.conv_post)
213
+
214
+
215
+ class DiscriminatorP(torch.nn.Module):
216
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
217
+ super(DiscriminatorP, self).__init__()
218
+ self.period = period
219
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
220
+ self.convs = nn.ModuleList(
221
+ [
222
+ norm_f(
223
+ Conv2d(
224
+ 1,
225
+ 32,
226
+ (kernel_size, 1),
227
+ (stride, 1),
228
+ padding=(get_padding(5, 1), 0),
229
+ )
230
+ ),
231
+ norm_f(
232
+ Conv2d(
233
+ 32,
234
+ 128,
235
+ (kernel_size, 1),
236
+ (stride, 1),
237
+ padding=(get_padding(5, 1), 0),
238
+ )
239
+ ),
240
+ norm_f(
241
+ Conv2d(
242
+ 128,
243
+ 512,
244
+ (kernel_size, 1),
245
+ (stride, 1),
246
+ padding=(get_padding(5, 1), 0),
247
+ )
248
+ ),
249
+ norm_f(
250
+ Conv2d(
251
+ 512,
252
+ 1024,
253
+ (kernel_size, 1),
254
+ (stride, 1),
255
+ padding=(get_padding(5, 1), 0),
256
+ )
257
+ ),
258
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
259
+ ]
260
+ )
261
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
262
+
263
+ def forward(self, x):
264
+ fmap = []
265
+
266
+ # 1d to 2d
267
+ b, c, t = x.shape
268
+ if t % self.period != 0: # pad first
269
+ n_pad = self.period - (t % self.period)
270
+ x = F.pad(x, (0, n_pad), "reflect")
271
+ t = t + n_pad
272
+ x = x.view(b, c, t // self.period, self.period)
273
+
274
+ for l in self.convs:
275
+ x = l(x)
276
+ x = F.leaky_relu(x, LRELU_SLOPE)
277
+ fmap.append(x)
278
+ x = self.conv_post(x)
279
+ fmap.append(x)
280
+ x = torch.flatten(x, 1, -1)
281
+
282
+ return x, fmap
283
+
284
+
285
+ class MultiPeriodDiscriminator(torch.nn.Module):
286
+ def __init__(self):
287
+ super(MultiPeriodDiscriminator, self).__init__()
288
+ self.discriminators = nn.ModuleList(
289
+ [
290
+ DiscriminatorP(2),
291
+ DiscriminatorP(3),
292
+ DiscriminatorP(5),
293
+ DiscriminatorP(7),
294
+ DiscriminatorP(11),
295
+ ]
296
+ )
297
+
298
+ def forward(self, y, y_hat):
299
+ y_d_rs = []
300
+ y_d_gs = []
301
+ fmap_rs = []
302
+ fmap_gs = []
303
+ for i, d in enumerate(self.discriminators):
304
+ y_d_r, fmap_r = d(y)
305
+ y_d_g, fmap_g = d(y_hat)
306
+ y_d_rs.append(y_d_r)
307
+ fmap_rs.append(fmap_r)
308
+ y_d_gs.append(y_d_g)
309
+ fmap_gs.append(fmap_g)
310
+
311
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
312
+
313
+
314
+ class DiscriminatorS(torch.nn.Module):
315
+ def __init__(self, use_spectral_norm=False):
316
+ super(DiscriminatorS, self).__init__()
317
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
318
+ self.convs = nn.ModuleList(
319
+ [
320
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
321
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
322
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
323
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
324
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
325
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
326
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
327
+ ]
328
+ )
329
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
330
+
331
+ def forward(self, x):
332
+ fmap = []
333
+ for l in self.convs:
334
+ x = l(x)
335
+ x = F.leaky_relu(x, LRELU_SLOPE)
336
+ fmap.append(x)
337
+ x = self.conv_post(x)
338
+ fmap.append(x)
339
+ x = torch.flatten(x, 1, -1)
340
+
341
+ return x, fmap
342
+
343
+
344
+ class MultiScaleDiscriminator(torch.nn.Module):
345
+ def __init__(self):
346
+ super(MultiScaleDiscriminator, self).__init__()
347
+ self.discriminators = nn.ModuleList(
348
+ [
349
+ DiscriminatorS(use_spectral_norm=True),
350
+ DiscriminatorS(),
351
+ DiscriminatorS(),
352
+ ]
353
+ )
354
+ self.meanpools = nn.ModuleList(
355
+ [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
356
+ )
357
+
358
+ def forward(self, y, y_hat):
359
+ y_d_rs = []
360
+ y_d_gs = []
361
+ fmap_rs = []
362
+ fmap_gs = []
363
+ for i, d in enumerate(self.discriminators):
364
+ if i != 0:
365
+ y = self.meanpools[i - 1](y)
366
+ y_hat = self.meanpools[i - 1](y_hat)
367
+ y_d_r, fmap_r = d(y)
368
+ y_d_g, fmap_g = d(y_hat)
369
+ y_d_rs.append(y_d_r)
370
+ fmap_rs.append(fmap_r)
371
+ y_d_gs.append(y_d_g)
372
+ fmap_gs.append(fmap_g)
373
+
374
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
375
+
376
+
377
+ def feature_loss(fmap_r, fmap_g):
378
+ loss = 0
379
+ for dr, dg in zip(fmap_r, fmap_g):
380
+ for rl, gl in zip(dr, dg):
381
+ loss += torch.mean(torch.abs(rl - gl))
382
+
383
+ return loss * 2
384
+
385
+
386
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
387
+ loss = 0
388
+ r_losses = []
389
+ g_losses = []
390
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
391
+ r_loss = torch.mean((1 - dr) ** 2)
392
+ g_loss = torch.mean(dg**2)
393
+ loss += r_loss + g_loss
394
+ r_losses.append(r_loss.item())
395
+ g_losses.append(g_loss.item())
396
+
397
+ return loss, r_losses, g_losses
398
+
399
+
400
+ def generator_loss(disc_outputs):
401
+ loss = 0
402
+ gen_losses = []
403
+ for dg in disc_outputs:
404
+ l = torch.mean((1 - dg) ** 2)
405
+ gen_losses.append(l)
406
+ loss += l
407
+
408
+ return loss, gen_losses
models/tts/naturalspeech2/ns2_dataset.py CHANGED
@@ -21,13 +21,11 @@ class NS2Dataset(torch.utils.data.Dataset):
21
  assert isinstance(dataset, str)
22
 
23
  processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
24
- # for example: /home/v-detaixin/LibriTTS/processed_data; train-full
25
 
26
  meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
27
  # train.json
28
 
29
  self.metafile_path = os.path.join(processed_data_dir, meta_file)
30
- # /home/v-detaixin/LibriTTS/processed_data/train-full/train.json
31
 
32
  self.metadata = self.get_metadata()
33
 
 
21
  assert isinstance(dataset, str)
22
 
23
  processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
 
24
 
25
  meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
26
  # train.json
27
 
28
  self.metafile_path = os.path.join(processed_data_dir, meta_file)
 
29
 
30
  self.metadata = self.get_metadata()
31
 
models/vocoders/autoregressive/autoregressive_vocoder_dataset.py ADDED
File without changes
models/vocoders/autoregressive/autoregressive_vocoder_inference.py ADDED
File without changes