RMSnow commited on
Commit
df2accb
1 Parent(s): 9f923d1

init and interface

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -1
  2. .gitignore +18 -0
  3. app.py +78 -0
  4. ckpts/svc/vocalist_l1_contentvec+whisper/args.json +256 -0
  5. ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin +3 -0
  6. ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin +3 -0
  7. ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl +3 -0
  8. ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json +17 -0
  9. ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy +3 -0
  10. ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy +3 -0
  11. ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json +31 -0
  12. ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json +242 -0
  13. ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 +3 -0
  14. ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 +3 -0
  15. ckpts/svc/vocalist_l1_contentvec+whisper/singers.json +17 -0
  16. egs/svc/MultipleContentsSVC/README.md +153 -0
  17. egs/svc/MultipleContentsSVC/exp_config.json +126 -0
  18. egs/svc/MultipleContentsSVC/run.sh +1 -0
  19. egs/svc/README.md +34 -0
  20. egs/svc/_template/run.sh +150 -0
  21. inference.py +258 -0
  22. models/__init__.py +0 -0
  23. models/base/__init__.py +7 -0
  24. models/base/base_dataset.py +350 -0
  25. models/base/base_inference.py +220 -0
  26. models/base/base_sampler.py +136 -0
  27. models/base/base_trainer.py +348 -0
  28. models/base/new_dataset.py +50 -0
  29. models/base/new_inference.py +249 -0
  30. models/base/new_trainer.py +722 -0
  31. models/svc/__init__.py +0 -0
  32. models/svc/base/__init__.py +7 -0
  33. models/svc/base/svc_dataset.py +425 -0
  34. models/svc/base/svc_inference.py +15 -0
  35. models/svc/base/svc_trainer.py +111 -0
  36. models/svc/comosvc/__init__.py +4 -0
  37. models/svc/comosvc/comosvc.py +377 -0
  38. models/svc/comosvc/comosvc_inference.py +39 -0
  39. models/svc/comosvc/comosvc_trainer.py +295 -0
  40. models/svc/comosvc/utils.py +31 -0
  41. models/svc/diffusion/__init__.py +0 -0
  42. models/svc/diffusion/diffusion_inference.py +63 -0
  43. models/svc/diffusion/diffusion_inference_pipeline.py +47 -0
  44. models/svc/diffusion/diffusion_trainer.py +88 -0
  45. models/svc/diffusion/diffusion_wrapper.py +73 -0
  46. models/svc/transformer/__init__.py +0 -0
  47. models/svc/transformer/conformer.py +405 -0
  48. models/svc/transformer/transformer.py +82 -0
  49. models/svc/transformer/transformer_inference.py +45 -0
  50. models/svc/transformer/transformer_trainer.py +52 -0
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ flagged
3
+ result
4
+
5
+ # Developing mode
6
+ _*.sh
7
+ _*.json
8
+ *.lst
9
+ yard*
10
+ *.out
11
+ evaluation/evalset_selection
12
+ mfa
13
+ egs/svc/*wavmark
14
+ egs/svc/custom
15
+ egs/svc/*/dev*
16
+ egs/svc/dev_exp_config.json
17
+ bins/svc/demo*
18
+ bins/svc/preprocess_custom.py
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+
4
+ SUPPORTED_TARGET_SINGERS = {
5
+ "Adele": "vocalist_l1_Adele",
6
+ "Beyonce": "vocalist_l1_Beyonce",
7
+ "Bruno Mars": "vocalist_l1_BrunoMars",
8
+ "John Mayer": "vocalist_l1_JohnMayer",
9
+ "Michael Jackson": "vocalist_l1_MichaelJackson",
10
+ "Taylor Swift": "vocalist_l1_TaylorSwift",
11
+ "Jacky Cheung 张学友": "vocalist_l1_张学友",
12
+ "Jian Li 李健": "vocalist_l1_李健",
13
+ "Feng Wang 汪峰": "vocalist_l1_汪峰",
14
+ "Faye Wong 王菲": "vocalist_l1_王菲",
15
+ "Yijie Shi 石倚洁": "vocalist_l1_石倚洁",
16
+ "Tsai Chin 蔡琴": "vocalist_l1_蔡琴",
17
+ "Ying Na 那英": "vocalist_l1_那英",
18
+ "Eason Chan 陈奕迅": "vocalist_l1_陈奕迅",
19
+ "David Tao 陶喆": "vocalist_l1_陶喆",
20
+ }
21
+
22
+
23
+ def svc_inference(
24
+ source_audio,
25
+ target_singer,
26
+ diffusion_steps=1000,
27
+ key_shift_mode="auto",
28
+ key_shift_num=0,
29
+ ):
30
+ pass
31
+
32
+
33
+ demo_inputs = [
34
+ gr.Audio(
35
+ sources=["upload", "microphone"],
36
+ label="Upload (or record) a song you want to listen",
37
+ ),
38
+ gr.Radio(
39
+ choices=list(SUPPORTED_TARGET_SINGERS.keys()),
40
+ label="Target Singer",
41
+ value="Jian Li 李健",
42
+ ),
43
+ gr.Slider(
44
+ 1,
45
+ 1000,
46
+ value=1000,
47
+ step=1,
48
+ label="Diffusion Inference Steps",
49
+ info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
50
+ ),
51
+ gr.Radio(
52
+ choices=["Auto Shift", "Key Shift"],
53
+ value="Auto Shift",
54
+ label="Pitch Shift Control",
55
+ info='If you want to control the specific pitch shift value, you need to choose "Key Shift"',
56
+ ),
57
+ gr.Slider(
58
+ -6,
59
+ 6,
60
+ value=0,
61
+ step=1,
62
+ label="Key Shift Values",
63
+ info='How many semitones you want to transpose. This parameter will work only if you choose "Key Shift"',
64
+ ),
65
+ ]
66
+
67
+ demo_outputs = gr.Audio(label="")
68
+
69
+
70
+ demo = gr.Interface(
71
+ fn=svc_inference,
72
+ inputs=demo_inputs,
73
+ outputs=demo_outputs,
74
+ title="Amphion Singing Voice Conversion",
75
+ )
76
+
77
+ if __name__ == "__main__":
78
+ demo.launch(show_api=False)
ckpts/svc/vocalist_l1_contentvec+whisper/args.json ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/diffusion.json",
3
+ "dataset": [
4
+ "vocalist_l1",
5
+ ],
6
+ "exp_name": "vocalist_l1_contentvec+whisper",
7
+ "inference": {
8
+ "diffusion": {
9
+ "scheduler": "pndm",
10
+ "scheduler_settings": {
11
+ "num_inference_timesteps": 1000,
12
+ },
13
+ },
14
+ },
15
+ "model": {
16
+ "condition_encoder": {
17
+ "content_encoder_dim": 384,
18
+ "contentvec_dim": 256,
19
+ "f0_max": 1100,
20
+ "f0_min": 50,
21
+ "input_loudness_dim": 1,
22
+ "input_melody_dim": 1,
23
+ "merge_mode": "add",
24
+ "mert_dim": 256,
25
+ "n_bins_loudness": 256,
26
+ "n_bins_melody": 256,
27
+ "output_content_dim": 384,
28
+ "output_loudness_dim": 384,
29
+ "output_melody_dim": 384,
30
+ "output_singer_dim": 384,
31
+ "pitch_max": 1100,
32
+ "pitch_min": 50,
33
+ "singer_table_size": 512,
34
+ "use_conformer_for_content_features": false,
35
+ "use_contentvec": true,
36
+ "use_log_f0": true,
37
+ "use_log_loudness": true,
38
+ "use_mert": false,
39
+ "use_singer_encoder": true,
40
+ "use_spkid": true,
41
+ "use_wenet": false,
42
+ "use_whisper": true,
43
+ "wenet_dim": 512,
44
+ "whisper_dim": 1024,
45
+ },
46
+ "diffusion": {
47
+ "bidilconv": {
48
+ "base_channel": 384,
49
+ "conditioner_size": 384,
50
+ "conv_kernel_size": 3,
51
+ "dilation_cycle_length": 4,
52
+ "n_res_block": 20,
53
+ },
54
+ "model_type": "bidilconv",
55
+ "scheduler": "ddpm",
56
+ "scheduler_settings": {
57
+ "beta_end": 0.02,
58
+ "beta_schedule": "linear",
59
+ "beta_start": 0.0001,
60
+ "num_train_timesteps": 1000,
61
+ },
62
+ "step_encoder": {
63
+ "activation": "SiLU",
64
+ "dim_hidden_layer": 512,
65
+ "dim_raw_embedding": 128,
66
+ "max_period": 10000,
67
+ "num_layer": 2,
68
+ },
69
+ "unet2d": {
70
+ "down_block_types": [
71
+ "CrossAttnDownBlock2D",
72
+ "CrossAttnDownBlock2D",
73
+ "CrossAttnDownBlock2D",
74
+ "DownBlock2D",
75
+ ],
76
+ "in_channels": 1,
77
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
78
+ "only_cross_attention": false,
79
+ "out_channels": 1,
80
+ "up_block_types": [
81
+ "UpBlock2D",
82
+ "CrossAttnUpBlock2D",
83
+ "CrossAttnUpBlock2D",
84
+ "CrossAttnUpBlock2D",
85
+ ],
86
+ },
87
+ },
88
+ },
89
+ "model_type": "DiffWaveNetSVC",
90
+ "preprocess": {
91
+ "audio_dir": "audios",
92
+ "bits": 8,
93
+ "content_feature_batch_size": 16,
94
+ "contentvec_batch_size": 1,
95
+ "contentvec_dir": "contentvec",
96
+ "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
97
+ "contentvec_frameshift": 0.02,
98
+ "contentvec_sample_rate": 16000,
99
+ "dur_dir": "durs",
100
+ "duration_dir": "duration",
101
+ "emo2id": "emo2id.json",
102
+ "energy_dir": "energys",
103
+ "extract_audio": false,
104
+ "extract_contentvec_feature": true,
105
+ "extract_energy": true,
106
+ "extract_label": false,
107
+ "extract_mcep": false,
108
+ "extract_mel": true,
109
+ "extract_mert_feature": false,
110
+ "extract_pitch": true,
111
+ "extract_uv": true,
112
+ "extract_wenet_feature": false,
113
+ "extract_whisper_feature": true,
114
+ "f0_max": 1100,
115
+ "f0_min": 50,
116
+ "file_lst": "file.lst",
117
+ "fmax": 12000,
118
+ "fmin": 0,
119
+ "hop_size": 256,
120
+ "is_label": true,
121
+ "is_mu_law": true,
122
+ "lab_dir": "labs",
123
+ "label_dir": "labels",
124
+ "mcep_dir": "mcep",
125
+ "mel_dir": "mels",
126
+ "mel_min_max_norm": true,
127
+ "mel_min_max_stats_dir": "mel_min_max_stats",
128
+ "mert_dir": "mert",
129
+ "mert_feature_layer": -1,
130
+ "mert_frameshit": 0.01333,
131
+ "mert_hop_size": 320,
132
+ "mert_model": "m-a-p/MERT-v1-330M",
133
+ "min_level_db": -115,
134
+ "mu_law_norm": false,
135
+ "n_fft": 1024,
136
+ "n_mel": 100,
137
+ "num_silent_frames": 8,
138
+ "num_workers": 8,
139
+ "phone_seq_file": "phone_seq_file",
140
+ "pin_memory": true,
141
+ "pitch_bin": 256,
142
+ "pitch_dir": "pitches",
143
+ "pitch_extractor": "parselmouth",
144
+ "pitch_max": 1100.0,
145
+ "pitch_min": 50.0,
146
+ "processed_dir": "ckpts/svc/vocalist_l1_contentvec+whisper/data",
147
+ "ref_level_db": 20,
148
+ "sample_rate": 24000,
149
+ "spk2id": "singers.json",
150
+ "train_file": "train.json",
151
+ "trim_fft_size": 512,
152
+ "trim_hop_size": 128,
153
+ "trim_silence": false,
154
+ "trim_top_db": 30,
155
+ "trimmed_wav_dir": "trimmed_wavs",
156
+ "use_audio": false,
157
+ "use_contentvec": true,
158
+ "use_dur": false,
159
+ "use_emoid": false,
160
+ "use_frame_duration": false,
161
+ "use_frame_energy": true,
162
+ "use_frame_pitch": true,
163
+ "use_lab": false,
164
+ "use_label": false,
165
+ "use_log_scale_energy": false,
166
+ "use_log_scale_pitch": false,
167
+ "use_mel": true,
168
+ "use_mert": false,
169
+ "use_min_max_norm_mel": true,
170
+ "use_one_hot": false,
171
+ "use_phn_seq": false,
172
+ "use_phone_duration": false,
173
+ "use_phone_energy": false,
174
+ "use_phone_pitch": false,
175
+ "use_spkid": true,
176
+ "use_uv": true,
177
+ "use_wav": false,
178
+ "use_wenet": false,
179
+ "use_whisper": true,
180
+ "utt2emo": "utt2emo",
181
+ "utt2spk": "utt2singer",
182
+ "uv_dir": "uvs",
183
+ "valid_file": "test.json",
184
+ "wav_dir": "wavs",
185
+ "wenet_batch_size": 1,
186
+ "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
187
+ "wenet_dir": "wenet",
188
+ "wenet_downsample_rate": 4,
189
+ "wenet_frameshift": 0.01,
190
+ "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
191
+ "wenet_sample_rate": 16000,
192
+ "whisper_batch_size": 30,
193
+ "whisper_dir": "whisper",
194
+ "whisper_downsample_rate": 2,
195
+ "whisper_frameshift": 0.01,
196
+ "whisper_model": "medium",
197
+ "whisper_model_path": "pretrained/whisper/medium.pt",
198
+ "win_size": 1024,
199
+ },
200
+ "supported_model_type": [
201
+ "Fastspeech2",
202
+ "DiffSVC",
203
+ "Transformer",
204
+ "EDM",
205
+ "CD",
206
+ ],
207
+ "train": {
208
+ "adamw": {
209
+ "lr": 0.0004,
210
+ },
211
+ "batch_size": 32,
212
+ "dataloader": {
213
+ "num_worker": 8,
214
+ "pin_memory": true,
215
+ },
216
+ "ddp": true,
217
+ "epochs": 50000,
218
+ "gradient_accumulation_step": 1,
219
+ "keep_checkpoint_max": 5,
220
+ "keep_last": [
221
+ 5,
222
+ -1,
223
+ ],
224
+ "max_epoch": -1,
225
+ "max_steps": 1000000,
226
+ "multi_speaker_training": false,
227
+ "optimizer": "AdamW",
228
+ "random_seed": 10086,
229
+ "reducelronplateau": {
230
+ "factor": 0.8,
231
+ "min_lr": 0.0001,
232
+ "patience": 10,
233
+ },
234
+ "run_eval": [
235
+ false,
236
+ true,
237
+ ],
238
+ "sampler": {
239
+ "drop_last": true,
240
+ "holistic_shuffle": false,
241
+ },
242
+ "save_checkpoint_stride": [
243
+ 3,
244
+ 10,
245
+ ],
246
+ "save_checkpoints_steps": 10000,
247
+ "save_summary_steps": 500,
248
+ "scheduler": "ReduceLROnPlateau",
249
+ "total_training_steps": 50000,
250
+ "tracker": [
251
+ "tensorboard",
252
+ ],
253
+ "valid_interval": 10000,
254
+ },
255
+ "use_custom_dataset": true,
256
+ }
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:836af10b834c7aec9209eb19ce43559e6ef1e3a59bd6468e90cadbc9a18749ef
3
+ size 249512389
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d54eed12bef331095fc367f196d07c5061d5cb72dd6fe0e1e4453b997bf1d68d
3
+ size 124755137
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6798ddffadcd7d5405a77e667c674c474e4fef0cba817fdd300c7c985c1e82fe
3
+ size 14599
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocalist_l1_Adele": 0,
3
+ "vocalist_l1_Beyonce": 1,
4
+ "vocalist_l1_BrunoMars": 2,
5
+ "vocalist_l1_JohnMayer": 3,
6
+ "vocalist_l1_MichaelJackson": 4,
7
+ "vocalist_l1_TaylorSwift": 5,
8
+ "vocalist_l1_张学友": 6,
9
+ "vocalist_l1_李健": 7,
10
+ "vocalist_l1_汪峰": 8,
11
+ "vocalist_l1_王菲": 9,
12
+ "vocalist_l1_石倚洁": 10,
13
+ "vocalist_l1_蔡琴": 11,
14
+ "vocalist_l1_那英": 12,
15
+ "vocalist_l1_陈奕迅": 13,
16
+ "vocalist_l1_陶喆": 14
17
+ }
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04131849378aa4f525a701909f743c303f8d56571682572b888046ead9f3e2ab
3
+ size 528
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef4895ebef0e9949a6e623315bdc8a68490ba95d2f81b2be9f5146f904203016
3
+ size 528
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset": "vocalist_l1",
3
+ "train": {
4
+ "size": 3180,
5
+ "hours": 6.1643
6
+ },
7
+ "test": {
8
+ "size": 114,
9
+ "hours": 0.2224
10
+ },
11
+ "singers": {
12
+ "size": 15,
13
+ "training_minutes": {
14
+ "vocalist_l1_陶喆": 45.51,
15
+ "vocalist_l1_陈奕迅": 43.36,
16
+ "vocalist_l1_汪峰": 41.08,
17
+ "vocalist_l1_李健": 38.9,
18
+ "vocalist_l1_JohnMayer": 30.83,
19
+ "vocalist_l1_Adele": 27.23,
20
+ "vocalist_l1_那英": 27.02,
21
+ "vocalist_l1_石倚洁": 24.93,
22
+ "vocalist_l1_张学友": 18.31,
23
+ "vocalist_l1_TaylorSwift": 18.31,
24
+ "vocalist_l1_王菲": 16.78,
25
+ "vocalist_l1_MichaelJackson": 15.13,
26
+ "vocalist_l1_蔡琴": 10.12,
27
+ "vocalist_l1_BrunoMars": 6.29,
28
+ "vocalist_l1_Beyonce": 6.06
29
+ }
30
+ }
31
+ }
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocalist_l1_Adele": {
3
+ "voiced_positions": {
4
+ "mean": 336.5038018286193,
5
+ "std": 100.2148774476881,
6
+ "median": 332.98363792619296,
7
+ "min": 59.99838412340723,
8
+ "max": 1099.849325287837
9
+ },
10
+ "total_positions": {
11
+ "mean": 231.79366581704338,
12
+ "std": 176.6042850107386,
13
+ "median": 273.2844263775394,
14
+ "min": 0.0,
15
+ "max": 1099.849325287837
16
+ }
17
+ },
18
+ "vocalist_l1_Beyonce": {
19
+ "voiced_positions": {
20
+ "mean": 357.5678927636881,
21
+ "std": 130.1132620135807,
22
+ "median": 318.2981879228934,
23
+ "min": 70.29719673914867,
24
+ "max": 1050.354470112099
25
+ },
26
+ "total_positions": {
27
+ "mean": 267.5248026267327,
28
+ "std": 191.71600807951046,
29
+ "median": 261.91981963774066,
30
+ "min": 0.0,
31
+ "max": 1050.354470112099
32
+ }
33
+ },
34
+ "vocalist_l1_BrunoMars": {
35
+ "voiced_positions": {
36
+ "mean": 330.92612740814315,
37
+ "std": 86.51034158515388,
38
+ "median": 324.65585832605217,
39
+ "min": 58.74277302450286,
40
+ "max": 999.2818302992808
41
+ },
42
+ "total_positions": {
43
+ "mean": 237.26076288057826,
44
+ "std": 166.09898203490803,
45
+ "median": 286.3097386522132,
46
+ "min": 0.0,
47
+ "max": 999.2818302992808
48
+ }
49
+ },
50
+ "vocalist_l1_JohnMayer": {
51
+ "voiced_positions": {
52
+ "mean": 218.3531239166661,
53
+ "std": 77.89887175223768,
54
+ "median": 200.19060542586652,
55
+ "min": 53.371912740674716,
56
+ "max": 1098.1986774161685
57
+ },
58
+ "total_positions": {
59
+ "mean": 112.95331907131244,
60
+ "std": 122.65534824070893,
61
+ "median": 124.71389285965317,
62
+ "min": 0.0,
63
+ "max": 1098.1986774161685
64
+ }
65
+ },
66
+ "vocalist_l1_MichaelJackson": {
67
+ "voiced_positions": {
68
+ "mean": 293.4663654519906,
69
+ "std": 89.02211325650234,
70
+ "median": 284.4323483619402,
71
+ "min": 61.14507754070825,
72
+ "max": 1096.4247902272325
73
+ },
74
+ "total_positions": {
75
+ "mean": 172.1013565770682,
76
+ "std": 159.79551912957191,
77
+ "median": 212.82938711725973,
78
+ "min": 0.0,
79
+ "max": 1096.4247902272325
80
+ }
81
+ },
82
+ "vocalist_l1_TaylorSwift": {
83
+ "voiced_positions": {
84
+ "mean": 302.5346928039029,
85
+ "std": 87.1724728626562,
86
+ "median": 286.91670244246586,
87
+ "min": 51.31173137207717,
88
+ "max": 1098.9374311806605
89
+ },
90
+ "total_positions": {
91
+ "mean": 169.90968097339214,
92
+ "std": 163.7133164876362,
93
+ "median": 220.90943653386546,
94
+ "min": 0.0,
95
+ "max": 1098.9374311806605
96
+ }
97
+ },
98
+ "vocalist_l1_张学友": {
99
+ "voiced_positions": {
100
+ "mean": 233.6845479691867,
101
+ "std": 66.47140810463938,
102
+ "median": 228.28695118043396,
103
+ "min": 51.65338480121057,
104
+ "max": 1094.4381927885959
105
+ },
106
+ "total_positions": {
107
+ "mean": 167.79543637603194,
108
+ "std": 119.28338415844308,
109
+ "median": 194.81504136428546,
110
+ "min": 0.0,
111
+ "max": 1094.4381927885959
112
+ }
113
+ },
114
+ "vocalist_l1_李健": {
115
+ "voiced_positions": {
116
+ "mean": 234.98401896504657,
117
+ "std": 71.3955175177514,
118
+ "median": 221.86415264367847,
119
+ "min": 54.070687769392585,
120
+ "max": 1096.3342286660531
121
+ },
122
+ "total_positions": {
123
+ "mean": 148.74760079412246,
124
+ "std": 126.70486473504008,
125
+ "median": 180.21374566147688,
126
+ "min": 0.0,
127
+ "max": 1096.3342286660531
128
+ }
129
+ },
130
+ "vocalist_l1_汪峰": {
131
+ "voiced_positions": {
132
+ "mean": 284.27752567207864,
133
+ "std": 78.51774150654873,
134
+ "median": 278.26186808969493,
135
+ "min": 54.30945929095861,
136
+ "max": 1053.6870553733015
137
+ },
138
+ "total_positions": {
139
+ "mean": 172.41584497486713,
140
+ "std": 151.74272125914902,
141
+ "median": 216.27534661524862,
142
+ "min": 0.0,
143
+ "max": 1053.6870553733015
144
+ }
145
+ },
146
+ "vocalist_l1_王菲": {
147
+ "voiced_positions": {
148
+ "mean": 339.1661679865587,
149
+ "std": 86.86768172635271,
150
+ "median": 327.4151031268507,
151
+ "min": 51.21299842481366,
152
+ "max": 1096.7044574066776
153
+ },
154
+ "total_positions": {
155
+ "mean": 217.726880186,
156
+ "std": 176.8748978138034,
157
+ "median": 277.8608050501477,
158
+ "min": 0.0,
159
+ "max": 1096.7044574066776
160
+ }
161
+ },
162
+ "vocalist_l1_石倚洁": {
163
+ "voiced_positions": {
164
+ "mean": 279.67710779262256,
165
+ "std": 87.82306577322389,
166
+ "median": 271.13024912248443,
167
+ "min": 59.604772357481075,
168
+ "max": 1098.0574674417153
169
+ },
170
+ "total_positions": {
171
+ "mean": 205.49634806008135,
172
+ "std": 144.6064344590865,
173
+ "median": 234.19454400899718,
174
+ "min": 0.0,
175
+ "max": 1098.0574674417153
176
+ }
177
+ },
178
+ "vocalist_l1_蔡琴": {
179
+ "voiced_positions": {
180
+ "mean": 258.9105806499278,
181
+ "std": 67.4079737418162,
182
+ "median": 250.29778287949176,
183
+ "min": 54.81875790199644,
184
+ "max": 930.3733192171918
185
+ },
186
+ "total_positions": {
187
+ "mean": 197.64675891035662,
188
+ "std": 124.80889987119957,
189
+ "median": 228.14775033720753,
190
+ "min": 0.0,
191
+ "max": 930.3733192171918
192
+ }
193
+ },
194
+ "vocalist_l1_那英": {
195
+ "voiced_positions": {
196
+ "mean": 358.98655838013195,
197
+ "std": 91.30591323348871,
198
+ "median": 346.95185476261275,
199
+ "min": 71.62879029165369,
200
+ "max": 1085.4349856526985
201
+ },
202
+ "total_positions": {
203
+ "mean": 243.83317702162077,
204
+ "std": 183.68660712060583,
205
+ "median": 294.9745603259994,
206
+ "min": 0.0,
207
+ "max": 1085.4349856526985
208
+ }
209
+ },
210
+ "vocalist_l1_陈奕迅": {
211
+ "voiced_positions": {
212
+ "mean": 222.0124146654594,
213
+ "std": 68.65002654904572,
214
+ "median": 218.9200565540147,
215
+ "min": 50.48503062529368,
216
+ "max": 1084.6336454006018
217
+ },
218
+ "total_positions": {
219
+ "mean": 154.2275169157727,
220
+ "std": 117.16740631313343,
221
+ "median": 176.89315636838086,
222
+ "min": 0.0,
223
+ "max": 1084.6336454006018
224
+ }
225
+ },
226
+ "vocalist_l1_陶喆": {
227
+ "voiced_positions": {
228
+ "mean": 242.58206762395713,
229
+ "std": 69.61805791083957,
230
+ "median": 227.5222796096177,
231
+ "min": 50.44809060945403,
232
+ "max": 1098.4942623171203
233
+ },
234
+ "total_positions": {
235
+ "mean": 171.59040988406485,
236
+ "std": 124.93911390018495,
237
+ "median": 204.4328861811408,
238
+ "min": 0.0,
239
+ "max": 1098.4942623171203
240
+ }
241
+ }
242
+ }
ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7f490fd0c97876e24bfc44413365ded7ff5d22c1c79f0dac0b754f3b32df76f
3
+ size 88
ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e01bcf2fa621ba563b70568c18fe0742d0f48cafae83a6e8beb0bb6d1f6d146d
3
+ size 77413046
ckpts/svc/vocalist_l1_contentvec+whisper/singers.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "vocalist_l1_Adele": 0,
3
+ "vocalist_l1_Beyonce": 1,
4
+ "vocalist_l1_BrunoMars": 2,
5
+ "vocalist_l1_JohnMayer": 3,
6
+ "vocalist_l1_MichaelJackson": 4,
7
+ "vocalist_l1_TaylorSwift": 5,
8
+ "vocalist_l1_张学友": 6,
9
+ "vocalist_l1_李健": 7,
10
+ "vocalist_l1_汪峰": 8,
11
+ "vocalist_l1_王菲": 9,
12
+ "vocalist_l1_石倚洁": 10,
13
+ "vocalist_l1_蔡琴": 11,
14
+ "vocalist_l1_那英": 12,
15
+ "vocalist_l1_陈奕迅": 13,
16
+ "vocalist_l1_陶喆": 14
17
+ }
egs/svc/MultipleContentsSVC/README.md ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion
2
+
3
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2310.11160)
4
+ [![demo](https://img.shields.io/badge/SVC-Demo-red)](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html)
5
+
6
+ <br>
7
+ <div align="center">
8
+ <img src="../../../imgs/svc/MultipleContentsSVC.png" width="85%">
9
+ </div>
10
+ <br>
11
+
12
+ This is the official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Specially,
13
+
14
+ - The muptile content features are from [Whipser](https://github.com/wenet-e2e/wenet) and [ContentVec](https://github.com/auspicious3000/contentvec).
15
+ - The acoustic model is based on Bidirectional Non-Causal Dilated CNN (called `DiffWaveNetSVC` in Amphion), which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
16
+ - The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture and we fine-tuned it in over 120 hours singing voice data.
17
+
18
+ There are four stages in total:
19
+
20
+ 1. Data preparation
21
+ 2. Features extraction
22
+ 3. Training
23
+ 4. Inference/conversion
24
+
25
+ > **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
26
+ > ```bash
27
+ > cd Amphion
28
+ > ```
29
+
30
+ ## 1. Data Preparation
31
+
32
+ ### Dataset Download
33
+
34
+ By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md).
35
+
36
+ ### Configuration
37
+
38
+ Specify the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
39
+
40
+ ```json
41
+ "dataset": [
42
+ "m4singer",
43
+ "opencpop",
44
+ "opensinger",
45
+ "svcc",
46
+ "vctk"
47
+ ],
48
+ "dataset_path": {
49
+ // TODO: Fill in your dataset path
50
+ "m4singer": "[M4Singer dataset path]",
51
+ "opencpop": "[Opencpop dataset path]",
52
+ "opensinger": "[OpenSinger dataset path]",
53
+ "svcc": "[SVCC dataset path]",
54
+ "vctk": "[VCTK dataset path]"
55
+ },
56
+ ```
57
+
58
+ ## 2. Features Extraction
59
+
60
+ ### Content-based Pretrained Models Download
61
+
62
+ By default, we utilize the Whisper and ContentVec to extract content features. How to download them is detailed [here](../../../pretrained/README.md).
63
+
64
+ ### Configuration
65
+
66
+ Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`:
67
+
68
+ ```json
69
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
70
+ "log_dir": "ckpts/svc",
71
+ "preprocess": {
72
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
73
+ "processed_dir": "data",
74
+ ...
75
+ },
76
+ ```
77
+
78
+ ### Run
79
+
80
+ Run the `run.sh` as the preproces stage (set `--stage 1`).
81
+
82
+ ```bash
83
+ sh egs/svc/MultipleContentsSVC/run.sh --stage 1
84
+ ```
85
+
86
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
87
+
88
+ ## 3. Training
89
+
90
+ ### Configuration
91
+
92
+ We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
93
+
94
+ ```json
95
+ "train": {
96
+ "batch_size": 32,
97
+ ...
98
+ "adamw": {
99
+ "lr": 2.0e-4
100
+ },
101
+ ...
102
+ }
103
+ ```
104
+
105
+ ### Run
106
+
107
+ Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`.
108
+
109
+ ```bash
110
+ sh egs/svc/MultipleContentsSVC/run.sh --stage 2 --name [YourExptName]
111
+ ```
112
+
113
+ > **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
114
+
115
+ ## 4. Inference/Conversion
116
+
117
+ ### Pretrained Vocoder Download
118
+
119
+ We fine-tune the official BigVGAN pretrained model with over 120 hours singing voice data. The benifits of fine-tuning has been investigated in our paper (see this [demo page](https://www.zhangxueyao.com/data/MultipleContentsSVC/vocoder.html)). The final pretrained singing voice vocoder is released [here](../../../pretrained/README.md#amphion-singing-bigvgan) (called `Amphion Singing BigVGAN`).
120
+
121
+ ### Run
122
+
123
+ For inference/conversion, you need to specify the following configurations when running `run.sh`:
124
+
125
+ | Parameters | Description | Example |
126
+ | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
127
+ | `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/svc/[YourExptName]` |
128
+ | `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/svc/[YourExptName]/result` |
129
+ | `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir). | The `infer_source_file` could be `Amphion/data/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). |
130
+ | `--infer_target_speaker` | The target speaker you want to convert into. You can refer to `Amphion/ckpts/svc/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`. |
131
+ | `--infer_key_shift` | How many semitones you want to transpose. | `"autoshfit"` (by default), `3`, `-3`, etc. |
132
+
133
+ For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run:
134
+
135
+ ```bash
136
+ sh egs/svc/MultipleContentsSVC/run.sh --stage 3 --gpu "0" \
137
+ --infer_expt_dir Amphion/ckpts/svc/[YourExptName] \
138
+ --infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \
139
+ --infer_source_audio_dir [Your Audios Folder] \
140
+ --infer_target_speaker "opencpop_female1" \
141
+ --infer_key_shift "autoshift"
142
+ ```
143
+
144
+ ## Citations
145
+
146
+ ```bibtex
147
+ @article{zhang2023leveraging,
148
+ title={Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion},
149
+ author={Zhang, Xueyao and Gu, Yicheng and Chen, Haopeng and Fang, Zihao and Zou, Lexiao and Xue, Liumeng and Wu, Zhizheng},
150
+ journal={Machine Learning for Audio Worshop, NeurIPS 2023},
151
+ year={2023}
152
+ }
153
+ ```
egs/svc/MultipleContentsSVC/exp_config.json ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "base_config": "config/diffusion.json",
3
+ "model_type": "DiffWaveNetSVC",
4
+ "dataset": [
5
+ "m4singer",
6
+ "opencpop",
7
+ "opensinger",
8
+ "svcc",
9
+ "vctk"
10
+ ],
11
+ "dataset_path": {
12
+ // TODO: Fill in your dataset path
13
+ "m4singer": "[M4Singer dataset path]",
14
+ "opencpop": "[Opencpop dataset path]",
15
+ "opensinger": "[OpenSinger dataset path]",
16
+ "svcc": "[SVCC dataset path]",
17
+ "vctk": "[VCTK dataset path]"
18
+ },
19
+ // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
20
+ "log_dir": "ckpts/svc",
21
+ "preprocess": {
22
+ // TODO: Fill in the output data path. The default value is "Amphion/data"
23
+ "processed_dir": "data",
24
+ // Config for features extraction
25
+ "extract_mel": true,
26
+ "extract_pitch": true,
27
+ "extract_energy": true,
28
+ "extract_whisper_feature": true,
29
+ "extract_contentvec_feature": true,
30
+ "extract_wenet_feature": false,
31
+ "whisper_batch_size": 30, // decrease it if your GPU is out of memory
32
+ "contentvec_batch_size": 1,
33
+ // Fill in the content-based pretrained model's path
34
+ "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
35
+ "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
36
+ "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
37
+ "whisper_model": "medium",
38
+ "whisper_model_path": "pretrained/whisper/medium.pt",
39
+ // Config for features usage
40
+ "use_mel": true,
41
+ "use_min_max_norm_mel": true,
42
+ "use_frame_pitch": true,
43
+ "use_frame_energy": true,
44
+ "use_spkid": true,
45
+ "use_whisper": true,
46
+ "use_contentvec": true,
47
+ "use_wenet": false,
48
+ "n_mel": 100,
49
+ "sample_rate": 24000
50
+ },
51
+ "model": {
52
+ "condition_encoder": {
53
+ // Config for features usage
54
+ "use_whisper": true,
55
+ "use_contentvec": true,
56
+ "use_wenet": false,
57
+ "whisper_dim": 1024,
58
+ "contentvec_dim": 256,
59
+ "wenet_dim": 512,
60
+ "use_singer_encoder": false,
61
+ "pitch_min": 50,
62
+ "pitch_max": 1100
63
+ },
64
+ "diffusion": {
65
+ "scheduler": "ddpm",
66
+ "scheduler_settings": {
67
+ "num_train_timesteps": 1000,
68
+ "beta_start": 1.0e-4,
69
+ "beta_end": 0.02,
70
+ "beta_schedule": "linear"
71
+ },
72
+ // Diffusion steps encoder
73
+ "step_encoder": {
74
+ "dim_raw_embedding": 128,
75
+ "dim_hidden_layer": 512,
76
+ "activation": "SiLU",
77
+ "num_layer": 2,
78
+ "max_period": 10000
79
+ },
80
+ // Diffusion decoder
81
+ "model_type": "bidilconv",
82
+ // bidilconv, unet2d, TODO: unet1d
83
+ "bidilconv": {
84
+ "base_channel": 512,
85
+ "n_res_block": 40,
86
+ "conv_kernel_size": 3,
87
+ "dilation_cycle_length": 4,
88
+ // specially, 1 means no dilation
89
+ "conditioner_size": 384
90
+ }
91
+ }
92
+ },
93
+ "train": {
94
+ "batch_size": 32,
95
+ "gradient_accumulation_step": 1,
96
+ "max_epoch": -1, // -1 means no limit
97
+ "save_checkpoint_stride": [
98
+ 3,
99
+ 50
100
+ ],
101
+ "keep_last": [
102
+ 3,
103
+ 2
104
+ ],
105
+ "run_eval": [
106
+ true,
107
+ true
108
+ ],
109
+ "adamw": {
110
+ "lr": 2.0e-4
111
+ },
112
+ "reducelronplateau": {
113
+ "factor": 0.8,
114
+ "patience": 30,
115
+ "min_lr": 1.0e-4
116
+ },
117
+ "dataloader": {
118
+ "num_worker": 8,
119
+ "pin_memory": true
120
+ },
121
+ "sampler": {
122
+ "holistic_shuffle": false,
123
+ "drop_last": true
124
+ }
125
+ }
126
+ }
egs/svc/MultipleContentsSVC/run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ ../_template/run.sh
egs/svc/README.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Amphion Singing Voice Conversion (SVC) Recipe
2
+
3
+ ## Quick Start
4
+
5
+ We provide a **[beginner recipe](MultipleContentsSVC)** to demonstrate how to train a cutting edge SVC model. Specifically, it is also an official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Some demos can be seen [here](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html).
6
+
7
+ ## Supported Model Architectures
8
+
9
+ The main idea of SVC is to first disentangle the speaker-agnostic representations from the source audio, and then inject the desired speaker information to synthesize the target, which usually utilizes an acoustic decoder and a subsequent waveform synthesizer (vocoder):
10
+
11
+ <br>
12
+ <div align="center">
13
+ <img src="../../imgs/svc/pipeline.png" width="70%">
14
+ </div>
15
+ <br>
16
+
17
+ Until now, Amphion SVC has supported the following features and models:
18
+
19
+ - **Speaker-agnostic Representations**:
20
+ - Content Features: Sourcing from [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec).
21
+ - Prosody Features: F0 and energy.
22
+ - **Speaker Embeddings**:
23
+ - Speaker Look-Up Table.
24
+ - Reference Encoder (👨‍💻 developing): It can be used for zero-shot SVC.
25
+ - **Acoustic Decoders**:
26
+ - Diffusion-based models:
27
+ - **[DiffWaveNetSVC](MultipleContentsSVC)**: The encoder is based on Bidirectional Non-Causal Dilated CNN, which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
28
+ - **[DiffComoSVC](DiffComoSVC)** (👨‍💻 developing): The diffusion framework is based on [Consistency Model](https://proceedings.mlr.press/v202/song23a.html). It can significantly accelerate the inference process of the diffusion model.
29
+ - Transformer-based models:
30
+ - **[TransformerSVC](TransformerSVC)**: Encoder-only and Non-autoregressive Transformer Architecture.
31
+ - VAE- and Flow-based models:
32
+ - **[VitsSVC]()** (👨‍💻 developing): It is designed as a [VITS](https://arxiv.org/abs/2106.06103)-like model whose textual input is replaced by the content features, which is similar to [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc).
33
+ - **Waveform Synthesizers (Vocoders)**:
34
+ - The supported vocoders can be seen in [Amphion Vocoder Recipe](../vocoder/README.md).
egs/svc/_template/run.sh ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ ######## Build Experiment Environment ###########
7
+ exp_dir=$(cd `dirname $0`; pwd)
8
+ work_dir=$(dirname $(dirname $(dirname $exp_dir)))
9
+
10
+ export WORK_DIR=$work_dir
11
+ export PYTHONPATH=$work_dir
12
+ export PYTHONIOENCODING=UTF-8
13
+
14
+ ######## Parse the Given Parameters from the Commond ###########
15
+ options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir: -- "$@")
16
+ eval set -- "$options"
17
+
18
+ while true; do
19
+ case $1 in
20
+ # Experimental Configuration File
21
+ -c | --config) shift; exp_config=$1 ; shift ;;
22
+ # Experimental Name
23
+ -n | --name) shift; exp_name=$1 ; shift ;;
24
+ # Running Stage
25
+ -s | --stage) shift; running_stage=$1 ; shift ;;
26
+ # Visible GPU machines. The default value is "0".
27
+ --gpu) shift; gpu=$1 ; shift ;;
28
+
29
+ # [Only for Training] Resume configuration
30
+ --resume) shift; resume=$1 ; shift ;;
31
+ # [Only for Training] The specific checkpoint path that you want to resume from.
32
+ --resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;;
33
+ # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
34
+ --resume_type) shift; resume_type=$1 ; shift ;;
35
+
36
+ # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
37
+ --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
38
+ # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
39
+ --infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
40
+ # [Only for Inference] The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir can be "$work_dir/source_audio" which includes several audio files (*.wav, *.mp3 or *.flac).
41
+ --infer_source_file) shift; infer_source_file=$1 ; shift ;;
42
+ --infer_source_audio_dir) shift; infer_source_audio_dir=$1 ; shift ;;
43
+ # [Only for Inference] Specify the target speaker you want to convert into. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1".
44
+ --infer_target_speaker) shift; infer_target_speaker=$1 ; shift ;;
45
+ # [Only for Inference] For advanced users, you can modify the trans_key parameters into an integer (which means the semitones you want to transpose). Its default value is "autoshift".
46
+ --infer_key_shift) shift; infer_key_shift=$1 ; shift ;;
47
+ # [Only for Inference] The vocoder dir. Its default value is Amphion/pretrained/bigvgan. See Amphion/pretrained/README.md to download the pretrained BigVGAN vocoders.
48
+ --infer_vocoder_dir) shift; infer_vocoder_dir=$1 ; shift ;;
49
+
50
+ --) shift ; break ;;
51
+ *) echo "Invalid option: $1" exit 1 ;;
52
+ esac
53
+ done
54
+
55
+
56
+ ### Value check ###
57
+ if [ -z "$running_stage" ]; then
58
+ echo "[Error] Please specify the running stage"
59
+ exit 1
60
+ fi
61
+
62
+ if [ -z "$exp_config" ]; then
63
+ exp_config="${exp_dir}"/exp_config.json
64
+ fi
65
+ echo "Exprimental Configuration File: $exp_config"
66
+
67
+ if [ -z "$gpu" ]; then
68
+ gpu="0"
69
+ fi
70
+
71
+ ######## Features Extraction ###########
72
+ if [ $running_stage -eq 1 ]; then
73
+ CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/svc/preprocess.py \
74
+ --config $exp_config \
75
+ --num_workers 4
76
+ fi
77
+
78
+ ######## Training ###########
79
+ if [ $running_stage -eq 2 ]; then
80
+ if [ -z "$exp_name" ]; then
81
+ echo "[Error] Please specify the experiments name"
82
+ exit 1
83
+ fi
84
+ echo "Exprimental Name: $exp_name"
85
+
86
+ if [ "$resume" = true ]; then
87
+ echo "Automatically resume from the experimental dir..."
88
+ CUDA_VISIBLE_DEVICES="$gpu" accelerate launch "${work_dir}"/bins/svc/train.py \
89
+ --config "$exp_config" \
90
+ --exp_name "$exp_name" \
91
+ --log_level info \
92
+ --resume
93
+ else
94
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/svc/train.py \
95
+ --config "$exp_config" \
96
+ --exp_name "$exp_name" \
97
+ --log_level info \
98
+ --resume_from_ckpt_path "$resume_from_ckpt_path" \
99
+ --resume_type "$resume_type"
100
+ fi
101
+ fi
102
+
103
+ ######## Inference/Conversion ###########
104
+ if [ $running_stage -eq 3 ]; then
105
+ if [ -z "$infer_expt_dir" ]; then
106
+ echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
107
+ exit 1
108
+ fi
109
+
110
+ if [ -z "$infer_output_dir" ]; then
111
+ infer_output_dir="$expt_dir/result"
112
+ fi
113
+
114
+ if [ -z "$infer_source_file" ] && [ -z "$infer_source_audio_dir" ]; then
115
+ echo "[Error] Please specify the source file/dir. The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir should include several audio files (*.wav, *.mp3 or *.flac)."
116
+ exit 1
117
+ fi
118
+
119
+ if [ -z "$infer_source_file" ]; then
120
+ infer_source=$infer_source_audio_dir
121
+ fi
122
+
123
+ if [ -z "$infer_source_audio_dir" ]; then
124
+ infer_source=$infer_source_file
125
+ fi
126
+
127
+ if [ -z "$infer_target_speaker" ]; then
128
+ echo "[Error] Please specify the target speaker. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1""
129
+ exit 1
130
+ fi
131
+
132
+ if [ -z "$infer_key_shift" ]; then
133
+ infer_key_shift="autoshift"
134
+ fi
135
+
136
+ if [ -z "$infer_vocoder_dir" ]; then
137
+ infer_vocoder_dir="$work_dir"/pretrained/bigvgan
138
+ echo "[Warning] You don't specify the infer_vocoder_dir. It is set $infer_vocoder_dir by default. Make sure that you have followed Amphoion/pretrained/README.md to download the pretrained BigVGAN vocoder checkpoint."
139
+ fi
140
+
141
+ CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/svc/inference.py \
142
+ --config $exp_config \
143
+ --acoustics_dir $infer_expt_dir \
144
+ --vocoder_dir $infer_vocoder_dir \
145
+ --target_singer $infer_target_speaker \
146
+ --trans_key $infer_key_shift \
147
+ --source $infer_source \
148
+ --output_dir $infer_output_dir \
149
+ --log_level debug
150
+ fi
inference.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import glob
9
+ from tqdm import tqdm
10
+ import json
11
+ import torch
12
+ import time
13
+
14
+ from models.svc.diffusion.diffusion_inference import DiffusionInference
15
+ from models.svc.comosvc.comosvc_inference import ComoSVCInference
16
+ from models.svc.transformer.transformer_inference import TransformerInference
17
+ from utils.util import load_config
18
+ from utils.audio_slicer import split_audio, merge_segments_encodec
19
+ from processors import acoustic_extractor, content_extractor
20
+
21
+
22
+ def build_inference(args, cfg, infer_type="from_dataset"):
23
+ supported_inference = {
24
+ "DiffWaveNetSVC": DiffusionInference,
25
+ "DiffComoSVC": ComoSVCInference,
26
+ "TransformerSVC": TransformerInference,
27
+ }
28
+
29
+ inference_class = supported_inference[cfg.model_type]
30
+ return inference_class(args, cfg, infer_type)
31
+
32
+
33
+ def prepare_for_audio_file(args, cfg, num_workers=1):
34
+ preprocess_path = cfg.preprocess.processed_dir
35
+ audio_name = cfg.inference.source_audio_name
36
+ temp_audio_dir = os.path.join(preprocess_path, audio_name)
37
+
38
+ ### eval file
39
+ t = time.time()
40
+ eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name)
41
+ args.source = eval_file
42
+ with open(eval_file, "r") as f:
43
+ metadata = json.load(f)
44
+ print("Prepare for meta eval data: {:.1f}s".format(time.time() - t))
45
+
46
+ ### acoustic features
47
+ t = time.time()
48
+ acoustic_extractor.extract_utt_acoustic_features_serial(
49
+ metadata, temp_audio_dir, cfg
50
+ )
51
+ acoustic_extractor.cal_mel_min_max(
52
+ dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
53
+ )
54
+ acoustic_extractor.cal_pitch_statistics_svc(
55
+ dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
56
+ )
57
+ print("Prepare for acoustic features: {:.1f}s".format(time.time() - t))
58
+
59
+ ### content features
60
+ t = time.time()
61
+ content_extractor.extract_utt_content_features_dataloader(
62
+ cfg, metadata, num_workers
63
+ )
64
+ print("Prepare for content features: {:.1f}s".format(time.time() - t))
65
+ return args, cfg, temp_audio_dir
66
+
67
+
68
+ def merge_for_audio_segments(audio_files, args, cfg):
69
+ audio_name = cfg.inference.source_audio_name
70
+ target_singer_name = args.target_singer
71
+
72
+ merge_segments_encodec(
73
+ wav_files=audio_files,
74
+ fs=cfg.preprocess.sample_rate,
75
+ output_path=os.path.join(
76
+ args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name)
77
+ ),
78
+ overlap_duration=cfg.inference.segments_overlap_duration,
79
+ )
80
+
81
+ for tmp_file in audio_files:
82
+ os.remove(tmp_file)
83
+
84
+
85
+ def prepare_source_eval_file(cfg, temp_audio_dir, audio_name):
86
+ """
87
+ Prepare the eval file (json) for an audio
88
+ """
89
+
90
+ audio_chunks_results = split_audio(
91
+ wav_file=cfg.inference.source_audio_path,
92
+ target_sr=cfg.preprocess.sample_rate,
93
+ output_dir=os.path.join(temp_audio_dir, "wavs"),
94
+ max_duration_of_segment=cfg.inference.segments_max_duration,
95
+ overlap_duration=cfg.inference.segments_overlap_duration,
96
+ )
97
+
98
+ metadata = []
99
+ for i, res in enumerate(audio_chunks_results):
100
+ res["index"] = i
101
+ res["Dataset"] = audio_name
102
+ res["Singer"] = audio_name
103
+ res["Uid"] = "{}_{}".format(audio_name, res["Uid"])
104
+ metadata.append(res)
105
+
106
+ eval_file = os.path.join(temp_audio_dir, "eval.json")
107
+ with open(eval_file, "w") as f:
108
+ json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True)
109
+
110
+ return eval_file
111
+
112
+
113
+ def cuda_relevant(deterministic=False):
114
+ torch.cuda.empty_cache()
115
+ # TF32 on Ampere and above
116
+ torch.backends.cuda.matmul.allow_tf32 = True
117
+ torch.backends.cudnn.enabled = True
118
+ torch.backends.cudnn.allow_tf32 = True
119
+ # Deterministic
120
+ torch.backends.cudnn.deterministic = deterministic
121
+ torch.backends.cudnn.benchmark = not deterministic
122
+ torch.use_deterministic_algorithms(deterministic)
123
+
124
+
125
+ def infer(args, cfg, infer_type):
126
+ # Build inference
127
+ t = time.time()
128
+ trainer = build_inference(args, cfg, infer_type)
129
+ print("Model Init: {:.1f}s".format(time.time() - t))
130
+
131
+ # Run inference
132
+ t = time.time()
133
+ output_audio_files = trainer.inference()
134
+ print("Model inference: {:.1f}s".format(time.time() - t))
135
+ return output_audio_files
136
+
137
+
138
+ def build_parser():
139
+ r"""Build argument parser for inference.py.
140
+ Anything else should be put in an extra config YAML file.
141
+ """
142
+
143
+ parser = argparse.ArgumentParser()
144
+ parser.add_argument(
145
+ "--config",
146
+ type=str,
147
+ required=True,
148
+ help="JSON/YAML file for configurations.",
149
+ )
150
+ parser.add_argument(
151
+ "--acoustics_dir",
152
+ type=str,
153
+ help="Acoustics model checkpoint directory. If a directory is given, "
154
+ "search for the latest checkpoint dir in the directory. If a specific "
155
+ "checkpoint dir is given, directly load the checkpoint.",
156
+ )
157
+ parser.add_argument(
158
+ "--vocoder_dir",
159
+ type=str,
160
+ required=True,
161
+ help="Vocoder checkpoint directory. Searching behavior is the same as "
162
+ "the acoustics one.",
163
+ )
164
+ parser.add_argument(
165
+ "--target_singer",
166
+ type=str,
167
+ required=True,
168
+ help="convert to a specific singer (e.g. --target_singers singer_id).",
169
+ )
170
+ parser.add_argument(
171
+ "--trans_key",
172
+ default=0,
173
+ help="0: no pitch shift; autoshift: pitch shift; int: key shift.",
174
+ )
175
+ parser.add_argument(
176
+ "--source",
177
+ type=str,
178
+ default="source_audio",
179
+ help="Source audio file or directory. If a JSON file is given, "
180
+ "inference from dataset is applied. If a directory is given, "
181
+ "inference from all wav/flac/mp3 audio files in the directory is applied. "
182
+ "Default: inference from all wav/flac/mp3 audio files in ./source_audio",
183
+ )
184
+ parser.add_argument(
185
+ "--output_dir",
186
+ type=str,
187
+ default="conversion_results",
188
+ help="Output directory. Default: ./conversion_results",
189
+ )
190
+ parser.add_argument(
191
+ "--log_level",
192
+ type=str,
193
+ default="warning",
194
+ help="Logging level. Default: warning",
195
+ )
196
+ parser.add_argument(
197
+ "--keep_cache",
198
+ action="store_true",
199
+ default=True,
200
+ help="Keep cache files. Only applicable to inference from files.",
201
+ )
202
+ parser.add_argument(
203
+ "--diffusion_inference_steps",
204
+ type=int,
205
+ default=1000,
206
+ help="Number of inference steps. Only applicable to diffusion inference.",
207
+ )
208
+ return parser
209
+
210
+
211
+ def main():
212
+ ### Parse arguments and config
213
+ args = build_parser().parse_args()
214
+ cfg = load_config(args.config)
215
+
216
+ # CUDA settings
217
+ cuda_relevant()
218
+
219
+ if os.path.isdir(args.source):
220
+ ### Infer from file
221
+
222
+ # Get all the source audio files (.wav, .flac, .mp3)
223
+ source_audio_dir = args.source
224
+ audio_list = []
225
+ for suffix in ["wav", "flac", "mp3"]:
226
+ audio_list += glob.glob(
227
+ os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True
228
+ )
229
+ print("There are {} source audios: ".format(len(audio_list)))
230
+
231
+ # Infer for every file as dataset
232
+ output_root_path = args.output_dir
233
+ for audio_path in tqdm(audio_list):
234
+ audio_name = audio_path.split("/")[-1].split(".")[0]
235
+ args.output_dir = os.path.join(output_root_path, audio_name)
236
+ print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name))
237
+
238
+ cfg.inference.source_audio_path = audio_path
239
+ cfg.inference.source_audio_name = audio_name
240
+ cfg.inference.segments_max_duration = 10.0
241
+ cfg.inference.segments_overlap_duration = 1.0
242
+
243
+ # Prepare metadata and features
244
+ args, cfg, cache_dir = prepare_for_audio_file(args, cfg)
245
+
246
+ # Infer from file
247
+ output_audio_files = infer(args, cfg, infer_type="from_file")
248
+
249
+ # Merge the split segments
250
+ merge_for_audio_segments(output_audio_files, args, cfg)
251
+
252
+ # Keep or remove caches
253
+ if not args.keep_cache:
254
+ os.removedirs(cache_dir)
255
+
256
+ else:
257
+ ### Infer from dataset
258
+ infer(args, cfg, infer_type="from_dataset")
models/__init__.py ADDED
File without changes
models/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .new_trainer import BaseTrainer
7
+ from .new_inference import BaseInference
models/base/base_dataset.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import numpy as np
8
+ import torch.utils.data
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from utils.data_utils import *
11
+ from processors.acoustic_extractor import cal_normalized_mel
12
+ from text import text_to_sequence
13
+ from text.text_token_collation import phoneIDCollation
14
+
15
+
16
+ class BaseDataset(torch.utils.data.Dataset):
17
+ def __init__(self, cfg, dataset, is_valid=False):
18
+ """
19
+ Args:
20
+ cfg: config
21
+ dataset: dataset name
22
+ is_valid: whether to use train or valid dataset
23
+ """
24
+
25
+ assert isinstance(dataset, str)
26
+
27
+ # self.data_root = processed_data_dir
28
+ self.cfg = cfg
29
+
30
+ processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
31
+ meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
32
+ self.metafile_path = os.path.join(processed_data_dir, meta_file)
33
+ self.metadata = self.get_metadata()
34
+
35
+
36
+
37
+ '''
38
+ load spk2id and utt2spk from json file
39
+ spk2id: {spk1: 0, spk2: 1, ...}
40
+ utt2spk: {dataset_uid: spk1, ...}
41
+ '''
42
+ if cfg.preprocess.use_spkid:
43
+ spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
44
+ with open(spk2id_path, "r") as f:
45
+ self.spk2id = json.load(f)
46
+
47
+ utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
48
+ self.utt2spk = dict()
49
+ with open(utt2spk_path, "r") as f:
50
+ for line in f.readlines():
51
+ utt, spk = line.strip().split('\t')
52
+ self.utt2spk[utt] = spk
53
+
54
+
55
+ if cfg.preprocess.use_uv:
56
+ self.utt2uv_path = {}
57
+ for utt_info in self.metadata:
58
+ dataset = utt_info["Dataset"]
59
+ uid = utt_info["Uid"]
60
+ utt = "{}_{}".format(dataset, uid)
61
+ self.utt2uv_path[utt] = os.path.join(
62
+ cfg.preprocess.processed_dir,
63
+ dataset,
64
+ cfg.preprocess.uv_dir,
65
+ uid + ".npy",
66
+ )
67
+
68
+ if cfg.preprocess.use_frame_pitch:
69
+ self.utt2frame_pitch_path = {}
70
+ for utt_info in self.metadata:
71
+ dataset = utt_info["Dataset"]
72
+ uid = utt_info["Uid"]
73
+ utt = "{}_{}".format(dataset, uid)
74
+
75
+ self.utt2frame_pitch_path[utt] = os.path.join(
76
+ cfg.preprocess.processed_dir,
77
+ dataset,
78
+ cfg.preprocess.pitch_dir,
79
+ uid + ".npy",
80
+ )
81
+
82
+ if cfg.preprocess.use_frame_energy:
83
+ self.utt2frame_energy_path = {}
84
+ for utt_info in self.metadata:
85
+ dataset = utt_info["Dataset"]
86
+ uid = utt_info["Uid"]
87
+ utt = "{}_{}".format(dataset, uid)
88
+
89
+ self.utt2frame_energy_path[utt] = os.path.join(
90
+ cfg.preprocess.processed_dir,
91
+ dataset,
92
+ cfg.preprocess.energy_dir,
93
+ uid + ".npy",
94
+ )
95
+
96
+ if cfg.preprocess.use_mel:
97
+ self.utt2mel_path = {}
98
+ for utt_info in self.metadata:
99
+ dataset = utt_info["Dataset"]
100
+ uid = utt_info["Uid"]
101
+ utt = "{}_{}".format(dataset, uid)
102
+
103
+ self.utt2mel_path[utt] = os.path.join(
104
+ cfg.preprocess.processed_dir,
105
+ dataset,
106
+ cfg.preprocess.mel_dir,
107
+ uid + ".npy",
108
+ )
109
+
110
+ if cfg.preprocess.use_linear:
111
+ self.utt2linear_path = {}
112
+ for utt_info in self.metadata:
113
+ dataset = utt_info["Dataset"]
114
+ uid = utt_info["Uid"]
115
+ utt = "{}_{}".format(dataset, uid)
116
+
117
+ self.utt2linear_path[utt] = os.path.join(
118
+ cfg.preprocess.processed_dir,
119
+ dataset,
120
+ cfg.preprocess.linear_dir,
121
+ uid + ".npy",
122
+ )
123
+
124
+ if cfg.preprocess.use_audio:
125
+ self.utt2audio_path = {}
126
+ for utt_info in self.metadata:
127
+ dataset = utt_info["Dataset"]
128
+ uid = utt_info["Uid"]
129
+ utt = "{}_{}".format(dataset, uid)
130
+
131
+ self.utt2audio_path[utt] = os.path.join(
132
+ cfg.preprocess.processed_dir,
133
+ dataset,
134
+ cfg.preprocess.audio_dir,
135
+ uid + ".npy",
136
+ )
137
+ elif cfg.preprocess.use_label:
138
+ self.utt2label_path = {}
139
+ for utt_info in self.metadata:
140
+ dataset = utt_info["Dataset"]
141
+ uid = utt_info["Uid"]
142
+ utt = "{}_{}".format(dataset, uid)
143
+
144
+ self.utt2label_path[utt] = os.path.join(
145
+ cfg.preprocess.processed_dir,
146
+ dataset,
147
+ cfg.preprocess.label_dir,
148
+ uid + ".npy",
149
+ )
150
+ elif cfg.preprocess.use_one_hot:
151
+ self.utt2one_hot_path = {}
152
+ for utt_info in self.metadata:
153
+ dataset = utt_info["Dataset"]
154
+ uid = utt_info["Uid"]
155
+ utt = "{}_{}".format(dataset, uid)
156
+
157
+ self.utt2one_hot_path[utt] = os.path.join(
158
+ cfg.preprocess.processed_dir,
159
+ dataset,
160
+ cfg.preprocess.one_hot_dir,
161
+ uid + ".npy",
162
+ )
163
+
164
+ if cfg.preprocess.use_text or cfg.preprocess.use_phone:
165
+ self.utt2seq = {}
166
+ for utt_info in self.metadata:
167
+ dataset = utt_info["Dataset"]
168
+ uid = utt_info["Uid"]
169
+ utt = "{}_{}".format(dataset, uid)
170
+
171
+ if cfg.preprocess.use_text:
172
+ text = utt_info["Text"]
173
+ sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
174
+ elif cfg.preprocess.use_phone:
175
+ # load phoneme squence from phone file
176
+ phone_path = os.path.join(processed_data_dir,
177
+ cfg.preprocess.phone_dir,
178
+ uid+'.phone'
179
+ )
180
+ with open(phone_path, 'r') as fin:
181
+ phones = fin.readlines()
182
+ assert len(phones) == 1
183
+ phones = phones[0].strip()
184
+ phones_seq = phones.split(' ')
185
+
186
+ phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
187
+ sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
188
+
189
+ self.utt2seq[utt] = sequence
190
+
191
+
192
+ def get_metadata(self):
193
+ with open(self.metafile_path, "r", encoding="utf-8") as f:
194
+ metadata = json.load(f)
195
+
196
+ return metadata
197
+
198
+ def get_dataset_name(self):
199
+ return self.metadata[0]["Dataset"]
200
+
201
+ def __getitem__(self, index):
202
+ utt_info = self.metadata[index]
203
+
204
+ dataset = utt_info["Dataset"]
205
+ uid = utt_info["Uid"]
206
+ utt = "{}_{}".format(dataset, uid)
207
+
208
+ single_feature = dict()
209
+
210
+ if self.cfg.preprocess.use_spkid:
211
+ single_feature["spk_id"] = np.array(
212
+ [self.spk2id[self.utt2spk[utt]]], dtype=np.int32
213
+ )
214
+
215
+ if self.cfg.preprocess.use_mel:
216
+ mel = np.load(self.utt2mel_path[utt])
217
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
218
+ if self.cfg.preprocess.use_min_max_norm_mel:
219
+ # do mel norm
220
+ mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
221
+
222
+ if "target_len" not in single_feature.keys():
223
+ single_feature["target_len"] = mel.shape[1]
224
+ single_feature["mel"] = mel.T # [T, n_mels]
225
+
226
+ if self.cfg.preprocess.use_linear:
227
+ linear = np.load(self.utt2linear_path[utt])
228
+ if "target_len" not in single_feature.keys():
229
+ single_feature["target_len"] = linear.shape[1]
230
+ single_feature["linear"] = linear.T # [T, n_linear]
231
+
232
+ if self.cfg.preprocess.use_frame_pitch:
233
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
234
+ frame_pitch = np.load(frame_pitch_path)
235
+ if "target_len" not in single_feature.keys():
236
+ single_feature["target_len"] = len(frame_pitch)
237
+ aligned_frame_pitch = align_length(
238
+ frame_pitch, single_feature["target_len"]
239
+ )
240
+ single_feature["frame_pitch"] = aligned_frame_pitch
241
+
242
+ if self.cfg.preprocess.use_uv:
243
+ frame_uv_path = self.utt2uv_path[utt]
244
+ frame_uv = np.load(frame_uv_path)
245
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
246
+ aligned_frame_uv = [
247
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
248
+ ]
249
+ aligned_frame_uv = np.array(aligned_frame_uv)
250
+ single_feature["frame_uv"] = aligned_frame_uv
251
+
252
+ if self.cfg.preprocess.use_frame_energy:
253
+ frame_energy_path = self.utt2frame_energy_path[utt]
254
+ frame_energy = np.load(frame_energy_path)
255
+ if "target_len" not in single_feature.keys():
256
+ single_feature["target_len"] = len(frame_energy)
257
+ aligned_frame_energy = align_length(
258
+ frame_energy, single_feature["target_len"]
259
+ )
260
+ single_feature["frame_energy"] = aligned_frame_energy
261
+
262
+ if self.cfg.preprocess.use_audio:
263
+ audio = np.load(self.utt2audio_path[utt])
264
+ single_feature["audio"] = audio
265
+ single_feature["audio_len"] = audio.shape[0]
266
+
267
+ if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
268
+ single_feature["phone_seq"] = np.array(self.utt2seq[utt])
269
+ single_feature["phone_len"] = len(self.utt2seq[utt])
270
+
271
+ return single_feature
272
+
273
+ def __len__(self):
274
+ return len(self.metadata)
275
+
276
+
277
+ class BaseCollator(object):
278
+ """Zero-pads model inputs and targets based on number of frames per step"""
279
+
280
+ def __init__(self, cfg):
281
+ self.cfg = cfg
282
+
283
+ def __call__(self, batch):
284
+ packed_batch_features = dict()
285
+
286
+ # mel: [b, T, n_mels]
287
+ # frame_pitch, frame_energy: [1, T]
288
+ # target_len: [1]
289
+ # spk_id: [b, 1]
290
+ # mask: [b, T, 1]
291
+
292
+ for key in batch[0].keys():
293
+ if key == "target_len":
294
+ packed_batch_features["target_len"] = torch.LongTensor(
295
+ [b["target_len"] for b in batch]
296
+ )
297
+ masks = [
298
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
299
+ ]
300
+ packed_batch_features["mask"] = pad_sequence(
301
+ masks, batch_first=True, padding_value=0
302
+ )
303
+ elif key == "phone_len":
304
+ packed_batch_features["phone_len"] = torch.LongTensor(
305
+ [b["phone_len"] for b in batch]
306
+ )
307
+ masks = [
308
+ torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
309
+ ]
310
+ packed_batch_features["phn_mask"] = pad_sequence(
311
+ masks, batch_first=True, padding_value=0
312
+ )
313
+ elif key == "audio_len":
314
+ packed_batch_features["audio_len"] = torch.LongTensor(
315
+ [b["audio_len"] for b in batch]
316
+ )
317
+ masks = [
318
+ torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
319
+ ]
320
+ else:
321
+ values = [torch.from_numpy(b[key]) for b in batch]
322
+ packed_batch_features[key] = pad_sequence(
323
+ values, batch_first=True, padding_value=0
324
+ )
325
+ return packed_batch_features
326
+
327
+
328
+ class BaseTestDataset(torch.utils.data.Dataset):
329
+ def __init__(self, cfg, args):
330
+ raise NotImplementedError
331
+
332
+
333
+ def get_metadata(self):
334
+ raise NotImplementedError
335
+
336
+ def __getitem__(self, index):
337
+ raise NotImplementedError
338
+
339
+ def __len__(self):
340
+ return len(self.metadata)
341
+
342
+
343
+ class BaseTestCollator(object):
344
+ """Zero-pads model inputs and targets based on number of frames per step"""
345
+
346
+ def __init__(self, cfg):
347
+ raise NotImplementedError
348
+
349
+ def __call__(self, batch):
350
+ raise NotImplementedError
models/base/base_inference.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import argparse
7
+ import os
8
+ import re
9
+ import time
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from tqdm import tqdm
15
+
16
+ from models.vocoders.vocoder_inference import synthesis
17
+ from torch.utils.data import DataLoader
18
+ from utils.util import set_all_random_seed
19
+ from utils.util import load_config
20
+
21
+
22
+ def parse_vocoder(vocoder_dir):
23
+ r"""Parse vocoder config"""
24
+ vocoder_dir = os.path.abspath(vocoder_dir)
25
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
26
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
27
+ ckpt_path = str(ckpt_list[0])
28
+ vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
29
+ vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
30
+ return vocoder_cfg, ckpt_path
31
+
32
+
33
+ class BaseInference(object):
34
+ def __init__(self, cfg, args):
35
+ self.cfg = cfg
36
+ self.args = args
37
+ self.model_type = cfg.model_type
38
+ self.avg_rtf = list()
39
+ set_all_random_seed(10086)
40
+ os.makedirs(args.output_dir, exist_ok=True)
41
+
42
+ if torch.cuda.is_available():
43
+ self.device = torch.device("cuda")
44
+ else:
45
+ self.device = torch.device("cpu")
46
+ torch.set_num_threads(10) # inference on 1 core cpu.
47
+
48
+ # Load acoustic model
49
+ self.model = self.create_model().to(self.device)
50
+ state_dict = self.load_state_dict()
51
+ self.load_model(state_dict)
52
+ self.model.eval()
53
+
54
+ # Load vocoder model if necessary
55
+ if self.args.checkpoint_dir_vocoder is not None:
56
+ self.get_vocoder_info()
57
+
58
+ def create_model(self):
59
+ raise NotImplementedError
60
+
61
+ def load_state_dict(self):
62
+ self.checkpoint_file = self.args.checkpoint_file
63
+ if self.checkpoint_file is None:
64
+ assert self.args.checkpoint_dir is not None
65
+ checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
66
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
67
+ self.checkpoint_file = os.path.join(
68
+ self.args.checkpoint_dir, checkpoint_filename
69
+ )
70
+
71
+ self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
72
+
73
+ print("Restore acoustic model from {}".format(self.checkpoint_file))
74
+ raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
75
+ self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
76
+
77
+ return raw_state_dict
78
+
79
+ def load_model(self, model):
80
+ raise NotImplementedError
81
+
82
+ def get_vocoder_info(self):
83
+ self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
84
+ self.vocoder_cfg = os.path.join(
85
+ os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
86
+ )
87
+ self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
88
+ self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
89
+ self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
90
+
91
+ def build_test_utt_data(self):
92
+ raise NotImplementedError
93
+
94
+ def build_testdata_loader(self, args, target_speaker=None):
95
+ datasets, collate = self.build_test_dataset()
96
+ self.test_dataset = datasets(self.cfg, args, target_speaker)
97
+ self.test_collate = collate(self.cfg)
98
+ self.test_batch_size = min(
99
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
100
+ )
101
+ test_loader = DataLoader(
102
+ self.test_dataset,
103
+ collate_fn=self.test_collate,
104
+ num_workers=self.args.num_workers,
105
+ batch_size=self.test_batch_size,
106
+ shuffle=False,
107
+ )
108
+ return test_loader
109
+
110
+ def inference_each_batch(self, batch_data):
111
+ raise NotImplementedError
112
+
113
+ def inference_for_batches(self, args, target_speaker=None):
114
+ ###### Construct test_batch ######
115
+ loader = self.build_testdata_loader(args, target_speaker)
116
+
117
+ n_batch = len(loader)
118
+ now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
119
+ print(
120
+ "Model eval time: {}, batch_size = {}, n_batch = {}".format(
121
+ now, self.test_batch_size, n_batch
122
+ )
123
+ )
124
+ self.model.eval()
125
+
126
+ ###### Inference for each batch ######
127
+ pred_res = []
128
+ with torch.no_grad():
129
+ for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
130
+ # Put the data to device
131
+ for k, v in batch_data.items():
132
+ batch_data[k] = batch_data[k].to(self.device)
133
+
134
+ y_pred, stats = self.inference_each_batch(batch_data)
135
+
136
+ pred_res += y_pred
137
+
138
+ return pred_res
139
+
140
+ def inference(self, feature):
141
+ raise NotImplementedError
142
+
143
+ def synthesis_by_vocoder(self, pred):
144
+ audios_pred = synthesis(
145
+ self.vocoder_cfg,
146
+ self.checkpoint_dir_vocoder,
147
+ len(pred),
148
+ pred,
149
+ )
150
+ return audios_pred
151
+
152
+ def __call__(self, utt):
153
+ feature = self.build_test_utt_data(utt)
154
+ start_time = time.time()
155
+ with torch.no_grad():
156
+ outputs = self.inference(feature)[0]
157
+ time_used = time.time() - start_time
158
+ rtf = time_used / (
159
+ outputs.shape[1]
160
+ * self.cfg.preprocess.hop_size
161
+ / self.cfg.preprocess.sample_rate
162
+ )
163
+ print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
164
+ self.avg_rtf.append(rtf)
165
+ audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
166
+ return audios
167
+
168
+
169
+ def base_parser():
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument(
172
+ "--config", default="config.json", help="json files for configurations."
173
+ )
174
+ parser.add_argument("--use_ddp_inference", default=False)
175
+ parser.add_argument("--n_workers", default=1, type=int)
176
+ parser.add_argument("--local_rank", default=-1, type=int)
177
+ parser.add_argument(
178
+ "--batch_size", default=1, type=int, help="Batch size for inference"
179
+ )
180
+ parser.add_argument(
181
+ "--num_workers",
182
+ default=1,
183
+ type=int,
184
+ help="Worker number for inference dataloader",
185
+ )
186
+ parser.add_argument(
187
+ "--checkpoint_dir",
188
+ type=str,
189
+ default=None,
190
+ help="Checkpoint dir including model file and configuration",
191
+ )
192
+ parser.add_argument(
193
+ "--checkpoint_file", help="checkpoint file", type=str, default=None
194
+ )
195
+ parser.add_argument(
196
+ "--test_list", help="test utterance list for testing", type=str, default=None
197
+ )
198
+ parser.add_argument(
199
+ "--checkpoint_dir_vocoder",
200
+ help="Vocoder's checkpoint dir including model file and configuration",
201
+ type=str,
202
+ default=None,
203
+ )
204
+ parser.add_argument(
205
+ "--output_dir",
206
+ type=str,
207
+ default=None,
208
+ help="Output dir for saving generated results",
209
+ )
210
+ return parser
211
+
212
+
213
+ if __name__ == "__main__":
214
+ parser = base_parser()
215
+ args = parser.parse_args()
216
+ cfg = load_config(args.config)
217
+
218
+ # Build inference
219
+ inference = BaseInference(cfg, args)
220
+ inference()
models/base/base_sampler.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import random
8
+
9
+ from torch.utils.data import ConcatDataset, Dataset
10
+ from torch.utils.data.sampler import (
11
+ BatchSampler,
12
+ RandomSampler,
13
+ Sampler,
14
+ SequentialSampler,
15
+ )
16
+
17
+
18
+ class ScheduledSampler(Sampler):
19
+ """A sampler that samples data from a given concat-dataset.
20
+
21
+ Args:
22
+ concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
23
+ batch_size (int): batch size
24
+ holistic_shuffle (bool): whether to shuffle the whole dataset or not
25
+ logger (logging.Logger): logger to print warning message
26
+
27
+ Usage:
28
+ For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
29
+ >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
30
+ [3, 4, 5, 0, 1, 2, 6, 7, 8]
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ concat_dataset,
36
+ batch_size,
37
+ holistic_shuffle,
38
+ logger=None,
39
+ loader_type="train",
40
+ ):
41
+ if not isinstance(concat_dataset, ConcatDataset):
42
+ raise ValueError(
43
+ "concat_dataset must be an instance of ConcatDataset, but got {}".format(
44
+ type(concat_dataset)
45
+ )
46
+ )
47
+ if not isinstance(batch_size, int):
48
+ raise ValueError(
49
+ "batch_size must be an integer, but got {}".format(type(batch_size))
50
+ )
51
+ if not isinstance(holistic_shuffle, bool):
52
+ raise ValueError(
53
+ "holistic_shuffle must be a boolean, but got {}".format(
54
+ type(holistic_shuffle)
55
+ )
56
+ )
57
+
58
+ self.concat_dataset = concat_dataset
59
+ self.batch_size = batch_size
60
+ self.holistic_shuffle = holistic_shuffle
61
+
62
+ affected_dataset_name = []
63
+ affected_dataset_len = []
64
+ for dataset in concat_dataset.datasets:
65
+ dataset_len = len(dataset)
66
+ dataset_name = dataset.get_dataset_name()
67
+ if dataset_len < batch_size:
68
+ affected_dataset_name.append(dataset_name)
69
+ affected_dataset_len.append(dataset_len)
70
+
71
+ self.type = loader_type
72
+ for dataset_name, dataset_len in zip(
73
+ affected_dataset_name, affected_dataset_len
74
+ ):
75
+ if not loader_type == "valid":
76
+ logger.warning(
77
+ "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
78
+ loader_type, dataset_name, dataset_len, batch_size
79
+ )
80
+ )
81
+
82
+ def __len__(self):
83
+ # the number of batches with drop last
84
+ num_of_batches = sum(
85
+ [
86
+ math.floor(len(dataset) / self.batch_size)
87
+ for dataset in self.concat_dataset.datasets
88
+ ]
89
+ )
90
+ # if samples are not enough for one batch, we don't drop last
91
+ if self.type == "valid" and num_of_batches < 1:
92
+ return len(self.concat_dataset)
93
+ return num_of_batches * self.batch_size
94
+
95
+ def __iter__(self):
96
+ iters = []
97
+ for dataset in self.concat_dataset.datasets:
98
+ iters.append(
99
+ SequentialSampler(dataset).__iter__()
100
+ if not self.holistic_shuffle
101
+ else RandomSampler(dataset).__iter__()
102
+ )
103
+ # e.g. [0, 200, 400]
104
+ init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
105
+ output_batches = []
106
+ for dataset_idx in range(len(self.concat_dataset.datasets)):
107
+ cur_batch = []
108
+ for idx in iters[dataset_idx]:
109
+ cur_batch.append(idx + init_indices[dataset_idx])
110
+ if len(cur_batch) == self.batch_size:
111
+ output_batches.append(cur_batch)
112
+ cur_batch = []
113
+ # if loader_type is valid, we don't need to drop last
114
+ if self.type == "valid" and len(cur_batch) > 0:
115
+ output_batches.append(cur_batch)
116
+
117
+ # force drop last in training
118
+ random.shuffle(output_batches)
119
+ output_indices = [item for sublist in output_batches for item in sublist]
120
+ return iter(output_indices)
121
+
122
+
123
+ def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
124
+ sampler = ScheduledSampler(
125
+ concat_dataset,
126
+ cfg.train.batch_size,
127
+ cfg.train.sampler.holistic_shuffle,
128
+ logger,
129
+ loader_type,
130
+ )
131
+ batch_sampler = BatchSampler(
132
+ sampler,
133
+ cfg.train.batch_size,
134
+ cfg.train.sampler.drop_last if not loader_type == "valid" else False,
135
+ )
136
+ return sampler, batch_sampler
models/base/base_trainer.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import collections
7
+ import json
8
+ import os
9
+ import sys
10
+ import time
11
+
12
+ import torch
13
+ import torch.distributed as dist
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from torch.utils.data import ConcatDataset, DataLoader
16
+ from torch.utils.tensorboard import SummaryWriter
17
+
18
+ from models.base.base_sampler import BatchSampler
19
+ from utils.util import (
20
+ Logger,
21
+ remove_older_ckpt,
22
+ save_config,
23
+ set_all_random_seed,
24
+ ValueWindow,
25
+ )
26
+
27
+
28
+ class BaseTrainer(object):
29
+ def __init__(self, args, cfg):
30
+ self.args = args
31
+ self.log_dir = args.log_dir
32
+ self.cfg = cfg
33
+
34
+ self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
35
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
36
+ if not cfg.train.ddp or args.local_rank == 0:
37
+ self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
38
+ self.logger = self.build_logger()
39
+ self.time_window = ValueWindow(50)
40
+
41
+ self.step = 0
42
+ self.epoch = -1
43
+ self.max_epochs = self.cfg.train.epochs
44
+ self.max_steps = self.cfg.train.max_steps
45
+
46
+ # set random seed & init distributed training
47
+ set_all_random_seed(self.cfg.train.random_seed)
48
+ if cfg.train.ddp:
49
+ dist.init_process_group(backend="nccl")
50
+
51
+ if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
52
+ self.singers = self.build_singers_lut()
53
+
54
+ # setup data_loader
55
+ self.data_loader = self.build_data_loader()
56
+
57
+ # setup model & enable distributed training
58
+ self.model = self.build_model()
59
+ print(self.model)
60
+
61
+ if isinstance(self.model, dict):
62
+ for key, value in self.model.items():
63
+ value.cuda(self.args.local_rank)
64
+ if key == "PQMF":
65
+ continue
66
+ if cfg.train.ddp:
67
+ self.model[key] = DistributedDataParallel(
68
+ value, device_ids=[self.args.local_rank]
69
+ )
70
+ else:
71
+ self.model.cuda(self.args.local_rank)
72
+ if cfg.train.ddp:
73
+ self.model = DistributedDataParallel(
74
+ self.model, device_ids=[self.args.local_rank]
75
+ )
76
+
77
+ # create criterion
78
+ self.criterion = self.build_criterion()
79
+ if isinstance(self.criterion, dict):
80
+ for key, value in self.criterion.items():
81
+ self.criterion[key].cuda(args.local_rank)
82
+ else:
83
+ self.criterion.cuda(self.args.local_rank)
84
+
85
+ # optimizer
86
+ self.optimizer = self.build_optimizer()
87
+ self.scheduler = self.build_scheduler()
88
+
89
+ # save config file
90
+ self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
91
+
92
+ def build_logger(self):
93
+ log_file = os.path.join(self.checkpoint_dir, "train.log")
94
+ logger = Logger(log_file, level=self.args.log_level).logger
95
+
96
+ return logger
97
+
98
+ def build_dataset(self):
99
+ raise NotImplementedError
100
+
101
+ def build_data_loader(self):
102
+ Dataset, Collator = self.build_dataset()
103
+ # build dataset instance for each dataset and combine them by ConcatDataset
104
+ datasets_list = []
105
+ for dataset in self.cfg.dataset:
106
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
107
+ datasets_list.append(subdataset)
108
+ train_dataset = ConcatDataset(datasets_list)
109
+
110
+ train_collate = Collator(self.cfg)
111
+ # TODO: multi-GPU training
112
+ if self.cfg.train.ddp:
113
+ raise NotImplementedError("DDP is not supported yet.")
114
+
115
+ # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
116
+ batch_sampler = BatchSampler(
117
+ cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
118
+ )
119
+
120
+ # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
121
+ train_loader = DataLoader(
122
+ train_dataset,
123
+ collate_fn=train_collate,
124
+ num_workers=self.args.num_workers,
125
+ batch_sampler=batch_sampler,
126
+ pin_memory=False,
127
+ )
128
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
129
+ datasets_list = []
130
+ for dataset in self.cfg.dataset:
131
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
132
+ datasets_list.append(subdataset)
133
+ valid_dataset = ConcatDataset(datasets_list)
134
+ valid_collate = Collator(self.cfg)
135
+ batch_sampler = BatchSampler(
136
+ cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
137
+ )
138
+ valid_loader = DataLoader(
139
+ valid_dataset,
140
+ collate_fn=valid_collate,
141
+ num_workers=1,
142
+ batch_sampler=batch_sampler,
143
+ )
144
+ else:
145
+ raise NotImplementedError("DDP is not supported yet.")
146
+ # valid_loader = None
147
+ data_loader = {"train": train_loader, "valid": valid_loader}
148
+ return data_loader
149
+
150
+ def build_singers_lut(self):
151
+ # combine singers
152
+ if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
153
+ singers = collections.OrderedDict()
154
+ else:
155
+ with open(
156
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
157
+ ) as singer_file:
158
+ singers = json.load(singer_file)
159
+ singer_count = len(singers)
160
+ for dataset in self.cfg.dataset:
161
+ singer_lut_path = os.path.join(
162
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
163
+ )
164
+ with open(singer_lut_path, "r") as singer_lut_path:
165
+ singer_lut = json.load(singer_lut_path)
166
+ for singer in singer_lut.keys():
167
+ if singer not in singers:
168
+ singers[singer] = singer_count
169
+ singer_count += 1
170
+ with open(
171
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
172
+ ) as singer_file:
173
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
174
+ print(
175
+ "singers have been dumped to {}".format(
176
+ os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
177
+ )
178
+ )
179
+ return singers
180
+
181
+ def build_model(self):
182
+ raise NotImplementedError()
183
+
184
+ def build_optimizer(self):
185
+ raise NotImplementedError
186
+
187
+ def build_scheduler(self):
188
+ raise NotImplementedError()
189
+
190
+ def build_criterion(self):
191
+ raise NotImplementedError
192
+
193
+ def get_state_dict(self):
194
+ raise NotImplementedError
195
+
196
+ def save_config_file(self):
197
+ save_config(self.config_save_path, self.cfg)
198
+
199
+ # TODO, save without module.
200
+ def save_checkpoint(self, state_dict, saved_model_path):
201
+ torch.save(state_dict, saved_model_path)
202
+
203
+ def load_checkpoint(self):
204
+ checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
205
+ assert os.path.exists(checkpoint_path)
206
+ checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
207
+ model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
208
+ assert os.path.exists(model_path)
209
+ if not self.cfg.train.ddp or self.args.local_rank == 0:
210
+ self.logger.info(f"Re(store) from {model_path}")
211
+ checkpoint = torch.load(model_path, map_location="cpu")
212
+ return checkpoint
213
+
214
+ def load_model(self, checkpoint):
215
+ raise NotImplementedError
216
+
217
+ def restore(self):
218
+ checkpoint = self.load_checkpoint()
219
+ self.load_model(checkpoint)
220
+
221
+ def train_step(self, data):
222
+ raise NotImplementedError(
223
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
224
+ f"your sub-class of {self.__class__.__name__}. "
225
+ )
226
+
227
+ @torch.no_grad()
228
+ def eval_step(self):
229
+ raise NotImplementedError(
230
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
231
+ f"your sub-class of {self.__class__.__name__}. "
232
+ )
233
+
234
+ def write_summary(self, losses, stats):
235
+ raise NotImplementedError(
236
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
237
+ f"your sub-class of {self.__class__.__name__}. "
238
+ )
239
+
240
+ def write_valid_summary(self, losses, stats):
241
+ raise NotImplementedError(
242
+ f"Need to implement function {sys._getframe().f_code.co_name} in "
243
+ f"your sub-class of {self.__class__.__name__}. "
244
+ )
245
+
246
+ def echo_log(self, losses, mode="Training"):
247
+ message = [
248
+ "{} - Epoch {} Step {}: [{:.3f} s/step]".format(
249
+ mode, self.epoch + 1, self.step, self.time_window.average
250
+ )
251
+ ]
252
+
253
+ for key in sorted(losses.keys()):
254
+ if isinstance(losses[key], dict):
255
+ for k, v in losses[key].items():
256
+ message.append(
257
+ str(k).split("/")[-1] + "=" + str(round(float(v), 5))
258
+ )
259
+ else:
260
+ message.append(
261
+ str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
262
+ )
263
+ self.logger.info(", ".join(message))
264
+
265
+ def eval_epoch(self):
266
+ self.logger.info("Validation...")
267
+ valid_losses = {}
268
+ for i, batch_data in enumerate(self.data_loader["valid"]):
269
+ for k, v in batch_data.items():
270
+ if isinstance(v, torch.Tensor):
271
+ batch_data[k] = v.cuda()
272
+ valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
273
+ for key in valid_loss:
274
+ if key not in valid_losses:
275
+ valid_losses[key] = 0
276
+ valid_losses[key] += valid_loss[key]
277
+
278
+ # Add mel and audio to the Tensorboard
279
+ # Average loss
280
+ for key in valid_losses:
281
+ valid_losses[key] /= i + 1
282
+ self.echo_log(valid_losses, "Valid")
283
+ return valid_losses, valid_stats
284
+
285
+ def train_epoch(self):
286
+ for i, batch_data in enumerate(self.data_loader["train"]):
287
+ start_time = time.time()
288
+ # Put the data to cuda device
289
+ for k, v in batch_data.items():
290
+ if isinstance(v, torch.Tensor):
291
+ batch_data[k] = v.cuda(self.args.local_rank)
292
+
293
+ # Training step
294
+ train_losses, train_stats, total_loss = self.train_step(batch_data)
295
+ self.time_window.append(time.time() - start_time)
296
+
297
+ if self.args.local_rank == 0 or not self.cfg.train.ddp:
298
+ if self.step % self.args.stdout_interval == 0:
299
+ self.echo_log(train_losses, "Training")
300
+
301
+ if self.step % self.cfg.train.save_summary_steps == 0:
302
+ self.logger.info(f"Save summary as step {self.step}")
303
+ self.write_summary(train_losses, train_stats)
304
+
305
+ if (
306
+ self.step % self.cfg.train.save_checkpoints_steps == 0
307
+ and self.step != 0
308
+ ):
309
+ saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
310
+ self.step, total_loss
311
+ )
312
+ saved_model_path = os.path.join(
313
+ self.checkpoint_dir, saved_model_name
314
+ )
315
+ saved_state_dict = self.get_state_dict()
316
+ self.save_checkpoint(saved_state_dict, saved_model_path)
317
+ self.save_config_file()
318
+ # keep max n models
319
+ remove_older_ckpt(
320
+ saved_model_name,
321
+ self.checkpoint_dir,
322
+ max_to_keep=self.cfg.train.keep_checkpoint_max,
323
+ )
324
+
325
+ if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
326
+ if isinstance(self.model, dict):
327
+ for key in self.model.keys():
328
+ self.model[key].eval()
329
+ else:
330
+ self.model.eval()
331
+ # Evaluate one epoch and get average loss
332
+ valid_losses, valid_stats = self.eval_epoch()
333
+ if isinstance(self.model, dict):
334
+ for key in self.model.keys():
335
+ self.model[key].train()
336
+ else:
337
+ self.model.train()
338
+ # Write validation losses to summary.
339
+ self.write_valid_summary(valid_losses, valid_stats)
340
+ self.step += 1
341
+
342
+ def train(self):
343
+ for epoch in range(max(0, self.epoch), self.max_epochs):
344
+ self.train_epoch()
345
+ self.epoch += 1
346
+ if self.step > self.max_steps:
347
+ self.logger.info("Training finished!")
348
+ break
models/base/new_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ from abc import abstractmethod
9
+ from pathlib import Path
10
+
11
+ import json5
12
+ import torch
13
+ import yaml
14
+
15
+
16
+ # TODO: for training and validating
17
+ class BaseDataset(torch.utils.data.Dataset):
18
+ r"""Base dataset for training and validating."""
19
+
20
+ def __init__(self, args, cfg, is_valid=False):
21
+ pass
22
+
23
+
24
+ class BaseTestDataset(torch.utils.data.Dataset):
25
+ r"""Test dataset for inference."""
26
+
27
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
28
+ assert infer_type in ["from_dataset", "from_file"]
29
+
30
+ self.args = args
31
+ self.cfg = cfg
32
+ self.infer_type = infer_type
33
+
34
+ @abstractmethod
35
+ def __getitem__(self, index):
36
+ pass
37
+
38
+ def __len__(self):
39
+ return len(self.metadata)
40
+
41
+ def get_metadata(self):
42
+ path = Path(self.args.source)
43
+ if path.suffix == ".json" or path.suffix == ".jsonc":
44
+ metadata = json5.load(open(self.args.source, "r"))
45
+ elif path.suffix == ".yaml" or path.suffix == ".yml":
46
+ metadata = yaml.full_load(open(self.args.source, "r"))
47
+ else:
48
+ raise ValueError(f"Unsupported file type: {path.suffix}")
49
+
50
+ return metadata
models/base/new_inference.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import random
8
+ import re
9
+ import time
10
+ from abc import abstractmethod
11
+ from pathlib import Path
12
+
13
+ import accelerate
14
+ import json5
15
+ import numpy as np
16
+ import torch
17
+ from accelerate.logging import get_logger
18
+ from torch.utils.data import DataLoader
19
+
20
+ from models.vocoders.vocoder_inference import synthesis
21
+ from utils.io import save_audio
22
+ from utils.util import load_config
23
+ from utils.audio_slicer import is_silence
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class BaseInference(object):
29
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
30
+ super().__init__()
31
+
32
+ start = time.monotonic_ns()
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ assert infer_type in ["from_dataset", "from_file"]
37
+ self.infer_type = infer_type
38
+
39
+ # init with accelerate
40
+ self.accelerator = accelerate.Accelerator()
41
+ self.accelerator.wait_for_everyone()
42
+
43
+ # Use accelerate logger for distributed inference
44
+ with self.accelerator.main_process_first():
45
+ self.logger = get_logger("inference", log_level=args.log_level)
46
+
47
+ # Log some info
48
+ self.logger.info("=" * 56)
49
+ self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
50
+ self.logger.info("=" * 56)
51
+ self.logger.info("\n")
52
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
53
+
54
+ self.acoustics_dir = args.acoustics_dir
55
+ self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
56
+ self.vocoder_dir = args.vocoder_dir
57
+ self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
58
+ # should be in svc inferencer
59
+ # self.target_singer = args.target_singer
60
+ # self.logger.info(f"Target singers: {args.target_singer}")
61
+ # self.trans_key = args.trans_key
62
+ # self.logger.info(f"Trans key: {args.trans_key}")
63
+
64
+ os.makedirs(args.output_dir, exist_ok=True)
65
+
66
+ # set random seed
67
+ with self.accelerator.main_process_first():
68
+ start = time.monotonic_ns()
69
+ self._set_random_seed(self.cfg.train.random_seed)
70
+ end = time.monotonic_ns()
71
+ self.logger.debug(
72
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
73
+ )
74
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
75
+
76
+ # setup data_loader
77
+ with self.accelerator.main_process_first():
78
+ self.logger.info("Building dataset...")
79
+ start = time.monotonic_ns()
80
+ self.test_dataloader = self._build_dataloader()
81
+ end = time.monotonic_ns()
82
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
83
+
84
+ # setup model
85
+ with self.accelerator.main_process_first():
86
+ self.logger.info("Building model...")
87
+ start = time.monotonic_ns()
88
+ self.model = self._build_model()
89
+ end = time.monotonic_ns()
90
+ # self.logger.debug(self.model)
91
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
92
+
93
+ # init with accelerate
94
+ self.logger.info("Initializing accelerate...")
95
+ start = time.monotonic_ns()
96
+ self.accelerator = accelerate.Accelerator()
97
+ self.model = self.accelerator.prepare(self.model)
98
+ end = time.monotonic_ns()
99
+ self.accelerator.wait_for_everyone()
100
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
101
+
102
+ with self.accelerator.main_process_first():
103
+ self.logger.info("Loading checkpoint...")
104
+ start = time.monotonic_ns()
105
+ # TODO: Also, suppose only use latest one yet
106
+ self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
107
+ end = time.monotonic_ns()
108
+ self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
109
+
110
+ self.model.eval()
111
+ self.accelerator.wait_for_everyone()
112
+
113
+ ### Abstract methods ###
114
+ @abstractmethod
115
+ def _build_test_dataset(self):
116
+ pass
117
+
118
+ @abstractmethod
119
+ def _build_model(self):
120
+ pass
121
+
122
+ @abstractmethod
123
+ @torch.inference_mode()
124
+ def _inference_each_batch(self, batch_data):
125
+ pass
126
+
127
+ ### Abstract methods end ###
128
+
129
+ @torch.inference_mode()
130
+ def inference(self):
131
+ for i, batch in enumerate(self.test_dataloader):
132
+ y_pred = self._inference_each_batch(batch).cpu()
133
+ mel_min, mel_max = self.test_dataset.target_mel_extrema
134
+ y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
135
+ y_ls = y_pred.chunk(self.test_batch_size)
136
+ tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
137
+ j = 0
138
+ for it, l in zip(y_ls, tgt_ls):
139
+ l = l.item()
140
+ it = it.squeeze(0)[:l]
141
+ uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
142
+ torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
143
+ j += 1
144
+
145
+ vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
146
+
147
+ res = synthesis(
148
+ cfg=vocoder_cfg,
149
+ vocoder_weight_file=vocoder_ckpt,
150
+ n_samples=None,
151
+ pred=[
152
+ torch.load(
153
+ os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
154
+ ).numpy(force=True)
155
+ for i in self.test_dataset.metadata
156
+ ],
157
+ )
158
+
159
+ output_audio_files = []
160
+ for it, wav in zip(self.test_dataset.metadata, res):
161
+ uid = it["Uid"]
162
+ file = os.path.join(self.args.output_dir, f"{uid}.wav")
163
+ output_audio_files.append(file)
164
+
165
+ wav = wav.numpy(force=True)
166
+ save_audio(
167
+ file,
168
+ wav,
169
+ self.cfg.preprocess.sample_rate,
170
+ add_silence=False,
171
+ turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
172
+ )
173
+ os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
174
+
175
+ return sorted(output_audio_files)
176
+
177
+ # TODO: LEGACY CODE
178
+ def _build_dataloader(self):
179
+ datasets, collate = self._build_test_dataset()
180
+ self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
181
+ self.test_collate = collate(self.cfg)
182
+ self.test_batch_size = min(
183
+ self.cfg.train.batch_size, len(self.test_dataset.metadata)
184
+ )
185
+ test_dataloader = DataLoader(
186
+ self.test_dataset,
187
+ collate_fn=self.test_collate,
188
+ num_workers=1,
189
+ batch_size=self.test_batch_size,
190
+ shuffle=False,
191
+ )
192
+ return test_dataloader
193
+
194
+ def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
195
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
196
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
197
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
198
+ method after** ``accelerator.prepare()``.
199
+ """
200
+ if checkpoint_path is None:
201
+ ls = []
202
+ for i in Path(checkpoint_dir).iterdir():
203
+ if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
204
+ ls.append(i)
205
+ ls.sort(
206
+ key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
207
+ )
208
+ checkpoint_path = ls[0]
209
+ else:
210
+ checkpoint_path = Path(checkpoint_path)
211
+ self.accelerator.load_state(str(checkpoint_path))
212
+ # set epoch and step
213
+ self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
214
+ self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
215
+ return str(checkpoint_path)
216
+
217
+ @staticmethod
218
+ def _set_random_seed(seed):
219
+ r"""Set random seed for all possible random modules."""
220
+ random.seed(seed)
221
+ np.random.seed(seed)
222
+ torch.random.manual_seed(seed)
223
+
224
+ @staticmethod
225
+ def _parse_vocoder(vocoder_dir):
226
+ r"""Parse vocoder config"""
227
+ vocoder_dir = os.path.abspath(vocoder_dir)
228
+ ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
229
+ ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
230
+ ckpt_path = str(ckpt_list[0])
231
+ vocoder_cfg = load_config(
232
+ os.path.join(vocoder_dir, "args.json"), lowercase=True
233
+ )
234
+ return vocoder_cfg, ckpt_path
235
+
236
+ @staticmethod
237
+ def __count_parameters(model):
238
+ return sum(p.numel() for p in model.parameters())
239
+
240
+ def __dump_cfg(self, path):
241
+ os.makedirs(os.path.dirname(path), exist_ok=True)
242
+ json5.dump(
243
+ self.cfg,
244
+ open(path, "w"),
245
+ indent=4,
246
+ sort_keys=True,
247
+ ensure_ascii=False,
248
+ quote_keys=True,
249
+ )
models/base/new_trainer.py ADDED
@@ -0,0 +1,722 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+ import random
9
+ import shutil
10
+ import time
11
+ from abc import abstractmethod
12
+ from pathlib import Path
13
+
14
+ import accelerate
15
+ import json5
16
+ import numpy as np
17
+ import torch
18
+ from accelerate.logging import get_logger
19
+ from accelerate.utils import ProjectConfiguration
20
+ from torch.utils.data import ConcatDataset, DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from models.base.base_sampler import build_samplers
24
+ from optimizer.optimizers import NoamLR
25
+
26
+
27
+ class BaseTrainer(object):
28
+ r"""The base trainer for all tasks. Any trainer should inherit from this class."""
29
+
30
+ def __init__(self, args=None, cfg=None):
31
+ super().__init__()
32
+
33
+ self.args = args
34
+ self.cfg = cfg
35
+
36
+ cfg.exp_name = args.exp_name
37
+
38
+ # init with accelerate
39
+ self._init_accelerator()
40
+ self.accelerator.wait_for_everyone()
41
+
42
+ # Use accelerate logger for distributed training
43
+ with self.accelerator.main_process_first():
44
+ self.logger = get_logger(args.exp_name, log_level=args.log_level)
45
+
46
+ # Log some info
47
+ self.logger.info("=" * 56)
48
+ self.logger.info("||\t\t" + "New training process started." + "\t\t||")
49
+ self.logger.info("=" * 56)
50
+ self.logger.info("\n")
51
+ self.logger.debug(f"Using {args.log_level.upper()} logging level.")
52
+ self.logger.info(f"Experiment name: {args.exp_name}")
53
+ self.logger.info(f"Experiment directory: {self.exp_dir}")
54
+ self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
55
+ if self.accelerator.is_main_process:
56
+ os.makedirs(self.checkpoint_dir, exist_ok=True)
57
+ self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
58
+
59
+ # init counts
60
+ self.batch_count: int = 0
61
+ self.step: int = 0
62
+ self.epoch: int = 0
63
+ self.max_epoch = (
64
+ self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
65
+ )
66
+ self.logger.info(
67
+ "Max epoch: {}".format(
68
+ self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
69
+ )
70
+ )
71
+
72
+ # Check values
73
+ if self.accelerator.is_main_process:
74
+ self.__check_basic_configs()
75
+ # Set runtime configs
76
+ self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
77
+ self.checkpoints_path = [
78
+ [] for _ in range(len(self.save_checkpoint_stride))
79
+ ]
80
+ self.keep_last = [
81
+ i if i > 0 else float("inf") for i in self.cfg.train.keep_last
82
+ ]
83
+ self.run_eval = self.cfg.train.run_eval
84
+
85
+ # set random seed
86
+ with self.accelerator.main_process_first():
87
+ start = time.monotonic_ns()
88
+ self._set_random_seed(self.cfg.train.random_seed)
89
+ end = time.monotonic_ns()
90
+ self.logger.debug(
91
+ f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
92
+ )
93
+ self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
94
+
95
+ # setup data_loader
96
+ with self.accelerator.main_process_first():
97
+ self.logger.info("Building dataset...")
98
+ start = time.monotonic_ns()
99
+ self.train_dataloader, self.valid_dataloader = self._build_dataloader()
100
+ end = time.monotonic_ns()
101
+ self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
102
+
103
+ # setup model
104
+ with self.accelerator.main_process_first():
105
+ self.logger.info("Building model...")
106
+ start = time.monotonic_ns()
107
+ self.model = self._build_model()
108
+ end = time.monotonic_ns()
109
+ self.logger.debug(self.model)
110
+ self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
111
+ self.logger.info(
112
+ f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
113
+ )
114
+ # optimizer & scheduler
115
+ with self.accelerator.main_process_first():
116
+ self.logger.info("Building optimizer and scheduler...")
117
+ start = time.monotonic_ns()
118
+ self.optimizer = self.__build_optimizer()
119
+ self.scheduler = self.__build_scheduler()
120
+ end = time.monotonic_ns()
121
+ self.logger.info(
122
+ f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
123
+ )
124
+
125
+ # accelerate prepare
126
+ self.logger.info("Initializing accelerate...")
127
+ start = time.monotonic_ns()
128
+ (
129
+ self.train_dataloader,
130
+ self.valid_dataloader,
131
+ self.model,
132
+ self.optimizer,
133
+ self.scheduler,
134
+ ) = self.accelerator.prepare(
135
+ self.train_dataloader,
136
+ self.valid_dataloader,
137
+ self.model,
138
+ self.optimizer,
139
+ self.scheduler,
140
+ )
141
+ end = time.monotonic_ns()
142
+ self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
143
+
144
+ # create criterion
145
+ with self.accelerator.main_process_first():
146
+ self.logger.info("Building criterion...")
147
+ start = time.monotonic_ns()
148
+ self.criterion = self._build_criterion()
149
+ end = time.monotonic_ns()
150
+ self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
151
+
152
+ # Resume or Finetune
153
+ with self.accelerator.main_process_first():
154
+ if args.resume:
155
+ ## Automatically resume according to the current exprimental name
156
+ self.logger.info("Resuming from {}...".format(self.checkpoint_dir))
157
+ start = time.monotonic_ns()
158
+ ckpt_path = self.__load_model(
159
+ checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
160
+ )
161
+ end = time.monotonic_ns()
162
+ self.logger.info(
163
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
164
+ )
165
+ self.checkpoints_path = json.load(
166
+ open(os.path.join(ckpt_path, "ckpts.json"), "r")
167
+ )
168
+ elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "":
169
+ ## Resume from the given checkpoint path
170
+ if not os.path.exists(args.resume_from_ckpt_path):
171
+ raise ValueError(
172
+ "[Error] The resumed checkpoint path {} don't exist.".format(
173
+ args.resume_from_ckpt_path
174
+ )
175
+ )
176
+
177
+ self.logger.info(
178
+ "Resuming from {}...".format(args.resume_from_ckpt_path)
179
+ )
180
+ start = time.monotonic_ns()
181
+ ckpt_path = self.__load_model(
182
+ checkpoint_path=args.resume_from_ckpt_path,
183
+ resume_type=args.resume_type,
184
+ )
185
+ end = time.monotonic_ns()
186
+ self.logger.info(
187
+ f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
188
+ )
189
+
190
+ # save config file path
191
+ self.config_save_path = os.path.join(self.exp_dir, "args.json")
192
+
193
+ ### Following are abstract methods that should be implemented in child classes ###
194
+ @abstractmethod
195
+ def _build_dataset(self):
196
+ r"""Build dataset for model training/validating/evaluating."""
197
+ pass
198
+
199
+ @staticmethod
200
+ @abstractmethod
201
+ def _build_criterion():
202
+ r"""Build criterion function for model loss calculation."""
203
+ pass
204
+
205
+ @abstractmethod
206
+ def _build_model(self):
207
+ r"""Build model for training/validating/evaluating."""
208
+ pass
209
+
210
+ @abstractmethod
211
+ def _forward_step(self, batch):
212
+ r"""One forward step of the neural network. This abstract method is trying to
213
+ unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
214
+ However, for special case that using different forward step pattern for
215
+ training and validating, you could just override this method with ``pass`` and
216
+ implement ``_train_step`` and ``_valid_step`` separately.
217
+ """
218
+ pass
219
+
220
+ @abstractmethod
221
+ def _save_auxiliary_states(self):
222
+ r"""To save some auxiliary states when saving model's ckpt"""
223
+ pass
224
+
225
+ ### Abstract methods end ###
226
+
227
+ ### THIS IS MAIN ENTRY ###
228
+ def train_loop(self):
229
+ r"""Training loop. The public entry of training process."""
230
+ # Wait everyone to prepare before we move on
231
+ self.accelerator.wait_for_everyone()
232
+ # dump config file
233
+ if self.accelerator.is_main_process:
234
+ self.__dump_cfg(self.config_save_path)
235
+ self.model.train()
236
+ self.optimizer.zero_grad()
237
+ # Wait to ensure good to go
238
+ self.accelerator.wait_for_everyone()
239
+ while self.epoch < self.max_epoch:
240
+ self.logger.info("\n")
241
+ self.logger.info("-" * 32)
242
+ self.logger.info("Epoch {}: ".format(self.epoch))
243
+
244
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
245
+ ### It's inconvenient for the model with multiple losses
246
+ # Do training & validating epoch
247
+ train_loss = self._train_epoch()
248
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
249
+ valid_loss = self._valid_epoch()
250
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
251
+ self.accelerator.log(
252
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
253
+ step=self.epoch,
254
+ )
255
+
256
+ self.accelerator.wait_for_everyone()
257
+ # TODO: what is scheduler?
258
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
259
+
260
+ # Check if hit save_checkpoint_stride and run_eval
261
+ run_eval = False
262
+ if self.accelerator.is_main_process:
263
+ save_checkpoint = False
264
+ hit_dix = []
265
+ for i, num in enumerate(self.save_checkpoint_stride):
266
+ if self.epoch % num == 0:
267
+ save_checkpoint = True
268
+ hit_dix.append(i)
269
+ run_eval |= self.run_eval[i]
270
+
271
+ self.accelerator.wait_for_everyone()
272
+ if self.accelerator.is_main_process and save_checkpoint:
273
+ path = os.path.join(
274
+ self.checkpoint_dir,
275
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
276
+ self.epoch, self.step, train_loss
277
+ ),
278
+ )
279
+ self.tmp_checkpoint_save_path = path
280
+ self.accelerator.save_state(path)
281
+ print(f"save checkpoint in {path}")
282
+ json.dump(
283
+ self.checkpoints_path,
284
+ open(os.path.join(path, "ckpts.json"), "w"),
285
+ ensure_ascii=False,
286
+ indent=4,
287
+ )
288
+ self._save_auxiliary_states()
289
+
290
+ # Remove old checkpoints
291
+ to_remove = []
292
+ for idx in hit_dix:
293
+ self.checkpoints_path[idx].append(path)
294
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
295
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
296
+
297
+ # Search conflicts
298
+ total = set()
299
+ for i in self.checkpoints_path:
300
+ total |= set(i)
301
+ do_remove = set()
302
+ for idx, path in to_remove[::-1]:
303
+ if path in total:
304
+ self.checkpoints_path[idx].insert(0, path)
305
+ else:
306
+ do_remove.add(path)
307
+
308
+ # Remove old checkpoints
309
+ for path in do_remove:
310
+ shutil.rmtree(path, ignore_errors=True)
311
+ self.logger.debug(f"Remove old checkpoint: {path}")
312
+
313
+ self.accelerator.wait_for_everyone()
314
+ if run_eval:
315
+ # TODO: run evaluation
316
+ pass
317
+
318
+ # Update info for each epoch
319
+ self.epoch += 1
320
+
321
+ # Finish training and save final checkpoint
322
+ self.accelerator.wait_for_everyone()
323
+ if self.accelerator.is_main_process:
324
+ self.accelerator.save_state(
325
+ os.path.join(
326
+ self.checkpoint_dir,
327
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
328
+ self.epoch, self.step, valid_loss
329
+ ),
330
+ )
331
+ )
332
+ self._save_auxiliary_states()
333
+
334
+ self.accelerator.end_training()
335
+
336
+ ### Following are methods that can be used directly in child classes ###
337
+ def _train_epoch(self):
338
+ r"""Training epoch. Should return average loss of a batch (sample) over
339
+ one epoch. See ``train_loop`` for usage.
340
+ """
341
+ self.model.train()
342
+ epoch_sum_loss: float = 0.0
343
+ epoch_step: int = 0
344
+ for batch in tqdm(
345
+ self.train_dataloader,
346
+ desc=f"Training Epoch {self.epoch}",
347
+ unit="batch",
348
+ colour="GREEN",
349
+ leave=False,
350
+ dynamic_ncols=True,
351
+ smoothing=0.04,
352
+ disable=not self.accelerator.is_main_process,
353
+ ):
354
+ # Do training step and BP
355
+ with self.accelerator.accumulate(self.model):
356
+ loss = self._train_step(batch)
357
+ self.accelerator.backward(loss)
358
+ self.optimizer.step()
359
+ self.optimizer.zero_grad()
360
+ self.batch_count += 1
361
+
362
+ # Update info for each step
363
+ # TODO: step means BP counts or batch counts?
364
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
365
+ epoch_sum_loss += loss
366
+ self.accelerator.log(
367
+ {
368
+ "Step/Train Loss": loss,
369
+ "Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
370
+ },
371
+ step=self.step,
372
+ )
373
+ self.step += 1
374
+ epoch_step += 1
375
+
376
+ self.accelerator.wait_for_everyone()
377
+ return (
378
+ epoch_sum_loss
379
+ / len(self.train_dataloader)
380
+ * self.cfg.train.gradient_accumulation_step
381
+ )
382
+
383
+ @torch.inference_mode()
384
+ def _valid_epoch(self):
385
+ r"""Testing epoch. Should return average loss of a batch (sample) over
386
+ one epoch. See ``train_loop`` for usage.
387
+ """
388
+ self.model.eval()
389
+ epoch_sum_loss = 0.0
390
+ for batch in tqdm(
391
+ self.valid_dataloader,
392
+ desc=f"Validating Epoch {self.epoch}",
393
+ unit="batch",
394
+ colour="GREEN",
395
+ leave=False,
396
+ dynamic_ncols=True,
397
+ smoothing=0.04,
398
+ disable=not self.accelerator.is_main_process,
399
+ ):
400
+ batch_loss = self._valid_step(batch)
401
+ epoch_sum_loss += batch_loss.item()
402
+
403
+ self.accelerator.wait_for_everyone()
404
+ return epoch_sum_loss / len(self.valid_dataloader)
405
+
406
+ def _train_step(self, batch):
407
+ r"""Training forward step. Should return average loss of a sample over
408
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
409
+ See ``_train_epoch`` for usage.
410
+ """
411
+ return self._forward_step(batch)
412
+
413
+ @torch.inference_mode()
414
+ def _valid_step(self, batch):
415
+ r"""Testing forward step. Should return average loss of a sample over
416
+ one batch. Provoke ``_forward_step`` is recommended except for special case.
417
+ See ``_test_epoch`` for usage.
418
+ """
419
+ return self._forward_step(batch)
420
+
421
+ def __load_model(
422
+ self,
423
+ checkpoint_dir: str = None,
424
+ checkpoint_path: str = None,
425
+ resume_type: str = "",
426
+ ):
427
+ r"""Load model from checkpoint. If checkpoint_path is None, it will
428
+ load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
429
+ None, it will load the checkpoint specified by checkpoint_path. **Only use this
430
+ method after** ``accelerator.prepare()``.
431
+ """
432
+ if checkpoint_path is None:
433
+ ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
434
+ ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
435
+ checkpoint_path = ls[0]
436
+ self.logger.info("Resume from {}...".format(checkpoint_path))
437
+
438
+ if resume_type in ["resume", ""]:
439
+ # Load all the things, including model weights, optimizer, scheduler, and random states.
440
+ self.accelerator.load_state(input_dir=checkpoint_path)
441
+
442
+ # set epoch and step
443
+ self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
444
+ self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
445
+
446
+ elif resume_type == "finetune":
447
+ # Load only the model weights
448
+ accelerate.load_checkpoint_and_dispatch(
449
+ self.accelerator.unwrap_model(self.model),
450
+ os.path.join(checkpoint_path, "pytorch_model.bin"),
451
+ )
452
+ self.logger.info("Load model weights for finetune...")
453
+
454
+ else:
455
+ raise ValueError("Resume_type must be `resume` or `finetune`.")
456
+
457
+ return checkpoint_path
458
+
459
+ # TODO: LEGACY CODE
460
+ def _build_dataloader(self):
461
+ Dataset, Collator = self._build_dataset()
462
+
463
+ # build dataset instance for each dataset and combine them by ConcatDataset
464
+ datasets_list = []
465
+ for dataset in self.cfg.dataset:
466
+ subdataset = Dataset(self.cfg, dataset, is_valid=False)
467
+ datasets_list.append(subdataset)
468
+ train_dataset = ConcatDataset(datasets_list)
469
+ train_collate = Collator(self.cfg)
470
+ _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
471
+ self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
472
+ self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
473
+ # TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
474
+ train_loader = DataLoader(
475
+ train_dataset,
476
+ collate_fn=train_collate,
477
+ batch_sampler=batch_sampler,
478
+ num_workers=self.cfg.train.dataloader.num_worker,
479
+ pin_memory=self.cfg.train.dataloader.pin_memory,
480
+ )
481
+
482
+ # Build valid dataloader
483
+ datasets_list = []
484
+ for dataset in self.cfg.dataset:
485
+ subdataset = Dataset(self.cfg, dataset, is_valid=True)
486
+ datasets_list.append(subdataset)
487
+ valid_dataset = ConcatDataset(datasets_list)
488
+ valid_collate = Collator(self.cfg)
489
+ _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
490
+ self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
491
+ self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
492
+ valid_loader = DataLoader(
493
+ valid_dataset,
494
+ collate_fn=valid_collate,
495
+ batch_sampler=batch_sampler,
496
+ num_workers=self.cfg.train.dataloader.num_worker,
497
+ pin_memory=self.cfg.train.dataloader.pin_memory,
498
+ )
499
+ return train_loader, valid_loader
500
+
501
+ @staticmethod
502
+ def _set_random_seed(seed):
503
+ r"""Set random seed for all possible random modules."""
504
+ random.seed(seed)
505
+ np.random.seed(seed)
506
+ torch.random.manual_seed(seed)
507
+
508
+ def _check_nan(self, loss, y_pred, y_gt):
509
+ if torch.any(torch.isnan(loss)):
510
+ self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
511
+ self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
512
+ if torch.any(torch.isnan(y_pred)):
513
+ self.logger.error(
514
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
515
+ )
516
+ else:
517
+ self.logger.debug(
518
+ f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
519
+ )
520
+ if torch.any(torch.isnan(y_gt)):
521
+ self.logger.error(
522
+ f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
523
+ )
524
+ else:
525
+ self.logger.debug(
526
+ f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
527
+ )
528
+ if torch.any(torch.isnan(y_pred)):
529
+ self.logger.error(f"y_pred: {y_pred}", in_order=True)
530
+ else:
531
+ self.logger.debug(f"y_pred: {y_pred}", in_order=True)
532
+ if torch.any(torch.isnan(y_gt)):
533
+ self.logger.error(f"y_gt: {y_gt}", in_order=True)
534
+ else:
535
+ self.logger.debug(f"y_gt: {y_gt}", in_order=True)
536
+
537
+ # TODO: still OK to save tracking?
538
+ self.accelerator.end_training()
539
+ raise RuntimeError("Loss has Nan! See log for more info.")
540
+
541
+ ### Protected methods end ###
542
+
543
+ ## Following are private methods ##
544
+ ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
545
+ def __build_optimizer(self):
546
+ r"""Build optimizer for model."""
547
+ # Make case-insensitive matching
548
+ if self.cfg.train.optimizer.lower() == "adadelta":
549
+ optimizer = torch.optim.Adadelta(
550
+ self.model.parameters(), **self.cfg.train.adadelta
551
+ )
552
+ self.logger.info("Using Adadelta optimizer.")
553
+ elif self.cfg.train.optimizer.lower() == "adagrad":
554
+ optimizer = torch.optim.Adagrad(
555
+ self.model.parameters(), **self.cfg.train.adagrad
556
+ )
557
+ self.logger.info("Using Adagrad optimizer.")
558
+ elif self.cfg.train.optimizer.lower() == "adam":
559
+ optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
560
+ self.logger.info("Using Adam optimizer.")
561
+ elif self.cfg.train.optimizer.lower() == "adamw":
562
+ optimizer = torch.optim.AdamW(
563
+ self.model.parameters(), **self.cfg.train.adamw
564
+ )
565
+ elif self.cfg.train.optimizer.lower() == "sparseadam":
566
+ optimizer = torch.optim.SparseAdam(
567
+ self.model.parameters(), **self.cfg.train.sparseadam
568
+ )
569
+ elif self.cfg.train.optimizer.lower() == "adamax":
570
+ optimizer = torch.optim.Adamax(
571
+ self.model.parameters(), **self.cfg.train.adamax
572
+ )
573
+ elif self.cfg.train.optimizer.lower() == "asgd":
574
+ optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
575
+ elif self.cfg.train.optimizer.lower() == "lbfgs":
576
+ optimizer = torch.optim.LBFGS(
577
+ self.model.parameters(), **self.cfg.train.lbfgs
578
+ )
579
+ elif self.cfg.train.optimizer.lower() == "nadam":
580
+ optimizer = torch.optim.NAdam(
581
+ self.model.parameters(), **self.cfg.train.nadam
582
+ )
583
+ elif self.cfg.train.optimizer.lower() == "radam":
584
+ optimizer = torch.optim.RAdam(
585
+ self.model.parameters(), **self.cfg.train.radam
586
+ )
587
+ elif self.cfg.train.optimizer.lower() == "rmsprop":
588
+ optimizer = torch.optim.RMSprop(
589
+ self.model.parameters(), **self.cfg.train.rmsprop
590
+ )
591
+ elif self.cfg.train.optimizer.lower() == "rprop":
592
+ optimizer = torch.optim.Rprop(
593
+ self.model.parameters(), **self.cfg.train.rprop
594
+ )
595
+ elif self.cfg.train.optimizer.lower() == "sgd":
596
+ optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
597
+ else:
598
+ raise NotImplementedError(
599
+ f"Optimizer {self.cfg.train.optimizer} not supported yet!"
600
+ )
601
+ return optimizer
602
+
603
+ def __build_scheduler(self):
604
+ r"""Build scheduler for optimizer."""
605
+ # Make case-insensitive matching
606
+ if self.cfg.train.scheduler.lower() == "lambdalr":
607
+ scheduler = torch.optim.lr_scheduler.LambdaLR(
608
+ self.optimizer, **self.cfg.train.lambdalr
609
+ )
610
+ elif self.cfg.train.scheduler.lower() == "multiplicativelr":
611
+ scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
612
+ self.optimizer, **self.cfg.train.multiplicativelr
613
+ )
614
+ elif self.cfg.train.scheduler.lower() == "steplr":
615
+ scheduler = torch.optim.lr_scheduler.StepLR(
616
+ self.optimizer, **self.cfg.train.steplr
617
+ )
618
+ elif self.cfg.train.scheduler.lower() == "multisteplr":
619
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(
620
+ self.optimizer, **self.cfg.train.multisteplr
621
+ )
622
+ elif self.cfg.train.scheduler.lower() == "constantlr":
623
+ scheduler = torch.optim.lr_scheduler.ConstantLR(
624
+ self.optimizer, **self.cfg.train.constantlr
625
+ )
626
+ elif self.cfg.train.scheduler.lower() == "linearlr":
627
+ scheduler = torch.optim.lr_scheduler.LinearLR(
628
+ self.optimizer, **self.cfg.train.linearlr
629
+ )
630
+ elif self.cfg.train.scheduler.lower() == "exponentiallr":
631
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
632
+ self.optimizer, **self.cfg.train.exponentiallr
633
+ )
634
+ elif self.cfg.train.scheduler.lower() == "polynomiallr":
635
+ scheduler = torch.optim.lr_scheduler.PolynomialLR(
636
+ self.optimizer, **self.cfg.train.polynomiallr
637
+ )
638
+ elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
639
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
640
+ self.optimizer, **self.cfg.train.cosineannealinglr
641
+ )
642
+ elif self.cfg.train.scheduler.lower() == "sequentiallr":
643
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
644
+ self.optimizer, **self.cfg.train.sequentiallr
645
+ )
646
+ elif self.cfg.train.scheduler.lower() == "reducelronplateau":
647
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
648
+ self.optimizer, **self.cfg.train.reducelronplateau
649
+ )
650
+ elif self.cfg.train.scheduler.lower() == "cycliclr":
651
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
652
+ self.optimizer, **self.cfg.train.cycliclr
653
+ )
654
+ elif self.cfg.train.scheduler.lower() == "onecyclelr":
655
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
656
+ self.optimizer, **self.cfg.train.onecyclelr
657
+ )
658
+ elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
659
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
660
+ self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
661
+ )
662
+ elif self.cfg.train.scheduler.lower() == "noamlr":
663
+ scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
664
+ else:
665
+ raise NotImplementedError(
666
+ f"Scheduler {self.cfg.train.scheduler} not supported yet!"
667
+ )
668
+ return scheduler
669
+
670
+ def _init_accelerator(self):
671
+ self.exp_dir = os.path.join(
672
+ os.path.abspath(self.cfg.log_dir), self.args.exp_name
673
+ )
674
+ project_config = ProjectConfiguration(
675
+ project_dir=self.exp_dir,
676
+ logging_dir=os.path.join(self.exp_dir, "log"),
677
+ )
678
+ self.accelerator = accelerate.Accelerator(
679
+ gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
680
+ log_with=self.cfg.train.tracker,
681
+ project_config=project_config,
682
+ )
683
+ if self.accelerator.is_main_process:
684
+ os.makedirs(project_config.project_dir, exist_ok=True)
685
+ os.makedirs(project_config.logging_dir, exist_ok=True)
686
+ with self.accelerator.main_process_first():
687
+ self.accelerator.init_trackers(self.args.exp_name)
688
+
689
+ def __check_basic_configs(self):
690
+ if self.cfg.train.gradient_accumulation_step <= 0:
691
+ self.logger.fatal("Invalid gradient_accumulation_step value!")
692
+ self.logger.error(
693
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
694
+ )
695
+ self.accelerator.end_training()
696
+ raise ValueError(
697
+ f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
698
+ )
699
+ # TODO: check other values
700
+
701
+ @staticmethod
702
+ def __count_parameters(model):
703
+ model_param = 0.0
704
+ if isinstance(model, dict):
705
+ for key, value in model.items():
706
+ model_param += sum(p.numel() for p in model[key].parameters())
707
+ else:
708
+ model_param = sum(p.numel() for p in model.parameters())
709
+ return model_param
710
+
711
+ def __dump_cfg(self, path):
712
+ os.makedirs(os.path.dirname(path), exist_ok=True)
713
+ json5.dump(
714
+ self.cfg,
715
+ open(path, "w"),
716
+ indent=4,
717
+ sort_keys=True,
718
+ ensure_ascii=False,
719
+ quote_keys=True,
720
+ )
721
+
722
+ ### Private methods end ###
models/svc/__init__.py ADDED
File without changes
models/svc/base/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from .svc_inference import SVCInference
7
+ from .svc_trainer import SVCTrainer
models/svc/base/svc_dataset.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import random
7
+ import torch
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ import json
10
+ import os
11
+ import numpy as np
12
+ from utils.data_utils import *
13
+ from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
14
+ from processors.content_extractor import (
15
+ ContentvecExtractor,
16
+ WhisperExtractor,
17
+ WenetExtractor,
18
+ )
19
+ from models.base.base_dataset import (
20
+ BaseCollator,
21
+ BaseDataset,
22
+ )
23
+ from models.base.new_dataset import BaseTestDataset
24
+
25
+ EPS = 1.0e-12
26
+
27
+
28
+ class SVCDataset(BaseDataset):
29
+ def __init__(self, cfg, dataset, is_valid=False):
30
+ BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
31
+
32
+ cfg = self.cfg
33
+
34
+ if cfg.model.condition_encoder.use_whisper:
35
+ self.whisper_aligner = WhisperExtractor(self.cfg)
36
+ self.utt2whisper_path = load_content_feature_path(
37
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
38
+ )
39
+
40
+ if cfg.model.condition_encoder.use_contentvec:
41
+ self.contentvec_aligner = ContentvecExtractor(self.cfg)
42
+ self.utt2contentVec_path = load_content_feature_path(
43
+ self.metadata,
44
+ cfg.preprocess.processed_dir,
45
+ cfg.preprocess.contentvec_dir,
46
+ )
47
+
48
+ if cfg.model.condition_encoder.use_mert:
49
+ self.utt2mert_path = load_content_feature_path(
50
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
51
+ )
52
+ if cfg.model.condition_encoder.use_wenet:
53
+ self.wenet_aligner = WenetExtractor(self.cfg)
54
+ self.utt2wenet_path = load_content_feature_path(
55
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
56
+ )
57
+
58
+ def __getitem__(self, index):
59
+ single_feature = BaseDataset.__getitem__(self, index)
60
+
61
+ utt_info = self.metadata[index]
62
+ dataset = utt_info["Dataset"]
63
+ uid = utt_info["Uid"]
64
+ utt = "{}_{}".format(dataset, uid)
65
+
66
+ if self.cfg.model.condition_encoder.use_whisper:
67
+ assert "target_len" in single_feature.keys()
68
+ aligned_whisper_feat = self.whisper_aligner.offline_align(
69
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
70
+ )
71
+ single_feature["whisper_feat"] = aligned_whisper_feat
72
+
73
+ if self.cfg.model.condition_encoder.use_contentvec:
74
+ assert "target_len" in single_feature.keys()
75
+ aligned_contentvec = self.contentvec_aligner.offline_align(
76
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
77
+ )
78
+ single_feature["contentvec_feat"] = aligned_contentvec
79
+
80
+ if self.cfg.model.condition_encoder.use_mert:
81
+ assert "target_len" in single_feature.keys()
82
+ aligned_mert_feat = align_content_feature_length(
83
+ np.load(self.utt2mert_path[utt]),
84
+ single_feature["target_len"],
85
+ source_hop=self.cfg.preprocess.mert_hop_size,
86
+ )
87
+ single_feature["mert_feat"] = aligned_mert_feat
88
+
89
+ if self.cfg.model.condition_encoder.use_wenet:
90
+ assert "target_len" in single_feature.keys()
91
+ aligned_wenet_feat = self.wenet_aligner.offline_align(
92
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
93
+ )
94
+ single_feature["wenet_feat"] = aligned_wenet_feat
95
+
96
+ # print(single_feature.keys())
97
+ # for k, v in single_feature.items():
98
+ # if type(v) in [torch.Tensor, np.ndarray]:
99
+ # print(k, v.shape)
100
+ # else:
101
+ # print(k, v)
102
+ # exit()
103
+
104
+ return self.clip_if_too_long(single_feature)
105
+
106
+ def __len__(self):
107
+ return len(self.metadata)
108
+
109
+ def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
110
+ """
111
+ ending_ts: to avoid invalid whisper features for over 30s audios
112
+ 2812 = 30 * 24000 // 256
113
+ """
114
+ ts = max(feature_seq_len - max_seq_len, 0)
115
+ ts = min(ts, ending_ts - max_seq_len)
116
+
117
+ start = random.randint(0, ts)
118
+ end = start + max_seq_len
119
+ return start, end
120
+
121
+ def clip_if_too_long(self, sample, max_seq_len=512):
122
+ """
123
+ sample :
124
+ {
125
+ 'spk_id': (1,),
126
+ 'target_len': int
127
+ 'mel': (seq_len, dim),
128
+ 'frame_pitch': (seq_len,)
129
+ 'frame_energy': (seq_len,)
130
+ 'content_vector_feat': (seq_len, dim)
131
+ }
132
+ """
133
+ if sample["target_len"] <= max_seq_len:
134
+ return sample
135
+
136
+ start, end = self.random_select(sample["target_len"], max_seq_len)
137
+ sample["target_len"] = end - start
138
+
139
+ for k in sample.keys():
140
+ if k not in ["spk_id", "target_len"]:
141
+ sample[k] = sample[k][start:end]
142
+
143
+ return sample
144
+
145
+
146
+ class SVCCollator(BaseCollator):
147
+ """Zero-pads model inputs and targets based on number of frames per step"""
148
+
149
+ def __init__(self, cfg):
150
+ BaseCollator.__init__(self, cfg)
151
+
152
+ def __call__(self, batch):
153
+ parsed_batch_features = BaseCollator.__call__(self, batch)
154
+ return parsed_batch_features
155
+
156
+
157
+ class SVCTestDataset(BaseTestDataset):
158
+ def __init__(self, args, cfg, infer_type):
159
+ BaseTestDataset.__init__(self, args, cfg, infer_type)
160
+ self.metadata = self.get_metadata()
161
+
162
+ target_singer = args.target_singer
163
+ self.cfg = cfg
164
+ self.trans_key = args.trans_key
165
+ assert type(target_singer) == str
166
+
167
+ self.target_singer = target_singer.split("_")[-1]
168
+ self.target_dataset = target_singer.replace(
169
+ "_{}".format(self.target_singer), ""
170
+ )
171
+
172
+ self.target_mel_extrema = load_mel_extrema(cfg.preprocess, self.target_dataset)
173
+ self.target_mel_extrema = torch.as_tensor(
174
+ self.target_mel_extrema[0]
175
+ ), torch.as_tensor(self.target_mel_extrema[1])
176
+
177
+ ######### Load source acoustic features #########
178
+ if cfg.preprocess.use_spkid:
179
+ spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
180
+ # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
181
+
182
+ with open(spk2id_path, "r") as f:
183
+ self.spk2id = json.load(f)
184
+ # print("self.spk2id", self.spk2id)
185
+
186
+ if cfg.preprocess.use_uv:
187
+ self.utt2uv_path = {
188
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
189
+ cfg.preprocess.processed_dir,
190
+ utt_info["Dataset"],
191
+ cfg.preprocess.uv_dir,
192
+ utt_info["Uid"] + ".npy",
193
+ )
194
+ for utt_info in self.metadata
195
+ }
196
+
197
+ if cfg.preprocess.use_frame_pitch:
198
+ self.utt2frame_pitch_path = {
199
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
200
+ cfg.preprocess.processed_dir,
201
+ utt_info["Dataset"],
202
+ cfg.preprocess.pitch_dir,
203
+ utt_info["Uid"] + ".npy",
204
+ )
205
+ for utt_info in self.metadata
206
+ }
207
+
208
+ # Target F0 median
209
+ target_f0_statistics_path = os.path.join(
210
+ cfg.preprocess.processed_dir,
211
+ self.target_dataset,
212
+ cfg.preprocess.pitch_dir,
213
+ "statistics.json",
214
+ )
215
+ self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[
216
+ f"{self.target_dataset}_{self.target_singer}"
217
+ ]["voiced_positions"]["median"]
218
+
219
+ # Source F0 median (if infer from file)
220
+ if infer_type == "from_file":
221
+ source_audio_name = cfg.inference.source_audio_name
222
+ source_f0_statistics_path = os.path.join(
223
+ cfg.preprocess.processed_dir,
224
+ source_audio_name,
225
+ cfg.preprocess.pitch_dir,
226
+ "statistics.json",
227
+ )
228
+ self.source_pitch_median = json.load(
229
+ open(source_f0_statistics_path, "r")
230
+ )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
231
+ "median"
232
+ ]
233
+ else:
234
+ self.source_pitch_median = None
235
+
236
+ if cfg.preprocess.use_frame_energy:
237
+ self.utt2frame_energy_path = {
238
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
239
+ cfg.preprocess.processed_dir,
240
+ utt_info["Dataset"],
241
+ cfg.preprocess.energy_dir,
242
+ utt_info["Uid"] + ".npy",
243
+ )
244
+ for utt_info in self.metadata
245
+ }
246
+
247
+ if cfg.preprocess.use_mel:
248
+ self.utt2mel_path = {
249
+ f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
250
+ cfg.preprocess.processed_dir,
251
+ utt_info["Dataset"],
252
+ cfg.preprocess.mel_dir,
253
+ utt_info["Uid"] + ".npy",
254
+ )
255
+ for utt_info in self.metadata
256
+ }
257
+
258
+ ######### Load source content features' path #########
259
+ if cfg.model.condition_encoder.use_whisper:
260
+ self.whisper_aligner = WhisperExtractor(cfg)
261
+ self.utt2whisper_path = load_content_feature_path(
262
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
263
+ )
264
+
265
+ if cfg.model.condition_encoder.use_contentvec:
266
+ self.contentvec_aligner = ContentvecExtractor(cfg)
267
+ self.utt2contentVec_path = load_content_feature_path(
268
+ self.metadata,
269
+ cfg.preprocess.processed_dir,
270
+ cfg.preprocess.contentvec_dir,
271
+ )
272
+
273
+ if cfg.model.condition_encoder.use_mert:
274
+ self.utt2mert_path = load_content_feature_path(
275
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
276
+ )
277
+ if cfg.model.condition_encoder.use_wenet:
278
+ self.wenet_aligner = WenetExtractor(cfg)
279
+ self.utt2wenet_path = load_content_feature_path(
280
+ self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
281
+ )
282
+
283
+ def __getitem__(self, index):
284
+ single_feature = {}
285
+
286
+ utt_info = self.metadata[index]
287
+ dataset = utt_info["Dataset"]
288
+ uid = utt_info["Uid"]
289
+ utt = "{}_{}".format(dataset, uid)
290
+
291
+ source_dataset = self.metadata[index]["Dataset"]
292
+
293
+ if self.cfg.preprocess.use_spkid:
294
+ single_feature["spk_id"] = np.array(
295
+ [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
296
+ dtype=np.int32,
297
+ )
298
+
299
+ ######### Get Acoustic Features Item #########
300
+ if self.cfg.preprocess.use_mel:
301
+ mel = np.load(self.utt2mel_path[utt])
302
+ assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
303
+ if self.cfg.preprocess.use_min_max_norm_mel:
304
+ # mel norm
305
+ mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
306
+
307
+ if "target_len" not in single_feature.keys():
308
+ single_feature["target_len"] = mel.shape[1]
309
+ single_feature["mel"] = mel.T # [T, n_mels]
310
+
311
+ if self.cfg.preprocess.use_frame_pitch:
312
+ frame_pitch_path = self.utt2frame_pitch_path[utt]
313
+ frame_pitch = np.load(frame_pitch_path)
314
+
315
+ if self.trans_key:
316
+ try:
317
+ self.trans_key = int(self.trans_key)
318
+ except:
319
+ pass
320
+ if type(self.trans_key) == int:
321
+ frame_pitch = transpose_key(frame_pitch, self.trans_key)
322
+ elif self.trans_key:
323
+ assert self.target_singer
324
+
325
+ frame_pitch = pitch_shift_to_target(
326
+ frame_pitch, self.target_pitch_median, self.source_pitch_median
327
+ )
328
+
329
+ if "target_len" not in single_feature.keys():
330
+ single_feature["target_len"] = len(frame_pitch)
331
+ aligned_frame_pitch = align_length(
332
+ frame_pitch, single_feature["target_len"]
333
+ )
334
+ single_feature["frame_pitch"] = aligned_frame_pitch
335
+
336
+ if self.cfg.preprocess.use_uv:
337
+ frame_uv_path = self.utt2uv_path[utt]
338
+ frame_uv = np.load(frame_uv_path)
339
+ aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
340
+ aligned_frame_uv = [
341
+ 0 if frame_uv else 1 for frame_uv in aligned_frame_uv
342
+ ]
343
+ aligned_frame_uv = np.array(aligned_frame_uv)
344
+ single_feature["frame_uv"] = aligned_frame_uv
345
+
346
+ if self.cfg.preprocess.use_frame_energy:
347
+ frame_energy_path = self.utt2frame_energy_path[utt]
348
+ frame_energy = np.load(frame_energy_path)
349
+ if "target_len" not in single_feature.keys():
350
+ single_feature["target_len"] = len(frame_energy)
351
+ aligned_frame_energy = align_length(
352
+ frame_energy, single_feature["target_len"]
353
+ )
354
+ single_feature["frame_energy"] = aligned_frame_energy
355
+
356
+ ######### Get Content Features Item #########
357
+ if self.cfg.model.condition_encoder.use_whisper:
358
+ assert "target_len" in single_feature.keys()
359
+ aligned_whisper_feat = self.whisper_aligner.offline_align(
360
+ np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
361
+ )
362
+ single_feature["whisper_feat"] = aligned_whisper_feat
363
+
364
+ if self.cfg.model.condition_encoder.use_contentvec:
365
+ assert "target_len" in single_feature.keys()
366
+ aligned_contentvec = self.contentvec_aligner.offline_align(
367
+ np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
368
+ )
369
+ single_feature["contentvec_feat"] = aligned_contentvec
370
+
371
+ if self.cfg.model.condition_encoder.use_mert:
372
+ assert "target_len" in single_feature.keys()
373
+ aligned_mert_feat = align_content_feature_length(
374
+ np.load(self.utt2mert_path[utt]),
375
+ single_feature["target_len"],
376
+ source_hop=self.cfg.preprocess.mert_hop_size,
377
+ )
378
+ single_feature["mert_feat"] = aligned_mert_feat
379
+
380
+ if self.cfg.model.condition_encoder.use_wenet:
381
+ assert "target_len" in single_feature.keys()
382
+ aligned_wenet_feat = self.wenet_aligner.offline_align(
383
+ np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
384
+ )
385
+ single_feature["wenet_feat"] = aligned_wenet_feat
386
+
387
+ return single_feature
388
+
389
+ def __len__(self):
390
+ return len(self.metadata)
391
+
392
+
393
+ class SVCTestCollator:
394
+ """Zero-pads model inputs and targets based on number of frames per step"""
395
+
396
+ def __init__(self, cfg):
397
+ self.cfg = cfg
398
+
399
+ def __call__(self, batch):
400
+ packed_batch_features = dict()
401
+
402
+ # mel: [b, T, n_mels]
403
+ # frame_pitch, frame_energy: [1, T]
404
+ # target_len: [1]
405
+ # spk_id: [b, 1]
406
+ # mask: [b, T, 1]
407
+
408
+ for key in batch[0].keys():
409
+ if key == "target_len":
410
+ packed_batch_features["target_len"] = torch.LongTensor(
411
+ [b["target_len"] for b in batch]
412
+ )
413
+ masks = [
414
+ torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
415
+ ]
416
+ packed_batch_features["mask"] = pad_sequence(
417
+ masks, batch_first=True, padding_value=0
418
+ )
419
+ else:
420
+ values = [torch.from_numpy(b[key]) for b in batch]
421
+ packed_batch_features[key] = pad_sequence(
422
+ values, batch_first=True, padding_value=0
423
+ )
424
+
425
+ return packed_batch_features
models/svc/base/svc_inference.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ from models.base.new_inference import BaseInference
7
+ from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
8
+
9
+
10
+ class SVCInference(BaseInference):
11
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
12
+ BaseInference.__init__(self, args, cfg, infer_type)
13
+
14
+ def _build_test_dataset(self):
15
+ return SVCTestDataset, SVCTestCollator
models/svc/base/svc_trainer.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import json
7
+ import os
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from models.base.new_trainer import BaseTrainer
13
+ from models.svc.base.svc_dataset import SVCCollator, SVCDataset
14
+
15
+
16
+ class SVCTrainer(BaseTrainer):
17
+ r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
18
+ ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
19
+ class, and implement ``_build_model``, ``_forward_step``.
20
+ """
21
+
22
+ def __init__(self, args=None, cfg=None):
23
+ self.args = args
24
+ self.cfg = cfg
25
+
26
+ self._init_accelerator()
27
+
28
+ # Only for SVC tasks
29
+ with self.accelerator.main_process_first():
30
+ self.singers = self._build_singer_lut()
31
+
32
+ # Super init
33
+ BaseTrainer.__init__(self, args, cfg)
34
+
35
+ # Only for SVC tasks
36
+ self.task_type = "SVC"
37
+ self.logger.info("Task type: {}".format(self.task_type))
38
+
39
+ ### Following are methods only for SVC tasks ###
40
+ # TODO: LEGACY CODE, NEED TO BE REFACTORED
41
+ def _build_dataset(self):
42
+ return SVCDataset, SVCCollator
43
+
44
+ @staticmethod
45
+ def _build_criterion():
46
+ criterion = nn.MSELoss(reduction="none")
47
+ return criterion
48
+
49
+ @staticmethod
50
+ def _compute_loss(criterion, y_pred, y_gt, loss_mask):
51
+ """
52
+ Args:
53
+ criterion: MSELoss(reduction='none')
54
+ y_pred, y_gt: (bs, seq_len, D)
55
+ loss_mask: (bs, seq_len, 1)
56
+ Returns:
57
+ loss: Tensor of shape []
58
+ """
59
+
60
+ # (bs, seq_len, D)
61
+ loss = criterion(y_pred, y_gt)
62
+ # expand loss_mask to (bs, seq_len, D)
63
+ loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
64
+
65
+ loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
66
+ return loss
67
+
68
+ def _save_auxiliary_states(self):
69
+ """
70
+ To save the singer's look-up table in the checkpoint saving path
71
+ """
72
+ with open(
73
+ os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
74
+ ) as f:
75
+ json.dump(self.singers, f, indent=4, ensure_ascii=False)
76
+
77
+ def _build_singer_lut(self):
78
+ resumed_singer_path = None
79
+ if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
80
+ resumed_singer_path = os.path.join(
81
+ self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
82
+ )
83
+ if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
84
+ resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
85
+
86
+ if resumed_singer_path:
87
+ with open(resumed_singer_path, "r") as f:
88
+ singers = json.load(f)
89
+ else:
90
+ singers = dict()
91
+
92
+ for dataset in self.cfg.dataset:
93
+ singer_lut_path = os.path.join(
94
+ self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
95
+ )
96
+ with open(singer_lut_path, "r") as singer_lut_path:
97
+ singer_lut = json.load(singer_lut_path)
98
+ for singer in singer_lut.keys():
99
+ if singer not in singers:
100
+ singers[singer] = len(singers)
101
+
102
+ with open(
103
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
104
+ ) as singer_file:
105
+ json.dump(singers, singer_file, indent=4, ensure_ascii=False)
106
+ print(
107
+ "singers have been dumped to {}".format(
108
+ os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
109
+ )
110
+ )
111
+ return singers
models/svc/comosvc/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
models/svc/comosvc/comosvc.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ """Adapted from https://github.com/zhenye234/CoMoSpeech"""
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import copy
11
+ import numpy as np
12
+ import math
13
+ from tqdm.auto import tqdm
14
+
15
+ from utils.ssim import SSIM
16
+
17
+ from models.svc.transformer.conformer import Conformer, BaseModule
18
+ from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
19
+ from models.svc.comosvc.utils import slice_segments, rand_ids_segments
20
+
21
+
22
+ class Consistency(nn.Module):
23
+ def __init__(self, cfg, distill=False):
24
+ super().__init__()
25
+ self.cfg = cfg
26
+ # self.denoise_fn = GradLogPEstimator2d(96)
27
+ self.denoise_fn = DiffusionWrapper(self.cfg)
28
+ self.cfg = cfg.model.comosvc
29
+ self.teacher = not distill
30
+ self.P_mean = self.cfg.P_mean
31
+ self.P_std = self.cfg.P_std
32
+ self.sigma_data = self.cfg.sigma_data
33
+ self.sigma_min = self.cfg.sigma_min
34
+ self.sigma_max = self.cfg.sigma_max
35
+ self.rho = self.cfg.rho
36
+ self.N = self.cfg.n_timesteps
37
+ self.ssim_loss = SSIM()
38
+
39
+ # Time step discretization
40
+ step_indices = torch.arange(self.N)
41
+ # karras boundaries formula
42
+ t_steps = (
43
+ self.sigma_min ** (1 / self.rho)
44
+ + step_indices
45
+ / (self.N - 1)
46
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
47
+ ) ** self.rho
48
+ self.t_steps = torch.cat(
49
+ [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
50
+ )
51
+
52
+ def init_consistency_training(self):
53
+ self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
54
+ self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
55
+
56
+ def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
57
+ """
58
+ karras diffusion reverse process
59
+
60
+ Args:
61
+ x: noisy mel-spectrogram [B x n_mel x L]
62
+ sigma: noise level [B x 1 x 1]
63
+ cond: output of conformer encoder [B x n_mel x L]
64
+ denoise_fn: denoiser neural network e.g. DilatedCNN
65
+ mask: mask of padded frames [B x n_mel x L]
66
+
67
+ Returns:
68
+ denoised mel-spectrogram [B x n_mel x L]
69
+ """
70
+ sigma = sigma.reshape(-1, 1, 1)
71
+
72
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
73
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
74
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
75
+ c_noise = sigma.log() / 4
76
+
77
+ x_in = c_in * x
78
+ x_in = x_in.transpose(1, 2)
79
+ x = x.transpose(1, 2)
80
+ cond = cond.transpose(1, 2)
81
+ F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
82
+ # F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten())
83
+ D_x = c_skip * x + c_out * (F_x)
84
+ D_x = D_x.transpose(1, 2)
85
+ return D_x
86
+
87
+ def EDMLoss(self, x_start, cond, mask):
88
+ """
89
+ compute loss for EDM model
90
+
91
+ Args:
92
+ x_start: ground truth mel-spectrogram [B x n_mel x L]
93
+ cond: output of conformer encoder [B x n_mel x L]
94
+ mask: mask of padded frames [B x n_mel x L]
95
+ """
96
+ rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
97
+ sigma = (rnd_normal * self.P_std + self.P_mean).exp()
98
+ weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
99
+
100
+ # follow Grad-TTS, start from Gaussian noise with mean cond and std I
101
+ noise = (torch.randn_like(x_start) + cond) * sigma
102
+ D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
103
+ loss = weight * ((D_yn - x_start) ** 2)
104
+ loss = torch.sum(loss * mask) / torch.sum(mask)
105
+ return loss
106
+
107
+ def round_sigma(self, sigma):
108
+ return torch.as_tensor(sigma)
109
+
110
+ def edm_sampler(
111
+ self,
112
+ latents,
113
+ cond,
114
+ nonpadding,
115
+ num_steps=50,
116
+ sigma_min=0.002,
117
+ sigma_max=80,
118
+ rho=7,
119
+ S_churn=0,
120
+ S_min=0,
121
+ S_max=float("inf"),
122
+ S_noise=1,
123
+ # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
124
+ # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
125
+ # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
126
+ # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
127
+ ):
128
+ """
129
+ karras diffusion sampler
130
+
131
+ Args:
132
+ latents: noisy mel-spectrogram [B x n_mel x L]
133
+ cond: output of conformer encoder [B x n_mel x L]
134
+ nonpadding: mask of padded frames [B x n_mel x L]
135
+ num_steps: number of steps for diffusion inference
136
+
137
+ Returns:
138
+ denoised mel-spectrogram [B x n_mel x L]
139
+ """
140
+ # Time step discretization.
141
+ step_indices = torch.arange(num_steps, device=latents.device)
142
+
143
+ num_steps = num_steps + 1
144
+ t_steps = (
145
+ sigma_max ** (1 / rho)
146
+ + step_indices
147
+ / (num_steps - 1)
148
+ * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
149
+ ) ** rho
150
+ t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
151
+
152
+ # Main sampling loop.
153
+ x_next = latents * t_steps[0]
154
+ # wrap in tqdm for progress bar
155
+ bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
156
+ for i, (t_cur, t_next) in bar:
157
+ x_cur = x_next
158
+ # Increase noise temporarily.
159
+ gamma = (
160
+ min(S_churn / num_steps, np.sqrt(2) - 1)
161
+ if S_min <= t_cur <= S_max
162
+ else 0
163
+ )
164
+ t_hat = self.round_sigma(t_cur + gamma * t_cur)
165
+ t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
166
+ t[:, 0, 0] = t_hat
167
+ t_hat = t
168
+ x_hat = x_cur + (
169
+ t_hat**2 - t_cur**2
170
+ ).sqrt() * S_noise * torch.randn_like(x_cur)
171
+ # Euler step.
172
+ denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
173
+ d_cur = (x_hat - denoised) / t_hat
174
+ x_next = x_hat + (t_next - t_hat) * d_cur
175
+
176
+ return x_next
177
+
178
+ def CTLoss_D(self, y, cond, mask):
179
+ """
180
+ compute loss for consistency distillation
181
+
182
+ Args:
183
+ y: ground truth mel-spectrogram [B x n_mel x L]
184
+ cond: output of conformer encoder [B x n_mel x L]
185
+ mask: mask of padded frames [B x n_mel x L]
186
+ """
187
+ with torch.no_grad():
188
+ mu = 0.95
189
+ for p, ema_p in zip(
190
+ self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
191
+ ):
192
+ ema_p.mul_(mu).add_(p, alpha=1 - mu)
193
+
194
+ n = torch.randint(1, self.N, (y.shape[0],))
195
+ z = torch.randn_like(y) + cond
196
+
197
+ tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
198
+ f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
199
+
200
+ with torch.no_grad():
201
+ tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
202
+
203
+ # euler step
204
+ x_hat = y + tn_1 * z
205
+ denoised = self.EDMPrecond(
206
+ x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
207
+ )
208
+ d_cur = (x_hat - denoised) / tn_1
209
+ y_tn = x_hat + (tn - tn_1) * d_cur
210
+
211
+ f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
212
+
213
+ # loss = (f_theta - f_theta_ema.detach()) ** 2
214
+ # loss = torch.sum(loss * mask) / torch.sum(mask)
215
+ loss = self.ssim_loss(f_theta, f_theta_ema.detach())
216
+ loss = torch.sum(loss * mask) / torch.sum(mask)
217
+
218
+ return loss
219
+
220
+ def get_t_steps(self, N):
221
+ N = N + 1
222
+ step_indices = torch.arange(N) # , device=latents.device)
223
+ t_steps = (
224
+ self.sigma_min ** (1 / self.rho)
225
+ + step_indices
226
+ / (N - 1)
227
+ * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
228
+ ) ** self.rho
229
+
230
+ return t_steps.flip(0)
231
+
232
+ def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
233
+ """
234
+ consistency distillation sampler
235
+
236
+ Args:
237
+ latents: noisy mel-spectrogram [B x n_mel x L]
238
+ cond: output of conformer encoder [B x n_mel x L]
239
+ nonpadding: mask of padded frames [B x n_mel x L]
240
+ t_steps: number of steps for diffusion inference
241
+
242
+ Returns:
243
+ denoised mel-spectrogram [B x n_mel x L]
244
+ """
245
+ # one-step
246
+ if t_steps == 1:
247
+ t_steps = [80]
248
+ # multi-step
249
+ else:
250
+ t_steps = self.get_t_steps(t_steps)
251
+
252
+ t_steps = torch.as_tensor(t_steps).to(latents.device)
253
+ latents = latents * t_steps[0]
254
+ _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
255
+ _t[:, 0, 0] = t_steps
256
+ x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
257
+
258
+ for t in t_steps[1:-1]:
259
+ z = torch.randn_like(x) + cond
260
+ x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
261
+ _t = torch.zeros((x.shape[0], 1, 1), device=x.device)
262
+ _t[:, 0, 0] = t
263
+ t = _t
264
+ print(t)
265
+ x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
266
+ return x
267
+
268
+ def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
269
+ """
270
+ calculate loss or sample mel-spectrogram
271
+
272
+ Args:
273
+ x:
274
+ training: ground truth mel-spectrogram [B x n_mel x L]
275
+ inference: output of encoder [B x n_mel x L]
276
+ """
277
+ if self.teacher: # teacher model -- karras diffusion
278
+ if not infer:
279
+ loss = self.EDMLoss(x, cond, nonpadding)
280
+ return loss
281
+ else:
282
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
283
+ x = torch.randn(shape, device=x.device) + cond
284
+ x = self.edm_sampler(x, cond, nonpadding, t_steps)
285
+
286
+ return x
287
+ else: # Consistency distillation
288
+ if not infer:
289
+ loss = self.CTLoss_D(x, cond, nonpadding)
290
+ return loss
291
+
292
+ else:
293
+ shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
294
+ x = torch.randn(shape, device=x.device) + cond
295
+ x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
296
+
297
+ return x
298
+
299
+
300
+ class ComoSVC(BaseModule):
301
+ def __init__(self, cfg):
302
+ super().__init__()
303
+ self.cfg = cfg
304
+ self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
305
+ self.distill = self.cfg.model.comosvc.distill
306
+ self.encoder = Conformer(self.cfg.model.comosvc)
307
+ self.decoder = Consistency(self.cfg, distill=self.distill)
308
+ self.ssim_loss = SSIM()
309
+
310
+ @torch.no_grad()
311
+ def forward(self, x_mask, x, n_timesteps, temperature=1.0):
312
+ """
313
+ Generates mel-spectrogram from pitch, content vector, energy. Returns:
314
+ 1. encoder outputs (from conformer)
315
+ 2. decoder outputs (from diffusion-based decoder)
316
+
317
+ Args:
318
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
319
+ x : output of encoder framework. [B x L x d_condition]
320
+ n_timesteps : number of steps to use for reverse diffusion in decoder.
321
+ temperature : controls variance of terminal distribution.
322
+ """
323
+
324
+ # Get encoder_outputs `mu_x`
325
+ mu_x = self.encoder(x, x_mask)
326
+ encoder_outputs = mu_x
327
+
328
+ mu_x = mu_x.transpose(1, 2)
329
+ x_mask = x_mask.transpose(1, 2)
330
+
331
+ # Generate sample by performing reverse dynamics
332
+ decoder_outputs = self.decoder(
333
+ mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
334
+ )
335
+ decoder_outputs = decoder_outputs.transpose(1, 2)
336
+ return encoder_outputs, decoder_outputs
337
+
338
+ def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
339
+ """
340
+ Computes 2 losses:
341
+ 1. prior loss: loss between mel-spectrogram and encoder outputs.
342
+ 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
343
+
344
+ Args:
345
+ x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
346
+ x : output of encoder framework. [B x L x d_condition]
347
+ mel : ground truth mel-spectrogram. [B x L x n_mel]
348
+ """
349
+
350
+ mu_x = self.encoder(x, x_mask)
351
+ # prior loss
352
+ prior_loss = torch.sum(
353
+ 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
354
+ )
355
+ prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
356
+ # ssim loss
357
+ ssim_loss = self.ssim_loss(mu_x, mel)
358
+ ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
359
+
360
+ x_mask = x_mask.transpose(1, 2)
361
+ mu_x = mu_x.transpose(1, 2)
362
+ mel = mel.transpose(1, 2)
363
+ if not self.distill and skip_diff:
364
+ diff_loss = prior_loss.clone()
365
+ diff_loss.fill_(0)
366
+
367
+ # Cut a small segment of mel-spectrogram in order to increase batch size
368
+ else:
369
+ if self.distill:
370
+ mu_y = mu_x.detach()
371
+ else:
372
+ mu_y = mu_x
373
+ mask_y = x_mask
374
+
375
+ diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
376
+
377
+ return ssim_loss, prior_loss, diff_loss
models/svc/comosvc/comosvc_inference.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from models.svc.base import SVCInference
9
+ from modules.encoder.condition_encoder import ConditionEncoder
10
+ from models.svc.comosvc.comosvc import ComoSVC
11
+
12
+
13
+ class ComoSVCInference(SVCInference):
14
+ def __init__(self, args, cfg, infer_type="from_dataset"):
15
+ SVCInference.__init__(self, args, cfg, infer_type)
16
+
17
+ def _build_model(self):
18
+ # TODO: sort out the config
19
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
20
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
21
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
22
+ self.acoustic_mapper = ComoSVC(self.cfg)
23
+ if self.cfg.model.comosvc.distill:
24
+ self.acoustic_mapper.decoder.init_consistency_training()
25
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
26
+ return model
27
+
28
+ def _inference_each_batch(self, batch_data):
29
+ device = self.accelerator.device
30
+ for k, v in batch_data.items():
31
+ batch_data[k] = v.to(device)
32
+
33
+ cond = self.condition_encoder(batch_data)
34
+ mask = batch_data["mask"]
35
+ encoder_pred, decoder_pred = self.acoustic_mapper(
36
+ mask, cond, self.cfg.inference.comosvc.inference_steps
37
+ )
38
+
39
+ return decoder_pred
models/svc/comosvc/comosvc_trainer.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import os
8
+ import json5
9
+ from collections import OrderedDict
10
+ from tqdm import tqdm
11
+ import json
12
+ import shutil
13
+
14
+ from models.svc.base import SVCTrainer
15
+ from modules.encoder.condition_encoder import ConditionEncoder
16
+ from models.svc.comosvc.comosvc import ComoSVC
17
+
18
+
19
+ class ComoSVCTrainer(SVCTrainer):
20
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
21
+ implements ``_build_model`` and ``_forward_step`` methods.
22
+ """
23
+
24
+ def __init__(self, args=None, cfg=None):
25
+ SVCTrainer.__init__(self, args, cfg)
26
+ self.distill = cfg.model.comosvc.distill
27
+ self.skip_diff = True
28
+ if self.distill: # and args.resume is None:
29
+ self.teacher_model_path = cfg.model.teacher_model_path
30
+ self.teacher_state_dict = self._load_teacher_state_dict()
31
+ self._load_teacher_model(self.teacher_state_dict)
32
+ self.acoustic_mapper.decoder.init_consistency_training()
33
+
34
+ ### Following are methods only for comoSVC models ###
35
+ def _load_teacher_state_dict(self):
36
+ self.checkpoint_file = self.teacher_model_path
37
+ print("Load teacher acoustic model from {}".format(self.checkpoint_file))
38
+ raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device)
39
+ return raw_state_dict
40
+
41
+ def _load_teacher_model(self, state_dict):
42
+ raw_dict = state_dict
43
+ clean_dict = OrderedDict()
44
+ for k, v in raw_dict.items():
45
+ if k.startswith("module."):
46
+ clean_dict[k[7:]] = v
47
+ else:
48
+ clean_dict[k] = v
49
+ self.model.load_state_dict(clean_dict)
50
+
51
+ def _build_model(self):
52
+ r"""Build the model for training. This function is called in ``__init__`` function."""
53
+
54
+ # TODO: sort out the config
55
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
56
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
57
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
58
+ self.acoustic_mapper = ComoSVC(self.cfg)
59
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
60
+ return model
61
+
62
+ def _forward_step(self, batch):
63
+ r"""Forward step for training and inference. This function is called
64
+ in ``_train_step`` & ``_test_step`` function.
65
+ """
66
+ loss = {}
67
+ mask = batch["mask"]
68
+ mel_input = batch["mel"]
69
+ cond = self.condition_encoder(batch)
70
+ if self.distill:
71
+ cond = cond.detach()
72
+ self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
73
+ ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
74
+ mask, cond, mel_input, skip_diff=self.skip_diff
75
+ )
76
+ if self.distill:
77
+ loss["distil_loss"] = diff_loss
78
+ else:
79
+ loss["ssim_loss_encoder"] = ssim_loss
80
+ loss["prior_loss_encoder"] = prior_loss
81
+ loss["diffusion_loss_decoder"] = diff_loss
82
+
83
+ return loss
84
+
85
+ def _train_epoch(self):
86
+ r"""Training epoch. Should return average loss of a batch (sample) over
87
+ one epoch. See ``train_loop`` for usage.
88
+ """
89
+ self.model.train()
90
+ epoch_sum_loss: float = 0.0
91
+ epoch_step: int = 0
92
+ for batch in tqdm(
93
+ self.train_dataloader,
94
+ desc=f"Training Epoch {self.epoch}",
95
+ unit="batch",
96
+ colour="GREEN",
97
+ leave=False,
98
+ dynamic_ncols=True,
99
+ smoothing=0.04,
100
+ disable=not self.accelerator.is_main_process,
101
+ ):
102
+ # Do training step and BP
103
+ with self.accelerator.accumulate(self.model):
104
+ loss = self._train_step(batch)
105
+ total_loss = 0
106
+ for k, v in loss.items():
107
+ total_loss += v
108
+ self.accelerator.backward(total_loss)
109
+ enc_grad_norm = torch.nn.utils.clip_grad_norm_(
110
+ self.acoustic_mapper.encoder.parameters(), max_norm=1
111
+ )
112
+ dec_grad_norm = torch.nn.utils.clip_grad_norm_(
113
+ self.acoustic_mapper.decoder.parameters(), max_norm=1
114
+ )
115
+ self.optimizer.step()
116
+ self.optimizer.zero_grad()
117
+ self.batch_count += 1
118
+
119
+ # Update info for each step
120
+ # TODO: step means BP counts or batch counts?
121
+ if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
122
+ epoch_sum_loss += total_loss
123
+ log_info = {}
124
+ for k, v in loss.items():
125
+ key = "Step/Train Loss/{}".format(k)
126
+ log_info[key] = v
127
+ log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
128
+ self.accelerator.log(
129
+ log_info,
130
+ step=self.step,
131
+ )
132
+ self.step += 1
133
+ epoch_step += 1
134
+
135
+ self.accelerator.wait_for_everyone()
136
+ return (
137
+ epoch_sum_loss
138
+ / len(self.train_dataloader)
139
+ * self.cfg.train.gradient_accumulation_step,
140
+ loss,
141
+ )
142
+
143
+ def train_loop(self):
144
+ r"""Training loop. The public entry of training process."""
145
+ # Wait everyone to prepare before we move on
146
+ self.accelerator.wait_for_everyone()
147
+ # dump config file
148
+ if self.accelerator.is_main_process:
149
+ self.__dump_cfg(self.config_save_path)
150
+ self.model.train()
151
+ self.optimizer.zero_grad()
152
+ # Wait to ensure good to go
153
+ self.accelerator.wait_for_everyone()
154
+ while self.epoch < self.max_epoch:
155
+ self.logger.info("\n")
156
+ self.logger.info("-" * 32)
157
+ self.logger.info("Epoch {}: ".format(self.epoch))
158
+
159
+ ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
160
+ ### It's inconvenient for the model with multiple losses
161
+ # Do training & validating epoch
162
+ train_loss, loss = self._train_epoch()
163
+ self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
164
+ for k, v in loss.items():
165
+ self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
166
+ valid_loss = self._valid_epoch()
167
+ self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
168
+ self.accelerator.log(
169
+ {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
170
+ step=self.epoch,
171
+ )
172
+
173
+ self.accelerator.wait_for_everyone()
174
+ # TODO: what is scheduler?
175
+ self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
176
+
177
+ # Check if hit save_checkpoint_stride and run_eval
178
+ run_eval = False
179
+ if self.accelerator.is_main_process:
180
+ save_checkpoint = False
181
+ hit_dix = []
182
+ for i, num in enumerate(self.save_checkpoint_stride):
183
+ if self.epoch % num == 0:
184
+ save_checkpoint = True
185
+ hit_dix.append(i)
186
+ run_eval |= self.run_eval[i]
187
+
188
+ self.accelerator.wait_for_everyone()
189
+ if (
190
+ self.accelerator.is_main_process
191
+ and save_checkpoint
192
+ and (self.distill or not self.skip_diff)
193
+ ):
194
+ path = os.path.join(
195
+ self.checkpoint_dir,
196
+ "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
197
+ self.epoch, self.step, train_loss
198
+ ),
199
+ )
200
+ self.accelerator.save_state(path)
201
+ json.dump(
202
+ self.checkpoints_path,
203
+ open(os.path.join(path, "ckpts.json"), "w"),
204
+ ensure_ascii=False,
205
+ indent=4,
206
+ )
207
+
208
+ # Remove old checkpoints
209
+ to_remove = []
210
+ for idx in hit_dix:
211
+ self.checkpoints_path[idx].append(path)
212
+ while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
213
+ to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
214
+
215
+ # Search conflicts
216
+ total = set()
217
+ for i in self.checkpoints_path:
218
+ total |= set(i)
219
+ do_remove = set()
220
+ for idx, path in to_remove[::-1]:
221
+ if path in total:
222
+ self.checkpoints_path[idx].insert(0, path)
223
+ else:
224
+ do_remove.add(path)
225
+
226
+ # Remove old checkpoints
227
+ for path in do_remove:
228
+ shutil.rmtree(path, ignore_errors=True)
229
+ self.logger.debug(f"Remove old checkpoint: {path}")
230
+
231
+ self.accelerator.wait_for_everyone()
232
+ if run_eval:
233
+ # TODO: run evaluation
234
+ pass
235
+
236
+ # Update info for each epoch
237
+ self.epoch += 1
238
+
239
+ # Finish training and save final checkpoint
240
+ self.accelerator.wait_for_everyone()
241
+ if self.accelerator.is_main_process:
242
+ self.accelerator.save_state(
243
+ os.path.join(
244
+ self.checkpoint_dir,
245
+ "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
246
+ self.epoch, self.step, valid_loss
247
+ ),
248
+ )
249
+ )
250
+ self.accelerator.end_training()
251
+
252
+ @torch.inference_mode()
253
+ def _valid_epoch(self):
254
+ r"""Testing epoch. Should return average loss of a batch (sample) over
255
+ one epoch. See ``train_loop`` for usage.
256
+ """
257
+ self.model.eval()
258
+ epoch_sum_loss = 0.0
259
+ for batch in tqdm(
260
+ self.valid_dataloader,
261
+ desc=f"Validating Epoch {self.epoch}",
262
+ unit="batch",
263
+ colour="GREEN",
264
+ leave=False,
265
+ dynamic_ncols=True,
266
+ smoothing=0.04,
267
+ disable=not self.accelerator.is_main_process,
268
+ ):
269
+ batch_loss = self._valid_step(batch)
270
+ for k, v in batch_loss.items():
271
+ epoch_sum_loss += v
272
+
273
+ self.accelerator.wait_for_everyone()
274
+ return epoch_sum_loss / len(self.valid_dataloader)
275
+
276
+ @staticmethod
277
+ def __count_parameters(model):
278
+ model_param = 0.0
279
+ if isinstance(model, dict):
280
+ for key, value in model.items():
281
+ model_param += sum(p.numel() for p in model[key].parameters())
282
+ else:
283
+ model_param = sum(p.numel() for p in model.parameters())
284
+ return model_param
285
+
286
+ def __dump_cfg(self, path):
287
+ os.makedirs(os.path.dirname(path), exist_ok=True)
288
+ json5.dump(
289
+ self.cfg,
290
+ open(path, "w"),
291
+ indent=4,
292
+ sort_keys=True,
293
+ ensure_ascii=False,
294
+ quote_keys=True,
295
+ )
models/svc/comosvc/utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+
9
+ def slice_segments(x, ids_str, segment_size=200):
10
+ ret = torch.zeros_like(x[:, :, :segment_size])
11
+ for i in range(x.size(0)):
12
+ idx_str = ids_str[i]
13
+ idx_end = idx_str + segment_size
14
+ ret[i] = x[i, :, idx_str:idx_end]
15
+ return ret
16
+
17
+
18
+ def rand_ids_segments(lengths, segment_size=200):
19
+ b = lengths.shape[0]
20
+ ids_str_max = lengths - segment_size
21
+ ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
22
+ dtype=torch.long
23
+ )
24
+ return ids_str
25
+
26
+
27
+ def fix_len_compatibility(length, num_downsamplings_in_unet=2):
28
+ while True:
29
+ if length % (2**num_downsamplings_in_unet) == 0:
30
+ return length
31
+ length += 1
models/svc/diffusion/__init__.py ADDED
File without changes
models/svc/diffusion/diffusion_inference.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
8
+
9
+ from models.svc.base import SVCInference
10
+ from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
11
+ from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
12
+ from modules.encoder.condition_encoder import ConditionEncoder
13
+
14
+
15
+ class DiffusionInference(SVCInference):
16
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
17
+ SVCInference.__init__(self, args, cfg, infer_type)
18
+
19
+ settings = {
20
+ **cfg.model.diffusion.scheduler_settings,
21
+ **cfg.inference.diffusion.scheduler_settings,
22
+ }
23
+ settings.pop("num_inference_timesteps")
24
+
25
+ if cfg.inference.diffusion.scheduler.lower() == "ddpm":
26
+ self.scheduler = DDPMScheduler(**settings)
27
+ self.logger.info("Using DDPM scheduler.")
28
+ elif cfg.inference.diffusion.scheduler.lower() == "ddim":
29
+ self.scheduler = DDIMScheduler(**settings)
30
+ self.logger.info("Using DDIM scheduler.")
31
+ elif cfg.inference.diffusion.scheduler.lower() == "pndm":
32
+ self.scheduler = PNDMScheduler(**settings)
33
+ self.logger.info("Using PNDM scheduler.")
34
+ else:
35
+ raise NotImplementedError(
36
+ "Unsupported scheduler type: {}".format(
37
+ cfg.inference.diffusion.scheduler.lower()
38
+ )
39
+ )
40
+
41
+ self.pipeline = DiffusionInferencePipeline(
42
+ self.model[1],
43
+ self.scheduler,
44
+ args.diffusion_inference_steps,
45
+ )
46
+
47
+ def _build_model(self):
48
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
49
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
50
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
51
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
52
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
53
+ return model
54
+
55
+ def _inference_each_batch(self, batch_data):
56
+ device = self.accelerator.device
57
+ for k, v in batch_data.items():
58
+ batch_data[k] = v.to(device)
59
+
60
+ conditioner = self.model[0](batch_data)
61
+ noise = torch.randn_like(batch_data["mel"], device=device)
62
+ y_pred = self.pipeline(noise, conditioner)
63
+ return y_pred
models/svc/diffusion/diffusion_inference_pipeline.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DiffusionPipeline
8
+
9
+
10
+ class DiffusionInferencePipeline(DiffusionPipeline):
11
+ def __init__(self, network, scheduler, num_inference_timesteps=1000):
12
+ super().__init__()
13
+
14
+ self.register_modules(network=network, scheduler=scheduler)
15
+ self.num_inference_timesteps = num_inference_timesteps
16
+
17
+ @torch.inference_mode()
18
+ def __call__(
19
+ self,
20
+ initial_noise: torch.Tensor,
21
+ conditioner: torch.Tensor = None,
22
+ ):
23
+ r"""
24
+ Args:
25
+ initial_noise: The initial noise to be denoised.
26
+ conditioner:The conditioner.
27
+ n_inference_steps: The number of denoising steps. More denoising steps
28
+ usually lead to a higher quality at the expense of slower inference.
29
+ """
30
+
31
+ mel = initial_noise
32
+ batch_size = mel.size(0)
33
+ self.scheduler.set_timesteps(self.num_inference_timesteps)
34
+
35
+ for t in self.progress_bar(self.scheduler.timesteps):
36
+ timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
37
+
38
+ # 1. predict noise model_output
39
+ model_output = self.network(mel, timestep, conditioner)
40
+
41
+ # 2. denoise, compute previous step: x_t -> x_t-1
42
+ mel = self.scheduler.step(model_output, t, mel).prev_sample
43
+
44
+ # 3. clamp
45
+ mel = mel.clamp(-1.0, 1.0)
46
+
47
+ return mel
models/svc/diffusion/diffusion_trainer.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ from diffusers import DDPMScheduler
8
+
9
+ from models.svc.base import SVCTrainer
10
+ from modules.encoder.condition_encoder import ConditionEncoder
11
+ from .diffusion_wrapper import DiffusionWrapper
12
+
13
+
14
+ class DiffusionTrainer(SVCTrainer):
15
+ r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
16
+ implements ``_build_model`` and ``_forward_step`` methods.
17
+ """
18
+
19
+ def __init__(self, args=None, cfg=None):
20
+ SVCTrainer.__init__(self, args, cfg)
21
+
22
+ # Only for SVC tasks using diffusion
23
+ self.noise_scheduler = DDPMScheduler(
24
+ **self.cfg.model.diffusion.scheduler_settings,
25
+ )
26
+ self.diffusion_timesteps = (
27
+ self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
28
+ )
29
+
30
+ ### Following are methods only for diffusion models ###
31
+ def _build_model(self):
32
+ r"""Build the model for training. This function is called in ``__init__`` function."""
33
+
34
+ # TODO: sort out the config
35
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
36
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
37
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
38
+ self.acoustic_mapper = DiffusionWrapper(self.cfg)
39
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
40
+
41
+ num_of_params_encoder = self.count_parameters(self.condition_encoder)
42
+ num_of_params_am = self.count_parameters(self.acoustic_mapper)
43
+ num_of_params = num_of_params_encoder + num_of_params_am
44
+ log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
45
+ num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
46
+ )
47
+ self.logger.info(log)
48
+
49
+ return model
50
+
51
+ def count_parameters(self, model):
52
+ model_param = 0.0
53
+ if isinstance(model, dict):
54
+ for key, value in model.items():
55
+ model_param += sum(p.numel() for p in model[key].parameters())
56
+ else:
57
+ model_param = sum(p.numel() for p in model.parameters())
58
+ return model_param
59
+
60
+ def _forward_step(self, batch):
61
+ r"""Forward step for training and inference. This function is called
62
+ in ``_train_step`` & ``_test_step`` function.
63
+ """
64
+
65
+ device = self.accelerator.device
66
+
67
+ mel_input = batch["mel"]
68
+ noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
69
+ batch_size = mel_input.size(0)
70
+ timesteps = torch.randint(
71
+ 0,
72
+ self.diffusion_timesteps,
73
+ (batch_size,),
74
+ device=device,
75
+ dtype=torch.long,
76
+ )
77
+
78
+ noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
79
+ conditioner = self.condition_encoder(batch)
80
+
81
+ y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
82
+
83
+ # TODO: Predict noise or gt should be configurable
84
+ loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
85
+ self._check_nan(loss, y_pred, noise)
86
+
87
+ # FIXME: Clarify that we should not divide it with batch size here
88
+ return loss
models/svc/diffusion/diffusion_wrapper.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch.nn as nn
7
+
8
+ from modules.diffusion import BiDilConv
9
+ from modules.encoder.position_encoder import PositionEncoder
10
+
11
+
12
+ class DiffusionWrapper(nn.Module):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+
16
+ self.cfg = cfg
17
+ self.diff_cfg = cfg.model.diffusion
18
+
19
+ self.diff_encoder = PositionEncoder(
20
+ d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
21
+ d_out=self.diff_cfg.bidilconv.base_channel,
22
+ d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
23
+ activation_function=self.diff_cfg.step_encoder.activation,
24
+ n_layer=self.diff_cfg.step_encoder.num_layer,
25
+ max_period=self.diff_cfg.step_encoder.max_period,
26
+ )
27
+
28
+ # FIXME: Only support BiDilConv now for debug
29
+ if self.diff_cfg.model_type.lower() == "bidilconv":
30
+ self.neural_network = BiDilConv(
31
+ input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
32
+ )
33
+ else:
34
+ raise ValueError(
35
+ f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
36
+ )
37
+
38
+ def forward(self, x, t, c):
39
+ """
40
+ Args:
41
+ x: [N, T, mel_band] of mel spectrogram
42
+ t: Diffusion time step with shape of [N]
43
+ c: [N, T, conditioner_size] of conditioner
44
+
45
+ Returns:
46
+ [N, T, mel_band] of mel spectrogram
47
+ """
48
+
49
+ assert (
50
+ x.size()[:-1] == c.size()[:-1]
51
+ ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
52
+ assert x.size(0) == t.size(
53
+ 0
54
+ ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
55
+ assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
56
+
57
+ N, T, mel_band = x.size()
58
+
59
+ x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
60
+ c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
61
+ t = self.diff_encoder(t).contiguous() # [N, base_channel]
62
+
63
+ h = self.neural_network(x, t, c)
64
+ h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
65
+
66
+ assert h.size() == (
67
+ N,
68
+ T,
69
+ mel_band,
70
+ ), "h mismatch with input x, got \n h: {} \n x: {}".format(
71
+ h.size(), (N, T, mel_band)
72
+ )
73
+ return h
models/svc/transformer/__init__.py ADDED
File without changes
models/svc/transformer/conformer.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import numpy as np
9
+ import torch.nn as nn
10
+ from utils.util import convert_pad_shape
11
+
12
+
13
+ class BaseModule(torch.nn.Module):
14
+ def __init__(self):
15
+ super(BaseModule, self).__init__()
16
+
17
+ @property
18
+ def nparams(self):
19
+ """
20
+ Returns number of trainable parameters of the module.
21
+ """
22
+ num_params = 0
23
+ for name, param in self.named_parameters():
24
+ if param.requires_grad:
25
+ num_params += np.prod(param.detach().cpu().numpy().shape)
26
+ return num_params
27
+
28
+ def relocate_input(self, x: list):
29
+ """
30
+ Relocates provided tensors to the same device set for the module.
31
+ """
32
+ device = next(self.parameters()).device
33
+ for i in range(len(x)):
34
+ if isinstance(x[i], torch.Tensor) and x[i].device != device:
35
+ x[i] = x[i].to(device)
36
+ return x
37
+
38
+
39
+ class LayerNorm(BaseModule):
40
+ def __init__(self, channels, eps=1e-4):
41
+ super(LayerNorm, self).__init__()
42
+ self.channels = channels
43
+ self.eps = eps
44
+
45
+ self.gamma = torch.nn.Parameter(torch.ones(channels))
46
+ self.beta = torch.nn.Parameter(torch.zeros(channels))
47
+
48
+ def forward(self, x):
49
+ n_dims = len(x.shape)
50
+ mean = torch.mean(x, 1, keepdim=True)
51
+ variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
52
+
53
+ x = (x - mean) * torch.rsqrt(variance + self.eps)
54
+
55
+ shape = [1, -1] + [1] * (n_dims - 2)
56
+ x = x * self.gamma.view(*shape) + self.beta.view(*shape)
57
+ return x
58
+
59
+
60
+ class ConvReluNorm(BaseModule):
61
+ def __init__(
62
+ self,
63
+ in_channels,
64
+ hidden_channels,
65
+ out_channels,
66
+ kernel_size,
67
+ n_layers,
68
+ p_dropout,
69
+ eps=1e-5,
70
+ ):
71
+ super(ConvReluNorm, self).__init__()
72
+ self.in_channels = in_channels
73
+ self.hidden_channels = hidden_channels
74
+ self.out_channels = out_channels
75
+ self.kernel_size = kernel_size
76
+ self.n_layers = n_layers
77
+ self.p_dropout = p_dropout
78
+ self.eps = eps
79
+
80
+ self.conv_layers = torch.nn.ModuleList()
81
+ self.conv_layers.append(
82
+ torch.nn.Conv1d(
83
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
84
+ )
85
+ )
86
+ self.relu_drop = torch.nn.Sequential(
87
+ torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
88
+ )
89
+ for _ in range(n_layers - 1):
90
+ self.conv_layers.append(
91
+ torch.nn.Conv1d(
92
+ hidden_channels,
93
+ hidden_channels,
94
+ kernel_size,
95
+ padding=kernel_size // 2,
96
+ )
97
+ )
98
+ self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
99
+ self.proj.weight.data.zero_()
100
+ self.proj.bias.data.zero_()
101
+
102
+ def forward(self, x, x_mask):
103
+ for i in range(self.n_layers):
104
+ x = self.conv_layers[i](x * x_mask)
105
+ x = self.instance_norm(x, x_mask)
106
+ x = self.relu_drop(x)
107
+ x = self.proj(x)
108
+ return x * x_mask
109
+
110
+ def instance_norm(self, x, mask, return_mean_std=False):
111
+ mean, std = self.calc_mean_std(x, mask)
112
+ x = (x - mean) / std
113
+ if return_mean_std:
114
+ return x, mean, std
115
+ else:
116
+ return x
117
+
118
+ def calc_mean_std(self, x, mask=None):
119
+ x = x * mask
120
+ B, C = x.shape[:2]
121
+ mn = x.view(B, C, -1).mean(-1)
122
+ sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
123
+ mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
124
+ sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
125
+ return mn, sd
126
+
127
+
128
+ class MultiHeadAttention(BaseModule):
129
+ def __init__(
130
+ self,
131
+ channels,
132
+ out_channels,
133
+ n_heads,
134
+ window_size=None,
135
+ heads_share=True,
136
+ p_dropout=0.0,
137
+ proximal_bias=False,
138
+ proximal_init=False,
139
+ ):
140
+ super(MultiHeadAttention, self).__init__()
141
+ assert channels % n_heads == 0
142
+
143
+ self.channels = channels
144
+ self.out_channels = out_channels
145
+ self.n_heads = n_heads
146
+ self.window_size = window_size
147
+ self.heads_share = heads_share
148
+ self.proximal_bias = proximal_bias
149
+ self.p_dropout = p_dropout
150
+ self.attn = None
151
+
152
+ self.k_channels = channels // n_heads
153
+ self.conv_q = torch.nn.Conv1d(channels, channels, 1)
154
+ self.conv_k = torch.nn.Conv1d(channels, channels, 1)
155
+ self.conv_v = torch.nn.Conv1d(channels, channels, 1)
156
+ if window_size is not None:
157
+ n_heads_rel = 1 if heads_share else n_heads
158
+ rel_stddev = self.k_channels**-0.5
159
+ self.emb_rel_k = torch.nn.Parameter(
160
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
161
+ * rel_stddev
162
+ )
163
+ self.emb_rel_v = torch.nn.Parameter(
164
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
165
+ * rel_stddev
166
+ )
167
+ self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
168
+ self.drop = torch.nn.Dropout(p_dropout)
169
+
170
+ torch.nn.init.xavier_uniform_(self.conv_q.weight)
171
+ torch.nn.init.xavier_uniform_(self.conv_k.weight)
172
+ if proximal_init:
173
+ self.conv_k.weight.data.copy_(self.conv_q.weight.data)
174
+ self.conv_k.bias.data.copy_(self.conv_q.bias.data)
175
+ torch.nn.init.xavier_uniform_(self.conv_v.weight)
176
+
177
+ def forward(self, x, c, attn_mask=None):
178
+ q = self.conv_q(x)
179
+ k = self.conv_k(c)
180
+ v = self.conv_v(c)
181
+
182
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
183
+
184
+ x = self.conv_o(x)
185
+ return x
186
+
187
+ def attention(self, query, key, value, mask=None):
188
+ b, d, t_s, t_t = (*key.size(), query.size(2))
189
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
190
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
191
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
192
+
193
+ scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
194
+ if self.window_size is not None:
195
+ assert (
196
+ t_s == t_t
197
+ ), "Relative attention is only available for self-attention."
198
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
199
+ rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
200
+ rel_logits = self._relative_position_to_absolute_position(rel_logits)
201
+ scores_local = rel_logits / math.sqrt(self.k_channels)
202
+ scores = scores + scores_local
203
+ if self.proximal_bias:
204
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
205
+ scores = scores + self._attention_bias_proximal(t_s).to(
206
+ device=scores.device, dtype=scores.dtype
207
+ )
208
+ if mask is not None:
209
+ scores = scores.masked_fill(mask == 0, -1e4)
210
+ p_attn = torch.nn.functional.softmax(scores, dim=-1)
211
+ p_attn = self.drop(p_attn)
212
+ output = torch.matmul(p_attn, value)
213
+ if self.window_size is not None:
214
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
215
+ value_relative_embeddings = self._get_relative_embeddings(
216
+ self.emb_rel_v, t_s
217
+ )
218
+ output = output + self._matmul_with_relative_values(
219
+ relative_weights, value_relative_embeddings
220
+ )
221
+ output = output.transpose(2, 3).contiguous().view(b, d, t_t)
222
+ return output, p_attn
223
+
224
+ def _matmul_with_relative_values(self, x, y):
225
+ ret = torch.matmul(x, y.unsqueeze(0))
226
+ return ret
227
+
228
+ def _matmul_with_relative_keys(self, x, y):
229
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
230
+ return ret
231
+
232
+ def _get_relative_embeddings(self, relative_embeddings, length):
233
+ pad_length = max(length - (self.window_size + 1), 0)
234
+ slice_start_position = max((self.window_size + 1) - length, 0)
235
+ slice_end_position = slice_start_position + 2 * length - 1
236
+ if pad_length > 0:
237
+ padded_relative_embeddings = torch.nn.functional.pad(
238
+ relative_embeddings,
239
+ convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
240
+ )
241
+ else:
242
+ padded_relative_embeddings = relative_embeddings
243
+ used_relative_embeddings = padded_relative_embeddings[
244
+ :, slice_start_position:slice_end_position
245
+ ]
246
+ return used_relative_embeddings
247
+
248
+ def _relative_position_to_absolute_position(self, x):
249
+ batch, heads, length, _ = x.size()
250
+ x = torch.nn.functional.pad(
251
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
252
+ )
253
+ x_flat = x.view([batch, heads, length * 2 * length])
254
+ x_flat = torch.nn.functional.pad(
255
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
256
+ )
257
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
258
+ :, :, :length, length - 1 :
259
+ ]
260
+ return x_final
261
+
262
+ def _absolute_position_to_relative_position(self, x):
263
+ batch, heads, length, _ = x.size()
264
+ x = torch.nn.functional.pad(
265
+ x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
266
+ )
267
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
268
+ x_flat = torch.nn.functional.pad(
269
+ x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
270
+ )
271
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
272
+ return x_final
273
+
274
+ def _attention_bias_proximal(self, length):
275
+ r = torch.arange(length, dtype=torch.float32)
276
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
277
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
278
+
279
+
280
+ class FFN(BaseModule):
281
+ def __init__(
282
+ self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
283
+ ):
284
+ super(FFN, self).__init__()
285
+ self.in_channels = in_channels
286
+ self.out_channels = out_channels
287
+ self.filter_channels = filter_channels
288
+ self.kernel_size = kernel_size
289
+ self.p_dropout = p_dropout
290
+
291
+ self.conv_1 = torch.nn.Conv1d(
292
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
293
+ )
294
+ self.conv_2 = torch.nn.Conv1d(
295
+ filter_channels, out_channels, kernel_size, padding=kernel_size // 2
296
+ )
297
+ self.drop = torch.nn.Dropout(p_dropout)
298
+
299
+ def forward(self, x, x_mask):
300
+ x = self.conv_1(x * x_mask)
301
+ x = torch.relu(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ return x * x_mask
305
+
306
+
307
+ class Encoder(BaseModule):
308
+ def __init__(
309
+ self,
310
+ hidden_channels,
311
+ filter_channels,
312
+ n_heads=2,
313
+ n_layers=6,
314
+ kernel_size=3,
315
+ p_dropout=0.1,
316
+ window_size=4,
317
+ **kwargs
318
+ ):
319
+ super(Encoder, self).__init__()
320
+ self.hidden_channels = hidden_channels
321
+ self.filter_channels = filter_channels
322
+ self.n_heads = n_heads
323
+ self.n_layers = n_layers
324
+ self.kernel_size = kernel_size
325
+ self.p_dropout = p_dropout
326
+ self.window_size = window_size
327
+
328
+ self.drop = torch.nn.Dropout(p_dropout)
329
+ self.attn_layers = torch.nn.ModuleList()
330
+ self.norm_layers_1 = torch.nn.ModuleList()
331
+ self.ffn_layers = torch.nn.ModuleList()
332
+ self.norm_layers_2 = torch.nn.ModuleList()
333
+ for _ in range(self.n_layers):
334
+ self.attn_layers.append(
335
+ MultiHeadAttention(
336
+ hidden_channels,
337
+ hidden_channels,
338
+ n_heads,
339
+ window_size=window_size,
340
+ p_dropout=p_dropout,
341
+ )
342
+ )
343
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
344
+ self.ffn_layers.append(
345
+ FFN(
346
+ hidden_channels,
347
+ hidden_channels,
348
+ filter_channels,
349
+ kernel_size,
350
+ p_dropout=p_dropout,
351
+ )
352
+ )
353
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
354
+
355
+ def forward(self, x, x_mask):
356
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
357
+ for i in range(self.n_layers):
358
+ x = x * x_mask
359
+ y = self.attn_layers[i](x, x, attn_mask)
360
+ y = self.drop(y)
361
+ x = self.norm_layers_1[i](x + y)
362
+ y = self.ffn_layers[i](x, x_mask)
363
+ y = self.drop(y)
364
+ x = self.norm_layers_2[i](x + y)
365
+ x = x * x_mask
366
+ return x
367
+
368
+
369
+ class Conformer(BaseModule):
370
+ def __init__(self, cfg):
371
+ super().__init__()
372
+ self.cfg = cfg
373
+ self.n_heads = self.cfg.n_heads
374
+ self.n_layers = self.cfg.n_layers
375
+ self.hidden_channels = self.cfg.input_dim
376
+ self.filter_channels = self.cfg.filter_channels
377
+ self.output_dim = self.cfg.output_dim
378
+ self.dropout = self.cfg.dropout
379
+
380
+ self.conformer_encoder = Encoder(
381
+ self.hidden_channels,
382
+ self.filter_channels,
383
+ n_heads=self.n_heads,
384
+ n_layers=self.n_layers,
385
+ kernel_size=3,
386
+ p_dropout=self.dropout,
387
+ window_size=4,
388
+ )
389
+ self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
390
+
391
+ def forward(self, x, x_mask):
392
+ """
393
+ Args:
394
+ x: (N, seq_len, input_dim)
395
+ Returns:
396
+ output: (N, seq_len, output_dim)
397
+ """
398
+ # (N, seq_len, d_model)
399
+ x = x.transpose(1, 2)
400
+ x_mask = x_mask.transpose(1, 2)
401
+ output = self.conformer_encoder(x, x_mask)
402
+ # (N, seq_len, output_dim)
403
+ output = self.projection(output)
404
+ output = output.transpose(1, 2)
405
+ return output
models/svc/transformer/transformer.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
10
+
11
+
12
+ class Transformer(nn.Module):
13
+ def __init__(self, cfg):
14
+ super().__init__()
15
+ self.cfg = cfg
16
+
17
+ dropout = self.cfg.dropout
18
+ nhead = self.cfg.n_heads
19
+ nlayers = self.cfg.n_layers
20
+ input_dim = self.cfg.input_dim
21
+ output_dim = self.cfg.output_dim
22
+
23
+ d_model = input_dim
24
+ self.pos_encoder = PositionalEncoding(d_model, dropout)
25
+ encoder_layers = TransformerEncoderLayer(
26
+ d_model, nhead, dropout=dropout, batch_first=True
27
+ )
28
+ self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
29
+
30
+ self.output_mlp = nn.Linear(d_model, output_dim)
31
+
32
+ def forward(self, x, mask=None):
33
+ """
34
+ Args:
35
+ x: (N, seq_len, input_dim)
36
+ Returns:
37
+ output: (N, seq_len, output_dim)
38
+ """
39
+ # (N, seq_len, d_model)
40
+ src = self.pos_encoder(x)
41
+ # model_stats["pos_embedding"] = x
42
+ # (N, seq_len, d_model)
43
+ output = self.transformer_encoder(src)
44
+ # (N, seq_len, output_dim)
45
+ output = self.output_mlp(output)
46
+ return output
47
+
48
+
49
+ class PositionalEncoding(nn.Module):
50
+ def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
51
+ super().__init__()
52
+ self.dropout = nn.Dropout(p=dropout)
53
+
54
+ position = torch.arange(max_len).unsqueeze(1)
55
+ div_term = torch.exp(
56
+ torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
57
+ )
58
+
59
+ # Assume that x is (seq_len, N, d)
60
+ # pe = torch.zeros(max_len, 1, d_model)
61
+ # pe[:, 0, 0::2] = torch.sin(position * div_term)
62
+ # pe[:, 0, 1::2] = torch.cos(position * div_term)
63
+
64
+ # Assume that x in (N, seq_len, d)
65
+ pe = torch.zeros(1, max_len, d_model)
66
+ pe[0, :, 0::2] = torch.sin(position * div_term)
67
+ pe[0, :, 1::2] = torch.cos(position * div_term)
68
+
69
+ self.register_buffer("pe", pe)
70
+
71
+ def forward(self, x):
72
+ """
73
+ Args:
74
+ x: Tensor, shape [N, seq_len, d]
75
+ """
76
+ # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
77
+ # x = x + self.pe[: x.size(0)]
78
+
79
+ # Now: self.pe is (1, max_len, d)
80
+ x = x + self.pe[:, : x.size(1), :]
81
+
82
+ return self.dropout(x)
models/svc/transformer/transformer_inference.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import time
8
+ import numpy as np
9
+ import torch
10
+ from tqdm import tqdm
11
+ import torch.nn as nn
12
+ from collections import OrderedDict
13
+
14
+ from models.svc.base import SVCInference
15
+ from modules.encoder.condition_encoder import ConditionEncoder
16
+ from models.svc.transformer.transformer import Transformer
17
+ from models.svc.transformer.conformer import Conformer
18
+
19
+
20
+ class TransformerInference(SVCInference):
21
+ def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
22
+ SVCInference.__init__(self, args, cfg, infer_type)
23
+
24
+ def _build_model(self):
25
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
26
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
27
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
28
+ if self.cfg.model.transformer.type == "transformer":
29
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
30
+ elif self.cfg.model.transformer.type == "conformer":
31
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
32
+ else:
33
+ raise NotImplementedError
34
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
35
+ return model
36
+
37
+ def _inference_each_batch(self, batch_data):
38
+ device = self.accelerator.device
39
+ for k, v in batch_data.items():
40
+ batch_data[k] = v.to(device)
41
+
42
+ condition = self.condition_encoder(batch_data)
43
+ y_pred = self.acoustic_mapper(condition, batch_data["mask"])
44
+
45
+ return y_pred
models/svc/transformer/transformer_trainer.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023 Amphion.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+
8
+ from models.svc.base import SVCTrainer
9
+ from modules.encoder.condition_encoder import ConditionEncoder
10
+ from models.svc.transformer.transformer import Transformer
11
+ from models.svc.transformer.conformer import Conformer
12
+ from utils.ssim import SSIM
13
+
14
+
15
+ class TransformerTrainer(SVCTrainer):
16
+ def __init__(self, args, cfg):
17
+ SVCTrainer.__init__(self, args, cfg)
18
+ self.ssim_loss = SSIM()
19
+
20
+ def _build_model(self):
21
+ self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
22
+ self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
23
+ self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
24
+ if self.cfg.model.transformer.type == "transformer":
25
+ self.acoustic_mapper = Transformer(self.cfg.model.transformer)
26
+ elif self.cfg.model.transformer.type == "conformer":
27
+ self.acoustic_mapper = Conformer(self.cfg.model.transformer)
28
+ else:
29
+ raise NotImplementedError
30
+ model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
31
+ return model
32
+
33
+ def _forward_step(self, batch):
34
+ total_loss = 0
35
+ device = self.accelerator.device
36
+ mel = batch["mel"]
37
+ mask = batch["mask"]
38
+
39
+ condition = self.condition_encoder(batch)
40
+ mel_pred = self.acoustic_mapper(condition, mask)
41
+
42
+ l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
43
+ batch["mask"]
44
+ )
45
+ self._check_nan(l1_loss, mel_pred, mel)
46
+ total_loss += l1_loss
47
+ ssim_loss = self.ssim_loss(mel_pred, mel)
48
+ ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
49
+ self._check_nan(ssim_loss, mel_pred, mel)
50
+ total_loss += ssim_loss
51
+
52
+ return total_loss