Spaces:
Runtime error
Runtime error
init and interface
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -1
- .gitignore +18 -0
- app.py +78 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/args.json +256 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json +17 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json +31 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json +242 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 +3 -0
- ckpts/svc/vocalist_l1_contentvec+whisper/singers.json +17 -0
- egs/svc/MultipleContentsSVC/README.md +153 -0
- egs/svc/MultipleContentsSVC/exp_config.json +126 -0
- egs/svc/MultipleContentsSVC/run.sh +1 -0
- egs/svc/README.md +34 -0
- egs/svc/_template/run.sh +150 -0
- inference.py +258 -0
- models/__init__.py +0 -0
- models/base/__init__.py +7 -0
- models/base/base_dataset.py +350 -0
- models/base/base_inference.py +220 -0
- models/base/base_sampler.py +136 -0
- models/base/base_trainer.py +348 -0
- models/base/new_dataset.py +50 -0
- models/base/new_inference.py +249 -0
- models/base/new_trainer.py +722 -0
- models/svc/__init__.py +0 -0
- models/svc/base/__init__.py +7 -0
- models/svc/base/svc_dataset.py +425 -0
- models/svc/base/svc_inference.py +15 -0
- models/svc/base/svc_trainer.py +111 -0
- models/svc/comosvc/__init__.py +4 -0
- models/svc/comosvc/comosvc.py +377 -0
- models/svc/comosvc/comosvc_inference.py +39 -0
- models/svc/comosvc/comosvc_trainer.py +295 -0
- models/svc/comosvc/utils.py +31 -0
- models/svc/diffusion/__init__.py +0 -0
- models/svc/diffusion/diffusion_inference.py +63 -0
- models/svc/diffusion/diffusion_inference_pipeline.py +47 -0
- models/svc/diffusion/diffusion_trainer.py +88 -0
- models/svc/diffusion/diffusion_wrapper.py +73 -0
- models/svc/transformer/__init__.py +0 -0
- models/svc/transformer/conformer.py +405 -0
- models/svc/transformer/transformer.py +82 -0
- models/svc/transformer/transformer_inference.py +45 -0
- models/svc/transformer/transformer_trainer.py +52 -0
.gitattributes
CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
flagged
|
3 |
+
result
|
4 |
+
|
5 |
+
# Developing mode
|
6 |
+
_*.sh
|
7 |
+
_*.json
|
8 |
+
*.lst
|
9 |
+
yard*
|
10 |
+
*.out
|
11 |
+
evaluation/evalset_selection
|
12 |
+
mfa
|
13 |
+
egs/svc/*wavmark
|
14 |
+
egs/svc/custom
|
15 |
+
egs/svc/*/dev*
|
16 |
+
egs/svc/dev_exp_config.json
|
17 |
+
bins/svc/demo*
|
18 |
+
bins/svc/preprocess_custom.py
|
app.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
|
4 |
+
SUPPORTED_TARGET_SINGERS = {
|
5 |
+
"Adele": "vocalist_l1_Adele",
|
6 |
+
"Beyonce": "vocalist_l1_Beyonce",
|
7 |
+
"Bruno Mars": "vocalist_l1_BrunoMars",
|
8 |
+
"John Mayer": "vocalist_l1_JohnMayer",
|
9 |
+
"Michael Jackson": "vocalist_l1_MichaelJackson",
|
10 |
+
"Taylor Swift": "vocalist_l1_TaylorSwift",
|
11 |
+
"Jacky Cheung 张学友": "vocalist_l1_张学友",
|
12 |
+
"Jian Li 李健": "vocalist_l1_李健",
|
13 |
+
"Feng Wang 汪峰": "vocalist_l1_汪峰",
|
14 |
+
"Faye Wong 王菲": "vocalist_l1_王菲",
|
15 |
+
"Yijie Shi 石倚洁": "vocalist_l1_石倚洁",
|
16 |
+
"Tsai Chin 蔡琴": "vocalist_l1_蔡琴",
|
17 |
+
"Ying Na 那英": "vocalist_l1_那英",
|
18 |
+
"Eason Chan 陈奕迅": "vocalist_l1_陈奕迅",
|
19 |
+
"David Tao 陶喆": "vocalist_l1_陶喆",
|
20 |
+
}
|
21 |
+
|
22 |
+
|
23 |
+
def svc_inference(
|
24 |
+
source_audio,
|
25 |
+
target_singer,
|
26 |
+
diffusion_steps=1000,
|
27 |
+
key_shift_mode="auto",
|
28 |
+
key_shift_num=0,
|
29 |
+
):
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
demo_inputs = [
|
34 |
+
gr.Audio(
|
35 |
+
sources=["upload", "microphone"],
|
36 |
+
label="Upload (or record) a song you want to listen",
|
37 |
+
),
|
38 |
+
gr.Radio(
|
39 |
+
choices=list(SUPPORTED_TARGET_SINGERS.keys()),
|
40 |
+
label="Target Singer",
|
41 |
+
value="Jian Li 李健",
|
42 |
+
),
|
43 |
+
gr.Slider(
|
44 |
+
1,
|
45 |
+
1000,
|
46 |
+
value=1000,
|
47 |
+
step=1,
|
48 |
+
label="Diffusion Inference Steps",
|
49 |
+
info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
|
50 |
+
),
|
51 |
+
gr.Radio(
|
52 |
+
choices=["Auto Shift", "Key Shift"],
|
53 |
+
value="Auto Shift",
|
54 |
+
label="Pitch Shift Control",
|
55 |
+
info='If you want to control the specific pitch shift value, you need to choose "Key Shift"',
|
56 |
+
),
|
57 |
+
gr.Slider(
|
58 |
+
-6,
|
59 |
+
6,
|
60 |
+
value=0,
|
61 |
+
step=1,
|
62 |
+
label="Key Shift Values",
|
63 |
+
info='How many semitones you want to transpose. This parameter will work only if you choose "Key Shift"',
|
64 |
+
),
|
65 |
+
]
|
66 |
+
|
67 |
+
demo_outputs = gr.Audio(label="")
|
68 |
+
|
69 |
+
|
70 |
+
demo = gr.Interface(
|
71 |
+
fn=svc_inference,
|
72 |
+
inputs=demo_inputs,
|
73 |
+
outputs=demo_outputs,
|
74 |
+
title="Amphion Singing Voice Conversion",
|
75 |
+
)
|
76 |
+
|
77 |
+
if __name__ == "__main__":
|
78 |
+
demo.launch(show_api=False)
|
ckpts/svc/vocalist_l1_contentvec+whisper/args.json
ADDED
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/diffusion.json",
|
3 |
+
"dataset": [
|
4 |
+
"vocalist_l1",
|
5 |
+
],
|
6 |
+
"exp_name": "vocalist_l1_contentvec+whisper",
|
7 |
+
"inference": {
|
8 |
+
"diffusion": {
|
9 |
+
"scheduler": "pndm",
|
10 |
+
"scheduler_settings": {
|
11 |
+
"num_inference_timesteps": 1000,
|
12 |
+
},
|
13 |
+
},
|
14 |
+
},
|
15 |
+
"model": {
|
16 |
+
"condition_encoder": {
|
17 |
+
"content_encoder_dim": 384,
|
18 |
+
"contentvec_dim": 256,
|
19 |
+
"f0_max": 1100,
|
20 |
+
"f0_min": 50,
|
21 |
+
"input_loudness_dim": 1,
|
22 |
+
"input_melody_dim": 1,
|
23 |
+
"merge_mode": "add",
|
24 |
+
"mert_dim": 256,
|
25 |
+
"n_bins_loudness": 256,
|
26 |
+
"n_bins_melody": 256,
|
27 |
+
"output_content_dim": 384,
|
28 |
+
"output_loudness_dim": 384,
|
29 |
+
"output_melody_dim": 384,
|
30 |
+
"output_singer_dim": 384,
|
31 |
+
"pitch_max": 1100,
|
32 |
+
"pitch_min": 50,
|
33 |
+
"singer_table_size": 512,
|
34 |
+
"use_conformer_for_content_features": false,
|
35 |
+
"use_contentvec": true,
|
36 |
+
"use_log_f0": true,
|
37 |
+
"use_log_loudness": true,
|
38 |
+
"use_mert": false,
|
39 |
+
"use_singer_encoder": true,
|
40 |
+
"use_spkid": true,
|
41 |
+
"use_wenet": false,
|
42 |
+
"use_whisper": true,
|
43 |
+
"wenet_dim": 512,
|
44 |
+
"whisper_dim": 1024,
|
45 |
+
},
|
46 |
+
"diffusion": {
|
47 |
+
"bidilconv": {
|
48 |
+
"base_channel": 384,
|
49 |
+
"conditioner_size": 384,
|
50 |
+
"conv_kernel_size": 3,
|
51 |
+
"dilation_cycle_length": 4,
|
52 |
+
"n_res_block": 20,
|
53 |
+
},
|
54 |
+
"model_type": "bidilconv",
|
55 |
+
"scheduler": "ddpm",
|
56 |
+
"scheduler_settings": {
|
57 |
+
"beta_end": 0.02,
|
58 |
+
"beta_schedule": "linear",
|
59 |
+
"beta_start": 0.0001,
|
60 |
+
"num_train_timesteps": 1000,
|
61 |
+
},
|
62 |
+
"step_encoder": {
|
63 |
+
"activation": "SiLU",
|
64 |
+
"dim_hidden_layer": 512,
|
65 |
+
"dim_raw_embedding": 128,
|
66 |
+
"max_period": 10000,
|
67 |
+
"num_layer": 2,
|
68 |
+
},
|
69 |
+
"unet2d": {
|
70 |
+
"down_block_types": [
|
71 |
+
"CrossAttnDownBlock2D",
|
72 |
+
"CrossAttnDownBlock2D",
|
73 |
+
"CrossAttnDownBlock2D",
|
74 |
+
"DownBlock2D",
|
75 |
+
],
|
76 |
+
"in_channels": 1,
|
77 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
78 |
+
"only_cross_attention": false,
|
79 |
+
"out_channels": 1,
|
80 |
+
"up_block_types": [
|
81 |
+
"UpBlock2D",
|
82 |
+
"CrossAttnUpBlock2D",
|
83 |
+
"CrossAttnUpBlock2D",
|
84 |
+
"CrossAttnUpBlock2D",
|
85 |
+
],
|
86 |
+
},
|
87 |
+
},
|
88 |
+
},
|
89 |
+
"model_type": "DiffWaveNetSVC",
|
90 |
+
"preprocess": {
|
91 |
+
"audio_dir": "audios",
|
92 |
+
"bits": 8,
|
93 |
+
"content_feature_batch_size": 16,
|
94 |
+
"contentvec_batch_size": 1,
|
95 |
+
"contentvec_dir": "contentvec",
|
96 |
+
"contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
|
97 |
+
"contentvec_frameshift": 0.02,
|
98 |
+
"contentvec_sample_rate": 16000,
|
99 |
+
"dur_dir": "durs",
|
100 |
+
"duration_dir": "duration",
|
101 |
+
"emo2id": "emo2id.json",
|
102 |
+
"energy_dir": "energys",
|
103 |
+
"extract_audio": false,
|
104 |
+
"extract_contentvec_feature": true,
|
105 |
+
"extract_energy": true,
|
106 |
+
"extract_label": false,
|
107 |
+
"extract_mcep": false,
|
108 |
+
"extract_mel": true,
|
109 |
+
"extract_mert_feature": false,
|
110 |
+
"extract_pitch": true,
|
111 |
+
"extract_uv": true,
|
112 |
+
"extract_wenet_feature": false,
|
113 |
+
"extract_whisper_feature": true,
|
114 |
+
"f0_max": 1100,
|
115 |
+
"f0_min": 50,
|
116 |
+
"file_lst": "file.lst",
|
117 |
+
"fmax": 12000,
|
118 |
+
"fmin": 0,
|
119 |
+
"hop_size": 256,
|
120 |
+
"is_label": true,
|
121 |
+
"is_mu_law": true,
|
122 |
+
"lab_dir": "labs",
|
123 |
+
"label_dir": "labels",
|
124 |
+
"mcep_dir": "mcep",
|
125 |
+
"mel_dir": "mels",
|
126 |
+
"mel_min_max_norm": true,
|
127 |
+
"mel_min_max_stats_dir": "mel_min_max_stats",
|
128 |
+
"mert_dir": "mert",
|
129 |
+
"mert_feature_layer": -1,
|
130 |
+
"mert_frameshit": 0.01333,
|
131 |
+
"mert_hop_size": 320,
|
132 |
+
"mert_model": "m-a-p/MERT-v1-330M",
|
133 |
+
"min_level_db": -115,
|
134 |
+
"mu_law_norm": false,
|
135 |
+
"n_fft": 1024,
|
136 |
+
"n_mel": 100,
|
137 |
+
"num_silent_frames": 8,
|
138 |
+
"num_workers": 8,
|
139 |
+
"phone_seq_file": "phone_seq_file",
|
140 |
+
"pin_memory": true,
|
141 |
+
"pitch_bin": 256,
|
142 |
+
"pitch_dir": "pitches",
|
143 |
+
"pitch_extractor": "parselmouth",
|
144 |
+
"pitch_max": 1100.0,
|
145 |
+
"pitch_min": 50.0,
|
146 |
+
"processed_dir": "ckpts/svc/vocalist_l1_contentvec+whisper/data",
|
147 |
+
"ref_level_db": 20,
|
148 |
+
"sample_rate": 24000,
|
149 |
+
"spk2id": "singers.json",
|
150 |
+
"train_file": "train.json",
|
151 |
+
"trim_fft_size": 512,
|
152 |
+
"trim_hop_size": 128,
|
153 |
+
"trim_silence": false,
|
154 |
+
"trim_top_db": 30,
|
155 |
+
"trimmed_wav_dir": "trimmed_wavs",
|
156 |
+
"use_audio": false,
|
157 |
+
"use_contentvec": true,
|
158 |
+
"use_dur": false,
|
159 |
+
"use_emoid": false,
|
160 |
+
"use_frame_duration": false,
|
161 |
+
"use_frame_energy": true,
|
162 |
+
"use_frame_pitch": true,
|
163 |
+
"use_lab": false,
|
164 |
+
"use_label": false,
|
165 |
+
"use_log_scale_energy": false,
|
166 |
+
"use_log_scale_pitch": false,
|
167 |
+
"use_mel": true,
|
168 |
+
"use_mert": false,
|
169 |
+
"use_min_max_norm_mel": true,
|
170 |
+
"use_one_hot": false,
|
171 |
+
"use_phn_seq": false,
|
172 |
+
"use_phone_duration": false,
|
173 |
+
"use_phone_energy": false,
|
174 |
+
"use_phone_pitch": false,
|
175 |
+
"use_spkid": true,
|
176 |
+
"use_uv": true,
|
177 |
+
"use_wav": false,
|
178 |
+
"use_wenet": false,
|
179 |
+
"use_whisper": true,
|
180 |
+
"utt2emo": "utt2emo",
|
181 |
+
"utt2spk": "utt2singer",
|
182 |
+
"uv_dir": "uvs",
|
183 |
+
"valid_file": "test.json",
|
184 |
+
"wav_dir": "wavs",
|
185 |
+
"wenet_batch_size": 1,
|
186 |
+
"wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
|
187 |
+
"wenet_dir": "wenet",
|
188 |
+
"wenet_downsample_rate": 4,
|
189 |
+
"wenet_frameshift": 0.01,
|
190 |
+
"wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
|
191 |
+
"wenet_sample_rate": 16000,
|
192 |
+
"whisper_batch_size": 30,
|
193 |
+
"whisper_dir": "whisper",
|
194 |
+
"whisper_downsample_rate": 2,
|
195 |
+
"whisper_frameshift": 0.01,
|
196 |
+
"whisper_model": "medium",
|
197 |
+
"whisper_model_path": "pretrained/whisper/medium.pt",
|
198 |
+
"win_size": 1024,
|
199 |
+
},
|
200 |
+
"supported_model_type": [
|
201 |
+
"Fastspeech2",
|
202 |
+
"DiffSVC",
|
203 |
+
"Transformer",
|
204 |
+
"EDM",
|
205 |
+
"CD",
|
206 |
+
],
|
207 |
+
"train": {
|
208 |
+
"adamw": {
|
209 |
+
"lr": 0.0004,
|
210 |
+
},
|
211 |
+
"batch_size": 32,
|
212 |
+
"dataloader": {
|
213 |
+
"num_worker": 8,
|
214 |
+
"pin_memory": true,
|
215 |
+
},
|
216 |
+
"ddp": true,
|
217 |
+
"epochs": 50000,
|
218 |
+
"gradient_accumulation_step": 1,
|
219 |
+
"keep_checkpoint_max": 5,
|
220 |
+
"keep_last": [
|
221 |
+
5,
|
222 |
+
-1,
|
223 |
+
],
|
224 |
+
"max_epoch": -1,
|
225 |
+
"max_steps": 1000000,
|
226 |
+
"multi_speaker_training": false,
|
227 |
+
"optimizer": "AdamW",
|
228 |
+
"random_seed": 10086,
|
229 |
+
"reducelronplateau": {
|
230 |
+
"factor": 0.8,
|
231 |
+
"min_lr": 0.0001,
|
232 |
+
"patience": 10,
|
233 |
+
},
|
234 |
+
"run_eval": [
|
235 |
+
false,
|
236 |
+
true,
|
237 |
+
],
|
238 |
+
"sampler": {
|
239 |
+
"drop_last": true,
|
240 |
+
"holistic_shuffle": false,
|
241 |
+
},
|
242 |
+
"save_checkpoint_stride": [
|
243 |
+
3,
|
244 |
+
10,
|
245 |
+
],
|
246 |
+
"save_checkpoints_steps": 10000,
|
247 |
+
"save_summary_steps": 500,
|
248 |
+
"scheduler": "ReduceLROnPlateau",
|
249 |
+
"total_training_steps": 50000,
|
250 |
+
"tracker": [
|
251 |
+
"tensorboard",
|
252 |
+
],
|
253 |
+
"valid_interval": 10000,
|
254 |
+
},
|
255 |
+
"use_custom_dataset": true,
|
256 |
+
}
|
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:836af10b834c7aec9209eb19ce43559e6ef1e3a59bd6468e90cadbc9a18749ef
|
3 |
+
size 249512389
|
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d54eed12bef331095fc367f196d07c5061d5cb72dd6fe0e1e4453b997bf1d68d
|
3 |
+
size 124755137
|
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6798ddffadcd7d5405a77e667c674c474e4fef0cba817fdd300c7c985c1e82fe
|
3 |
+
size 14599
|
ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocalist_l1_Adele": 0,
|
3 |
+
"vocalist_l1_Beyonce": 1,
|
4 |
+
"vocalist_l1_BrunoMars": 2,
|
5 |
+
"vocalist_l1_JohnMayer": 3,
|
6 |
+
"vocalist_l1_MichaelJackson": 4,
|
7 |
+
"vocalist_l1_TaylorSwift": 5,
|
8 |
+
"vocalist_l1_张学友": 6,
|
9 |
+
"vocalist_l1_李健": 7,
|
10 |
+
"vocalist_l1_汪峰": 8,
|
11 |
+
"vocalist_l1_王菲": 9,
|
12 |
+
"vocalist_l1_石倚洁": 10,
|
13 |
+
"vocalist_l1_蔡琴": 11,
|
14 |
+
"vocalist_l1_那英": 12,
|
15 |
+
"vocalist_l1_陈奕迅": 13,
|
16 |
+
"vocalist_l1_陶喆": 14
|
17 |
+
}
|
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:04131849378aa4f525a701909f743c303f8d56571682572b888046ead9f3e2ab
|
3 |
+
size 528
|
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef4895ebef0e9949a6e623315bdc8a68490ba95d2f81b2be9f5146f904203016
|
3 |
+
size 528
|
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset": "vocalist_l1",
|
3 |
+
"train": {
|
4 |
+
"size": 3180,
|
5 |
+
"hours": 6.1643
|
6 |
+
},
|
7 |
+
"test": {
|
8 |
+
"size": 114,
|
9 |
+
"hours": 0.2224
|
10 |
+
},
|
11 |
+
"singers": {
|
12 |
+
"size": 15,
|
13 |
+
"training_minutes": {
|
14 |
+
"vocalist_l1_陶喆": 45.51,
|
15 |
+
"vocalist_l1_陈奕迅": 43.36,
|
16 |
+
"vocalist_l1_汪峰": 41.08,
|
17 |
+
"vocalist_l1_李健": 38.9,
|
18 |
+
"vocalist_l1_JohnMayer": 30.83,
|
19 |
+
"vocalist_l1_Adele": 27.23,
|
20 |
+
"vocalist_l1_那英": 27.02,
|
21 |
+
"vocalist_l1_石倚洁": 24.93,
|
22 |
+
"vocalist_l1_张学友": 18.31,
|
23 |
+
"vocalist_l1_TaylorSwift": 18.31,
|
24 |
+
"vocalist_l1_王菲": 16.78,
|
25 |
+
"vocalist_l1_MichaelJackson": 15.13,
|
26 |
+
"vocalist_l1_蔡琴": 10.12,
|
27 |
+
"vocalist_l1_BrunoMars": 6.29,
|
28 |
+
"vocalist_l1_Beyonce": 6.06
|
29 |
+
}
|
30 |
+
}
|
31 |
+
}
|
ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocalist_l1_Adele": {
|
3 |
+
"voiced_positions": {
|
4 |
+
"mean": 336.5038018286193,
|
5 |
+
"std": 100.2148774476881,
|
6 |
+
"median": 332.98363792619296,
|
7 |
+
"min": 59.99838412340723,
|
8 |
+
"max": 1099.849325287837
|
9 |
+
},
|
10 |
+
"total_positions": {
|
11 |
+
"mean": 231.79366581704338,
|
12 |
+
"std": 176.6042850107386,
|
13 |
+
"median": 273.2844263775394,
|
14 |
+
"min": 0.0,
|
15 |
+
"max": 1099.849325287837
|
16 |
+
}
|
17 |
+
},
|
18 |
+
"vocalist_l1_Beyonce": {
|
19 |
+
"voiced_positions": {
|
20 |
+
"mean": 357.5678927636881,
|
21 |
+
"std": 130.1132620135807,
|
22 |
+
"median": 318.2981879228934,
|
23 |
+
"min": 70.29719673914867,
|
24 |
+
"max": 1050.354470112099
|
25 |
+
},
|
26 |
+
"total_positions": {
|
27 |
+
"mean": 267.5248026267327,
|
28 |
+
"std": 191.71600807951046,
|
29 |
+
"median": 261.91981963774066,
|
30 |
+
"min": 0.0,
|
31 |
+
"max": 1050.354470112099
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"vocalist_l1_BrunoMars": {
|
35 |
+
"voiced_positions": {
|
36 |
+
"mean": 330.92612740814315,
|
37 |
+
"std": 86.51034158515388,
|
38 |
+
"median": 324.65585832605217,
|
39 |
+
"min": 58.74277302450286,
|
40 |
+
"max": 999.2818302992808
|
41 |
+
},
|
42 |
+
"total_positions": {
|
43 |
+
"mean": 237.26076288057826,
|
44 |
+
"std": 166.09898203490803,
|
45 |
+
"median": 286.3097386522132,
|
46 |
+
"min": 0.0,
|
47 |
+
"max": 999.2818302992808
|
48 |
+
}
|
49 |
+
},
|
50 |
+
"vocalist_l1_JohnMayer": {
|
51 |
+
"voiced_positions": {
|
52 |
+
"mean": 218.3531239166661,
|
53 |
+
"std": 77.89887175223768,
|
54 |
+
"median": 200.19060542586652,
|
55 |
+
"min": 53.371912740674716,
|
56 |
+
"max": 1098.1986774161685
|
57 |
+
},
|
58 |
+
"total_positions": {
|
59 |
+
"mean": 112.95331907131244,
|
60 |
+
"std": 122.65534824070893,
|
61 |
+
"median": 124.71389285965317,
|
62 |
+
"min": 0.0,
|
63 |
+
"max": 1098.1986774161685
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"vocalist_l1_MichaelJackson": {
|
67 |
+
"voiced_positions": {
|
68 |
+
"mean": 293.4663654519906,
|
69 |
+
"std": 89.02211325650234,
|
70 |
+
"median": 284.4323483619402,
|
71 |
+
"min": 61.14507754070825,
|
72 |
+
"max": 1096.4247902272325
|
73 |
+
},
|
74 |
+
"total_positions": {
|
75 |
+
"mean": 172.1013565770682,
|
76 |
+
"std": 159.79551912957191,
|
77 |
+
"median": 212.82938711725973,
|
78 |
+
"min": 0.0,
|
79 |
+
"max": 1096.4247902272325
|
80 |
+
}
|
81 |
+
},
|
82 |
+
"vocalist_l1_TaylorSwift": {
|
83 |
+
"voiced_positions": {
|
84 |
+
"mean": 302.5346928039029,
|
85 |
+
"std": 87.1724728626562,
|
86 |
+
"median": 286.91670244246586,
|
87 |
+
"min": 51.31173137207717,
|
88 |
+
"max": 1098.9374311806605
|
89 |
+
},
|
90 |
+
"total_positions": {
|
91 |
+
"mean": 169.90968097339214,
|
92 |
+
"std": 163.7133164876362,
|
93 |
+
"median": 220.90943653386546,
|
94 |
+
"min": 0.0,
|
95 |
+
"max": 1098.9374311806605
|
96 |
+
}
|
97 |
+
},
|
98 |
+
"vocalist_l1_张学友": {
|
99 |
+
"voiced_positions": {
|
100 |
+
"mean": 233.6845479691867,
|
101 |
+
"std": 66.47140810463938,
|
102 |
+
"median": 228.28695118043396,
|
103 |
+
"min": 51.65338480121057,
|
104 |
+
"max": 1094.4381927885959
|
105 |
+
},
|
106 |
+
"total_positions": {
|
107 |
+
"mean": 167.79543637603194,
|
108 |
+
"std": 119.28338415844308,
|
109 |
+
"median": 194.81504136428546,
|
110 |
+
"min": 0.0,
|
111 |
+
"max": 1094.4381927885959
|
112 |
+
}
|
113 |
+
},
|
114 |
+
"vocalist_l1_李健": {
|
115 |
+
"voiced_positions": {
|
116 |
+
"mean": 234.98401896504657,
|
117 |
+
"std": 71.3955175177514,
|
118 |
+
"median": 221.86415264367847,
|
119 |
+
"min": 54.070687769392585,
|
120 |
+
"max": 1096.3342286660531
|
121 |
+
},
|
122 |
+
"total_positions": {
|
123 |
+
"mean": 148.74760079412246,
|
124 |
+
"std": 126.70486473504008,
|
125 |
+
"median": 180.21374566147688,
|
126 |
+
"min": 0.0,
|
127 |
+
"max": 1096.3342286660531
|
128 |
+
}
|
129 |
+
},
|
130 |
+
"vocalist_l1_汪峰": {
|
131 |
+
"voiced_positions": {
|
132 |
+
"mean": 284.27752567207864,
|
133 |
+
"std": 78.51774150654873,
|
134 |
+
"median": 278.26186808969493,
|
135 |
+
"min": 54.30945929095861,
|
136 |
+
"max": 1053.6870553733015
|
137 |
+
},
|
138 |
+
"total_positions": {
|
139 |
+
"mean": 172.41584497486713,
|
140 |
+
"std": 151.74272125914902,
|
141 |
+
"median": 216.27534661524862,
|
142 |
+
"min": 0.0,
|
143 |
+
"max": 1053.6870553733015
|
144 |
+
}
|
145 |
+
},
|
146 |
+
"vocalist_l1_王菲": {
|
147 |
+
"voiced_positions": {
|
148 |
+
"mean": 339.1661679865587,
|
149 |
+
"std": 86.86768172635271,
|
150 |
+
"median": 327.4151031268507,
|
151 |
+
"min": 51.21299842481366,
|
152 |
+
"max": 1096.7044574066776
|
153 |
+
},
|
154 |
+
"total_positions": {
|
155 |
+
"mean": 217.726880186,
|
156 |
+
"std": 176.8748978138034,
|
157 |
+
"median": 277.8608050501477,
|
158 |
+
"min": 0.0,
|
159 |
+
"max": 1096.7044574066776
|
160 |
+
}
|
161 |
+
},
|
162 |
+
"vocalist_l1_石倚洁": {
|
163 |
+
"voiced_positions": {
|
164 |
+
"mean": 279.67710779262256,
|
165 |
+
"std": 87.82306577322389,
|
166 |
+
"median": 271.13024912248443,
|
167 |
+
"min": 59.604772357481075,
|
168 |
+
"max": 1098.0574674417153
|
169 |
+
},
|
170 |
+
"total_positions": {
|
171 |
+
"mean": 205.49634806008135,
|
172 |
+
"std": 144.6064344590865,
|
173 |
+
"median": 234.19454400899718,
|
174 |
+
"min": 0.0,
|
175 |
+
"max": 1098.0574674417153
|
176 |
+
}
|
177 |
+
},
|
178 |
+
"vocalist_l1_蔡琴": {
|
179 |
+
"voiced_positions": {
|
180 |
+
"mean": 258.9105806499278,
|
181 |
+
"std": 67.4079737418162,
|
182 |
+
"median": 250.29778287949176,
|
183 |
+
"min": 54.81875790199644,
|
184 |
+
"max": 930.3733192171918
|
185 |
+
},
|
186 |
+
"total_positions": {
|
187 |
+
"mean": 197.64675891035662,
|
188 |
+
"std": 124.80889987119957,
|
189 |
+
"median": 228.14775033720753,
|
190 |
+
"min": 0.0,
|
191 |
+
"max": 930.3733192171918
|
192 |
+
}
|
193 |
+
},
|
194 |
+
"vocalist_l1_那英": {
|
195 |
+
"voiced_positions": {
|
196 |
+
"mean": 358.98655838013195,
|
197 |
+
"std": 91.30591323348871,
|
198 |
+
"median": 346.95185476261275,
|
199 |
+
"min": 71.62879029165369,
|
200 |
+
"max": 1085.4349856526985
|
201 |
+
},
|
202 |
+
"total_positions": {
|
203 |
+
"mean": 243.83317702162077,
|
204 |
+
"std": 183.68660712060583,
|
205 |
+
"median": 294.9745603259994,
|
206 |
+
"min": 0.0,
|
207 |
+
"max": 1085.4349856526985
|
208 |
+
}
|
209 |
+
},
|
210 |
+
"vocalist_l1_陈奕迅": {
|
211 |
+
"voiced_positions": {
|
212 |
+
"mean": 222.0124146654594,
|
213 |
+
"std": 68.65002654904572,
|
214 |
+
"median": 218.9200565540147,
|
215 |
+
"min": 50.48503062529368,
|
216 |
+
"max": 1084.6336454006018
|
217 |
+
},
|
218 |
+
"total_positions": {
|
219 |
+
"mean": 154.2275169157727,
|
220 |
+
"std": 117.16740631313343,
|
221 |
+
"median": 176.89315636838086,
|
222 |
+
"min": 0.0,
|
223 |
+
"max": 1084.6336454006018
|
224 |
+
}
|
225 |
+
},
|
226 |
+
"vocalist_l1_陶喆": {
|
227 |
+
"voiced_positions": {
|
228 |
+
"mean": 242.58206762395713,
|
229 |
+
"std": 69.61805791083957,
|
230 |
+
"median": 227.5222796096177,
|
231 |
+
"min": 50.44809060945403,
|
232 |
+
"max": 1098.4942623171203
|
233 |
+
},
|
234 |
+
"total_positions": {
|
235 |
+
"mean": 171.59040988406485,
|
236 |
+
"std": 124.93911390018495,
|
237 |
+
"median": 204.4328861811408,
|
238 |
+
"min": 0.0,
|
239 |
+
"max": 1098.4942623171203
|
240 |
+
}
|
241 |
+
}
|
242 |
+
}
|
ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d7f490fd0c97876e24bfc44413365ded7ff5d22c1c79f0dac0b754f3b32df76f
|
3 |
+
size 88
|
ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e01bcf2fa621ba563b70568c18fe0742d0f48cafae83a6e8beb0bb6d1f6d146d
|
3 |
+
size 77413046
|
ckpts/svc/vocalist_l1_contentvec+whisper/singers.json
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"vocalist_l1_Adele": 0,
|
3 |
+
"vocalist_l1_Beyonce": 1,
|
4 |
+
"vocalist_l1_BrunoMars": 2,
|
5 |
+
"vocalist_l1_JohnMayer": 3,
|
6 |
+
"vocalist_l1_MichaelJackson": 4,
|
7 |
+
"vocalist_l1_TaylorSwift": 5,
|
8 |
+
"vocalist_l1_张学友": 6,
|
9 |
+
"vocalist_l1_李健": 7,
|
10 |
+
"vocalist_l1_汪峰": 8,
|
11 |
+
"vocalist_l1_王菲": 9,
|
12 |
+
"vocalist_l1_石倚洁": 10,
|
13 |
+
"vocalist_l1_蔡琴": 11,
|
14 |
+
"vocalist_l1_那英": 12,
|
15 |
+
"vocalist_l1_陈奕迅": 13,
|
16 |
+
"vocalist_l1_陶喆": 14
|
17 |
+
}
|
egs/svc/MultipleContentsSVC/README.md
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion
|
2 |
+
|
3 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2310.11160)
|
4 |
+
[![demo](https://img.shields.io/badge/SVC-Demo-red)](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html)
|
5 |
+
|
6 |
+
<br>
|
7 |
+
<div align="center">
|
8 |
+
<img src="../../../imgs/svc/MultipleContentsSVC.png" width="85%">
|
9 |
+
</div>
|
10 |
+
<br>
|
11 |
+
|
12 |
+
This is the official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Specially,
|
13 |
+
|
14 |
+
- The muptile content features are from [Whipser](https://github.com/wenet-e2e/wenet) and [ContentVec](https://github.com/auspicious3000/contentvec).
|
15 |
+
- The acoustic model is based on Bidirectional Non-Causal Dilated CNN (called `DiffWaveNetSVC` in Amphion), which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
|
16 |
+
- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture and we fine-tuned it in over 120 hours singing voice data.
|
17 |
+
|
18 |
+
There are four stages in total:
|
19 |
+
|
20 |
+
1. Data preparation
|
21 |
+
2. Features extraction
|
22 |
+
3. Training
|
23 |
+
4. Inference/conversion
|
24 |
+
|
25 |
+
> **NOTE:** You need to run every command of this recipe in the `Amphion` root path:
|
26 |
+
> ```bash
|
27 |
+
> cd Amphion
|
28 |
+
> ```
|
29 |
+
|
30 |
+
## 1. Data Preparation
|
31 |
+
|
32 |
+
### Dataset Download
|
33 |
+
|
34 |
+
By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md).
|
35 |
+
|
36 |
+
### Configuration
|
37 |
+
|
38 |
+
Specify the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets.
|
39 |
+
|
40 |
+
```json
|
41 |
+
"dataset": [
|
42 |
+
"m4singer",
|
43 |
+
"opencpop",
|
44 |
+
"opensinger",
|
45 |
+
"svcc",
|
46 |
+
"vctk"
|
47 |
+
],
|
48 |
+
"dataset_path": {
|
49 |
+
// TODO: Fill in your dataset path
|
50 |
+
"m4singer": "[M4Singer dataset path]",
|
51 |
+
"opencpop": "[Opencpop dataset path]",
|
52 |
+
"opensinger": "[OpenSinger dataset path]",
|
53 |
+
"svcc": "[SVCC dataset path]",
|
54 |
+
"vctk": "[VCTK dataset path]"
|
55 |
+
},
|
56 |
+
```
|
57 |
+
|
58 |
+
## 2. Features Extraction
|
59 |
+
|
60 |
+
### Content-based Pretrained Models Download
|
61 |
+
|
62 |
+
By default, we utilize the Whisper and ContentVec to extract content features. How to download them is detailed [here](../../../pretrained/README.md).
|
63 |
+
|
64 |
+
### Configuration
|
65 |
+
|
66 |
+
Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`:
|
67 |
+
|
68 |
+
```json
|
69 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
|
70 |
+
"log_dir": "ckpts/svc",
|
71 |
+
"preprocess": {
|
72 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
73 |
+
"processed_dir": "data",
|
74 |
+
...
|
75 |
+
},
|
76 |
+
```
|
77 |
+
|
78 |
+
### Run
|
79 |
+
|
80 |
+
Run the `run.sh` as the preproces stage (set `--stage 1`).
|
81 |
+
|
82 |
+
```bash
|
83 |
+
sh egs/svc/MultipleContentsSVC/run.sh --stage 1
|
84 |
+
```
|
85 |
+
|
86 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`.
|
87 |
+
|
88 |
+
## 3. Training
|
89 |
+
|
90 |
+
### Configuration
|
91 |
+
|
92 |
+
We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines.
|
93 |
+
|
94 |
+
```json
|
95 |
+
"train": {
|
96 |
+
"batch_size": 32,
|
97 |
+
...
|
98 |
+
"adamw": {
|
99 |
+
"lr": 2.0e-4
|
100 |
+
},
|
101 |
+
...
|
102 |
+
}
|
103 |
+
```
|
104 |
+
|
105 |
+
### Run
|
106 |
+
|
107 |
+
Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`.
|
108 |
+
|
109 |
+
```bash
|
110 |
+
sh egs/svc/MultipleContentsSVC/run.sh --stage 2 --name [YourExptName]
|
111 |
+
```
|
112 |
+
|
113 |
+
> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`.
|
114 |
+
|
115 |
+
## 4. Inference/Conversion
|
116 |
+
|
117 |
+
### Pretrained Vocoder Download
|
118 |
+
|
119 |
+
We fine-tune the official BigVGAN pretrained model with over 120 hours singing voice data. The benifits of fine-tuning has been investigated in our paper (see this [demo page](https://www.zhangxueyao.com/data/MultipleContentsSVC/vocoder.html)). The final pretrained singing voice vocoder is released [here](../../../pretrained/README.md#amphion-singing-bigvgan) (called `Amphion Singing BigVGAN`).
|
120 |
+
|
121 |
+
### Run
|
122 |
+
|
123 |
+
For inference/conversion, you need to specify the following configurations when running `run.sh`:
|
124 |
+
|
125 |
+
| Parameters | Description | Example |
|
126 |
+
| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
127 |
+
| `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/svc/[YourExptName]` |
|
128 |
+
| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/svc/[YourExptName]/result` |
|
129 |
+
| `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir). | The `infer_source_file` could be `Amphion/data/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). |
|
130 |
+
| `--infer_target_speaker` | The target speaker you want to convert into. You can refer to `Amphion/ckpts/svc/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`. |
|
131 |
+
| `--infer_key_shift` | How many semitones you want to transpose. | `"autoshfit"` (by default), `3`, `-3`, etc. |
|
132 |
+
|
133 |
+
For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run:
|
134 |
+
|
135 |
+
```bash
|
136 |
+
sh egs/svc/MultipleContentsSVC/run.sh --stage 3 --gpu "0" \
|
137 |
+
--infer_expt_dir Amphion/ckpts/svc/[YourExptName] \
|
138 |
+
--infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \
|
139 |
+
--infer_source_audio_dir [Your Audios Folder] \
|
140 |
+
--infer_target_speaker "opencpop_female1" \
|
141 |
+
--infer_key_shift "autoshift"
|
142 |
+
```
|
143 |
+
|
144 |
+
## Citations
|
145 |
+
|
146 |
+
```bibtex
|
147 |
+
@article{zhang2023leveraging,
|
148 |
+
title={Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion},
|
149 |
+
author={Zhang, Xueyao and Gu, Yicheng and Chen, Haopeng and Fang, Zihao and Zou, Lexiao and Xue, Liumeng and Wu, Zhizheng},
|
150 |
+
journal={Machine Learning for Audio Worshop, NeurIPS 2023},
|
151 |
+
year={2023}
|
152 |
+
}
|
153 |
+
```
|
egs/svc/MultipleContentsSVC/exp_config.json
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"base_config": "config/diffusion.json",
|
3 |
+
"model_type": "DiffWaveNetSVC",
|
4 |
+
"dataset": [
|
5 |
+
"m4singer",
|
6 |
+
"opencpop",
|
7 |
+
"opensinger",
|
8 |
+
"svcc",
|
9 |
+
"vctk"
|
10 |
+
],
|
11 |
+
"dataset_path": {
|
12 |
+
// TODO: Fill in your dataset path
|
13 |
+
"m4singer": "[M4Singer dataset path]",
|
14 |
+
"opencpop": "[Opencpop dataset path]",
|
15 |
+
"opensinger": "[OpenSinger dataset path]",
|
16 |
+
"svcc": "[SVCC dataset path]",
|
17 |
+
"vctk": "[VCTK dataset path]"
|
18 |
+
},
|
19 |
+
// TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc"
|
20 |
+
"log_dir": "ckpts/svc",
|
21 |
+
"preprocess": {
|
22 |
+
// TODO: Fill in the output data path. The default value is "Amphion/data"
|
23 |
+
"processed_dir": "data",
|
24 |
+
// Config for features extraction
|
25 |
+
"extract_mel": true,
|
26 |
+
"extract_pitch": true,
|
27 |
+
"extract_energy": true,
|
28 |
+
"extract_whisper_feature": true,
|
29 |
+
"extract_contentvec_feature": true,
|
30 |
+
"extract_wenet_feature": false,
|
31 |
+
"whisper_batch_size": 30, // decrease it if your GPU is out of memory
|
32 |
+
"contentvec_batch_size": 1,
|
33 |
+
// Fill in the content-based pretrained model's path
|
34 |
+
"contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt",
|
35 |
+
"wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt",
|
36 |
+
"wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml",
|
37 |
+
"whisper_model": "medium",
|
38 |
+
"whisper_model_path": "pretrained/whisper/medium.pt",
|
39 |
+
// Config for features usage
|
40 |
+
"use_mel": true,
|
41 |
+
"use_min_max_norm_mel": true,
|
42 |
+
"use_frame_pitch": true,
|
43 |
+
"use_frame_energy": true,
|
44 |
+
"use_spkid": true,
|
45 |
+
"use_whisper": true,
|
46 |
+
"use_contentvec": true,
|
47 |
+
"use_wenet": false,
|
48 |
+
"n_mel": 100,
|
49 |
+
"sample_rate": 24000
|
50 |
+
},
|
51 |
+
"model": {
|
52 |
+
"condition_encoder": {
|
53 |
+
// Config for features usage
|
54 |
+
"use_whisper": true,
|
55 |
+
"use_contentvec": true,
|
56 |
+
"use_wenet": false,
|
57 |
+
"whisper_dim": 1024,
|
58 |
+
"contentvec_dim": 256,
|
59 |
+
"wenet_dim": 512,
|
60 |
+
"use_singer_encoder": false,
|
61 |
+
"pitch_min": 50,
|
62 |
+
"pitch_max": 1100
|
63 |
+
},
|
64 |
+
"diffusion": {
|
65 |
+
"scheduler": "ddpm",
|
66 |
+
"scheduler_settings": {
|
67 |
+
"num_train_timesteps": 1000,
|
68 |
+
"beta_start": 1.0e-4,
|
69 |
+
"beta_end": 0.02,
|
70 |
+
"beta_schedule": "linear"
|
71 |
+
},
|
72 |
+
// Diffusion steps encoder
|
73 |
+
"step_encoder": {
|
74 |
+
"dim_raw_embedding": 128,
|
75 |
+
"dim_hidden_layer": 512,
|
76 |
+
"activation": "SiLU",
|
77 |
+
"num_layer": 2,
|
78 |
+
"max_period": 10000
|
79 |
+
},
|
80 |
+
// Diffusion decoder
|
81 |
+
"model_type": "bidilconv",
|
82 |
+
// bidilconv, unet2d, TODO: unet1d
|
83 |
+
"bidilconv": {
|
84 |
+
"base_channel": 512,
|
85 |
+
"n_res_block": 40,
|
86 |
+
"conv_kernel_size": 3,
|
87 |
+
"dilation_cycle_length": 4,
|
88 |
+
// specially, 1 means no dilation
|
89 |
+
"conditioner_size": 384
|
90 |
+
}
|
91 |
+
}
|
92 |
+
},
|
93 |
+
"train": {
|
94 |
+
"batch_size": 32,
|
95 |
+
"gradient_accumulation_step": 1,
|
96 |
+
"max_epoch": -1, // -1 means no limit
|
97 |
+
"save_checkpoint_stride": [
|
98 |
+
3,
|
99 |
+
50
|
100 |
+
],
|
101 |
+
"keep_last": [
|
102 |
+
3,
|
103 |
+
2
|
104 |
+
],
|
105 |
+
"run_eval": [
|
106 |
+
true,
|
107 |
+
true
|
108 |
+
],
|
109 |
+
"adamw": {
|
110 |
+
"lr": 2.0e-4
|
111 |
+
},
|
112 |
+
"reducelronplateau": {
|
113 |
+
"factor": 0.8,
|
114 |
+
"patience": 30,
|
115 |
+
"min_lr": 1.0e-4
|
116 |
+
},
|
117 |
+
"dataloader": {
|
118 |
+
"num_worker": 8,
|
119 |
+
"pin_memory": true
|
120 |
+
},
|
121 |
+
"sampler": {
|
122 |
+
"holistic_shuffle": false,
|
123 |
+
"drop_last": true
|
124 |
+
}
|
125 |
+
}
|
126 |
+
}
|
egs/svc/MultipleContentsSVC/run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
../_template/run.sh
|
egs/svc/README.md
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Amphion Singing Voice Conversion (SVC) Recipe
|
2 |
+
|
3 |
+
## Quick Start
|
4 |
+
|
5 |
+
We provide a **[beginner recipe](MultipleContentsSVC)** to demonstrate how to train a cutting edge SVC model. Specifically, it is also an official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Some demos can be seen [here](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html).
|
6 |
+
|
7 |
+
## Supported Model Architectures
|
8 |
+
|
9 |
+
The main idea of SVC is to first disentangle the speaker-agnostic representations from the source audio, and then inject the desired speaker information to synthesize the target, which usually utilizes an acoustic decoder and a subsequent waveform synthesizer (vocoder):
|
10 |
+
|
11 |
+
<br>
|
12 |
+
<div align="center">
|
13 |
+
<img src="../../imgs/svc/pipeline.png" width="70%">
|
14 |
+
</div>
|
15 |
+
<br>
|
16 |
+
|
17 |
+
Until now, Amphion SVC has supported the following features and models:
|
18 |
+
|
19 |
+
- **Speaker-agnostic Representations**:
|
20 |
+
- Content Features: Sourcing from [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec).
|
21 |
+
- Prosody Features: F0 and energy.
|
22 |
+
- **Speaker Embeddings**:
|
23 |
+
- Speaker Look-Up Table.
|
24 |
+
- Reference Encoder (👨💻 developing): It can be used for zero-shot SVC.
|
25 |
+
- **Acoustic Decoders**:
|
26 |
+
- Diffusion-based models:
|
27 |
+
- **[DiffWaveNetSVC](MultipleContentsSVC)**: The encoder is based on Bidirectional Non-Causal Dilated CNN, which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219).
|
28 |
+
- **[DiffComoSVC](DiffComoSVC)** (👨💻 developing): The diffusion framework is based on [Consistency Model](https://proceedings.mlr.press/v202/song23a.html). It can significantly accelerate the inference process of the diffusion model.
|
29 |
+
- Transformer-based models:
|
30 |
+
- **[TransformerSVC](TransformerSVC)**: Encoder-only and Non-autoregressive Transformer Architecture.
|
31 |
+
- VAE- and Flow-based models:
|
32 |
+
- **[VitsSVC]()** (👨💻 developing): It is designed as a [VITS](https://arxiv.org/abs/2106.06103)-like model whose textual input is replaced by the content features, which is similar to [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc).
|
33 |
+
- **Waveform Synthesizers (Vocoders)**:
|
34 |
+
- The supported vocoders can be seen in [Amphion Vocoder Recipe](../vocoder/README.md).
|
egs/svc/_template/run.sh
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
######## Build Experiment Environment ###########
|
7 |
+
exp_dir=$(cd `dirname $0`; pwd)
|
8 |
+
work_dir=$(dirname $(dirname $(dirname $exp_dir)))
|
9 |
+
|
10 |
+
export WORK_DIR=$work_dir
|
11 |
+
export PYTHONPATH=$work_dir
|
12 |
+
export PYTHONIOENCODING=UTF-8
|
13 |
+
|
14 |
+
######## Parse the Given Parameters from the Commond ###########
|
15 |
+
options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir: -- "$@")
|
16 |
+
eval set -- "$options"
|
17 |
+
|
18 |
+
while true; do
|
19 |
+
case $1 in
|
20 |
+
# Experimental Configuration File
|
21 |
+
-c | --config) shift; exp_config=$1 ; shift ;;
|
22 |
+
# Experimental Name
|
23 |
+
-n | --name) shift; exp_name=$1 ; shift ;;
|
24 |
+
# Running Stage
|
25 |
+
-s | --stage) shift; running_stage=$1 ; shift ;;
|
26 |
+
# Visible GPU machines. The default value is "0".
|
27 |
+
--gpu) shift; gpu=$1 ; shift ;;
|
28 |
+
|
29 |
+
# [Only for Training] Resume configuration
|
30 |
+
--resume) shift; resume=$1 ; shift ;;
|
31 |
+
# [Only for Training] The specific checkpoint path that you want to resume from.
|
32 |
+
--resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;;
|
33 |
+
# [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights.
|
34 |
+
--resume_type) shift; resume_type=$1 ; shift ;;
|
35 |
+
|
36 |
+
# [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]"
|
37 |
+
--infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;;
|
38 |
+
# [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result"
|
39 |
+
--infer_output_dir) shift; infer_output_dir=$1 ; shift ;;
|
40 |
+
# [Only for Inference] The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir can be "$work_dir/source_audio" which includes several audio files (*.wav, *.mp3 or *.flac).
|
41 |
+
--infer_source_file) shift; infer_source_file=$1 ; shift ;;
|
42 |
+
--infer_source_audio_dir) shift; infer_source_audio_dir=$1 ; shift ;;
|
43 |
+
# [Only for Inference] Specify the target speaker you want to convert into. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1".
|
44 |
+
--infer_target_speaker) shift; infer_target_speaker=$1 ; shift ;;
|
45 |
+
# [Only for Inference] For advanced users, you can modify the trans_key parameters into an integer (which means the semitones you want to transpose). Its default value is "autoshift".
|
46 |
+
--infer_key_shift) shift; infer_key_shift=$1 ; shift ;;
|
47 |
+
# [Only for Inference] The vocoder dir. Its default value is Amphion/pretrained/bigvgan. See Amphion/pretrained/README.md to download the pretrained BigVGAN vocoders.
|
48 |
+
--infer_vocoder_dir) shift; infer_vocoder_dir=$1 ; shift ;;
|
49 |
+
|
50 |
+
--) shift ; break ;;
|
51 |
+
*) echo "Invalid option: $1" exit 1 ;;
|
52 |
+
esac
|
53 |
+
done
|
54 |
+
|
55 |
+
|
56 |
+
### Value check ###
|
57 |
+
if [ -z "$running_stage" ]; then
|
58 |
+
echo "[Error] Please specify the running stage"
|
59 |
+
exit 1
|
60 |
+
fi
|
61 |
+
|
62 |
+
if [ -z "$exp_config" ]; then
|
63 |
+
exp_config="${exp_dir}"/exp_config.json
|
64 |
+
fi
|
65 |
+
echo "Exprimental Configuration File: $exp_config"
|
66 |
+
|
67 |
+
if [ -z "$gpu" ]; then
|
68 |
+
gpu="0"
|
69 |
+
fi
|
70 |
+
|
71 |
+
######## Features Extraction ###########
|
72 |
+
if [ $running_stage -eq 1 ]; then
|
73 |
+
CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/svc/preprocess.py \
|
74 |
+
--config $exp_config \
|
75 |
+
--num_workers 4
|
76 |
+
fi
|
77 |
+
|
78 |
+
######## Training ###########
|
79 |
+
if [ $running_stage -eq 2 ]; then
|
80 |
+
if [ -z "$exp_name" ]; then
|
81 |
+
echo "[Error] Please specify the experiments name"
|
82 |
+
exit 1
|
83 |
+
fi
|
84 |
+
echo "Exprimental Name: $exp_name"
|
85 |
+
|
86 |
+
if [ "$resume" = true ]; then
|
87 |
+
echo "Automatically resume from the experimental dir..."
|
88 |
+
CUDA_VISIBLE_DEVICES="$gpu" accelerate launch "${work_dir}"/bins/svc/train.py \
|
89 |
+
--config "$exp_config" \
|
90 |
+
--exp_name "$exp_name" \
|
91 |
+
--log_level info \
|
92 |
+
--resume
|
93 |
+
else
|
94 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/svc/train.py \
|
95 |
+
--config "$exp_config" \
|
96 |
+
--exp_name "$exp_name" \
|
97 |
+
--log_level info \
|
98 |
+
--resume_from_ckpt_path "$resume_from_ckpt_path" \
|
99 |
+
--resume_type "$resume_type"
|
100 |
+
fi
|
101 |
+
fi
|
102 |
+
|
103 |
+
######## Inference/Conversion ###########
|
104 |
+
if [ $running_stage -eq 3 ]; then
|
105 |
+
if [ -z "$infer_expt_dir" ]; then
|
106 |
+
echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]"
|
107 |
+
exit 1
|
108 |
+
fi
|
109 |
+
|
110 |
+
if [ -z "$infer_output_dir" ]; then
|
111 |
+
infer_output_dir="$expt_dir/result"
|
112 |
+
fi
|
113 |
+
|
114 |
+
if [ -z "$infer_source_file" ] && [ -z "$infer_source_audio_dir" ]; then
|
115 |
+
echo "[Error] Please specify the source file/dir. The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir should include several audio files (*.wav, *.mp3 or *.flac)."
|
116 |
+
exit 1
|
117 |
+
fi
|
118 |
+
|
119 |
+
if [ -z "$infer_source_file" ]; then
|
120 |
+
infer_source=$infer_source_audio_dir
|
121 |
+
fi
|
122 |
+
|
123 |
+
if [ -z "$infer_source_audio_dir" ]; then
|
124 |
+
infer_source=$infer_source_file
|
125 |
+
fi
|
126 |
+
|
127 |
+
if [ -z "$infer_target_speaker" ]; then
|
128 |
+
echo "[Error] Please specify the target speaker. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1""
|
129 |
+
exit 1
|
130 |
+
fi
|
131 |
+
|
132 |
+
if [ -z "$infer_key_shift" ]; then
|
133 |
+
infer_key_shift="autoshift"
|
134 |
+
fi
|
135 |
+
|
136 |
+
if [ -z "$infer_vocoder_dir" ]; then
|
137 |
+
infer_vocoder_dir="$work_dir"/pretrained/bigvgan
|
138 |
+
echo "[Warning] You don't specify the infer_vocoder_dir. It is set $infer_vocoder_dir by default. Make sure that you have followed Amphoion/pretrained/README.md to download the pretrained BigVGAN vocoder checkpoint."
|
139 |
+
fi
|
140 |
+
|
141 |
+
CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/svc/inference.py \
|
142 |
+
--config $exp_config \
|
143 |
+
--acoustics_dir $infer_expt_dir \
|
144 |
+
--vocoder_dir $infer_vocoder_dir \
|
145 |
+
--target_singer $infer_target_speaker \
|
146 |
+
--trans_key $infer_key_shift \
|
147 |
+
--source $infer_source \
|
148 |
+
--output_dir $infer_output_dir \
|
149 |
+
--log_level debug
|
150 |
+
fi
|
inference.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import glob
|
9 |
+
from tqdm import tqdm
|
10 |
+
import json
|
11 |
+
import torch
|
12 |
+
import time
|
13 |
+
|
14 |
+
from models.svc.diffusion.diffusion_inference import DiffusionInference
|
15 |
+
from models.svc.comosvc.comosvc_inference import ComoSVCInference
|
16 |
+
from models.svc.transformer.transformer_inference import TransformerInference
|
17 |
+
from utils.util import load_config
|
18 |
+
from utils.audio_slicer import split_audio, merge_segments_encodec
|
19 |
+
from processors import acoustic_extractor, content_extractor
|
20 |
+
|
21 |
+
|
22 |
+
def build_inference(args, cfg, infer_type="from_dataset"):
|
23 |
+
supported_inference = {
|
24 |
+
"DiffWaveNetSVC": DiffusionInference,
|
25 |
+
"DiffComoSVC": ComoSVCInference,
|
26 |
+
"TransformerSVC": TransformerInference,
|
27 |
+
}
|
28 |
+
|
29 |
+
inference_class = supported_inference[cfg.model_type]
|
30 |
+
return inference_class(args, cfg, infer_type)
|
31 |
+
|
32 |
+
|
33 |
+
def prepare_for_audio_file(args, cfg, num_workers=1):
|
34 |
+
preprocess_path = cfg.preprocess.processed_dir
|
35 |
+
audio_name = cfg.inference.source_audio_name
|
36 |
+
temp_audio_dir = os.path.join(preprocess_path, audio_name)
|
37 |
+
|
38 |
+
### eval file
|
39 |
+
t = time.time()
|
40 |
+
eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name)
|
41 |
+
args.source = eval_file
|
42 |
+
with open(eval_file, "r") as f:
|
43 |
+
metadata = json.load(f)
|
44 |
+
print("Prepare for meta eval data: {:.1f}s".format(time.time() - t))
|
45 |
+
|
46 |
+
### acoustic features
|
47 |
+
t = time.time()
|
48 |
+
acoustic_extractor.extract_utt_acoustic_features_serial(
|
49 |
+
metadata, temp_audio_dir, cfg
|
50 |
+
)
|
51 |
+
acoustic_extractor.cal_mel_min_max(
|
52 |
+
dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
|
53 |
+
)
|
54 |
+
acoustic_extractor.cal_pitch_statistics_svc(
|
55 |
+
dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata
|
56 |
+
)
|
57 |
+
print("Prepare for acoustic features: {:.1f}s".format(time.time() - t))
|
58 |
+
|
59 |
+
### content features
|
60 |
+
t = time.time()
|
61 |
+
content_extractor.extract_utt_content_features_dataloader(
|
62 |
+
cfg, metadata, num_workers
|
63 |
+
)
|
64 |
+
print("Prepare for content features: {:.1f}s".format(time.time() - t))
|
65 |
+
return args, cfg, temp_audio_dir
|
66 |
+
|
67 |
+
|
68 |
+
def merge_for_audio_segments(audio_files, args, cfg):
|
69 |
+
audio_name = cfg.inference.source_audio_name
|
70 |
+
target_singer_name = args.target_singer
|
71 |
+
|
72 |
+
merge_segments_encodec(
|
73 |
+
wav_files=audio_files,
|
74 |
+
fs=cfg.preprocess.sample_rate,
|
75 |
+
output_path=os.path.join(
|
76 |
+
args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name)
|
77 |
+
),
|
78 |
+
overlap_duration=cfg.inference.segments_overlap_duration,
|
79 |
+
)
|
80 |
+
|
81 |
+
for tmp_file in audio_files:
|
82 |
+
os.remove(tmp_file)
|
83 |
+
|
84 |
+
|
85 |
+
def prepare_source_eval_file(cfg, temp_audio_dir, audio_name):
|
86 |
+
"""
|
87 |
+
Prepare the eval file (json) for an audio
|
88 |
+
"""
|
89 |
+
|
90 |
+
audio_chunks_results = split_audio(
|
91 |
+
wav_file=cfg.inference.source_audio_path,
|
92 |
+
target_sr=cfg.preprocess.sample_rate,
|
93 |
+
output_dir=os.path.join(temp_audio_dir, "wavs"),
|
94 |
+
max_duration_of_segment=cfg.inference.segments_max_duration,
|
95 |
+
overlap_duration=cfg.inference.segments_overlap_duration,
|
96 |
+
)
|
97 |
+
|
98 |
+
metadata = []
|
99 |
+
for i, res in enumerate(audio_chunks_results):
|
100 |
+
res["index"] = i
|
101 |
+
res["Dataset"] = audio_name
|
102 |
+
res["Singer"] = audio_name
|
103 |
+
res["Uid"] = "{}_{}".format(audio_name, res["Uid"])
|
104 |
+
metadata.append(res)
|
105 |
+
|
106 |
+
eval_file = os.path.join(temp_audio_dir, "eval.json")
|
107 |
+
with open(eval_file, "w") as f:
|
108 |
+
json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True)
|
109 |
+
|
110 |
+
return eval_file
|
111 |
+
|
112 |
+
|
113 |
+
def cuda_relevant(deterministic=False):
|
114 |
+
torch.cuda.empty_cache()
|
115 |
+
# TF32 on Ampere and above
|
116 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
117 |
+
torch.backends.cudnn.enabled = True
|
118 |
+
torch.backends.cudnn.allow_tf32 = True
|
119 |
+
# Deterministic
|
120 |
+
torch.backends.cudnn.deterministic = deterministic
|
121 |
+
torch.backends.cudnn.benchmark = not deterministic
|
122 |
+
torch.use_deterministic_algorithms(deterministic)
|
123 |
+
|
124 |
+
|
125 |
+
def infer(args, cfg, infer_type):
|
126 |
+
# Build inference
|
127 |
+
t = time.time()
|
128 |
+
trainer = build_inference(args, cfg, infer_type)
|
129 |
+
print("Model Init: {:.1f}s".format(time.time() - t))
|
130 |
+
|
131 |
+
# Run inference
|
132 |
+
t = time.time()
|
133 |
+
output_audio_files = trainer.inference()
|
134 |
+
print("Model inference: {:.1f}s".format(time.time() - t))
|
135 |
+
return output_audio_files
|
136 |
+
|
137 |
+
|
138 |
+
def build_parser():
|
139 |
+
r"""Build argument parser for inference.py.
|
140 |
+
Anything else should be put in an extra config YAML file.
|
141 |
+
"""
|
142 |
+
|
143 |
+
parser = argparse.ArgumentParser()
|
144 |
+
parser.add_argument(
|
145 |
+
"--config",
|
146 |
+
type=str,
|
147 |
+
required=True,
|
148 |
+
help="JSON/YAML file for configurations.",
|
149 |
+
)
|
150 |
+
parser.add_argument(
|
151 |
+
"--acoustics_dir",
|
152 |
+
type=str,
|
153 |
+
help="Acoustics model checkpoint directory. If a directory is given, "
|
154 |
+
"search for the latest checkpoint dir in the directory. If a specific "
|
155 |
+
"checkpoint dir is given, directly load the checkpoint.",
|
156 |
+
)
|
157 |
+
parser.add_argument(
|
158 |
+
"--vocoder_dir",
|
159 |
+
type=str,
|
160 |
+
required=True,
|
161 |
+
help="Vocoder checkpoint directory. Searching behavior is the same as "
|
162 |
+
"the acoustics one.",
|
163 |
+
)
|
164 |
+
parser.add_argument(
|
165 |
+
"--target_singer",
|
166 |
+
type=str,
|
167 |
+
required=True,
|
168 |
+
help="convert to a specific singer (e.g. --target_singers singer_id).",
|
169 |
+
)
|
170 |
+
parser.add_argument(
|
171 |
+
"--trans_key",
|
172 |
+
default=0,
|
173 |
+
help="0: no pitch shift; autoshift: pitch shift; int: key shift.",
|
174 |
+
)
|
175 |
+
parser.add_argument(
|
176 |
+
"--source",
|
177 |
+
type=str,
|
178 |
+
default="source_audio",
|
179 |
+
help="Source audio file or directory. If a JSON file is given, "
|
180 |
+
"inference from dataset is applied. If a directory is given, "
|
181 |
+
"inference from all wav/flac/mp3 audio files in the directory is applied. "
|
182 |
+
"Default: inference from all wav/flac/mp3 audio files in ./source_audio",
|
183 |
+
)
|
184 |
+
parser.add_argument(
|
185 |
+
"--output_dir",
|
186 |
+
type=str,
|
187 |
+
default="conversion_results",
|
188 |
+
help="Output directory. Default: ./conversion_results",
|
189 |
+
)
|
190 |
+
parser.add_argument(
|
191 |
+
"--log_level",
|
192 |
+
type=str,
|
193 |
+
default="warning",
|
194 |
+
help="Logging level. Default: warning",
|
195 |
+
)
|
196 |
+
parser.add_argument(
|
197 |
+
"--keep_cache",
|
198 |
+
action="store_true",
|
199 |
+
default=True,
|
200 |
+
help="Keep cache files. Only applicable to inference from files.",
|
201 |
+
)
|
202 |
+
parser.add_argument(
|
203 |
+
"--diffusion_inference_steps",
|
204 |
+
type=int,
|
205 |
+
default=1000,
|
206 |
+
help="Number of inference steps. Only applicable to diffusion inference.",
|
207 |
+
)
|
208 |
+
return parser
|
209 |
+
|
210 |
+
|
211 |
+
def main():
|
212 |
+
### Parse arguments and config
|
213 |
+
args = build_parser().parse_args()
|
214 |
+
cfg = load_config(args.config)
|
215 |
+
|
216 |
+
# CUDA settings
|
217 |
+
cuda_relevant()
|
218 |
+
|
219 |
+
if os.path.isdir(args.source):
|
220 |
+
### Infer from file
|
221 |
+
|
222 |
+
# Get all the source audio files (.wav, .flac, .mp3)
|
223 |
+
source_audio_dir = args.source
|
224 |
+
audio_list = []
|
225 |
+
for suffix in ["wav", "flac", "mp3"]:
|
226 |
+
audio_list += glob.glob(
|
227 |
+
os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True
|
228 |
+
)
|
229 |
+
print("There are {} source audios: ".format(len(audio_list)))
|
230 |
+
|
231 |
+
# Infer for every file as dataset
|
232 |
+
output_root_path = args.output_dir
|
233 |
+
for audio_path in tqdm(audio_list):
|
234 |
+
audio_name = audio_path.split("/")[-1].split(".")[0]
|
235 |
+
args.output_dir = os.path.join(output_root_path, audio_name)
|
236 |
+
print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name))
|
237 |
+
|
238 |
+
cfg.inference.source_audio_path = audio_path
|
239 |
+
cfg.inference.source_audio_name = audio_name
|
240 |
+
cfg.inference.segments_max_duration = 10.0
|
241 |
+
cfg.inference.segments_overlap_duration = 1.0
|
242 |
+
|
243 |
+
# Prepare metadata and features
|
244 |
+
args, cfg, cache_dir = prepare_for_audio_file(args, cfg)
|
245 |
+
|
246 |
+
# Infer from file
|
247 |
+
output_audio_files = infer(args, cfg, infer_type="from_file")
|
248 |
+
|
249 |
+
# Merge the split segments
|
250 |
+
merge_for_audio_segments(output_audio_files, args, cfg)
|
251 |
+
|
252 |
+
# Keep or remove caches
|
253 |
+
if not args.keep_cache:
|
254 |
+
os.removedirs(cache_dir)
|
255 |
+
|
256 |
+
else:
|
257 |
+
### Infer from dataset
|
258 |
+
infer(args, cfg, infer_type="from_dataset")
|
models/__init__.py
ADDED
File without changes
|
models/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .new_trainer import BaseTrainer
|
7 |
+
from .new_inference import BaseInference
|
models/base/base_dataset.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import numpy as np
|
8 |
+
import torch.utils.data
|
9 |
+
from torch.nn.utils.rnn import pad_sequence
|
10 |
+
from utils.data_utils import *
|
11 |
+
from processors.acoustic_extractor import cal_normalized_mel
|
12 |
+
from text import text_to_sequence
|
13 |
+
from text.text_token_collation import phoneIDCollation
|
14 |
+
|
15 |
+
|
16 |
+
class BaseDataset(torch.utils.data.Dataset):
|
17 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
cfg: config
|
21 |
+
dataset: dataset name
|
22 |
+
is_valid: whether to use train or valid dataset
|
23 |
+
"""
|
24 |
+
|
25 |
+
assert isinstance(dataset, str)
|
26 |
+
|
27 |
+
# self.data_root = processed_data_dir
|
28 |
+
self.cfg = cfg
|
29 |
+
|
30 |
+
processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset)
|
31 |
+
meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file
|
32 |
+
self.metafile_path = os.path.join(processed_data_dir, meta_file)
|
33 |
+
self.metadata = self.get_metadata()
|
34 |
+
|
35 |
+
|
36 |
+
|
37 |
+
'''
|
38 |
+
load spk2id and utt2spk from json file
|
39 |
+
spk2id: {spk1: 0, spk2: 1, ...}
|
40 |
+
utt2spk: {dataset_uid: spk1, ...}
|
41 |
+
'''
|
42 |
+
if cfg.preprocess.use_spkid:
|
43 |
+
spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id)
|
44 |
+
with open(spk2id_path, "r") as f:
|
45 |
+
self.spk2id = json.load(f)
|
46 |
+
|
47 |
+
utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk)
|
48 |
+
self.utt2spk = dict()
|
49 |
+
with open(utt2spk_path, "r") as f:
|
50 |
+
for line in f.readlines():
|
51 |
+
utt, spk = line.strip().split('\t')
|
52 |
+
self.utt2spk[utt] = spk
|
53 |
+
|
54 |
+
|
55 |
+
if cfg.preprocess.use_uv:
|
56 |
+
self.utt2uv_path = {}
|
57 |
+
for utt_info in self.metadata:
|
58 |
+
dataset = utt_info["Dataset"]
|
59 |
+
uid = utt_info["Uid"]
|
60 |
+
utt = "{}_{}".format(dataset, uid)
|
61 |
+
self.utt2uv_path[utt] = os.path.join(
|
62 |
+
cfg.preprocess.processed_dir,
|
63 |
+
dataset,
|
64 |
+
cfg.preprocess.uv_dir,
|
65 |
+
uid + ".npy",
|
66 |
+
)
|
67 |
+
|
68 |
+
if cfg.preprocess.use_frame_pitch:
|
69 |
+
self.utt2frame_pitch_path = {}
|
70 |
+
for utt_info in self.metadata:
|
71 |
+
dataset = utt_info["Dataset"]
|
72 |
+
uid = utt_info["Uid"]
|
73 |
+
utt = "{}_{}".format(dataset, uid)
|
74 |
+
|
75 |
+
self.utt2frame_pitch_path[utt] = os.path.join(
|
76 |
+
cfg.preprocess.processed_dir,
|
77 |
+
dataset,
|
78 |
+
cfg.preprocess.pitch_dir,
|
79 |
+
uid + ".npy",
|
80 |
+
)
|
81 |
+
|
82 |
+
if cfg.preprocess.use_frame_energy:
|
83 |
+
self.utt2frame_energy_path = {}
|
84 |
+
for utt_info in self.metadata:
|
85 |
+
dataset = utt_info["Dataset"]
|
86 |
+
uid = utt_info["Uid"]
|
87 |
+
utt = "{}_{}".format(dataset, uid)
|
88 |
+
|
89 |
+
self.utt2frame_energy_path[utt] = os.path.join(
|
90 |
+
cfg.preprocess.processed_dir,
|
91 |
+
dataset,
|
92 |
+
cfg.preprocess.energy_dir,
|
93 |
+
uid + ".npy",
|
94 |
+
)
|
95 |
+
|
96 |
+
if cfg.preprocess.use_mel:
|
97 |
+
self.utt2mel_path = {}
|
98 |
+
for utt_info in self.metadata:
|
99 |
+
dataset = utt_info["Dataset"]
|
100 |
+
uid = utt_info["Uid"]
|
101 |
+
utt = "{}_{}".format(dataset, uid)
|
102 |
+
|
103 |
+
self.utt2mel_path[utt] = os.path.join(
|
104 |
+
cfg.preprocess.processed_dir,
|
105 |
+
dataset,
|
106 |
+
cfg.preprocess.mel_dir,
|
107 |
+
uid + ".npy",
|
108 |
+
)
|
109 |
+
|
110 |
+
if cfg.preprocess.use_linear:
|
111 |
+
self.utt2linear_path = {}
|
112 |
+
for utt_info in self.metadata:
|
113 |
+
dataset = utt_info["Dataset"]
|
114 |
+
uid = utt_info["Uid"]
|
115 |
+
utt = "{}_{}".format(dataset, uid)
|
116 |
+
|
117 |
+
self.utt2linear_path[utt] = os.path.join(
|
118 |
+
cfg.preprocess.processed_dir,
|
119 |
+
dataset,
|
120 |
+
cfg.preprocess.linear_dir,
|
121 |
+
uid + ".npy",
|
122 |
+
)
|
123 |
+
|
124 |
+
if cfg.preprocess.use_audio:
|
125 |
+
self.utt2audio_path = {}
|
126 |
+
for utt_info in self.metadata:
|
127 |
+
dataset = utt_info["Dataset"]
|
128 |
+
uid = utt_info["Uid"]
|
129 |
+
utt = "{}_{}".format(dataset, uid)
|
130 |
+
|
131 |
+
self.utt2audio_path[utt] = os.path.join(
|
132 |
+
cfg.preprocess.processed_dir,
|
133 |
+
dataset,
|
134 |
+
cfg.preprocess.audio_dir,
|
135 |
+
uid + ".npy",
|
136 |
+
)
|
137 |
+
elif cfg.preprocess.use_label:
|
138 |
+
self.utt2label_path = {}
|
139 |
+
for utt_info in self.metadata:
|
140 |
+
dataset = utt_info["Dataset"]
|
141 |
+
uid = utt_info["Uid"]
|
142 |
+
utt = "{}_{}".format(dataset, uid)
|
143 |
+
|
144 |
+
self.utt2label_path[utt] = os.path.join(
|
145 |
+
cfg.preprocess.processed_dir,
|
146 |
+
dataset,
|
147 |
+
cfg.preprocess.label_dir,
|
148 |
+
uid + ".npy",
|
149 |
+
)
|
150 |
+
elif cfg.preprocess.use_one_hot:
|
151 |
+
self.utt2one_hot_path = {}
|
152 |
+
for utt_info in self.metadata:
|
153 |
+
dataset = utt_info["Dataset"]
|
154 |
+
uid = utt_info["Uid"]
|
155 |
+
utt = "{}_{}".format(dataset, uid)
|
156 |
+
|
157 |
+
self.utt2one_hot_path[utt] = os.path.join(
|
158 |
+
cfg.preprocess.processed_dir,
|
159 |
+
dataset,
|
160 |
+
cfg.preprocess.one_hot_dir,
|
161 |
+
uid + ".npy",
|
162 |
+
)
|
163 |
+
|
164 |
+
if cfg.preprocess.use_text or cfg.preprocess.use_phone:
|
165 |
+
self.utt2seq = {}
|
166 |
+
for utt_info in self.metadata:
|
167 |
+
dataset = utt_info["Dataset"]
|
168 |
+
uid = utt_info["Uid"]
|
169 |
+
utt = "{}_{}".format(dataset, uid)
|
170 |
+
|
171 |
+
if cfg.preprocess.use_text:
|
172 |
+
text = utt_info["Text"]
|
173 |
+
sequence = text_to_sequence(text, cfg.preprocess.text_cleaners)
|
174 |
+
elif cfg.preprocess.use_phone:
|
175 |
+
# load phoneme squence from phone file
|
176 |
+
phone_path = os.path.join(processed_data_dir,
|
177 |
+
cfg.preprocess.phone_dir,
|
178 |
+
uid+'.phone'
|
179 |
+
)
|
180 |
+
with open(phone_path, 'r') as fin:
|
181 |
+
phones = fin.readlines()
|
182 |
+
assert len(phones) == 1
|
183 |
+
phones = phones[0].strip()
|
184 |
+
phones_seq = phones.split(' ')
|
185 |
+
|
186 |
+
phon_id_collator = phoneIDCollation(cfg, dataset=dataset)
|
187 |
+
sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq)
|
188 |
+
|
189 |
+
self.utt2seq[utt] = sequence
|
190 |
+
|
191 |
+
|
192 |
+
def get_metadata(self):
|
193 |
+
with open(self.metafile_path, "r", encoding="utf-8") as f:
|
194 |
+
metadata = json.load(f)
|
195 |
+
|
196 |
+
return metadata
|
197 |
+
|
198 |
+
def get_dataset_name(self):
|
199 |
+
return self.metadata[0]["Dataset"]
|
200 |
+
|
201 |
+
def __getitem__(self, index):
|
202 |
+
utt_info = self.metadata[index]
|
203 |
+
|
204 |
+
dataset = utt_info["Dataset"]
|
205 |
+
uid = utt_info["Uid"]
|
206 |
+
utt = "{}_{}".format(dataset, uid)
|
207 |
+
|
208 |
+
single_feature = dict()
|
209 |
+
|
210 |
+
if self.cfg.preprocess.use_spkid:
|
211 |
+
single_feature["spk_id"] = np.array(
|
212 |
+
[self.spk2id[self.utt2spk[utt]]], dtype=np.int32
|
213 |
+
)
|
214 |
+
|
215 |
+
if self.cfg.preprocess.use_mel:
|
216 |
+
mel = np.load(self.utt2mel_path[utt])
|
217 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
218 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
219 |
+
# do mel norm
|
220 |
+
mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess)
|
221 |
+
|
222 |
+
if "target_len" not in single_feature.keys():
|
223 |
+
single_feature["target_len"] = mel.shape[1]
|
224 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
225 |
+
|
226 |
+
if self.cfg.preprocess.use_linear:
|
227 |
+
linear = np.load(self.utt2linear_path[utt])
|
228 |
+
if "target_len" not in single_feature.keys():
|
229 |
+
single_feature["target_len"] = linear.shape[1]
|
230 |
+
single_feature["linear"] = linear.T # [T, n_linear]
|
231 |
+
|
232 |
+
if self.cfg.preprocess.use_frame_pitch:
|
233 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
234 |
+
frame_pitch = np.load(frame_pitch_path)
|
235 |
+
if "target_len" not in single_feature.keys():
|
236 |
+
single_feature["target_len"] = len(frame_pitch)
|
237 |
+
aligned_frame_pitch = align_length(
|
238 |
+
frame_pitch, single_feature["target_len"]
|
239 |
+
)
|
240 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
241 |
+
|
242 |
+
if self.cfg.preprocess.use_uv:
|
243 |
+
frame_uv_path = self.utt2uv_path[utt]
|
244 |
+
frame_uv = np.load(frame_uv_path)
|
245 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
246 |
+
aligned_frame_uv = [
|
247 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
248 |
+
]
|
249 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
250 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
251 |
+
|
252 |
+
if self.cfg.preprocess.use_frame_energy:
|
253 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
254 |
+
frame_energy = np.load(frame_energy_path)
|
255 |
+
if "target_len" not in single_feature.keys():
|
256 |
+
single_feature["target_len"] = len(frame_energy)
|
257 |
+
aligned_frame_energy = align_length(
|
258 |
+
frame_energy, single_feature["target_len"]
|
259 |
+
)
|
260 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
261 |
+
|
262 |
+
if self.cfg.preprocess.use_audio:
|
263 |
+
audio = np.load(self.utt2audio_path[utt])
|
264 |
+
single_feature["audio"] = audio
|
265 |
+
single_feature["audio_len"] = audio.shape[0]
|
266 |
+
|
267 |
+
if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text:
|
268 |
+
single_feature["phone_seq"] = np.array(self.utt2seq[utt])
|
269 |
+
single_feature["phone_len"] = len(self.utt2seq[utt])
|
270 |
+
|
271 |
+
return single_feature
|
272 |
+
|
273 |
+
def __len__(self):
|
274 |
+
return len(self.metadata)
|
275 |
+
|
276 |
+
|
277 |
+
class BaseCollator(object):
|
278 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
279 |
+
|
280 |
+
def __init__(self, cfg):
|
281 |
+
self.cfg = cfg
|
282 |
+
|
283 |
+
def __call__(self, batch):
|
284 |
+
packed_batch_features = dict()
|
285 |
+
|
286 |
+
# mel: [b, T, n_mels]
|
287 |
+
# frame_pitch, frame_energy: [1, T]
|
288 |
+
# target_len: [1]
|
289 |
+
# spk_id: [b, 1]
|
290 |
+
# mask: [b, T, 1]
|
291 |
+
|
292 |
+
for key in batch[0].keys():
|
293 |
+
if key == "target_len":
|
294 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
295 |
+
[b["target_len"] for b in batch]
|
296 |
+
)
|
297 |
+
masks = [
|
298 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
299 |
+
]
|
300 |
+
packed_batch_features["mask"] = pad_sequence(
|
301 |
+
masks, batch_first=True, padding_value=0
|
302 |
+
)
|
303 |
+
elif key == "phone_len":
|
304 |
+
packed_batch_features["phone_len"] = torch.LongTensor(
|
305 |
+
[b["phone_len"] for b in batch]
|
306 |
+
)
|
307 |
+
masks = [
|
308 |
+
torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch
|
309 |
+
]
|
310 |
+
packed_batch_features["phn_mask"] = pad_sequence(
|
311 |
+
masks, batch_first=True, padding_value=0
|
312 |
+
)
|
313 |
+
elif key == "audio_len":
|
314 |
+
packed_batch_features["audio_len"] = torch.LongTensor(
|
315 |
+
[b["audio_len"] for b in batch]
|
316 |
+
)
|
317 |
+
masks = [
|
318 |
+
torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch
|
319 |
+
]
|
320 |
+
else:
|
321 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
322 |
+
packed_batch_features[key] = pad_sequence(
|
323 |
+
values, batch_first=True, padding_value=0
|
324 |
+
)
|
325 |
+
return packed_batch_features
|
326 |
+
|
327 |
+
|
328 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
329 |
+
def __init__(self, cfg, args):
|
330 |
+
raise NotImplementedError
|
331 |
+
|
332 |
+
|
333 |
+
def get_metadata(self):
|
334 |
+
raise NotImplementedError
|
335 |
+
|
336 |
+
def __getitem__(self, index):
|
337 |
+
raise NotImplementedError
|
338 |
+
|
339 |
+
def __len__(self):
|
340 |
+
return len(self.metadata)
|
341 |
+
|
342 |
+
|
343 |
+
class BaseTestCollator(object):
|
344 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
345 |
+
|
346 |
+
def __init__(self, cfg):
|
347 |
+
raise NotImplementedError
|
348 |
+
|
349 |
+
def __call__(self, batch):
|
350 |
+
raise NotImplementedError
|
models/base/base_inference.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import DataLoader
|
14 |
+
from tqdm import tqdm
|
15 |
+
|
16 |
+
from models.vocoders.vocoder_inference import synthesis
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from utils.util import set_all_random_seed
|
19 |
+
from utils.util import load_config
|
20 |
+
|
21 |
+
|
22 |
+
def parse_vocoder(vocoder_dir):
|
23 |
+
r"""Parse vocoder config"""
|
24 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
25 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
26 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
27 |
+
ckpt_path = str(ckpt_list[0])
|
28 |
+
vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True)
|
29 |
+
vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder
|
30 |
+
return vocoder_cfg, ckpt_path
|
31 |
+
|
32 |
+
|
33 |
+
class BaseInference(object):
|
34 |
+
def __init__(self, cfg, args):
|
35 |
+
self.cfg = cfg
|
36 |
+
self.args = args
|
37 |
+
self.model_type = cfg.model_type
|
38 |
+
self.avg_rtf = list()
|
39 |
+
set_all_random_seed(10086)
|
40 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
41 |
+
|
42 |
+
if torch.cuda.is_available():
|
43 |
+
self.device = torch.device("cuda")
|
44 |
+
else:
|
45 |
+
self.device = torch.device("cpu")
|
46 |
+
torch.set_num_threads(10) # inference on 1 core cpu.
|
47 |
+
|
48 |
+
# Load acoustic model
|
49 |
+
self.model = self.create_model().to(self.device)
|
50 |
+
state_dict = self.load_state_dict()
|
51 |
+
self.load_model(state_dict)
|
52 |
+
self.model.eval()
|
53 |
+
|
54 |
+
# Load vocoder model if necessary
|
55 |
+
if self.args.checkpoint_dir_vocoder is not None:
|
56 |
+
self.get_vocoder_info()
|
57 |
+
|
58 |
+
def create_model(self):
|
59 |
+
raise NotImplementedError
|
60 |
+
|
61 |
+
def load_state_dict(self):
|
62 |
+
self.checkpoint_file = self.args.checkpoint_file
|
63 |
+
if self.checkpoint_file is None:
|
64 |
+
assert self.args.checkpoint_dir is not None
|
65 |
+
checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint")
|
66 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
67 |
+
self.checkpoint_file = os.path.join(
|
68 |
+
self.args.checkpoint_dir, checkpoint_filename
|
69 |
+
)
|
70 |
+
|
71 |
+
self.checkpoint_dir = os.path.split(self.checkpoint_file)[0]
|
72 |
+
|
73 |
+
print("Restore acoustic model from {}".format(self.checkpoint_file))
|
74 |
+
raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device)
|
75 |
+
self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0]
|
76 |
+
|
77 |
+
return raw_state_dict
|
78 |
+
|
79 |
+
def load_model(self, model):
|
80 |
+
raise NotImplementedError
|
81 |
+
|
82 |
+
def get_vocoder_info(self):
|
83 |
+
self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder
|
84 |
+
self.vocoder_cfg = os.path.join(
|
85 |
+
os.path.dirname(self.checkpoint_dir_vocoder), "args.json"
|
86 |
+
)
|
87 |
+
self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True)
|
88 |
+
self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1]
|
89 |
+
self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0]
|
90 |
+
|
91 |
+
def build_test_utt_data(self):
|
92 |
+
raise NotImplementedError
|
93 |
+
|
94 |
+
def build_testdata_loader(self, args, target_speaker=None):
|
95 |
+
datasets, collate = self.build_test_dataset()
|
96 |
+
self.test_dataset = datasets(self.cfg, args, target_speaker)
|
97 |
+
self.test_collate = collate(self.cfg)
|
98 |
+
self.test_batch_size = min(
|
99 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
100 |
+
)
|
101 |
+
test_loader = DataLoader(
|
102 |
+
self.test_dataset,
|
103 |
+
collate_fn=self.test_collate,
|
104 |
+
num_workers=self.args.num_workers,
|
105 |
+
batch_size=self.test_batch_size,
|
106 |
+
shuffle=False,
|
107 |
+
)
|
108 |
+
return test_loader
|
109 |
+
|
110 |
+
def inference_each_batch(self, batch_data):
|
111 |
+
raise NotImplementedError
|
112 |
+
|
113 |
+
def inference_for_batches(self, args, target_speaker=None):
|
114 |
+
###### Construct test_batch ######
|
115 |
+
loader = self.build_testdata_loader(args, target_speaker)
|
116 |
+
|
117 |
+
n_batch = len(loader)
|
118 |
+
now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time()))
|
119 |
+
print(
|
120 |
+
"Model eval time: {}, batch_size = {}, n_batch = {}".format(
|
121 |
+
now, self.test_batch_size, n_batch
|
122 |
+
)
|
123 |
+
)
|
124 |
+
self.model.eval()
|
125 |
+
|
126 |
+
###### Inference for each batch ######
|
127 |
+
pred_res = []
|
128 |
+
with torch.no_grad():
|
129 |
+
for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)):
|
130 |
+
# Put the data to device
|
131 |
+
for k, v in batch_data.items():
|
132 |
+
batch_data[k] = batch_data[k].to(self.device)
|
133 |
+
|
134 |
+
y_pred, stats = self.inference_each_batch(batch_data)
|
135 |
+
|
136 |
+
pred_res += y_pred
|
137 |
+
|
138 |
+
return pred_res
|
139 |
+
|
140 |
+
def inference(self, feature):
|
141 |
+
raise NotImplementedError
|
142 |
+
|
143 |
+
def synthesis_by_vocoder(self, pred):
|
144 |
+
audios_pred = synthesis(
|
145 |
+
self.vocoder_cfg,
|
146 |
+
self.checkpoint_dir_vocoder,
|
147 |
+
len(pred),
|
148 |
+
pred,
|
149 |
+
)
|
150 |
+
return audios_pred
|
151 |
+
|
152 |
+
def __call__(self, utt):
|
153 |
+
feature = self.build_test_utt_data(utt)
|
154 |
+
start_time = time.time()
|
155 |
+
with torch.no_grad():
|
156 |
+
outputs = self.inference(feature)[0]
|
157 |
+
time_used = time.time() - start_time
|
158 |
+
rtf = time_used / (
|
159 |
+
outputs.shape[1]
|
160 |
+
* self.cfg.preprocess.hop_size
|
161 |
+
/ self.cfg.preprocess.sample_rate
|
162 |
+
)
|
163 |
+
print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf))
|
164 |
+
self.avg_rtf.append(rtf)
|
165 |
+
audios = outputs.cpu().squeeze().numpy().reshape(-1, 1)
|
166 |
+
return audios
|
167 |
+
|
168 |
+
|
169 |
+
def base_parser():
|
170 |
+
parser = argparse.ArgumentParser()
|
171 |
+
parser.add_argument(
|
172 |
+
"--config", default="config.json", help="json files for configurations."
|
173 |
+
)
|
174 |
+
parser.add_argument("--use_ddp_inference", default=False)
|
175 |
+
parser.add_argument("--n_workers", default=1, type=int)
|
176 |
+
parser.add_argument("--local_rank", default=-1, type=int)
|
177 |
+
parser.add_argument(
|
178 |
+
"--batch_size", default=1, type=int, help="Batch size for inference"
|
179 |
+
)
|
180 |
+
parser.add_argument(
|
181 |
+
"--num_workers",
|
182 |
+
default=1,
|
183 |
+
type=int,
|
184 |
+
help="Worker number for inference dataloader",
|
185 |
+
)
|
186 |
+
parser.add_argument(
|
187 |
+
"--checkpoint_dir",
|
188 |
+
type=str,
|
189 |
+
default=None,
|
190 |
+
help="Checkpoint dir including model file and configuration",
|
191 |
+
)
|
192 |
+
parser.add_argument(
|
193 |
+
"--checkpoint_file", help="checkpoint file", type=str, default=None
|
194 |
+
)
|
195 |
+
parser.add_argument(
|
196 |
+
"--test_list", help="test utterance list for testing", type=str, default=None
|
197 |
+
)
|
198 |
+
parser.add_argument(
|
199 |
+
"--checkpoint_dir_vocoder",
|
200 |
+
help="Vocoder's checkpoint dir including model file and configuration",
|
201 |
+
type=str,
|
202 |
+
default=None,
|
203 |
+
)
|
204 |
+
parser.add_argument(
|
205 |
+
"--output_dir",
|
206 |
+
type=str,
|
207 |
+
default=None,
|
208 |
+
help="Output dir for saving generated results",
|
209 |
+
)
|
210 |
+
return parser
|
211 |
+
|
212 |
+
|
213 |
+
if __name__ == "__main__":
|
214 |
+
parser = base_parser()
|
215 |
+
args = parser.parse_args()
|
216 |
+
cfg = load_config(args.config)
|
217 |
+
|
218 |
+
# Build inference
|
219 |
+
inference = BaseInference(cfg, args)
|
220 |
+
inference()
|
models/base/base_sampler.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import random
|
8 |
+
|
9 |
+
from torch.utils.data import ConcatDataset, Dataset
|
10 |
+
from torch.utils.data.sampler import (
|
11 |
+
BatchSampler,
|
12 |
+
RandomSampler,
|
13 |
+
Sampler,
|
14 |
+
SequentialSampler,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class ScheduledSampler(Sampler):
|
19 |
+
"""A sampler that samples data from a given concat-dataset.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets
|
23 |
+
batch_size (int): batch size
|
24 |
+
holistic_shuffle (bool): whether to shuffle the whole dataset or not
|
25 |
+
logger (logging.Logger): logger to print warning message
|
26 |
+
|
27 |
+
Usage:
|
28 |
+
For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True:
|
29 |
+
>>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]])))
|
30 |
+
[3, 4, 5, 0, 1, 2, 6, 7, 8]
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
concat_dataset,
|
36 |
+
batch_size,
|
37 |
+
holistic_shuffle,
|
38 |
+
logger=None,
|
39 |
+
loader_type="train",
|
40 |
+
):
|
41 |
+
if not isinstance(concat_dataset, ConcatDataset):
|
42 |
+
raise ValueError(
|
43 |
+
"concat_dataset must be an instance of ConcatDataset, but got {}".format(
|
44 |
+
type(concat_dataset)
|
45 |
+
)
|
46 |
+
)
|
47 |
+
if not isinstance(batch_size, int):
|
48 |
+
raise ValueError(
|
49 |
+
"batch_size must be an integer, but got {}".format(type(batch_size))
|
50 |
+
)
|
51 |
+
if not isinstance(holistic_shuffle, bool):
|
52 |
+
raise ValueError(
|
53 |
+
"holistic_shuffle must be a boolean, but got {}".format(
|
54 |
+
type(holistic_shuffle)
|
55 |
+
)
|
56 |
+
)
|
57 |
+
|
58 |
+
self.concat_dataset = concat_dataset
|
59 |
+
self.batch_size = batch_size
|
60 |
+
self.holistic_shuffle = holistic_shuffle
|
61 |
+
|
62 |
+
affected_dataset_name = []
|
63 |
+
affected_dataset_len = []
|
64 |
+
for dataset in concat_dataset.datasets:
|
65 |
+
dataset_len = len(dataset)
|
66 |
+
dataset_name = dataset.get_dataset_name()
|
67 |
+
if dataset_len < batch_size:
|
68 |
+
affected_dataset_name.append(dataset_name)
|
69 |
+
affected_dataset_len.append(dataset_len)
|
70 |
+
|
71 |
+
self.type = loader_type
|
72 |
+
for dataset_name, dataset_len in zip(
|
73 |
+
affected_dataset_name, affected_dataset_len
|
74 |
+
):
|
75 |
+
if not loader_type == "valid":
|
76 |
+
logger.warning(
|
77 |
+
"The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format(
|
78 |
+
loader_type, dataset_name, dataset_len, batch_size
|
79 |
+
)
|
80 |
+
)
|
81 |
+
|
82 |
+
def __len__(self):
|
83 |
+
# the number of batches with drop last
|
84 |
+
num_of_batches = sum(
|
85 |
+
[
|
86 |
+
math.floor(len(dataset) / self.batch_size)
|
87 |
+
for dataset in self.concat_dataset.datasets
|
88 |
+
]
|
89 |
+
)
|
90 |
+
# if samples are not enough for one batch, we don't drop last
|
91 |
+
if self.type == "valid" and num_of_batches < 1:
|
92 |
+
return len(self.concat_dataset)
|
93 |
+
return num_of_batches * self.batch_size
|
94 |
+
|
95 |
+
def __iter__(self):
|
96 |
+
iters = []
|
97 |
+
for dataset in self.concat_dataset.datasets:
|
98 |
+
iters.append(
|
99 |
+
SequentialSampler(dataset).__iter__()
|
100 |
+
if not self.holistic_shuffle
|
101 |
+
else RandomSampler(dataset).__iter__()
|
102 |
+
)
|
103 |
+
# e.g. [0, 200, 400]
|
104 |
+
init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1]
|
105 |
+
output_batches = []
|
106 |
+
for dataset_idx in range(len(self.concat_dataset.datasets)):
|
107 |
+
cur_batch = []
|
108 |
+
for idx in iters[dataset_idx]:
|
109 |
+
cur_batch.append(idx + init_indices[dataset_idx])
|
110 |
+
if len(cur_batch) == self.batch_size:
|
111 |
+
output_batches.append(cur_batch)
|
112 |
+
cur_batch = []
|
113 |
+
# if loader_type is valid, we don't need to drop last
|
114 |
+
if self.type == "valid" and len(cur_batch) > 0:
|
115 |
+
output_batches.append(cur_batch)
|
116 |
+
|
117 |
+
# force drop last in training
|
118 |
+
random.shuffle(output_batches)
|
119 |
+
output_indices = [item for sublist in output_batches for item in sublist]
|
120 |
+
return iter(output_indices)
|
121 |
+
|
122 |
+
|
123 |
+
def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type):
|
124 |
+
sampler = ScheduledSampler(
|
125 |
+
concat_dataset,
|
126 |
+
cfg.train.batch_size,
|
127 |
+
cfg.train.sampler.holistic_shuffle,
|
128 |
+
logger,
|
129 |
+
loader_type,
|
130 |
+
)
|
131 |
+
batch_sampler = BatchSampler(
|
132 |
+
sampler,
|
133 |
+
cfg.train.batch_size,
|
134 |
+
cfg.train.sampler.drop_last if not loader_type == "valid" else False,
|
135 |
+
)
|
136 |
+
return sampler, batch_sampler
|
models/base/base_trainer.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import collections
|
7 |
+
import json
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import time
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.distributed as dist
|
14 |
+
from torch.nn.parallel import DistributedDataParallel
|
15 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
16 |
+
from torch.utils.tensorboard import SummaryWriter
|
17 |
+
|
18 |
+
from models.base.base_sampler import BatchSampler
|
19 |
+
from utils.util import (
|
20 |
+
Logger,
|
21 |
+
remove_older_ckpt,
|
22 |
+
save_config,
|
23 |
+
set_all_random_seed,
|
24 |
+
ValueWindow,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
class BaseTrainer(object):
|
29 |
+
def __init__(self, args, cfg):
|
30 |
+
self.args = args
|
31 |
+
self.log_dir = args.log_dir
|
32 |
+
self.cfg = cfg
|
33 |
+
|
34 |
+
self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints")
|
35 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
36 |
+
if not cfg.train.ddp or args.local_rank == 0:
|
37 |
+
self.sw = SummaryWriter(os.path.join(args.log_dir, "events"))
|
38 |
+
self.logger = self.build_logger()
|
39 |
+
self.time_window = ValueWindow(50)
|
40 |
+
|
41 |
+
self.step = 0
|
42 |
+
self.epoch = -1
|
43 |
+
self.max_epochs = self.cfg.train.epochs
|
44 |
+
self.max_steps = self.cfg.train.max_steps
|
45 |
+
|
46 |
+
# set random seed & init distributed training
|
47 |
+
set_all_random_seed(self.cfg.train.random_seed)
|
48 |
+
if cfg.train.ddp:
|
49 |
+
dist.init_process_group(backend="nccl")
|
50 |
+
|
51 |
+
if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]:
|
52 |
+
self.singers = self.build_singers_lut()
|
53 |
+
|
54 |
+
# setup data_loader
|
55 |
+
self.data_loader = self.build_data_loader()
|
56 |
+
|
57 |
+
# setup model & enable distributed training
|
58 |
+
self.model = self.build_model()
|
59 |
+
print(self.model)
|
60 |
+
|
61 |
+
if isinstance(self.model, dict):
|
62 |
+
for key, value in self.model.items():
|
63 |
+
value.cuda(self.args.local_rank)
|
64 |
+
if key == "PQMF":
|
65 |
+
continue
|
66 |
+
if cfg.train.ddp:
|
67 |
+
self.model[key] = DistributedDataParallel(
|
68 |
+
value, device_ids=[self.args.local_rank]
|
69 |
+
)
|
70 |
+
else:
|
71 |
+
self.model.cuda(self.args.local_rank)
|
72 |
+
if cfg.train.ddp:
|
73 |
+
self.model = DistributedDataParallel(
|
74 |
+
self.model, device_ids=[self.args.local_rank]
|
75 |
+
)
|
76 |
+
|
77 |
+
# create criterion
|
78 |
+
self.criterion = self.build_criterion()
|
79 |
+
if isinstance(self.criterion, dict):
|
80 |
+
for key, value in self.criterion.items():
|
81 |
+
self.criterion[key].cuda(args.local_rank)
|
82 |
+
else:
|
83 |
+
self.criterion.cuda(self.args.local_rank)
|
84 |
+
|
85 |
+
# optimizer
|
86 |
+
self.optimizer = self.build_optimizer()
|
87 |
+
self.scheduler = self.build_scheduler()
|
88 |
+
|
89 |
+
# save config file
|
90 |
+
self.config_save_path = os.path.join(self.checkpoint_dir, "args.json")
|
91 |
+
|
92 |
+
def build_logger(self):
|
93 |
+
log_file = os.path.join(self.checkpoint_dir, "train.log")
|
94 |
+
logger = Logger(log_file, level=self.args.log_level).logger
|
95 |
+
|
96 |
+
return logger
|
97 |
+
|
98 |
+
def build_dataset(self):
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
def build_data_loader(self):
|
102 |
+
Dataset, Collator = self.build_dataset()
|
103 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
104 |
+
datasets_list = []
|
105 |
+
for dataset in self.cfg.dataset:
|
106 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
107 |
+
datasets_list.append(subdataset)
|
108 |
+
train_dataset = ConcatDataset(datasets_list)
|
109 |
+
|
110 |
+
train_collate = Collator(self.cfg)
|
111 |
+
# TODO: multi-GPU training
|
112 |
+
if self.cfg.train.ddp:
|
113 |
+
raise NotImplementedError("DDP is not supported yet.")
|
114 |
+
|
115 |
+
# sampler will provide indices to batch_sampler, which will perform batching and yield batch indices
|
116 |
+
batch_sampler = BatchSampler(
|
117 |
+
cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list
|
118 |
+
)
|
119 |
+
|
120 |
+
# use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size)
|
121 |
+
train_loader = DataLoader(
|
122 |
+
train_dataset,
|
123 |
+
collate_fn=train_collate,
|
124 |
+
num_workers=self.args.num_workers,
|
125 |
+
batch_sampler=batch_sampler,
|
126 |
+
pin_memory=False,
|
127 |
+
)
|
128 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
129 |
+
datasets_list = []
|
130 |
+
for dataset in self.cfg.dataset:
|
131 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
132 |
+
datasets_list.append(subdataset)
|
133 |
+
valid_dataset = ConcatDataset(datasets_list)
|
134 |
+
valid_collate = Collator(self.cfg)
|
135 |
+
batch_sampler = BatchSampler(
|
136 |
+
cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list
|
137 |
+
)
|
138 |
+
valid_loader = DataLoader(
|
139 |
+
valid_dataset,
|
140 |
+
collate_fn=valid_collate,
|
141 |
+
num_workers=1,
|
142 |
+
batch_sampler=batch_sampler,
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
raise NotImplementedError("DDP is not supported yet.")
|
146 |
+
# valid_loader = None
|
147 |
+
data_loader = {"train": train_loader, "valid": valid_loader}
|
148 |
+
return data_loader
|
149 |
+
|
150 |
+
def build_singers_lut(self):
|
151 |
+
# combine singers
|
152 |
+
if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)):
|
153 |
+
singers = collections.OrderedDict()
|
154 |
+
else:
|
155 |
+
with open(
|
156 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r"
|
157 |
+
) as singer_file:
|
158 |
+
singers = json.load(singer_file)
|
159 |
+
singer_count = len(singers)
|
160 |
+
for dataset in self.cfg.dataset:
|
161 |
+
singer_lut_path = os.path.join(
|
162 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
163 |
+
)
|
164 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
165 |
+
singer_lut = json.load(singer_lut_path)
|
166 |
+
for singer in singer_lut.keys():
|
167 |
+
if singer not in singers:
|
168 |
+
singers[singer] = singer_count
|
169 |
+
singer_count += 1
|
170 |
+
with open(
|
171 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w"
|
172 |
+
) as singer_file:
|
173 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
174 |
+
print(
|
175 |
+
"singers have been dumped to {}".format(
|
176 |
+
os.path.join(self.log_dir, self.cfg.preprocess.spk2id)
|
177 |
+
)
|
178 |
+
)
|
179 |
+
return singers
|
180 |
+
|
181 |
+
def build_model(self):
|
182 |
+
raise NotImplementedError()
|
183 |
+
|
184 |
+
def build_optimizer(self):
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
def build_scheduler(self):
|
188 |
+
raise NotImplementedError()
|
189 |
+
|
190 |
+
def build_criterion(self):
|
191 |
+
raise NotImplementedError
|
192 |
+
|
193 |
+
def get_state_dict(self):
|
194 |
+
raise NotImplementedError
|
195 |
+
|
196 |
+
def save_config_file(self):
|
197 |
+
save_config(self.config_save_path, self.cfg)
|
198 |
+
|
199 |
+
# TODO, save without module.
|
200 |
+
def save_checkpoint(self, state_dict, saved_model_path):
|
201 |
+
torch.save(state_dict, saved_model_path)
|
202 |
+
|
203 |
+
def load_checkpoint(self):
|
204 |
+
checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint")
|
205 |
+
assert os.path.exists(checkpoint_path)
|
206 |
+
checkpoint_filename = open(checkpoint_path).readlines()[-1].strip()
|
207 |
+
model_path = os.path.join(self.checkpoint_dir, checkpoint_filename)
|
208 |
+
assert os.path.exists(model_path)
|
209 |
+
if not self.cfg.train.ddp or self.args.local_rank == 0:
|
210 |
+
self.logger.info(f"Re(store) from {model_path}")
|
211 |
+
checkpoint = torch.load(model_path, map_location="cpu")
|
212 |
+
return checkpoint
|
213 |
+
|
214 |
+
def load_model(self, checkpoint):
|
215 |
+
raise NotImplementedError
|
216 |
+
|
217 |
+
def restore(self):
|
218 |
+
checkpoint = self.load_checkpoint()
|
219 |
+
self.load_model(checkpoint)
|
220 |
+
|
221 |
+
def train_step(self, data):
|
222 |
+
raise NotImplementedError(
|
223 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
224 |
+
f"your sub-class of {self.__class__.__name__}. "
|
225 |
+
)
|
226 |
+
|
227 |
+
@torch.no_grad()
|
228 |
+
def eval_step(self):
|
229 |
+
raise NotImplementedError(
|
230 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
231 |
+
f"your sub-class of {self.__class__.__name__}. "
|
232 |
+
)
|
233 |
+
|
234 |
+
def write_summary(self, losses, stats):
|
235 |
+
raise NotImplementedError(
|
236 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
237 |
+
f"your sub-class of {self.__class__.__name__}. "
|
238 |
+
)
|
239 |
+
|
240 |
+
def write_valid_summary(self, losses, stats):
|
241 |
+
raise NotImplementedError(
|
242 |
+
f"Need to implement function {sys._getframe().f_code.co_name} in "
|
243 |
+
f"your sub-class of {self.__class__.__name__}. "
|
244 |
+
)
|
245 |
+
|
246 |
+
def echo_log(self, losses, mode="Training"):
|
247 |
+
message = [
|
248 |
+
"{} - Epoch {} Step {}: [{:.3f} s/step]".format(
|
249 |
+
mode, self.epoch + 1, self.step, self.time_window.average
|
250 |
+
)
|
251 |
+
]
|
252 |
+
|
253 |
+
for key in sorted(losses.keys()):
|
254 |
+
if isinstance(losses[key], dict):
|
255 |
+
for k, v in losses[key].items():
|
256 |
+
message.append(
|
257 |
+
str(k).split("/")[-1] + "=" + str(round(float(v), 5))
|
258 |
+
)
|
259 |
+
else:
|
260 |
+
message.append(
|
261 |
+
str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5))
|
262 |
+
)
|
263 |
+
self.logger.info(", ".join(message))
|
264 |
+
|
265 |
+
def eval_epoch(self):
|
266 |
+
self.logger.info("Validation...")
|
267 |
+
valid_losses = {}
|
268 |
+
for i, batch_data in enumerate(self.data_loader["valid"]):
|
269 |
+
for k, v in batch_data.items():
|
270 |
+
if isinstance(v, torch.Tensor):
|
271 |
+
batch_data[k] = v.cuda()
|
272 |
+
valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i)
|
273 |
+
for key in valid_loss:
|
274 |
+
if key not in valid_losses:
|
275 |
+
valid_losses[key] = 0
|
276 |
+
valid_losses[key] += valid_loss[key]
|
277 |
+
|
278 |
+
# Add mel and audio to the Tensorboard
|
279 |
+
# Average loss
|
280 |
+
for key in valid_losses:
|
281 |
+
valid_losses[key] /= i + 1
|
282 |
+
self.echo_log(valid_losses, "Valid")
|
283 |
+
return valid_losses, valid_stats
|
284 |
+
|
285 |
+
def train_epoch(self):
|
286 |
+
for i, batch_data in enumerate(self.data_loader["train"]):
|
287 |
+
start_time = time.time()
|
288 |
+
# Put the data to cuda device
|
289 |
+
for k, v in batch_data.items():
|
290 |
+
if isinstance(v, torch.Tensor):
|
291 |
+
batch_data[k] = v.cuda(self.args.local_rank)
|
292 |
+
|
293 |
+
# Training step
|
294 |
+
train_losses, train_stats, total_loss = self.train_step(batch_data)
|
295 |
+
self.time_window.append(time.time() - start_time)
|
296 |
+
|
297 |
+
if self.args.local_rank == 0 or not self.cfg.train.ddp:
|
298 |
+
if self.step % self.args.stdout_interval == 0:
|
299 |
+
self.echo_log(train_losses, "Training")
|
300 |
+
|
301 |
+
if self.step % self.cfg.train.save_summary_steps == 0:
|
302 |
+
self.logger.info(f"Save summary as step {self.step}")
|
303 |
+
self.write_summary(train_losses, train_stats)
|
304 |
+
|
305 |
+
if (
|
306 |
+
self.step % self.cfg.train.save_checkpoints_steps == 0
|
307 |
+
and self.step != 0
|
308 |
+
):
|
309 |
+
saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format(
|
310 |
+
self.step, total_loss
|
311 |
+
)
|
312 |
+
saved_model_path = os.path.join(
|
313 |
+
self.checkpoint_dir, saved_model_name
|
314 |
+
)
|
315 |
+
saved_state_dict = self.get_state_dict()
|
316 |
+
self.save_checkpoint(saved_state_dict, saved_model_path)
|
317 |
+
self.save_config_file()
|
318 |
+
# keep max n models
|
319 |
+
remove_older_ckpt(
|
320 |
+
saved_model_name,
|
321 |
+
self.checkpoint_dir,
|
322 |
+
max_to_keep=self.cfg.train.keep_checkpoint_max,
|
323 |
+
)
|
324 |
+
|
325 |
+
if self.step != 0 and self.step % self.cfg.train.valid_interval == 0:
|
326 |
+
if isinstance(self.model, dict):
|
327 |
+
for key in self.model.keys():
|
328 |
+
self.model[key].eval()
|
329 |
+
else:
|
330 |
+
self.model.eval()
|
331 |
+
# Evaluate one epoch and get average loss
|
332 |
+
valid_losses, valid_stats = self.eval_epoch()
|
333 |
+
if isinstance(self.model, dict):
|
334 |
+
for key in self.model.keys():
|
335 |
+
self.model[key].train()
|
336 |
+
else:
|
337 |
+
self.model.train()
|
338 |
+
# Write validation losses to summary.
|
339 |
+
self.write_valid_summary(valid_losses, valid_stats)
|
340 |
+
self.step += 1
|
341 |
+
|
342 |
+
def train(self):
|
343 |
+
for epoch in range(max(0, self.epoch), self.max_epochs):
|
344 |
+
self.train_epoch()
|
345 |
+
self.epoch += 1
|
346 |
+
if self.step > self.max_steps:
|
347 |
+
self.logger.info("Training finished!")
|
348 |
+
break
|
models/base/new_dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
from abc import abstractmethod
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
import json5
|
12 |
+
import torch
|
13 |
+
import yaml
|
14 |
+
|
15 |
+
|
16 |
+
# TODO: for training and validating
|
17 |
+
class BaseDataset(torch.utils.data.Dataset):
|
18 |
+
r"""Base dataset for training and validating."""
|
19 |
+
|
20 |
+
def __init__(self, args, cfg, is_valid=False):
|
21 |
+
pass
|
22 |
+
|
23 |
+
|
24 |
+
class BaseTestDataset(torch.utils.data.Dataset):
|
25 |
+
r"""Test dataset for inference."""
|
26 |
+
|
27 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
28 |
+
assert infer_type in ["from_dataset", "from_file"]
|
29 |
+
|
30 |
+
self.args = args
|
31 |
+
self.cfg = cfg
|
32 |
+
self.infer_type = infer_type
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def __getitem__(self, index):
|
36 |
+
pass
|
37 |
+
|
38 |
+
def __len__(self):
|
39 |
+
return len(self.metadata)
|
40 |
+
|
41 |
+
def get_metadata(self):
|
42 |
+
path = Path(self.args.source)
|
43 |
+
if path.suffix == ".json" or path.suffix == ".jsonc":
|
44 |
+
metadata = json5.load(open(self.args.source, "r"))
|
45 |
+
elif path.suffix == ".yaml" or path.suffix == ".yml":
|
46 |
+
metadata = yaml.full_load(open(self.args.source, "r"))
|
47 |
+
else:
|
48 |
+
raise ValueError(f"Unsupported file type: {path.suffix}")
|
49 |
+
|
50 |
+
return metadata
|
models/base/new_inference.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import random
|
8 |
+
import re
|
9 |
+
import time
|
10 |
+
from abc import abstractmethod
|
11 |
+
from pathlib import Path
|
12 |
+
|
13 |
+
import accelerate
|
14 |
+
import json5
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
from accelerate.logging import get_logger
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
|
20 |
+
from models.vocoders.vocoder_inference import synthesis
|
21 |
+
from utils.io import save_audio
|
22 |
+
from utils.util import load_config
|
23 |
+
from utils.audio_slicer import is_silence
|
24 |
+
|
25 |
+
EPS = 1.0e-12
|
26 |
+
|
27 |
+
|
28 |
+
class BaseInference(object):
|
29 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
start = time.monotonic_ns()
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
assert infer_type in ["from_dataset", "from_file"]
|
37 |
+
self.infer_type = infer_type
|
38 |
+
|
39 |
+
# init with accelerate
|
40 |
+
self.accelerator = accelerate.Accelerator()
|
41 |
+
self.accelerator.wait_for_everyone()
|
42 |
+
|
43 |
+
# Use accelerate logger for distributed inference
|
44 |
+
with self.accelerator.main_process_first():
|
45 |
+
self.logger = get_logger("inference", log_level=args.log_level)
|
46 |
+
|
47 |
+
# Log some info
|
48 |
+
self.logger.info("=" * 56)
|
49 |
+
self.logger.info("||\t\t" + "New inference process started." + "\t\t||")
|
50 |
+
self.logger.info("=" * 56)
|
51 |
+
self.logger.info("\n")
|
52 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
53 |
+
|
54 |
+
self.acoustics_dir = args.acoustics_dir
|
55 |
+
self.logger.debug(f"Acoustic dir: {args.acoustics_dir}")
|
56 |
+
self.vocoder_dir = args.vocoder_dir
|
57 |
+
self.logger.debug(f"Vocoder dir: {args.vocoder_dir}")
|
58 |
+
# should be in svc inferencer
|
59 |
+
# self.target_singer = args.target_singer
|
60 |
+
# self.logger.info(f"Target singers: {args.target_singer}")
|
61 |
+
# self.trans_key = args.trans_key
|
62 |
+
# self.logger.info(f"Trans key: {args.trans_key}")
|
63 |
+
|
64 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
65 |
+
|
66 |
+
# set random seed
|
67 |
+
with self.accelerator.main_process_first():
|
68 |
+
start = time.monotonic_ns()
|
69 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
70 |
+
end = time.monotonic_ns()
|
71 |
+
self.logger.debug(
|
72 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
73 |
+
)
|
74 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
75 |
+
|
76 |
+
# setup data_loader
|
77 |
+
with self.accelerator.main_process_first():
|
78 |
+
self.logger.info("Building dataset...")
|
79 |
+
start = time.monotonic_ns()
|
80 |
+
self.test_dataloader = self._build_dataloader()
|
81 |
+
end = time.monotonic_ns()
|
82 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
83 |
+
|
84 |
+
# setup model
|
85 |
+
with self.accelerator.main_process_first():
|
86 |
+
self.logger.info("Building model...")
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self.model = self._build_model()
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
# self.logger.debug(self.model)
|
91 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms")
|
92 |
+
|
93 |
+
# init with accelerate
|
94 |
+
self.logger.info("Initializing accelerate...")
|
95 |
+
start = time.monotonic_ns()
|
96 |
+
self.accelerator = accelerate.Accelerator()
|
97 |
+
self.model = self.accelerator.prepare(self.model)
|
98 |
+
end = time.monotonic_ns()
|
99 |
+
self.accelerator.wait_for_everyone()
|
100 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms")
|
101 |
+
|
102 |
+
with self.accelerator.main_process_first():
|
103 |
+
self.logger.info("Loading checkpoint...")
|
104 |
+
start = time.monotonic_ns()
|
105 |
+
# TODO: Also, suppose only use latest one yet
|
106 |
+
self.__load_model(os.path.join(args.acoustics_dir, "checkpoint"))
|
107 |
+
end = time.monotonic_ns()
|
108 |
+
self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms")
|
109 |
+
|
110 |
+
self.model.eval()
|
111 |
+
self.accelerator.wait_for_everyone()
|
112 |
+
|
113 |
+
### Abstract methods ###
|
114 |
+
@abstractmethod
|
115 |
+
def _build_test_dataset(self):
|
116 |
+
pass
|
117 |
+
|
118 |
+
@abstractmethod
|
119 |
+
def _build_model(self):
|
120 |
+
pass
|
121 |
+
|
122 |
+
@abstractmethod
|
123 |
+
@torch.inference_mode()
|
124 |
+
def _inference_each_batch(self, batch_data):
|
125 |
+
pass
|
126 |
+
|
127 |
+
### Abstract methods end ###
|
128 |
+
|
129 |
+
@torch.inference_mode()
|
130 |
+
def inference(self):
|
131 |
+
for i, batch in enumerate(self.test_dataloader):
|
132 |
+
y_pred = self._inference_each_batch(batch).cpu()
|
133 |
+
mel_min, mel_max = self.test_dataset.target_mel_extrema
|
134 |
+
y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min
|
135 |
+
y_ls = y_pred.chunk(self.test_batch_size)
|
136 |
+
tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size)
|
137 |
+
j = 0
|
138 |
+
for it, l in zip(y_ls, tgt_ls):
|
139 |
+
l = l.item()
|
140 |
+
it = it.squeeze(0)[:l]
|
141 |
+
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"]
|
142 |
+
torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt"))
|
143 |
+
j += 1
|
144 |
+
|
145 |
+
vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir)
|
146 |
+
|
147 |
+
res = synthesis(
|
148 |
+
cfg=vocoder_cfg,
|
149 |
+
vocoder_weight_file=vocoder_ckpt,
|
150 |
+
n_samples=None,
|
151 |
+
pred=[
|
152 |
+
torch.load(
|
153 |
+
os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"]))
|
154 |
+
).numpy(force=True)
|
155 |
+
for i in self.test_dataset.metadata
|
156 |
+
],
|
157 |
+
)
|
158 |
+
|
159 |
+
output_audio_files = []
|
160 |
+
for it, wav in zip(self.test_dataset.metadata, res):
|
161 |
+
uid = it["Uid"]
|
162 |
+
file = os.path.join(self.args.output_dir, f"{uid}.wav")
|
163 |
+
output_audio_files.append(file)
|
164 |
+
|
165 |
+
wav = wav.numpy(force=True)
|
166 |
+
save_audio(
|
167 |
+
file,
|
168 |
+
wav,
|
169 |
+
self.cfg.preprocess.sample_rate,
|
170 |
+
add_silence=False,
|
171 |
+
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate),
|
172 |
+
)
|
173 |
+
os.remove(os.path.join(self.args.output_dir, f"{uid}.pt"))
|
174 |
+
|
175 |
+
return sorted(output_audio_files)
|
176 |
+
|
177 |
+
# TODO: LEGACY CODE
|
178 |
+
def _build_dataloader(self):
|
179 |
+
datasets, collate = self._build_test_dataset()
|
180 |
+
self.test_dataset = datasets(self.args, self.cfg, self.infer_type)
|
181 |
+
self.test_collate = collate(self.cfg)
|
182 |
+
self.test_batch_size = min(
|
183 |
+
self.cfg.train.batch_size, len(self.test_dataset.metadata)
|
184 |
+
)
|
185 |
+
test_dataloader = DataLoader(
|
186 |
+
self.test_dataset,
|
187 |
+
collate_fn=self.test_collate,
|
188 |
+
num_workers=1,
|
189 |
+
batch_size=self.test_batch_size,
|
190 |
+
shuffle=False,
|
191 |
+
)
|
192 |
+
return test_dataloader
|
193 |
+
|
194 |
+
def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None):
|
195 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
196 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
197 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
198 |
+
method after** ``accelerator.prepare()``.
|
199 |
+
"""
|
200 |
+
if checkpoint_path is None:
|
201 |
+
ls = []
|
202 |
+
for i in Path(checkpoint_dir).iterdir():
|
203 |
+
if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)):
|
204 |
+
ls.append(i)
|
205 |
+
ls.sort(
|
206 |
+
key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True
|
207 |
+
)
|
208 |
+
checkpoint_path = ls[0]
|
209 |
+
else:
|
210 |
+
checkpoint_path = Path(checkpoint_path)
|
211 |
+
self.accelerator.load_state(str(checkpoint_path))
|
212 |
+
# set epoch and step
|
213 |
+
self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1])
|
214 |
+
self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1])
|
215 |
+
return str(checkpoint_path)
|
216 |
+
|
217 |
+
@staticmethod
|
218 |
+
def _set_random_seed(seed):
|
219 |
+
r"""Set random seed for all possible random modules."""
|
220 |
+
random.seed(seed)
|
221 |
+
np.random.seed(seed)
|
222 |
+
torch.random.manual_seed(seed)
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def _parse_vocoder(vocoder_dir):
|
226 |
+
r"""Parse vocoder config"""
|
227 |
+
vocoder_dir = os.path.abspath(vocoder_dir)
|
228 |
+
ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")]
|
229 |
+
ckpt_list.sort(key=lambda x: int(x.stem), reverse=True)
|
230 |
+
ckpt_path = str(ckpt_list[0])
|
231 |
+
vocoder_cfg = load_config(
|
232 |
+
os.path.join(vocoder_dir, "args.json"), lowercase=True
|
233 |
+
)
|
234 |
+
return vocoder_cfg, ckpt_path
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def __count_parameters(model):
|
238 |
+
return sum(p.numel() for p in model.parameters())
|
239 |
+
|
240 |
+
def __dump_cfg(self, path):
|
241 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
242 |
+
json5.dump(
|
243 |
+
self.cfg,
|
244 |
+
open(path, "w"),
|
245 |
+
indent=4,
|
246 |
+
sort_keys=True,
|
247 |
+
ensure_ascii=False,
|
248 |
+
quote_keys=True,
|
249 |
+
)
|
models/base/new_trainer.py
ADDED
@@ -0,0 +1,722 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import shutil
|
10 |
+
import time
|
11 |
+
from abc import abstractmethod
|
12 |
+
from pathlib import Path
|
13 |
+
|
14 |
+
import accelerate
|
15 |
+
import json5
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from accelerate.logging import get_logger
|
19 |
+
from accelerate.utils import ProjectConfiguration
|
20 |
+
from torch.utils.data import ConcatDataset, DataLoader
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from models.base.base_sampler import build_samplers
|
24 |
+
from optimizer.optimizers import NoamLR
|
25 |
+
|
26 |
+
|
27 |
+
class BaseTrainer(object):
|
28 |
+
r"""The base trainer for all tasks. Any trainer should inherit from this class."""
|
29 |
+
|
30 |
+
def __init__(self, args=None, cfg=None):
|
31 |
+
super().__init__()
|
32 |
+
|
33 |
+
self.args = args
|
34 |
+
self.cfg = cfg
|
35 |
+
|
36 |
+
cfg.exp_name = args.exp_name
|
37 |
+
|
38 |
+
# init with accelerate
|
39 |
+
self._init_accelerator()
|
40 |
+
self.accelerator.wait_for_everyone()
|
41 |
+
|
42 |
+
# Use accelerate logger for distributed training
|
43 |
+
with self.accelerator.main_process_first():
|
44 |
+
self.logger = get_logger(args.exp_name, log_level=args.log_level)
|
45 |
+
|
46 |
+
# Log some info
|
47 |
+
self.logger.info("=" * 56)
|
48 |
+
self.logger.info("||\t\t" + "New training process started." + "\t\t||")
|
49 |
+
self.logger.info("=" * 56)
|
50 |
+
self.logger.info("\n")
|
51 |
+
self.logger.debug(f"Using {args.log_level.upper()} logging level.")
|
52 |
+
self.logger.info(f"Experiment name: {args.exp_name}")
|
53 |
+
self.logger.info(f"Experiment directory: {self.exp_dir}")
|
54 |
+
self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint")
|
55 |
+
if self.accelerator.is_main_process:
|
56 |
+
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
57 |
+
self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}")
|
58 |
+
|
59 |
+
# init counts
|
60 |
+
self.batch_count: int = 0
|
61 |
+
self.step: int = 0
|
62 |
+
self.epoch: int = 0
|
63 |
+
self.max_epoch = (
|
64 |
+
self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf")
|
65 |
+
)
|
66 |
+
self.logger.info(
|
67 |
+
"Max epoch: {}".format(
|
68 |
+
self.max_epoch if self.max_epoch < float("inf") else "Unlimited"
|
69 |
+
)
|
70 |
+
)
|
71 |
+
|
72 |
+
# Check values
|
73 |
+
if self.accelerator.is_main_process:
|
74 |
+
self.__check_basic_configs()
|
75 |
+
# Set runtime configs
|
76 |
+
self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride
|
77 |
+
self.checkpoints_path = [
|
78 |
+
[] for _ in range(len(self.save_checkpoint_stride))
|
79 |
+
]
|
80 |
+
self.keep_last = [
|
81 |
+
i if i > 0 else float("inf") for i in self.cfg.train.keep_last
|
82 |
+
]
|
83 |
+
self.run_eval = self.cfg.train.run_eval
|
84 |
+
|
85 |
+
# set random seed
|
86 |
+
with self.accelerator.main_process_first():
|
87 |
+
start = time.monotonic_ns()
|
88 |
+
self._set_random_seed(self.cfg.train.random_seed)
|
89 |
+
end = time.monotonic_ns()
|
90 |
+
self.logger.debug(
|
91 |
+
f"Setting random seed done in {(end - start) / 1e6:.2f}ms"
|
92 |
+
)
|
93 |
+
self.logger.debug(f"Random seed: {self.cfg.train.random_seed}")
|
94 |
+
|
95 |
+
# setup data_loader
|
96 |
+
with self.accelerator.main_process_first():
|
97 |
+
self.logger.info("Building dataset...")
|
98 |
+
start = time.monotonic_ns()
|
99 |
+
self.train_dataloader, self.valid_dataloader = self._build_dataloader()
|
100 |
+
end = time.monotonic_ns()
|
101 |
+
self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms")
|
102 |
+
|
103 |
+
# setup model
|
104 |
+
with self.accelerator.main_process_first():
|
105 |
+
self.logger.info("Building model...")
|
106 |
+
start = time.monotonic_ns()
|
107 |
+
self.model = self._build_model()
|
108 |
+
end = time.monotonic_ns()
|
109 |
+
self.logger.debug(self.model)
|
110 |
+
self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms")
|
111 |
+
self.logger.info(
|
112 |
+
f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M"
|
113 |
+
)
|
114 |
+
# optimizer & scheduler
|
115 |
+
with self.accelerator.main_process_first():
|
116 |
+
self.logger.info("Building optimizer and scheduler...")
|
117 |
+
start = time.monotonic_ns()
|
118 |
+
self.optimizer = self.__build_optimizer()
|
119 |
+
self.scheduler = self.__build_scheduler()
|
120 |
+
end = time.monotonic_ns()
|
121 |
+
self.logger.info(
|
122 |
+
f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms"
|
123 |
+
)
|
124 |
+
|
125 |
+
# accelerate prepare
|
126 |
+
self.logger.info("Initializing accelerate...")
|
127 |
+
start = time.monotonic_ns()
|
128 |
+
(
|
129 |
+
self.train_dataloader,
|
130 |
+
self.valid_dataloader,
|
131 |
+
self.model,
|
132 |
+
self.optimizer,
|
133 |
+
self.scheduler,
|
134 |
+
) = self.accelerator.prepare(
|
135 |
+
self.train_dataloader,
|
136 |
+
self.valid_dataloader,
|
137 |
+
self.model,
|
138 |
+
self.optimizer,
|
139 |
+
self.scheduler,
|
140 |
+
)
|
141 |
+
end = time.monotonic_ns()
|
142 |
+
self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms")
|
143 |
+
|
144 |
+
# create criterion
|
145 |
+
with self.accelerator.main_process_first():
|
146 |
+
self.logger.info("Building criterion...")
|
147 |
+
start = time.monotonic_ns()
|
148 |
+
self.criterion = self._build_criterion()
|
149 |
+
end = time.monotonic_ns()
|
150 |
+
self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms")
|
151 |
+
|
152 |
+
# Resume or Finetune
|
153 |
+
with self.accelerator.main_process_first():
|
154 |
+
if args.resume:
|
155 |
+
## Automatically resume according to the current exprimental name
|
156 |
+
self.logger.info("Resuming from {}...".format(self.checkpoint_dir))
|
157 |
+
start = time.monotonic_ns()
|
158 |
+
ckpt_path = self.__load_model(
|
159 |
+
checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type
|
160 |
+
)
|
161 |
+
end = time.monotonic_ns()
|
162 |
+
self.logger.info(
|
163 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
164 |
+
)
|
165 |
+
self.checkpoints_path = json.load(
|
166 |
+
open(os.path.join(ckpt_path, "ckpts.json"), "r")
|
167 |
+
)
|
168 |
+
elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "":
|
169 |
+
## Resume from the given checkpoint path
|
170 |
+
if not os.path.exists(args.resume_from_ckpt_path):
|
171 |
+
raise ValueError(
|
172 |
+
"[Error] The resumed checkpoint path {} don't exist.".format(
|
173 |
+
args.resume_from_ckpt_path
|
174 |
+
)
|
175 |
+
)
|
176 |
+
|
177 |
+
self.logger.info(
|
178 |
+
"Resuming from {}...".format(args.resume_from_ckpt_path)
|
179 |
+
)
|
180 |
+
start = time.monotonic_ns()
|
181 |
+
ckpt_path = self.__load_model(
|
182 |
+
checkpoint_path=args.resume_from_ckpt_path,
|
183 |
+
resume_type=args.resume_type,
|
184 |
+
)
|
185 |
+
end = time.monotonic_ns()
|
186 |
+
self.logger.info(
|
187 |
+
f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms"
|
188 |
+
)
|
189 |
+
|
190 |
+
# save config file path
|
191 |
+
self.config_save_path = os.path.join(self.exp_dir, "args.json")
|
192 |
+
|
193 |
+
### Following are abstract methods that should be implemented in child classes ###
|
194 |
+
@abstractmethod
|
195 |
+
def _build_dataset(self):
|
196 |
+
r"""Build dataset for model training/validating/evaluating."""
|
197 |
+
pass
|
198 |
+
|
199 |
+
@staticmethod
|
200 |
+
@abstractmethod
|
201 |
+
def _build_criterion():
|
202 |
+
r"""Build criterion function for model loss calculation."""
|
203 |
+
pass
|
204 |
+
|
205 |
+
@abstractmethod
|
206 |
+
def _build_model(self):
|
207 |
+
r"""Build model for training/validating/evaluating."""
|
208 |
+
pass
|
209 |
+
|
210 |
+
@abstractmethod
|
211 |
+
def _forward_step(self, batch):
|
212 |
+
r"""One forward step of the neural network. This abstract method is trying to
|
213 |
+
unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation.
|
214 |
+
However, for special case that using different forward step pattern for
|
215 |
+
training and validating, you could just override this method with ``pass`` and
|
216 |
+
implement ``_train_step`` and ``_valid_step`` separately.
|
217 |
+
"""
|
218 |
+
pass
|
219 |
+
|
220 |
+
@abstractmethod
|
221 |
+
def _save_auxiliary_states(self):
|
222 |
+
r"""To save some auxiliary states when saving model's ckpt"""
|
223 |
+
pass
|
224 |
+
|
225 |
+
### Abstract methods end ###
|
226 |
+
|
227 |
+
### THIS IS MAIN ENTRY ###
|
228 |
+
def train_loop(self):
|
229 |
+
r"""Training loop. The public entry of training process."""
|
230 |
+
# Wait everyone to prepare before we move on
|
231 |
+
self.accelerator.wait_for_everyone()
|
232 |
+
# dump config file
|
233 |
+
if self.accelerator.is_main_process:
|
234 |
+
self.__dump_cfg(self.config_save_path)
|
235 |
+
self.model.train()
|
236 |
+
self.optimizer.zero_grad()
|
237 |
+
# Wait to ensure good to go
|
238 |
+
self.accelerator.wait_for_everyone()
|
239 |
+
while self.epoch < self.max_epoch:
|
240 |
+
self.logger.info("\n")
|
241 |
+
self.logger.info("-" * 32)
|
242 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
243 |
+
|
244 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
245 |
+
### It's inconvenient for the model with multiple losses
|
246 |
+
# Do training & validating epoch
|
247 |
+
train_loss = self._train_epoch()
|
248 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
249 |
+
valid_loss = self._valid_epoch()
|
250 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
251 |
+
self.accelerator.log(
|
252 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
253 |
+
step=self.epoch,
|
254 |
+
)
|
255 |
+
|
256 |
+
self.accelerator.wait_for_everyone()
|
257 |
+
# TODO: what is scheduler?
|
258 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
259 |
+
|
260 |
+
# Check if hit save_checkpoint_stride and run_eval
|
261 |
+
run_eval = False
|
262 |
+
if self.accelerator.is_main_process:
|
263 |
+
save_checkpoint = False
|
264 |
+
hit_dix = []
|
265 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
266 |
+
if self.epoch % num == 0:
|
267 |
+
save_checkpoint = True
|
268 |
+
hit_dix.append(i)
|
269 |
+
run_eval |= self.run_eval[i]
|
270 |
+
|
271 |
+
self.accelerator.wait_for_everyone()
|
272 |
+
if self.accelerator.is_main_process and save_checkpoint:
|
273 |
+
path = os.path.join(
|
274 |
+
self.checkpoint_dir,
|
275 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
276 |
+
self.epoch, self.step, train_loss
|
277 |
+
),
|
278 |
+
)
|
279 |
+
self.tmp_checkpoint_save_path = path
|
280 |
+
self.accelerator.save_state(path)
|
281 |
+
print(f"save checkpoint in {path}")
|
282 |
+
json.dump(
|
283 |
+
self.checkpoints_path,
|
284 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
285 |
+
ensure_ascii=False,
|
286 |
+
indent=4,
|
287 |
+
)
|
288 |
+
self._save_auxiliary_states()
|
289 |
+
|
290 |
+
# Remove old checkpoints
|
291 |
+
to_remove = []
|
292 |
+
for idx in hit_dix:
|
293 |
+
self.checkpoints_path[idx].append(path)
|
294 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
295 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
296 |
+
|
297 |
+
# Search conflicts
|
298 |
+
total = set()
|
299 |
+
for i in self.checkpoints_path:
|
300 |
+
total |= set(i)
|
301 |
+
do_remove = set()
|
302 |
+
for idx, path in to_remove[::-1]:
|
303 |
+
if path in total:
|
304 |
+
self.checkpoints_path[idx].insert(0, path)
|
305 |
+
else:
|
306 |
+
do_remove.add(path)
|
307 |
+
|
308 |
+
# Remove old checkpoints
|
309 |
+
for path in do_remove:
|
310 |
+
shutil.rmtree(path, ignore_errors=True)
|
311 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
312 |
+
|
313 |
+
self.accelerator.wait_for_everyone()
|
314 |
+
if run_eval:
|
315 |
+
# TODO: run evaluation
|
316 |
+
pass
|
317 |
+
|
318 |
+
# Update info for each epoch
|
319 |
+
self.epoch += 1
|
320 |
+
|
321 |
+
# Finish training and save final checkpoint
|
322 |
+
self.accelerator.wait_for_everyone()
|
323 |
+
if self.accelerator.is_main_process:
|
324 |
+
self.accelerator.save_state(
|
325 |
+
os.path.join(
|
326 |
+
self.checkpoint_dir,
|
327 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
328 |
+
self.epoch, self.step, valid_loss
|
329 |
+
),
|
330 |
+
)
|
331 |
+
)
|
332 |
+
self._save_auxiliary_states()
|
333 |
+
|
334 |
+
self.accelerator.end_training()
|
335 |
+
|
336 |
+
### Following are methods that can be used directly in child classes ###
|
337 |
+
def _train_epoch(self):
|
338 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
339 |
+
one epoch. See ``train_loop`` for usage.
|
340 |
+
"""
|
341 |
+
self.model.train()
|
342 |
+
epoch_sum_loss: float = 0.0
|
343 |
+
epoch_step: int = 0
|
344 |
+
for batch in tqdm(
|
345 |
+
self.train_dataloader,
|
346 |
+
desc=f"Training Epoch {self.epoch}",
|
347 |
+
unit="batch",
|
348 |
+
colour="GREEN",
|
349 |
+
leave=False,
|
350 |
+
dynamic_ncols=True,
|
351 |
+
smoothing=0.04,
|
352 |
+
disable=not self.accelerator.is_main_process,
|
353 |
+
):
|
354 |
+
# Do training step and BP
|
355 |
+
with self.accelerator.accumulate(self.model):
|
356 |
+
loss = self._train_step(batch)
|
357 |
+
self.accelerator.backward(loss)
|
358 |
+
self.optimizer.step()
|
359 |
+
self.optimizer.zero_grad()
|
360 |
+
self.batch_count += 1
|
361 |
+
|
362 |
+
# Update info for each step
|
363 |
+
# TODO: step means BP counts or batch counts?
|
364 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
365 |
+
epoch_sum_loss += loss
|
366 |
+
self.accelerator.log(
|
367 |
+
{
|
368 |
+
"Step/Train Loss": loss,
|
369 |
+
"Step/Learning Rate": self.optimizer.param_groups[0]["lr"],
|
370 |
+
},
|
371 |
+
step=self.step,
|
372 |
+
)
|
373 |
+
self.step += 1
|
374 |
+
epoch_step += 1
|
375 |
+
|
376 |
+
self.accelerator.wait_for_everyone()
|
377 |
+
return (
|
378 |
+
epoch_sum_loss
|
379 |
+
/ len(self.train_dataloader)
|
380 |
+
* self.cfg.train.gradient_accumulation_step
|
381 |
+
)
|
382 |
+
|
383 |
+
@torch.inference_mode()
|
384 |
+
def _valid_epoch(self):
|
385 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
386 |
+
one epoch. See ``train_loop`` for usage.
|
387 |
+
"""
|
388 |
+
self.model.eval()
|
389 |
+
epoch_sum_loss = 0.0
|
390 |
+
for batch in tqdm(
|
391 |
+
self.valid_dataloader,
|
392 |
+
desc=f"Validating Epoch {self.epoch}",
|
393 |
+
unit="batch",
|
394 |
+
colour="GREEN",
|
395 |
+
leave=False,
|
396 |
+
dynamic_ncols=True,
|
397 |
+
smoothing=0.04,
|
398 |
+
disable=not self.accelerator.is_main_process,
|
399 |
+
):
|
400 |
+
batch_loss = self._valid_step(batch)
|
401 |
+
epoch_sum_loss += batch_loss.item()
|
402 |
+
|
403 |
+
self.accelerator.wait_for_everyone()
|
404 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
405 |
+
|
406 |
+
def _train_step(self, batch):
|
407 |
+
r"""Training forward step. Should return average loss of a sample over
|
408 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
409 |
+
See ``_train_epoch`` for usage.
|
410 |
+
"""
|
411 |
+
return self._forward_step(batch)
|
412 |
+
|
413 |
+
@torch.inference_mode()
|
414 |
+
def _valid_step(self, batch):
|
415 |
+
r"""Testing forward step. Should return average loss of a sample over
|
416 |
+
one batch. Provoke ``_forward_step`` is recommended except for special case.
|
417 |
+
See ``_test_epoch`` for usage.
|
418 |
+
"""
|
419 |
+
return self._forward_step(batch)
|
420 |
+
|
421 |
+
def __load_model(
|
422 |
+
self,
|
423 |
+
checkpoint_dir: str = None,
|
424 |
+
checkpoint_path: str = None,
|
425 |
+
resume_type: str = "",
|
426 |
+
):
|
427 |
+
r"""Load model from checkpoint. If checkpoint_path is None, it will
|
428 |
+
load the latest checkpoint in checkpoint_dir. If checkpoint_path is not
|
429 |
+
None, it will load the checkpoint specified by checkpoint_path. **Only use this
|
430 |
+
method after** ``accelerator.prepare()``.
|
431 |
+
"""
|
432 |
+
if checkpoint_path is None:
|
433 |
+
ls = [str(i) for i in Path(checkpoint_dir).glob("*")]
|
434 |
+
ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True)
|
435 |
+
checkpoint_path = ls[0]
|
436 |
+
self.logger.info("Resume from {}...".format(checkpoint_path))
|
437 |
+
|
438 |
+
if resume_type in ["resume", ""]:
|
439 |
+
# Load all the things, including model weights, optimizer, scheduler, and random states.
|
440 |
+
self.accelerator.load_state(input_dir=checkpoint_path)
|
441 |
+
|
442 |
+
# set epoch and step
|
443 |
+
self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1
|
444 |
+
self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1
|
445 |
+
|
446 |
+
elif resume_type == "finetune":
|
447 |
+
# Load only the model weights
|
448 |
+
accelerate.load_checkpoint_and_dispatch(
|
449 |
+
self.accelerator.unwrap_model(self.model),
|
450 |
+
os.path.join(checkpoint_path, "pytorch_model.bin"),
|
451 |
+
)
|
452 |
+
self.logger.info("Load model weights for finetune...")
|
453 |
+
|
454 |
+
else:
|
455 |
+
raise ValueError("Resume_type must be `resume` or `finetune`.")
|
456 |
+
|
457 |
+
return checkpoint_path
|
458 |
+
|
459 |
+
# TODO: LEGACY CODE
|
460 |
+
def _build_dataloader(self):
|
461 |
+
Dataset, Collator = self._build_dataset()
|
462 |
+
|
463 |
+
# build dataset instance for each dataset and combine them by ConcatDataset
|
464 |
+
datasets_list = []
|
465 |
+
for dataset in self.cfg.dataset:
|
466 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=False)
|
467 |
+
datasets_list.append(subdataset)
|
468 |
+
train_dataset = ConcatDataset(datasets_list)
|
469 |
+
train_collate = Collator(self.cfg)
|
470 |
+
_, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train")
|
471 |
+
self.logger.debug(f"train batch_sampler: {list(batch_sampler)}")
|
472 |
+
self.logger.debug(f"length: {train_dataset.cumulative_sizes}")
|
473 |
+
# TODO: use config instead of (sampler, shuffle, drop_last, batch_size)
|
474 |
+
train_loader = DataLoader(
|
475 |
+
train_dataset,
|
476 |
+
collate_fn=train_collate,
|
477 |
+
batch_sampler=batch_sampler,
|
478 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
479 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
480 |
+
)
|
481 |
+
|
482 |
+
# Build valid dataloader
|
483 |
+
datasets_list = []
|
484 |
+
for dataset in self.cfg.dataset:
|
485 |
+
subdataset = Dataset(self.cfg, dataset, is_valid=True)
|
486 |
+
datasets_list.append(subdataset)
|
487 |
+
valid_dataset = ConcatDataset(datasets_list)
|
488 |
+
valid_collate = Collator(self.cfg)
|
489 |
+
_, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid")
|
490 |
+
self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}")
|
491 |
+
self.logger.debug(f"length: {valid_dataset.cumulative_sizes}")
|
492 |
+
valid_loader = DataLoader(
|
493 |
+
valid_dataset,
|
494 |
+
collate_fn=valid_collate,
|
495 |
+
batch_sampler=batch_sampler,
|
496 |
+
num_workers=self.cfg.train.dataloader.num_worker,
|
497 |
+
pin_memory=self.cfg.train.dataloader.pin_memory,
|
498 |
+
)
|
499 |
+
return train_loader, valid_loader
|
500 |
+
|
501 |
+
@staticmethod
|
502 |
+
def _set_random_seed(seed):
|
503 |
+
r"""Set random seed for all possible random modules."""
|
504 |
+
random.seed(seed)
|
505 |
+
np.random.seed(seed)
|
506 |
+
torch.random.manual_seed(seed)
|
507 |
+
|
508 |
+
def _check_nan(self, loss, y_pred, y_gt):
|
509 |
+
if torch.any(torch.isnan(loss)):
|
510 |
+
self.logger.fatal("Fatal Error: Training is down since loss has Nan!")
|
511 |
+
self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True)
|
512 |
+
if torch.any(torch.isnan(y_pred)):
|
513 |
+
self.logger.error(
|
514 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
self.logger.debug(
|
518 |
+
f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True
|
519 |
+
)
|
520 |
+
if torch.any(torch.isnan(y_gt)):
|
521 |
+
self.logger.error(
|
522 |
+
f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
self.logger.debug(
|
526 |
+
f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True
|
527 |
+
)
|
528 |
+
if torch.any(torch.isnan(y_pred)):
|
529 |
+
self.logger.error(f"y_pred: {y_pred}", in_order=True)
|
530 |
+
else:
|
531 |
+
self.logger.debug(f"y_pred: {y_pred}", in_order=True)
|
532 |
+
if torch.any(torch.isnan(y_gt)):
|
533 |
+
self.logger.error(f"y_gt: {y_gt}", in_order=True)
|
534 |
+
else:
|
535 |
+
self.logger.debug(f"y_gt: {y_gt}", in_order=True)
|
536 |
+
|
537 |
+
# TODO: still OK to save tracking?
|
538 |
+
self.accelerator.end_training()
|
539 |
+
raise RuntimeError("Loss has Nan! See log for more info.")
|
540 |
+
|
541 |
+
### Protected methods end ###
|
542 |
+
|
543 |
+
## Following are private methods ##
|
544 |
+
## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed.
|
545 |
+
def __build_optimizer(self):
|
546 |
+
r"""Build optimizer for model."""
|
547 |
+
# Make case-insensitive matching
|
548 |
+
if self.cfg.train.optimizer.lower() == "adadelta":
|
549 |
+
optimizer = torch.optim.Adadelta(
|
550 |
+
self.model.parameters(), **self.cfg.train.adadelta
|
551 |
+
)
|
552 |
+
self.logger.info("Using Adadelta optimizer.")
|
553 |
+
elif self.cfg.train.optimizer.lower() == "adagrad":
|
554 |
+
optimizer = torch.optim.Adagrad(
|
555 |
+
self.model.parameters(), **self.cfg.train.adagrad
|
556 |
+
)
|
557 |
+
self.logger.info("Using Adagrad optimizer.")
|
558 |
+
elif self.cfg.train.optimizer.lower() == "adam":
|
559 |
+
optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam)
|
560 |
+
self.logger.info("Using Adam optimizer.")
|
561 |
+
elif self.cfg.train.optimizer.lower() == "adamw":
|
562 |
+
optimizer = torch.optim.AdamW(
|
563 |
+
self.model.parameters(), **self.cfg.train.adamw
|
564 |
+
)
|
565 |
+
elif self.cfg.train.optimizer.lower() == "sparseadam":
|
566 |
+
optimizer = torch.optim.SparseAdam(
|
567 |
+
self.model.parameters(), **self.cfg.train.sparseadam
|
568 |
+
)
|
569 |
+
elif self.cfg.train.optimizer.lower() == "adamax":
|
570 |
+
optimizer = torch.optim.Adamax(
|
571 |
+
self.model.parameters(), **self.cfg.train.adamax
|
572 |
+
)
|
573 |
+
elif self.cfg.train.optimizer.lower() == "asgd":
|
574 |
+
optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd)
|
575 |
+
elif self.cfg.train.optimizer.lower() == "lbfgs":
|
576 |
+
optimizer = torch.optim.LBFGS(
|
577 |
+
self.model.parameters(), **self.cfg.train.lbfgs
|
578 |
+
)
|
579 |
+
elif self.cfg.train.optimizer.lower() == "nadam":
|
580 |
+
optimizer = torch.optim.NAdam(
|
581 |
+
self.model.parameters(), **self.cfg.train.nadam
|
582 |
+
)
|
583 |
+
elif self.cfg.train.optimizer.lower() == "radam":
|
584 |
+
optimizer = torch.optim.RAdam(
|
585 |
+
self.model.parameters(), **self.cfg.train.radam
|
586 |
+
)
|
587 |
+
elif self.cfg.train.optimizer.lower() == "rmsprop":
|
588 |
+
optimizer = torch.optim.RMSprop(
|
589 |
+
self.model.parameters(), **self.cfg.train.rmsprop
|
590 |
+
)
|
591 |
+
elif self.cfg.train.optimizer.lower() == "rprop":
|
592 |
+
optimizer = torch.optim.Rprop(
|
593 |
+
self.model.parameters(), **self.cfg.train.rprop
|
594 |
+
)
|
595 |
+
elif self.cfg.train.optimizer.lower() == "sgd":
|
596 |
+
optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd)
|
597 |
+
else:
|
598 |
+
raise NotImplementedError(
|
599 |
+
f"Optimizer {self.cfg.train.optimizer} not supported yet!"
|
600 |
+
)
|
601 |
+
return optimizer
|
602 |
+
|
603 |
+
def __build_scheduler(self):
|
604 |
+
r"""Build scheduler for optimizer."""
|
605 |
+
# Make case-insensitive matching
|
606 |
+
if self.cfg.train.scheduler.lower() == "lambdalr":
|
607 |
+
scheduler = torch.optim.lr_scheduler.LambdaLR(
|
608 |
+
self.optimizer, **self.cfg.train.lambdalr
|
609 |
+
)
|
610 |
+
elif self.cfg.train.scheduler.lower() == "multiplicativelr":
|
611 |
+
scheduler = torch.optim.lr_scheduler.MultiplicativeLR(
|
612 |
+
self.optimizer, **self.cfg.train.multiplicativelr
|
613 |
+
)
|
614 |
+
elif self.cfg.train.scheduler.lower() == "steplr":
|
615 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
616 |
+
self.optimizer, **self.cfg.train.steplr
|
617 |
+
)
|
618 |
+
elif self.cfg.train.scheduler.lower() == "multisteplr":
|
619 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
620 |
+
self.optimizer, **self.cfg.train.multisteplr
|
621 |
+
)
|
622 |
+
elif self.cfg.train.scheduler.lower() == "constantlr":
|
623 |
+
scheduler = torch.optim.lr_scheduler.ConstantLR(
|
624 |
+
self.optimizer, **self.cfg.train.constantlr
|
625 |
+
)
|
626 |
+
elif self.cfg.train.scheduler.lower() == "linearlr":
|
627 |
+
scheduler = torch.optim.lr_scheduler.LinearLR(
|
628 |
+
self.optimizer, **self.cfg.train.linearlr
|
629 |
+
)
|
630 |
+
elif self.cfg.train.scheduler.lower() == "exponentiallr":
|
631 |
+
scheduler = torch.optim.lr_scheduler.ExponentialLR(
|
632 |
+
self.optimizer, **self.cfg.train.exponentiallr
|
633 |
+
)
|
634 |
+
elif self.cfg.train.scheduler.lower() == "polynomiallr":
|
635 |
+
scheduler = torch.optim.lr_scheduler.PolynomialLR(
|
636 |
+
self.optimizer, **self.cfg.train.polynomiallr
|
637 |
+
)
|
638 |
+
elif self.cfg.train.scheduler.lower() == "cosineannealinglr":
|
639 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
640 |
+
self.optimizer, **self.cfg.train.cosineannealinglr
|
641 |
+
)
|
642 |
+
elif self.cfg.train.scheduler.lower() == "sequentiallr":
|
643 |
+
scheduler = torch.optim.lr_scheduler.SequentialLR(
|
644 |
+
self.optimizer, **self.cfg.train.sequentiallr
|
645 |
+
)
|
646 |
+
elif self.cfg.train.scheduler.lower() == "reducelronplateau":
|
647 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
648 |
+
self.optimizer, **self.cfg.train.reducelronplateau
|
649 |
+
)
|
650 |
+
elif self.cfg.train.scheduler.lower() == "cycliclr":
|
651 |
+
scheduler = torch.optim.lr_scheduler.CyclicLR(
|
652 |
+
self.optimizer, **self.cfg.train.cycliclr
|
653 |
+
)
|
654 |
+
elif self.cfg.train.scheduler.lower() == "onecyclelr":
|
655 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
656 |
+
self.optimizer, **self.cfg.train.onecyclelr
|
657 |
+
)
|
658 |
+
elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts":
|
659 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
660 |
+
self.optimizer, **self.cfg.train.cosineannearingwarmrestarts
|
661 |
+
)
|
662 |
+
elif self.cfg.train.scheduler.lower() == "noamlr":
|
663 |
+
scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler)
|
664 |
+
else:
|
665 |
+
raise NotImplementedError(
|
666 |
+
f"Scheduler {self.cfg.train.scheduler} not supported yet!"
|
667 |
+
)
|
668 |
+
return scheduler
|
669 |
+
|
670 |
+
def _init_accelerator(self):
|
671 |
+
self.exp_dir = os.path.join(
|
672 |
+
os.path.abspath(self.cfg.log_dir), self.args.exp_name
|
673 |
+
)
|
674 |
+
project_config = ProjectConfiguration(
|
675 |
+
project_dir=self.exp_dir,
|
676 |
+
logging_dir=os.path.join(self.exp_dir, "log"),
|
677 |
+
)
|
678 |
+
self.accelerator = accelerate.Accelerator(
|
679 |
+
gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step,
|
680 |
+
log_with=self.cfg.train.tracker,
|
681 |
+
project_config=project_config,
|
682 |
+
)
|
683 |
+
if self.accelerator.is_main_process:
|
684 |
+
os.makedirs(project_config.project_dir, exist_ok=True)
|
685 |
+
os.makedirs(project_config.logging_dir, exist_ok=True)
|
686 |
+
with self.accelerator.main_process_first():
|
687 |
+
self.accelerator.init_trackers(self.args.exp_name)
|
688 |
+
|
689 |
+
def __check_basic_configs(self):
|
690 |
+
if self.cfg.train.gradient_accumulation_step <= 0:
|
691 |
+
self.logger.fatal("Invalid gradient_accumulation_step value!")
|
692 |
+
self.logger.error(
|
693 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
694 |
+
)
|
695 |
+
self.accelerator.end_training()
|
696 |
+
raise ValueError(
|
697 |
+
f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive."
|
698 |
+
)
|
699 |
+
# TODO: check other values
|
700 |
+
|
701 |
+
@staticmethod
|
702 |
+
def __count_parameters(model):
|
703 |
+
model_param = 0.0
|
704 |
+
if isinstance(model, dict):
|
705 |
+
for key, value in model.items():
|
706 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
707 |
+
else:
|
708 |
+
model_param = sum(p.numel() for p in model.parameters())
|
709 |
+
return model_param
|
710 |
+
|
711 |
+
def __dump_cfg(self, path):
|
712 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
713 |
+
json5.dump(
|
714 |
+
self.cfg,
|
715 |
+
open(path, "w"),
|
716 |
+
indent=4,
|
717 |
+
sort_keys=True,
|
718 |
+
ensure_ascii=False,
|
719 |
+
quote_keys=True,
|
720 |
+
)
|
721 |
+
|
722 |
+
### Private methods end ###
|
models/svc/__init__.py
ADDED
File without changes
|
models/svc/base/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .svc_inference import SVCInference
|
7 |
+
from .svc_trainer import SVCTrainer
|
models/svc/base/svc_dataset.py
ADDED
@@ -0,0 +1,425 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import random
|
7 |
+
import torch
|
8 |
+
from torch.nn.utils.rnn import pad_sequence
|
9 |
+
import json
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from utils.data_utils import *
|
13 |
+
from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema
|
14 |
+
from processors.content_extractor import (
|
15 |
+
ContentvecExtractor,
|
16 |
+
WhisperExtractor,
|
17 |
+
WenetExtractor,
|
18 |
+
)
|
19 |
+
from models.base.base_dataset import (
|
20 |
+
BaseCollator,
|
21 |
+
BaseDataset,
|
22 |
+
)
|
23 |
+
from models.base.new_dataset import BaseTestDataset
|
24 |
+
|
25 |
+
EPS = 1.0e-12
|
26 |
+
|
27 |
+
|
28 |
+
class SVCDataset(BaseDataset):
|
29 |
+
def __init__(self, cfg, dataset, is_valid=False):
|
30 |
+
BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid)
|
31 |
+
|
32 |
+
cfg = self.cfg
|
33 |
+
|
34 |
+
if cfg.model.condition_encoder.use_whisper:
|
35 |
+
self.whisper_aligner = WhisperExtractor(self.cfg)
|
36 |
+
self.utt2whisper_path = load_content_feature_path(
|
37 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
|
38 |
+
)
|
39 |
+
|
40 |
+
if cfg.model.condition_encoder.use_contentvec:
|
41 |
+
self.contentvec_aligner = ContentvecExtractor(self.cfg)
|
42 |
+
self.utt2contentVec_path = load_content_feature_path(
|
43 |
+
self.metadata,
|
44 |
+
cfg.preprocess.processed_dir,
|
45 |
+
cfg.preprocess.contentvec_dir,
|
46 |
+
)
|
47 |
+
|
48 |
+
if cfg.model.condition_encoder.use_mert:
|
49 |
+
self.utt2mert_path = load_content_feature_path(
|
50 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
|
51 |
+
)
|
52 |
+
if cfg.model.condition_encoder.use_wenet:
|
53 |
+
self.wenet_aligner = WenetExtractor(self.cfg)
|
54 |
+
self.utt2wenet_path = load_content_feature_path(
|
55 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
|
56 |
+
)
|
57 |
+
|
58 |
+
def __getitem__(self, index):
|
59 |
+
single_feature = BaseDataset.__getitem__(self, index)
|
60 |
+
|
61 |
+
utt_info = self.metadata[index]
|
62 |
+
dataset = utt_info["Dataset"]
|
63 |
+
uid = utt_info["Uid"]
|
64 |
+
utt = "{}_{}".format(dataset, uid)
|
65 |
+
|
66 |
+
if self.cfg.model.condition_encoder.use_whisper:
|
67 |
+
assert "target_len" in single_feature.keys()
|
68 |
+
aligned_whisper_feat = self.whisper_aligner.offline_align(
|
69 |
+
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
|
70 |
+
)
|
71 |
+
single_feature["whisper_feat"] = aligned_whisper_feat
|
72 |
+
|
73 |
+
if self.cfg.model.condition_encoder.use_contentvec:
|
74 |
+
assert "target_len" in single_feature.keys()
|
75 |
+
aligned_contentvec = self.contentvec_aligner.offline_align(
|
76 |
+
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
|
77 |
+
)
|
78 |
+
single_feature["contentvec_feat"] = aligned_contentvec
|
79 |
+
|
80 |
+
if self.cfg.model.condition_encoder.use_mert:
|
81 |
+
assert "target_len" in single_feature.keys()
|
82 |
+
aligned_mert_feat = align_content_feature_length(
|
83 |
+
np.load(self.utt2mert_path[utt]),
|
84 |
+
single_feature["target_len"],
|
85 |
+
source_hop=self.cfg.preprocess.mert_hop_size,
|
86 |
+
)
|
87 |
+
single_feature["mert_feat"] = aligned_mert_feat
|
88 |
+
|
89 |
+
if self.cfg.model.condition_encoder.use_wenet:
|
90 |
+
assert "target_len" in single_feature.keys()
|
91 |
+
aligned_wenet_feat = self.wenet_aligner.offline_align(
|
92 |
+
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
|
93 |
+
)
|
94 |
+
single_feature["wenet_feat"] = aligned_wenet_feat
|
95 |
+
|
96 |
+
# print(single_feature.keys())
|
97 |
+
# for k, v in single_feature.items():
|
98 |
+
# if type(v) in [torch.Tensor, np.ndarray]:
|
99 |
+
# print(k, v.shape)
|
100 |
+
# else:
|
101 |
+
# print(k, v)
|
102 |
+
# exit()
|
103 |
+
|
104 |
+
return self.clip_if_too_long(single_feature)
|
105 |
+
|
106 |
+
def __len__(self):
|
107 |
+
return len(self.metadata)
|
108 |
+
|
109 |
+
def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812):
|
110 |
+
"""
|
111 |
+
ending_ts: to avoid invalid whisper features for over 30s audios
|
112 |
+
2812 = 30 * 24000 // 256
|
113 |
+
"""
|
114 |
+
ts = max(feature_seq_len - max_seq_len, 0)
|
115 |
+
ts = min(ts, ending_ts - max_seq_len)
|
116 |
+
|
117 |
+
start = random.randint(0, ts)
|
118 |
+
end = start + max_seq_len
|
119 |
+
return start, end
|
120 |
+
|
121 |
+
def clip_if_too_long(self, sample, max_seq_len=512):
|
122 |
+
"""
|
123 |
+
sample :
|
124 |
+
{
|
125 |
+
'spk_id': (1,),
|
126 |
+
'target_len': int
|
127 |
+
'mel': (seq_len, dim),
|
128 |
+
'frame_pitch': (seq_len,)
|
129 |
+
'frame_energy': (seq_len,)
|
130 |
+
'content_vector_feat': (seq_len, dim)
|
131 |
+
}
|
132 |
+
"""
|
133 |
+
if sample["target_len"] <= max_seq_len:
|
134 |
+
return sample
|
135 |
+
|
136 |
+
start, end = self.random_select(sample["target_len"], max_seq_len)
|
137 |
+
sample["target_len"] = end - start
|
138 |
+
|
139 |
+
for k in sample.keys():
|
140 |
+
if k not in ["spk_id", "target_len"]:
|
141 |
+
sample[k] = sample[k][start:end]
|
142 |
+
|
143 |
+
return sample
|
144 |
+
|
145 |
+
|
146 |
+
class SVCCollator(BaseCollator):
|
147 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
148 |
+
|
149 |
+
def __init__(self, cfg):
|
150 |
+
BaseCollator.__init__(self, cfg)
|
151 |
+
|
152 |
+
def __call__(self, batch):
|
153 |
+
parsed_batch_features = BaseCollator.__call__(self, batch)
|
154 |
+
return parsed_batch_features
|
155 |
+
|
156 |
+
|
157 |
+
class SVCTestDataset(BaseTestDataset):
|
158 |
+
def __init__(self, args, cfg, infer_type):
|
159 |
+
BaseTestDataset.__init__(self, args, cfg, infer_type)
|
160 |
+
self.metadata = self.get_metadata()
|
161 |
+
|
162 |
+
target_singer = args.target_singer
|
163 |
+
self.cfg = cfg
|
164 |
+
self.trans_key = args.trans_key
|
165 |
+
assert type(target_singer) == str
|
166 |
+
|
167 |
+
self.target_singer = target_singer.split("_")[-1]
|
168 |
+
self.target_dataset = target_singer.replace(
|
169 |
+
"_{}".format(self.target_singer), ""
|
170 |
+
)
|
171 |
+
|
172 |
+
self.target_mel_extrema = load_mel_extrema(cfg.preprocess, self.target_dataset)
|
173 |
+
self.target_mel_extrema = torch.as_tensor(
|
174 |
+
self.target_mel_extrema[0]
|
175 |
+
), torch.as_tensor(self.target_mel_extrema[1])
|
176 |
+
|
177 |
+
######### Load source acoustic features #########
|
178 |
+
if cfg.preprocess.use_spkid:
|
179 |
+
spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id)
|
180 |
+
# utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk)
|
181 |
+
|
182 |
+
with open(spk2id_path, "r") as f:
|
183 |
+
self.spk2id = json.load(f)
|
184 |
+
# print("self.spk2id", self.spk2id)
|
185 |
+
|
186 |
+
if cfg.preprocess.use_uv:
|
187 |
+
self.utt2uv_path = {
|
188 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
189 |
+
cfg.preprocess.processed_dir,
|
190 |
+
utt_info["Dataset"],
|
191 |
+
cfg.preprocess.uv_dir,
|
192 |
+
utt_info["Uid"] + ".npy",
|
193 |
+
)
|
194 |
+
for utt_info in self.metadata
|
195 |
+
}
|
196 |
+
|
197 |
+
if cfg.preprocess.use_frame_pitch:
|
198 |
+
self.utt2frame_pitch_path = {
|
199 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
200 |
+
cfg.preprocess.processed_dir,
|
201 |
+
utt_info["Dataset"],
|
202 |
+
cfg.preprocess.pitch_dir,
|
203 |
+
utt_info["Uid"] + ".npy",
|
204 |
+
)
|
205 |
+
for utt_info in self.metadata
|
206 |
+
}
|
207 |
+
|
208 |
+
# Target F0 median
|
209 |
+
target_f0_statistics_path = os.path.join(
|
210 |
+
cfg.preprocess.processed_dir,
|
211 |
+
self.target_dataset,
|
212 |
+
cfg.preprocess.pitch_dir,
|
213 |
+
"statistics.json",
|
214 |
+
)
|
215 |
+
self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[
|
216 |
+
f"{self.target_dataset}_{self.target_singer}"
|
217 |
+
]["voiced_positions"]["median"]
|
218 |
+
|
219 |
+
# Source F0 median (if infer from file)
|
220 |
+
if infer_type == "from_file":
|
221 |
+
source_audio_name = cfg.inference.source_audio_name
|
222 |
+
source_f0_statistics_path = os.path.join(
|
223 |
+
cfg.preprocess.processed_dir,
|
224 |
+
source_audio_name,
|
225 |
+
cfg.preprocess.pitch_dir,
|
226 |
+
"statistics.json",
|
227 |
+
)
|
228 |
+
self.source_pitch_median = json.load(
|
229 |
+
open(source_f0_statistics_path, "r")
|
230 |
+
)[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][
|
231 |
+
"median"
|
232 |
+
]
|
233 |
+
else:
|
234 |
+
self.source_pitch_median = None
|
235 |
+
|
236 |
+
if cfg.preprocess.use_frame_energy:
|
237 |
+
self.utt2frame_energy_path = {
|
238 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
239 |
+
cfg.preprocess.processed_dir,
|
240 |
+
utt_info["Dataset"],
|
241 |
+
cfg.preprocess.energy_dir,
|
242 |
+
utt_info["Uid"] + ".npy",
|
243 |
+
)
|
244 |
+
for utt_info in self.metadata
|
245 |
+
}
|
246 |
+
|
247 |
+
if cfg.preprocess.use_mel:
|
248 |
+
self.utt2mel_path = {
|
249 |
+
f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join(
|
250 |
+
cfg.preprocess.processed_dir,
|
251 |
+
utt_info["Dataset"],
|
252 |
+
cfg.preprocess.mel_dir,
|
253 |
+
utt_info["Uid"] + ".npy",
|
254 |
+
)
|
255 |
+
for utt_info in self.metadata
|
256 |
+
}
|
257 |
+
|
258 |
+
######### Load source content features' path #########
|
259 |
+
if cfg.model.condition_encoder.use_whisper:
|
260 |
+
self.whisper_aligner = WhisperExtractor(cfg)
|
261 |
+
self.utt2whisper_path = load_content_feature_path(
|
262 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir
|
263 |
+
)
|
264 |
+
|
265 |
+
if cfg.model.condition_encoder.use_contentvec:
|
266 |
+
self.contentvec_aligner = ContentvecExtractor(cfg)
|
267 |
+
self.utt2contentVec_path = load_content_feature_path(
|
268 |
+
self.metadata,
|
269 |
+
cfg.preprocess.processed_dir,
|
270 |
+
cfg.preprocess.contentvec_dir,
|
271 |
+
)
|
272 |
+
|
273 |
+
if cfg.model.condition_encoder.use_mert:
|
274 |
+
self.utt2mert_path = load_content_feature_path(
|
275 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir
|
276 |
+
)
|
277 |
+
if cfg.model.condition_encoder.use_wenet:
|
278 |
+
self.wenet_aligner = WenetExtractor(cfg)
|
279 |
+
self.utt2wenet_path = load_content_feature_path(
|
280 |
+
self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir
|
281 |
+
)
|
282 |
+
|
283 |
+
def __getitem__(self, index):
|
284 |
+
single_feature = {}
|
285 |
+
|
286 |
+
utt_info = self.metadata[index]
|
287 |
+
dataset = utt_info["Dataset"]
|
288 |
+
uid = utt_info["Uid"]
|
289 |
+
utt = "{}_{}".format(dataset, uid)
|
290 |
+
|
291 |
+
source_dataset = self.metadata[index]["Dataset"]
|
292 |
+
|
293 |
+
if self.cfg.preprocess.use_spkid:
|
294 |
+
single_feature["spk_id"] = np.array(
|
295 |
+
[self.spk2id[f"{self.target_dataset}_{self.target_singer}"]],
|
296 |
+
dtype=np.int32,
|
297 |
+
)
|
298 |
+
|
299 |
+
######### Get Acoustic Features Item #########
|
300 |
+
if self.cfg.preprocess.use_mel:
|
301 |
+
mel = np.load(self.utt2mel_path[utt])
|
302 |
+
assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T]
|
303 |
+
if self.cfg.preprocess.use_min_max_norm_mel:
|
304 |
+
# mel norm
|
305 |
+
mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess)
|
306 |
+
|
307 |
+
if "target_len" not in single_feature.keys():
|
308 |
+
single_feature["target_len"] = mel.shape[1]
|
309 |
+
single_feature["mel"] = mel.T # [T, n_mels]
|
310 |
+
|
311 |
+
if self.cfg.preprocess.use_frame_pitch:
|
312 |
+
frame_pitch_path = self.utt2frame_pitch_path[utt]
|
313 |
+
frame_pitch = np.load(frame_pitch_path)
|
314 |
+
|
315 |
+
if self.trans_key:
|
316 |
+
try:
|
317 |
+
self.trans_key = int(self.trans_key)
|
318 |
+
except:
|
319 |
+
pass
|
320 |
+
if type(self.trans_key) == int:
|
321 |
+
frame_pitch = transpose_key(frame_pitch, self.trans_key)
|
322 |
+
elif self.trans_key:
|
323 |
+
assert self.target_singer
|
324 |
+
|
325 |
+
frame_pitch = pitch_shift_to_target(
|
326 |
+
frame_pitch, self.target_pitch_median, self.source_pitch_median
|
327 |
+
)
|
328 |
+
|
329 |
+
if "target_len" not in single_feature.keys():
|
330 |
+
single_feature["target_len"] = len(frame_pitch)
|
331 |
+
aligned_frame_pitch = align_length(
|
332 |
+
frame_pitch, single_feature["target_len"]
|
333 |
+
)
|
334 |
+
single_feature["frame_pitch"] = aligned_frame_pitch
|
335 |
+
|
336 |
+
if self.cfg.preprocess.use_uv:
|
337 |
+
frame_uv_path = self.utt2uv_path[utt]
|
338 |
+
frame_uv = np.load(frame_uv_path)
|
339 |
+
aligned_frame_uv = align_length(frame_uv, single_feature["target_len"])
|
340 |
+
aligned_frame_uv = [
|
341 |
+
0 if frame_uv else 1 for frame_uv in aligned_frame_uv
|
342 |
+
]
|
343 |
+
aligned_frame_uv = np.array(aligned_frame_uv)
|
344 |
+
single_feature["frame_uv"] = aligned_frame_uv
|
345 |
+
|
346 |
+
if self.cfg.preprocess.use_frame_energy:
|
347 |
+
frame_energy_path = self.utt2frame_energy_path[utt]
|
348 |
+
frame_energy = np.load(frame_energy_path)
|
349 |
+
if "target_len" not in single_feature.keys():
|
350 |
+
single_feature["target_len"] = len(frame_energy)
|
351 |
+
aligned_frame_energy = align_length(
|
352 |
+
frame_energy, single_feature["target_len"]
|
353 |
+
)
|
354 |
+
single_feature["frame_energy"] = aligned_frame_energy
|
355 |
+
|
356 |
+
######### Get Content Features Item #########
|
357 |
+
if self.cfg.model.condition_encoder.use_whisper:
|
358 |
+
assert "target_len" in single_feature.keys()
|
359 |
+
aligned_whisper_feat = self.whisper_aligner.offline_align(
|
360 |
+
np.load(self.utt2whisper_path[utt]), single_feature["target_len"]
|
361 |
+
)
|
362 |
+
single_feature["whisper_feat"] = aligned_whisper_feat
|
363 |
+
|
364 |
+
if self.cfg.model.condition_encoder.use_contentvec:
|
365 |
+
assert "target_len" in single_feature.keys()
|
366 |
+
aligned_contentvec = self.contentvec_aligner.offline_align(
|
367 |
+
np.load(self.utt2contentVec_path[utt]), single_feature["target_len"]
|
368 |
+
)
|
369 |
+
single_feature["contentvec_feat"] = aligned_contentvec
|
370 |
+
|
371 |
+
if self.cfg.model.condition_encoder.use_mert:
|
372 |
+
assert "target_len" in single_feature.keys()
|
373 |
+
aligned_mert_feat = align_content_feature_length(
|
374 |
+
np.load(self.utt2mert_path[utt]),
|
375 |
+
single_feature["target_len"],
|
376 |
+
source_hop=self.cfg.preprocess.mert_hop_size,
|
377 |
+
)
|
378 |
+
single_feature["mert_feat"] = aligned_mert_feat
|
379 |
+
|
380 |
+
if self.cfg.model.condition_encoder.use_wenet:
|
381 |
+
assert "target_len" in single_feature.keys()
|
382 |
+
aligned_wenet_feat = self.wenet_aligner.offline_align(
|
383 |
+
np.load(self.utt2wenet_path[utt]), single_feature["target_len"]
|
384 |
+
)
|
385 |
+
single_feature["wenet_feat"] = aligned_wenet_feat
|
386 |
+
|
387 |
+
return single_feature
|
388 |
+
|
389 |
+
def __len__(self):
|
390 |
+
return len(self.metadata)
|
391 |
+
|
392 |
+
|
393 |
+
class SVCTestCollator:
|
394 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
395 |
+
|
396 |
+
def __init__(self, cfg):
|
397 |
+
self.cfg = cfg
|
398 |
+
|
399 |
+
def __call__(self, batch):
|
400 |
+
packed_batch_features = dict()
|
401 |
+
|
402 |
+
# mel: [b, T, n_mels]
|
403 |
+
# frame_pitch, frame_energy: [1, T]
|
404 |
+
# target_len: [1]
|
405 |
+
# spk_id: [b, 1]
|
406 |
+
# mask: [b, T, 1]
|
407 |
+
|
408 |
+
for key in batch[0].keys():
|
409 |
+
if key == "target_len":
|
410 |
+
packed_batch_features["target_len"] = torch.LongTensor(
|
411 |
+
[b["target_len"] for b in batch]
|
412 |
+
)
|
413 |
+
masks = [
|
414 |
+
torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch
|
415 |
+
]
|
416 |
+
packed_batch_features["mask"] = pad_sequence(
|
417 |
+
masks, batch_first=True, padding_value=0
|
418 |
+
)
|
419 |
+
else:
|
420 |
+
values = [torch.from_numpy(b[key]) for b in batch]
|
421 |
+
packed_batch_features[key] = pad_sequence(
|
422 |
+
values, batch_first=True, padding_value=0
|
423 |
+
)
|
424 |
+
|
425 |
+
return packed_batch_features
|
models/svc/base/svc_inference.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from models.base.new_inference import BaseInference
|
7 |
+
from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset
|
8 |
+
|
9 |
+
|
10 |
+
class SVCInference(BaseInference):
|
11 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
12 |
+
BaseInference.__init__(self, args, cfg, infer_type)
|
13 |
+
|
14 |
+
def _build_test_dataset(self):
|
15 |
+
return SVCTestDataset, SVCTestCollator
|
models/svc/base/svc_trainer.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from models.base.new_trainer import BaseTrainer
|
13 |
+
from models.svc.base.svc_dataset import SVCCollator, SVCDataset
|
14 |
+
|
15 |
+
|
16 |
+
class SVCTrainer(BaseTrainer):
|
17 |
+
r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements
|
18 |
+
``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this
|
19 |
+
class, and implement ``_build_model``, ``_forward_step``.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, args=None, cfg=None):
|
23 |
+
self.args = args
|
24 |
+
self.cfg = cfg
|
25 |
+
|
26 |
+
self._init_accelerator()
|
27 |
+
|
28 |
+
# Only for SVC tasks
|
29 |
+
with self.accelerator.main_process_first():
|
30 |
+
self.singers = self._build_singer_lut()
|
31 |
+
|
32 |
+
# Super init
|
33 |
+
BaseTrainer.__init__(self, args, cfg)
|
34 |
+
|
35 |
+
# Only for SVC tasks
|
36 |
+
self.task_type = "SVC"
|
37 |
+
self.logger.info("Task type: {}".format(self.task_type))
|
38 |
+
|
39 |
+
### Following are methods only for SVC tasks ###
|
40 |
+
# TODO: LEGACY CODE, NEED TO BE REFACTORED
|
41 |
+
def _build_dataset(self):
|
42 |
+
return SVCDataset, SVCCollator
|
43 |
+
|
44 |
+
@staticmethod
|
45 |
+
def _build_criterion():
|
46 |
+
criterion = nn.MSELoss(reduction="none")
|
47 |
+
return criterion
|
48 |
+
|
49 |
+
@staticmethod
|
50 |
+
def _compute_loss(criterion, y_pred, y_gt, loss_mask):
|
51 |
+
"""
|
52 |
+
Args:
|
53 |
+
criterion: MSELoss(reduction='none')
|
54 |
+
y_pred, y_gt: (bs, seq_len, D)
|
55 |
+
loss_mask: (bs, seq_len, 1)
|
56 |
+
Returns:
|
57 |
+
loss: Tensor of shape []
|
58 |
+
"""
|
59 |
+
|
60 |
+
# (bs, seq_len, D)
|
61 |
+
loss = criterion(y_pred, y_gt)
|
62 |
+
# expand loss_mask to (bs, seq_len, D)
|
63 |
+
loss_mask = loss_mask.repeat(1, 1, loss.shape[-1])
|
64 |
+
|
65 |
+
loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask)
|
66 |
+
return loss
|
67 |
+
|
68 |
+
def _save_auxiliary_states(self):
|
69 |
+
"""
|
70 |
+
To save the singer's look-up table in the checkpoint saving path
|
71 |
+
"""
|
72 |
+
with open(
|
73 |
+
os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w"
|
74 |
+
) as f:
|
75 |
+
json.dump(self.singers, f, indent=4, ensure_ascii=False)
|
76 |
+
|
77 |
+
def _build_singer_lut(self):
|
78 |
+
resumed_singer_path = None
|
79 |
+
if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "":
|
80 |
+
resumed_singer_path = os.path.join(
|
81 |
+
self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id
|
82 |
+
)
|
83 |
+
if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)):
|
84 |
+
resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
85 |
+
|
86 |
+
if resumed_singer_path:
|
87 |
+
with open(resumed_singer_path, "r") as f:
|
88 |
+
singers = json.load(f)
|
89 |
+
else:
|
90 |
+
singers = dict()
|
91 |
+
|
92 |
+
for dataset in self.cfg.dataset:
|
93 |
+
singer_lut_path = os.path.join(
|
94 |
+
self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id
|
95 |
+
)
|
96 |
+
with open(singer_lut_path, "r") as singer_lut_path:
|
97 |
+
singer_lut = json.load(singer_lut_path)
|
98 |
+
for singer in singer_lut.keys():
|
99 |
+
if singer not in singers:
|
100 |
+
singers[singer] = len(singers)
|
101 |
+
|
102 |
+
with open(
|
103 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w"
|
104 |
+
) as singer_file:
|
105 |
+
json.dump(singers, singer_file, indent=4, ensure_ascii=False)
|
106 |
+
print(
|
107 |
+
"singers have been dumped to {}".format(
|
108 |
+
os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)
|
109 |
+
)
|
110 |
+
)
|
111 |
+
return singers
|
models/svc/comosvc/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
models/svc/comosvc/comosvc.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
"""Adapted from https://github.com/zhenye234/CoMoSpeech"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import copy
|
11 |
+
import numpy as np
|
12 |
+
import math
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
|
15 |
+
from utils.ssim import SSIM
|
16 |
+
|
17 |
+
from models.svc.transformer.conformer import Conformer, BaseModule
|
18 |
+
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
|
19 |
+
from models.svc.comosvc.utils import slice_segments, rand_ids_segments
|
20 |
+
|
21 |
+
|
22 |
+
class Consistency(nn.Module):
|
23 |
+
def __init__(self, cfg, distill=False):
|
24 |
+
super().__init__()
|
25 |
+
self.cfg = cfg
|
26 |
+
# self.denoise_fn = GradLogPEstimator2d(96)
|
27 |
+
self.denoise_fn = DiffusionWrapper(self.cfg)
|
28 |
+
self.cfg = cfg.model.comosvc
|
29 |
+
self.teacher = not distill
|
30 |
+
self.P_mean = self.cfg.P_mean
|
31 |
+
self.P_std = self.cfg.P_std
|
32 |
+
self.sigma_data = self.cfg.sigma_data
|
33 |
+
self.sigma_min = self.cfg.sigma_min
|
34 |
+
self.sigma_max = self.cfg.sigma_max
|
35 |
+
self.rho = self.cfg.rho
|
36 |
+
self.N = self.cfg.n_timesteps
|
37 |
+
self.ssim_loss = SSIM()
|
38 |
+
|
39 |
+
# Time step discretization
|
40 |
+
step_indices = torch.arange(self.N)
|
41 |
+
# karras boundaries formula
|
42 |
+
t_steps = (
|
43 |
+
self.sigma_min ** (1 / self.rho)
|
44 |
+
+ step_indices
|
45 |
+
/ (self.N - 1)
|
46 |
+
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
|
47 |
+
) ** self.rho
|
48 |
+
self.t_steps = torch.cat(
|
49 |
+
[torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)]
|
50 |
+
)
|
51 |
+
|
52 |
+
def init_consistency_training(self):
|
53 |
+
self.denoise_fn_ema = copy.deepcopy(self.denoise_fn)
|
54 |
+
self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn)
|
55 |
+
|
56 |
+
def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None):
|
57 |
+
"""
|
58 |
+
karras diffusion reverse process
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x: noisy mel-spectrogram [B x n_mel x L]
|
62 |
+
sigma: noise level [B x 1 x 1]
|
63 |
+
cond: output of conformer encoder [B x n_mel x L]
|
64 |
+
denoise_fn: denoiser neural network e.g. DilatedCNN
|
65 |
+
mask: mask of padded frames [B x n_mel x L]
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
denoised mel-spectrogram [B x n_mel x L]
|
69 |
+
"""
|
70 |
+
sigma = sigma.reshape(-1, 1, 1)
|
71 |
+
|
72 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
73 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
74 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
75 |
+
c_noise = sigma.log() / 4
|
76 |
+
|
77 |
+
x_in = c_in * x
|
78 |
+
x_in = x_in.transpose(1, 2)
|
79 |
+
x = x.transpose(1, 2)
|
80 |
+
cond = cond.transpose(1, 2)
|
81 |
+
F_x = denoise_fn(x_in, c_noise.squeeze(), cond)
|
82 |
+
# F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten())
|
83 |
+
D_x = c_skip * x + c_out * (F_x)
|
84 |
+
D_x = D_x.transpose(1, 2)
|
85 |
+
return D_x
|
86 |
+
|
87 |
+
def EDMLoss(self, x_start, cond, mask):
|
88 |
+
"""
|
89 |
+
compute loss for EDM model
|
90 |
+
|
91 |
+
Args:
|
92 |
+
x_start: ground truth mel-spectrogram [B x n_mel x L]
|
93 |
+
cond: output of conformer encoder [B x n_mel x L]
|
94 |
+
mask: mask of padded frames [B x n_mel x L]
|
95 |
+
"""
|
96 |
+
rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device)
|
97 |
+
sigma = (rnd_normal * self.P_std + self.P_mean).exp()
|
98 |
+
weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
|
99 |
+
|
100 |
+
# follow Grad-TTS, start from Gaussian noise with mean cond and std I
|
101 |
+
noise = (torch.randn_like(x_start) + cond) * sigma
|
102 |
+
D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask)
|
103 |
+
loss = weight * ((D_yn - x_start) ** 2)
|
104 |
+
loss = torch.sum(loss * mask) / torch.sum(mask)
|
105 |
+
return loss
|
106 |
+
|
107 |
+
def round_sigma(self, sigma):
|
108 |
+
return torch.as_tensor(sigma)
|
109 |
+
|
110 |
+
def edm_sampler(
|
111 |
+
self,
|
112 |
+
latents,
|
113 |
+
cond,
|
114 |
+
nonpadding,
|
115 |
+
num_steps=50,
|
116 |
+
sigma_min=0.002,
|
117 |
+
sigma_max=80,
|
118 |
+
rho=7,
|
119 |
+
S_churn=0,
|
120 |
+
S_min=0,
|
121 |
+
S_max=float("inf"),
|
122 |
+
S_noise=1,
|
123 |
+
# S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
|
124 |
+
# S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007,
|
125 |
+
# S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007,
|
126 |
+
# S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003,
|
127 |
+
):
|
128 |
+
"""
|
129 |
+
karras diffusion sampler
|
130 |
+
|
131 |
+
Args:
|
132 |
+
latents: noisy mel-spectrogram [B x n_mel x L]
|
133 |
+
cond: output of conformer encoder [B x n_mel x L]
|
134 |
+
nonpadding: mask of padded frames [B x n_mel x L]
|
135 |
+
num_steps: number of steps for diffusion inference
|
136 |
+
|
137 |
+
Returns:
|
138 |
+
denoised mel-spectrogram [B x n_mel x L]
|
139 |
+
"""
|
140 |
+
# Time step discretization.
|
141 |
+
step_indices = torch.arange(num_steps, device=latents.device)
|
142 |
+
|
143 |
+
num_steps = num_steps + 1
|
144 |
+
t_steps = (
|
145 |
+
sigma_max ** (1 / rho)
|
146 |
+
+ step_indices
|
147 |
+
/ (num_steps - 1)
|
148 |
+
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
|
149 |
+
) ** rho
|
150 |
+
t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])
|
151 |
+
|
152 |
+
# Main sampling loop.
|
153 |
+
x_next = latents * t_steps[0]
|
154 |
+
# wrap in tqdm for progress bar
|
155 |
+
bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:])))
|
156 |
+
for i, (t_cur, t_next) in bar:
|
157 |
+
x_cur = x_next
|
158 |
+
# Increase noise temporarily.
|
159 |
+
gamma = (
|
160 |
+
min(S_churn / num_steps, np.sqrt(2) - 1)
|
161 |
+
if S_min <= t_cur <= S_max
|
162 |
+
else 0
|
163 |
+
)
|
164 |
+
t_hat = self.round_sigma(t_cur + gamma * t_cur)
|
165 |
+
t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device)
|
166 |
+
t[:, 0, 0] = t_hat
|
167 |
+
t_hat = t
|
168 |
+
x_hat = x_cur + (
|
169 |
+
t_hat**2 - t_cur**2
|
170 |
+
).sqrt() * S_noise * torch.randn_like(x_cur)
|
171 |
+
# Euler step.
|
172 |
+
denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding)
|
173 |
+
d_cur = (x_hat - denoised) / t_hat
|
174 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
175 |
+
|
176 |
+
return x_next
|
177 |
+
|
178 |
+
def CTLoss_D(self, y, cond, mask):
|
179 |
+
"""
|
180 |
+
compute loss for consistency distillation
|
181 |
+
|
182 |
+
Args:
|
183 |
+
y: ground truth mel-spectrogram [B x n_mel x L]
|
184 |
+
cond: output of conformer encoder [B x n_mel x L]
|
185 |
+
mask: mask of padded frames [B x n_mel x L]
|
186 |
+
"""
|
187 |
+
with torch.no_grad():
|
188 |
+
mu = 0.95
|
189 |
+
for p, ema_p in zip(
|
190 |
+
self.denoise_fn.parameters(), self.denoise_fn_ema.parameters()
|
191 |
+
):
|
192 |
+
ema_p.mul_(mu).add_(p, alpha=1 - mu)
|
193 |
+
|
194 |
+
n = torch.randint(1, self.N, (y.shape[0],))
|
195 |
+
z = torch.randn_like(y) + cond
|
196 |
+
|
197 |
+
tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device)
|
198 |
+
f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask)
|
199 |
+
|
200 |
+
with torch.no_grad():
|
201 |
+
tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device)
|
202 |
+
|
203 |
+
# euler step
|
204 |
+
x_hat = y + tn_1 * z
|
205 |
+
denoised = self.EDMPrecond(
|
206 |
+
x_hat, tn_1, cond, self.denoise_fn_pretrained, mask
|
207 |
+
)
|
208 |
+
d_cur = (x_hat - denoised) / tn_1
|
209 |
+
y_tn = x_hat + (tn - tn_1) * d_cur
|
210 |
+
|
211 |
+
f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask)
|
212 |
+
|
213 |
+
# loss = (f_theta - f_theta_ema.detach()) ** 2
|
214 |
+
# loss = torch.sum(loss * mask) / torch.sum(mask)
|
215 |
+
loss = self.ssim_loss(f_theta, f_theta_ema.detach())
|
216 |
+
loss = torch.sum(loss * mask) / torch.sum(mask)
|
217 |
+
|
218 |
+
return loss
|
219 |
+
|
220 |
+
def get_t_steps(self, N):
|
221 |
+
N = N + 1
|
222 |
+
step_indices = torch.arange(N) # , device=latents.device)
|
223 |
+
t_steps = (
|
224 |
+
self.sigma_min ** (1 / self.rho)
|
225 |
+
+ step_indices
|
226 |
+
/ (N - 1)
|
227 |
+
* (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho))
|
228 |
+
) ** self.rho
|
229 |
+
|
230 |
+
return t_steps.flip(0)
|
231 |
+
|
232 |
+
def CT_sampler(self, latents, cond, nonpadding, t_steps=1):
|
233 |
+
"""
|
234 |
+
consistency distillation sampler
|
235 |
+
|
236 |
+
Args:
|
237 |
+
latents: noisy mel-spectrogram [B x n_mel x L]
|
238 |
+
cond: output of conformer encoder [B x n_mel x L]
|
239 |
+
nonpadding: mask of padded frames [B x n_mel x L]
|
240 |
+
t_steps: number of steps for diffusion inference
|
241 |
+
|
242 |
+
Returns:
|
243 |
+
denoised mel-spectrogram [B x n_mel x L]
|
244 |
+
"""
|
245 |
+
# one-step
|
246 |
+
if t_steps == 1:
|
247 |
+
t_steps = [80]
|
248 |
+
# multi-step
|
249 |
+
else:
|
250 |
+
t_steps = self.get_t_steps(t_steps)
|
251 |
+
|
252 |
+
t_steps = torch.as_tensor(t_steps).to(latents.device)
|
253 |
+
latents = latents * t_steps[0]
|
254 |
+
_t = torch.zeros((latents.shape[0], 1, 1), device=latents.device)
|
255 |
+
_t[:, 0, 0] = t_steps
|
256 |
+
x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding)
|
257 |
+
|
258 |
+
for t in t_steps[1:-1]:
|
259 |
+
z = torch.randn_like(x) + cond
|
260 |
+
x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z
|
261 |
+
_t = torch.zeros((x.shape[0], 1, 1), device=x.device)
|
262 |
+
_t[:, 0, 0] = t
|
263 |
+
t = _t
|
264 |
+
print(t)
|
265 |
+
x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding)
|
266 |
+
return x
|
267 |
+
|
268 |
+
def forward(self, x, nonpadding, cond, t_steps=1, infer=False):
|
269 |
+
"""
|
270 |
+
calculate loss or sample mel-spectrogram
|
271 |
+
|
272 |
+
Args:
|
273 |
+
x:
|
274 |
+
training: ground truth mel-spectrogram [B x n_mel x L]
|
275 |
+
inference: output of encoder [B x n_mel x L]
|
276 |
+
"""
|
277 |
+
if self.teacher: # teacher model -- karras diffusion
|
278 |
+
if not infer:
|
279 |
+
loss = self.EDMLoss(x, cond, nonpadding)
|
280 |
+
return loss
|
281 |
+
else:
|
282 |
+
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
|
283 |
+
x = torch.randn(shape, device=x.device) + cond
|
284 |
+
x = self.edm_sampler(x, cond, nonpadding, t_steps)
|
285 |
+
|
286 |
+
return x
|
287 |
+
else: # Consistency distillation
|
288 |
+
if not infer:
|
289 |
+
loss = self.CTLoss_D(x, cond, nonpadding)
|
290 |
+
return loss
|
291 |
+
|
292 |
+
else:
|
293 |
+
shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2])
|
294 |
+
x = torch.randn(shape, device=x.device) + cond
|
295 |
+
x = self.CT_sampler(x, cond, nonpadding, t_steps=1)
|
296 |
+
|
297 |
+
return x
|
298 |
+
|
299 |
+
|
300 |
+
class ComoSVC(BaseModule):
|
301 |
+
def __init__(self, cfg):
|
302 |
+
super().__init__()
|
303 |
+
self.cfg = cfg
|
304 |
+
self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel
|
305 |
+
self.distill = self.cfg.model.comosvc.distill
|
306 |
+
self.encoder = Conformer(self.cfg.model.comosvc)
|
307 |
+
self.decoder = Consistency(self.cfg, distill=self.distill)
|
308 |
+
self.ssim_loss = SSIM()
|
309 |
+
|
310 |
+
@torch.no_grad()
|
311 |
+
def forward(self, x_mask, x, n_timesteps, temperature=1.0):
|
312 |
+
"""
|
313 |
+
Generates mel-spectrogram from pitch, content vector, energy. Returns:
|
314 |
+
1. encoder outputs (from conformer)
|
315 |
+
2. decoder outputs (from diffusion-based decoder)
|
316 |
+
|
317 |
+
Args:
|
318 |
+
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
|
319 |
+
x : output of encoder framework. [B x L x d_condition]
|
320 |
+
n_timesteps : number of steps to use for reverse diffusion in decoder.
|
321 |
+
temperature : controls variance of terminal distribution.
|
322 |
+
"""
|
323 |
+
|
324 |
+
# Get encoder_outputs `mu_x`
|
325 |
+
mu_x = self.encoder(x, x_mask)
|
326 |
+
encoder_outputs = mu_x
|
327 |
+
|
328 |
+
mu_x = mu_x.transpose(1, 2)
|
329 |
+
x_mask = x_mask.transpose(1, 2)
|
330 |
+
|
331 |
+
# Generate sample by performing reverse dynamics
|
332 |
+
decoder_outputs = self.decoder(
|
333 |
+
mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True
|
334 |
+
)
|
335 |
+
decoder_outputs = decoder_outputs.transpose(1, 2)
|
336 |
+
return encoder_outputs, decoder_outputs
|
337 |
+
|
338 |
+
def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False):
|
339 |
+
"""
|
340 |
+
Computes 2 losses:
|
341 |
+
1. prior loss: loss between mel-spectrogram and encoder outputs.
|
342 |
+
2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel]
|
346 |
+
x : output of encoder framework. [B x L x d_condition]
|
347 |
+
mel : ground truth mel-spectrogram. [B x L x n_mel]
|
348 |
+
"""
|
349 |
+
|
350 |
+
mu_x = self.encoder(x, x_mask)
|
351 |
+
# prior loss
|
352 |
+
prior_loss = torch.sum(
|
353 |
+
0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask
|
354 |
+
)
|
355 |
+
prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel)
|
356 |
+
# ssim loss
|
357 |
+
ssim_loss = self.ssim_loss(mu_x, mel)
|
358 |
+
ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask)
|
359 |
+
|
360 |
+
x_mask = x_mask.transpose(1, 2)
|
361 |
+
mu_x = mu_x.transpose(1, 2)
|
362 |
+
mel = mel.transpose(1, 2)
|
363 |
+
if not self.distill and skip_diff:
|
364 |
+
diff_loss = prior_loss.clone()
|
365 |
+
diff_loss.fill_(0)
|
366 |
+
|
367 |
+
# Cut a small segment of mel-spectrogram in order to increase batch size
|
368 |
+
else:
|
369 |
+
if self.distill:
|
370 |
+
mu_y = mu_x.detach()
|
371 |
+
else:
|
372 |
+
mu_y = mu_x
|
373 |
+
mask_y = x_mask
|
374 |
+
|
375 |
+
diff_loss = self.decoder(mel, mask_y, mu_y, infer=False)
|
376 |
+
|
377 |
+
return ssim_loss, prior_loss, diff_loss
|
models/svc/comosvc/comosvc_inference.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from models.svc.base import SVCInference
|
9 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
10 |
+
from models.svc.comosvc.comosvc import ComoSVC
|
11 |
+
|
12 |
+
|
13 |
+
class ComoSVCInference(SVCInference):
|
14 |
+
def __init__(self, args, cfg, infer_type="from_dataset"):
|
15 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
16 |
+
|
17 |
+
def _build_model(self):
|
18 |
+
# TODO: sort out the config
|
19 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
20 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
21 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
22 |
+
self.acoustic_mapper = ComoSVC(self.cfg)
|
23 |
+
if self.cfg.model.comosvc.distill:
|
24 |
+
self.acoustic_mapper.decoder.init_consistency_training()
|
25 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
26 |
+
return model
|
27 |
+
|
28 |
+
def _inference_each_batch(self, batch_data):
|
29 |
+
device = self.accelerator.device
|
30 |
+
for k, v in batch_data.items():
|
31 |
+
batch_data[k] = v.to(device)
|
32 |
+
|
33 |
+
cond = self.condition_encoder(batch_data)
|
34 |
+
mask = batch_data["mask"]
|
35 |
+
encoder_pred, decoder_pred = self.acoustic_mapper(
|
36 |
+
mask, cond, self.cfg.inference.comosvc.inference_steps
|
37 |
+
)
|
38 |
+
|
39 |
+
return decoder_pred
|
models/svc/comosvc/comosvc_trainer.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import os
|
8 |
+
import json5
|
9 |
+
from collections import OrderedDict
|
10 |
+
from tqdm import tqdm
|
11 |
+
import json
|
12 |
+
import shutil
|
13 |
+
|
14 |
+
from models.svc.base import SVCTrainer
|
15 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
16 |
+
from models.svc.comosvc.comosvc import ComoSVC
|
17 |
+
|
18 |
+
|
19 |
+
class ComoSVCTrainer(SVCTrainer):
|
20 |
+
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
|
21 |
+
implements ``_build_model`` and ``_forward_step`` methods.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, args=None, cfg=None):
|
25 |
+
SVCTrainer.__init__(self, args, cfg)
|
26 |
+
self.distill = cfg.model.comosvc.distill
|
27 |
+
self.skip_diff = True
|
28 |
+
if self.distill: # and args.resume is None:
|
29 |
+
self.teacher_model_path = cfg.model.teacher_model_path
|
30 |
+
self.teacher_state_dict = self._load_teacher_state_dict()
|
31 |
+
self._load_teacher_model(self.teacher_state_dict)
|
32 |
+
self.acoustic_mapper.decoder.init_consistency_training()
|
33 |
+
|
34 |
+
### Following are methods only for comoSVC models ###
|
35 |
+
def _load_teacher_state_dict(self):
|
36 |
+
self.checkpoint_file = self.teacher_model_path
|
37 |
+
print("Load teacher acoustic model from {}".format(self.checkpoint_file))
|
38 |
+
raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device)
|
39 |
+
return raw_state_dict
|
40 |
+
|
41 |
+
def _load_teacher_model(self, state_dict):
|
42 |
+
raw_dict = state_dict
|
43 |
+
clean_dict = OrderedDict()
|
44 |
+
for k, v in raw_dict.items():
|
45 |
+
if k.startswith("module."):
|
46 |
+
clean_dict[k[7:]] = v
|
47 |
+
else:
|
48 |
+
clean_dict[k] = v
|
49 |
+
self.model.load_state_dict(clean_dict)
|
50 |
+
|
51 |
+
def _build_model(self):
|
52 |
+
r"""Build the model for training. This function is called in ``__init__`` function."""
|
53 |
+
|
54 |
+
# TODO: sort out the config
|
55 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
56 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
57 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
58 |
+
self.acoustic_mapper = ComoSVC(self.cfg)
|
59 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
60 |
+
return model
|
61 |
+
|
62 |
+
def _forward_step(self, batch):
|
63 |
+
r"""Forward step for training and inference. This function is called
|
64 |
+
in ``_train_step`` & ``_test_step`` function.
|
65 |
+
"""
|
66 |
+
loss = {}
|
67 |
+
mask = batch["mask"]
|
68 |
+
mel_input = batch["mel"]
|
69 |
+
cond = self.condition_encoder(batch)
|
70 |
+
if self.distill:
|
71 |
+
cond = cond.detach()
|
72 |
+
self.skip_diff = True if self.step < self.cfg.train.fast_steps else False
|
73 |
+
ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss(
|
74 |
+
mask, cond, mel_input, skip_diff=self.skip_diff
|
75 |
+
)
|
76 |
+
if self.distill:
|
77 |
+
loss["distil_loss"] = diff_loss
|
78 |
+
else:
|
79 |
+
loss["ssim_loss_encoder"] = ssim_loss
|
80 |
+
loss["prior_loss_encoder"] = prior_loss
|
81 |
+
loss["diffusion_loss_decoder"] = diff_loss
|
82 |
+
|
83 |
+
return loss
|
84 |
+
|
85 |
+
def _train_epoch(self):
|
86 |
+
r"""Training epoch. Should return average loss of a batch (sample) over
|
87 |
+
one epoch. See ``train_loop`` for usage.
|
88 |
+
"""
|
89 |
+
self.model.train()
|
90 |
+
epoch_sum_loss: float = 0.0
|
91 |
+
epoch_step: int = 0
|
92 |
+
for batch in tqdm(
|
93 |
+
self.train_dataloader,
|
94 |
+
desc=f"Training Epoch {self.epoch}",
|
95 |
+
unit="batch",
|
96 |
+
colour="GREEN",
|
97 |
+
leave=False,
|
98 |
+
dynamic_ncols=True,
|
99 |
+
smoothing=0.04,
|
100 |
+
disable=not self.accelerator.is_main_process,
|
101 |
+
):
|
102 |
+
# Do training step and BP
|
103 |
+
with self.accelerator.accumulate(self.model):
|
104 |
+
loss = self._train_step(batch)
|
105 |
+
total_loss = 0
|
106 |
+
for k, v in loss.items():
|
107 |
+
total_loss += v
|
108 |
+
self.accelerator.backward(total_loss)
|
109 |
+
enc_grad_norm = torch.nn.utils.clip_grad_norm_(
|
110 |
+
self.acoustic_mapper.encoder.parameters(), max_norm=1
|
111 |
+
)
|
112 |
+
dec_grad_norm = torch.nn.utils.clip_grad_norm_(
|
113 |
+
self.acoustic_mapper.decoder.parameters(), max_norm=1
|
114 |
+
)
|
115 |
+
self.optimizer.step()
|
116 |
+
self.optimizer.zero_grad()
|
117 |
+
self.batch_count += 1
|
118 |
+
|
119 |
+
# Update info for each step
|
120 |
+
# TODO: step means BP counts or batch counts?
|
121 |
+
if self.batch_count % self.cfg.train.gradient_accumulation_step == 0:
|
122 |
+
epoch_sum_loss += total_loss
|
123 |
+
log_info = {}
|
124 |
+
for k, v in loss.items():
|
125 |
+
key = "Step/Train Loss/{}".format(k)
|
126 |
+
log_info[key] = v
|
127 |
+
log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"]
|
128 |
+
self.accelerator.log(
|
129 |
+
log_info,
|
130 |
+
step=self.step,
|
131 |
+
)
|
132 |
+
self.step += 1
|
133 |
+
epoch_step += 1
|
134 |
+
|
135 |
+
self.accelerator.wait_for_everyone()
|
136 |
+
return (
|
137 |
+
epoch_sum_loss
|
138 |
+
/ len(self.train_dataloader)
|
139 |
+
* self.cfg.train.gradient_accumulation_step,
|
140 |
+
loss,
|
141 |
+
)
|
142 |
+
|
143 |
+
def train_loop(self):
|
144 |
+
r"""Training loop. The public entry of training process."""
|
145 |
+
# Wait everyone to prepare before we move on
|
146 |
+
self.accelerator.wait_for_everyone()
|
147 |
+
# dump config file
|
148 |
+
if self.accelerator.is_main_process:
|
149 |
+
self.__dump_cfg(self.config_save_path)
|
150 |
+
self.model.train()
|
151 |
+
self.optimizer.zero_grad()
|
152 |
+
# Wait to ensure good to go
|
153 |
+
self.accelerator.wait_for_everyone()
|
154 |
+
while self.epoch < self.max_epoch:
|
155 |
+
self.logger.info("\n")
|
156 |
+
self.logger.info("-" * 32)
|
157 |
+
self.logger.info("Epoch {}: ".format(self.epoch))
|
158 |
+
|
159 |
+
### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict)
|
160 |
+
### It's inconvenient for the model with multiple losses
|
161 |
+
# Do training & validating epoch
|
162 |
+
train_loss, loss = self._train_epoch()
|
163 |
+
self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss))
|
164 |
+
for k, v in loss.items():
|
165 |
+
self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v))
|
166 |
+
valid_loss = self._valid_epoch()
|
167 |
+
self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss))
|
168 |
+
self.accelerator.log(
|
169 |
+
{"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss},
|
170 |
+
step=self.epoch,
|
171 |
+
)
|
172 |
+
|
173 |
+
self.accelerator.wait_for_everyone()
|
174 |
+
# TODO: what is scheduler?
|
175 |
+
self.scheduler.step(valid_loss) # FIXME: use epoch track correct?
|
176 |
+
|
177 |
+
# Check if hit save_checkpoint_stride and run_eval
|
178 |
+
run_eval = False
|
179 |
+
if self.accelerator.is_main_process:
|
180 |
+
save_checkpoint = False
|
181 |
+
hit_dix = []
|
182 |
+
for i, num in enumerate(self.save_checkpoint_stride):
|
183 |
+
if self.epoch % num == 0:
|
184 |
+
save_checkpoint = True
|
185 |
+
hit_dix.append(i)
|
186 |
+
run_eval |= self.run_eval[i]
|
187 |
+
|
188 |
+
self.accelerator.wait_for_everyone()
|
189 |
+
if (
|
190 |
+
self.accelerator.is_main_process
|
191 |
+
and save_checkpoint
|
192 |
+
and (self.distill or not self.skip_diff)
|
193 |
+
):
|
194 |
+
path = os.path.join(
|
195 |
+
self.checkpoint_dir,
|
196 |
+
"epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
197 |
+
self.epoch, self.step, train_loss
|
198 |
+
),
|
199 |
+
)
|
200 |
+
self.accelerator.save_state(path)
|
201 |
+
json.dump(
|
202 |
+
self.checkpoints_path,
|
203 |
+
open(os.path.join(path, "ckpts.json"), "w"),
|
204 |
+
ensure_ascii=False,
|
205 |
+
indent=4,
|
206 |
+
)
|
207 |
+
|
208 |
+
# Remove old checkpoints
|
209 |
+
to_remove = []
|
210 |
+
for idx in hit_dix:
|
211 |
+
self.checkpoints_path[idx].append(path)
|
212 |
+
while len(self.checkpoints_path[idx]) > self.keep_last[idx]:
|
213 |
+
to_remove.append((idx, self.checkpoints_path[idx].pop(0)))
|
214 |
+
|
215 |
+
# Search conflicts
|
216 |
+
total = set()
|
217 |
+
for i in self.checkpoints_path:
|
218 |
+
total |= set(i)
|
219 |
+
do_remove = set()
|
220 |
+
for idx, path in to_remove[::-1]:
|
221 |
+
if path in total:
|
222 |
+
self.checkpoints_path[idx].insert(0, path)
|
223 |
+
else:
|
224 |
+
do_remove.add(path)
|
225 |
+
|
226 |
+
# Remove old checkpoints
|
227 |
+
for path in do_remove:
|
228 |
+
shutil.rmtree(path, ignore_errors=True)
|
229 |
+
self.logger.debug(f"Remove old checkpoint: {path}")
|
230 |
+
|
231 |
+
self.accelerator.wait_for_everyone()
|
232 |
+
if run_eval:
|
233 |
+
# TODO: run evaluation
|
234 |
+
pass
|
235 |
+
|
236 |
+
# Update info for each epoch
|
237 |
+
self.epoch += 1
|
238 |
+
|
239 |
+
# Finish training and save final checkpoint
|
240 |
+
self.accelerator.wait_for_everyone()
|
241 |
+
if self.accelerator.is_main_process:
|
242 |
+
self.accelerator.save_state(
|
243 |
+
os.path.join(
|
244 |
+
self.checkpoint_dir,
|
245 |
+
"final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format(
|
246 |
+
self.epoch, self.step, valid_loss
|
247 |
+
),
|
248 |
+
)
|
249 |
+
)
|
250 |
+
self.accelerator.end_training()
|
251 |
+
|
252 |
+
@torch.inference_mode()
|
253 |
+
def _valid_epoch(self):
|
254 |
+
r"""Testing epoch. Should return average loss of a batch (sample) over
|
255 |
+
one epoch. See ``train_loop`` for usage.
|
256 |
+
"""
|
257 |
+
self.model.eval()
|
258 |
+
epoch_sum_loss = 0.0
|
259 |
+
for batch in tqdm(
|
260 |
+
self.valid_dataloader,
|
261 |
+
desc=f"Validating Epoch {self.epoch}",
|
262 |
+
unit="batch",
|
263 |
+
colour="GREEN",
|
264 |
+
leave=False,
|
265 |
+
dynamic_ncols=True,
|
266 |
+
smoothing=0.04,
|
267 |
+
disable=not self.accelerator.is_main_process,
|
268 |
+
):
|
269 |
+
batch_loss = self._valid_step(batch)
|
270 |
+
for k, v in batch_loss.items():
|
271 |
+
epoch_sum_loss += v
|
272 |
+
|
273 |
+
self.accelerator.wait_for_everyone()
|
274 |
+
return epoch_sum_loss / len(self.valid_dataloader)
|
275 |
+
|
276 |
+
@staticmethod
|
277 |
+
def __count_parameters(model):
|
278 |
+
model_param = 0.0
|
279 |
+
if isinstance(model, dict):
|
280 |
+
for key, value in model.items():
|
281 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
282 |
+
else:
|
283 |
+
model_param = sum(p.numel() for p in model.parameters())
|
284 |
+
return model_param
|
285 |
+
|
286 |
+
def __dump_cfg(self, path):
|
287 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
288 |
+
json5.dump(
|
289 |
+
self.cfg,
|
290 |
+
open(path, "w"),
|
291 |
+
indent=4,
|
292 |
+
sort_keys=True,
|
293 |
+
ensure_ascii=False,
|
294 |
+
quote_keys=True,
|
295 |
+
)
|
models/svc/comosvc/utils.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
def slice_segments(x, ids_str, segment_size=200):
|
10 |
+
ret = torch.zeros_like(x[:, :, :segment_size])
|
11 |
+
for i in range(x.size(0)):
|
12 |
+
idx_str = ids_str[i]
|
13 |
+
idx_end = idx_str + segment_size
|
14 |
+
ret[i] = x[i, :, idx_str:idx_end]
|
15 |
+
return ret
|
16 |
+
|
17 |
+
|
18 |
+
def rand_ids_segments(lengths, segment_size=200):
|
19 |
+
b = lengths.shape[0]
|
20 |
+
ids_str_max = lengths - segment_size
|
21 |
+
ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to(
|
22 |
+
dtype=torch.long
|
23 |
+
)
|
24 |
+
return ids_str
|
25 |
+
|
26 |
+
|
27 |
+
def fix_len_compatibility(length, num_downsamplings_in_unet=2):
|
28 |
+
while True:
|
29 |
+
if length % (2**num_downsamplings_in_unet) == 0:
|
30 |
+
return length
|
31 |
+
length += 1
|
models/svc/diffusion/__init__.py
ADDED
File without changes
|
models/svc/diffusion/diffusion_inference.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler
|
8 |
+
|
9 |
+
from models.svc.base import SVCInference
|
10 |
+
from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline
|
11 |
+
from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper
|
12 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
13 |
+
|
14 |
+
|
15 |
+
class DiffusionInference(SVCInference):
|
16 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
17 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
18 |
+
|
19 |
+
settings = {
|
20 |
+
**cfg.model.diffusion.scheduler_settings,
|
21 |
+
**cfg.inference.diffusion.scheduler_settings,
|
22 |
+
}
|
23 |
+
settings.pop("num_inference_timesteps")
|
24 |
+
|
25 |
+
if cfg.inference.diffusion.scheduler.lower() == "ddpm":
|
26 |
+
self.scheduler = DDPMScheduler(**settings)
|
27 |
+
self.logger.info("Using DDPM scheduler.")
|
28 |
+
elif cfg.inference.diffusion.scheduler.lower() == "ddim":
|
29 |
+
self.scheduler = DDIMScheduler(**settings)
|
30 |
+
self.logger.info("Using DDIM scheduler.")
|
31 |
+
elif cfg.inference.diffusion.scheduler.lower() == "pndm":
|
32 |
+
self.scheduler = PNDMScheduler(**settings)
|
33 |
+
self.logger.info("Using PNDM scheduler.")
|
34 |
+
else:
|
35 |
+
raise NotImplementedError(
|
36 |
+
"Unsupported scheduler type: {}".format(
|
37 |
+
cfg.inference.diffusion.scheduler.lower()
|
38 |
+
)
|
39 |
+
)
|
40 |
+
|
41 |
+
self.pipeline = DiffusionInferencePipeline(
|
42 |
+
self.model[1],
|
43 |
+
self.scheduler,
|
44 |
+
args.diffusion_inference_steps,
|
45 |
+
)
|
46 |
+
|
47 |
+
def _build_model(self):
|
48 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
49 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
50 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
51 |
+
self.acoustic_mapper = DiffusionWrapper(self.cfg)
|
52 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
53 |
+
return model
|
54 |
+
|
55 |
+
def _inference_each_batch(self, batch_data):
|
56 |
+
device = self.accelerator.device
|
57 |
+
for k, v in batch_data.items():
|
58 |
+
batch_data[k] = v.to(device)
|
59 |
+
|
60 |
+
conditioner = self.model[0](batch_data)
|
61 |
+
noise = torch.randn_like(batch_data["mel"], device=device)
|
62 |
+
y_pred = self.pipeline(noise, conditioner)
|
63 |
+
return y_pred
|
models/svc/diffusion/diffusion_inference_pipeline.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DiffusionPipeline
|
8 |
+
|
9 |
+
|
10 |
+
class DiffusionInferencePipeline(DiffusionPipeline):
|
11 |
+
def __init__(self, network, scheduler, num_inference_timesteps=1000):
|
12 |
+
super().__init__()
|
13 |
+
|
14 |
+
self.register_modules(network=network, scheduler=scheduler)
|
15 |
+
self.num_inference_timesteps = num_inference_timesteps
|
16 |
+
|
17 |
+
@torch.inference_mode()
|
18 |
+
def __call__(
|
19 |
+
self,
|
20 |
+
initial_noise: torch.Tensor,
|
21 |
+
conditioner: torch.Tensor = None,
|
22 |
+
):
|
23 |
+
r"""
|
24 |
+
Args:
|
25 |
+
initial_noise: The initial noise to be denoised.
|
26 |
+
conditioner:The conditioner.
|
27 |
+
n_inference_steps: The number of denoising steps. More denoising steps
|
28 |
+
usually lead to a higher quality at the expense of slower inference.
|
29 |
+
"""
|
30 |
+
|
31 |
+
mel = initial_noise
|
32 |
+
batch_size = mel.size(0)
|
33 |
+
self.scheduler.set_timesteps(self.num_inference_timesteps)
|
34 |
+
|
35 |
+
for t in self.progress_bar(self.scheduler.timesteps):
|
36 |
+
timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long)
|
37 |
+
|
38 |
+
# 1. predict noise model_output
|
39 |
+
model_output = self.network(mel, timestep, conditioner)
|
40 |
+
|
41 |
+
# 2. denoise, compute previous step: x_t -> x_t-1
|
42 |
+
mel = self.scheduler.step(model_output, t, mel).prev_sample
|
43 |
+
|
44 |
+
# 3. clamp
|
45 |
+
mel = mel.clamp(-1.0, 1.0)
|
46 |
+
|
47 |
+
return mel
|
models/svc/diffusion/diffusion_trainer.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from diffusers import DDPMScheduler
|
8 |
+
|
9 |
+
from models.svc.base import SVCTrainer
|
10 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
11 |
+
from .diffusion_wrapper import DiffusionWrapper
|
12 |
+
|
13 |
+
|
14 |
+
class DiffusionTrainer(SVCTrainer):
|
15 |
+
r"""The base trainer for all diffusion models. It inherits from SVCTrainer and
|
16 |
+
implements ``_build_model`` and ``_forward_step`` methods.
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, args=None, cfg=None):
|
20 |
+
SVCTrainer.__init__(self, args, cfg)
|
21 |
+
|
22 |
+
# Only for SVC tasks using diffusion
|
23 |
+
self.noise_scheduler = DDPMScheduler(
|
24 |
+
**self.cfg.model.diffusion.scheduler_settings,
|
25 |
+
)
|
26 |
+
self.diffusion_timesteps = (
|
27 |
+
self.cfg.model.diffusion.scheduler_settings.num_train_timesteps
|
28 |
+
)
|
29 |
+
|
30 |
+
### Following are methods only for diffusion models ###
|
31 |
+
def _build_model(self):
|
32 |
+
r"""Build the model for training. This function is called in ``__init__`` function."""
|
33 |
+
|
34 |
+
# TODO: sort out the config
|
35 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
36 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
37 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
38 |
+
self.acoustic_mapper = DiffusionWrapper(self.cfg)
|
39 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
40 |
+
|
41 |
+
num_of_params_encoder = self.count_parameters(self.condition_encoder)
|
42 |
+
num_of_params_am = self.count_parameters(self.acoustic_mapper)
|
43 |
+
num_of_params = num_of_params_encoder + num_of_params_am
|
44 |
+
log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format(
|
45 |
+
num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6
|
46 |
+
)
|
47 |
+
self.logger.info(log)
|
48 |
+
|
49 |
+
return model
|
50 |
+
|
51 |
+
def count_parameters(self, model):
|
52 |
+
model_param = 0.0
|
53 |
+
if isinstance(model, dict):
|
54 |
+
for key, value in model.items():
|
55 |
+
model_param += sum(p.numel() for p in model[key].parameters())
|
56 |
+
else:
|
57 |
+
model_param = sum(p.numel() for p in model.parameters())
|
58 |
+
return model_param
|
59 |
+
|
60 |
+
def _forward_step(self, batch):
|
61 |
+
r"""Forward step for training and inference. This function is called
|
62 |
+
in ``_train_step`` & ``_test_step`` function.
|
63 |
+
"""
|
64 |
+
|
65 |
+
device = self.accelerator.device
|
66 |
+
|
67 |
+
mel_input = batch["mel"]
|
68 |
+
noise = torch.randn_like(mel_input, device=device, dtype=torch.float32)
|
69 |
+
batch_size = mel_input.size(0)
|
70 |
+
timesteps = torch.randint(
|
71 |
+
0,
|
72 |
+
self.diffusion_timesteps,
|
73 |
+
(batch_size,),
|
74 |
+
device=device,
|
75 |
+
dtype=torch.long,
|
76 |
+
)
|
77 |
+
|
78 |
+
noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps)
|
79 |
+
conditioner = self.condition_encoder(batch)
|
80 |
+
|
81 |
+
y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner)
|
82 |
+
|
83 |
+
# TODO: Predict noise or gt should be configurable
|
84 |
+
loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"])
|
85 |
+
self._check_nan(loss, y_pred, noise)
|
86 |
+
|
87 |
+
# FIXME: Clarify that we should not divide it with batch size here
|
88 |
+
return loss
|
models/svc/diffusion/diffusion_wrapper.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from modules.diffusion import BiDilConv
|
9 |
+
from modules.encoder.position_encoder import PositionEncoder
|
10 |
+
|
11 |
+
|
12 |
+
class DiffusionWrapper(nn.Module):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.cfg = cfg
|
17 |
+
self.diff_cfg = cfg.model.diffusion
|
18 |
+
|
19 |
+
self.diff_encoder = PositionEncoder(
|
20 |
+
d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding,
|
21 |
+
d_out=self.diff_cfg.bidilconv.base_channel,
|
22 |
+
d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer,
|
23 |
+
activation_function=self.diff_cfg.step_encoder.activation,
|
24 |
+
n_layer=self.diff_cfg.step_encoder.num_layer,
|
25 |
+
max_period=self.diff_cfg.step_encoder.max_period,
|
26 |
+
)
|
27 |
+
|
28 |
+
# FIXME: Only support BiDilConv now for debug
|
29 |
+
if self.diff_cfg.model_type.lower() == "bidilconv":
|
30 |
+
self.neural_network = BiDilConv(
|
31 |
+
input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv
|
32 |
+
)
|
33 |
+
else:
|
34 |
+
raise ValueError(
|
35 |
+
f"Unsupported diffusion model type: {self.diff_cfg.model_type}"
|
36 |
+
)
|
37 |
+
|
38 |
+
def forward(self, x, t, c):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
x: [N, T, mel_band] of mel spectrogram
|
42 |
+
t: Diffusion time step with shape of [N]
|
43 |
+
c: [N, T, conditioner_size] of conditioner
|
44 |
+
|
45 |
+
Returns:
|
46 |
+
[N, T, mel_band] of mel spectrogram
|
47 |
+
"""
|
48 |
+
|
49 |
+
assert (
|
50 |
+
x.size()[:-1] == c.size()[:-1]
|
51 |
+
), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size())
|
52 |
+
assert x.size(0) == t.size(
|
53 |
+
0
|
54 |
+
), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size())
|
55 |
+
assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim())
|
56 |
+
|
57 |
+
N, T, mel_band = x.size()
|
58 |
+
|
59 |
+
x = x.transpose(1, 2).contiguous() # [N, mel_band, T]
|
60 |
+
c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T]
|
61 |
+
t = self.diff_encoder(t).contiguous() # [N, base_channel]
|
62 |
+
|
63 |
+
h = self.neural_network(x, t, c)
|
64 |
+
h = h.transpose(1, 2).contiguous() # [N, T, mel_band]
|
65 |
+
|
66 |
+
assert h.size() == (
|
67 |
+
N,
|
68 |
+
T,
|
69 |
+
mel_band,
|
70 |
+
), "h mismatch with input x, got \n h: {} \n x: {}".format(
|
71 |
+
h.size(), (N, T, mel_band)
|
72 |
+
)
|
73 |
+
return h
|
models/svc/transformer/__init__.py
ADDED
File without changes
|
models/svc/transformer/conformer.py
ADDED
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import numpy as np
|
9 |
+
import torch.nn as nn
|
10 |
+
from utils.util import convert_pad_shape
|
11 |
+
|
12 |
+
|
13 |
+
class BaseModule(torch.nn.Module):
|
14 |
+
def __init__(self):
|
15 |
+
super(BaseModule, self).__init__()
|
16 |
+
|
17 |
+
@property
|
18 |
+
def nparams(self):
|
19 |
+
"""
|
20 |
+
Returns number of trainable parameters of the module.
|
21 |
+
"""
|
22 |
+
num_params = 0
|
23 |
+
for name, param in self.named_parameters():
|
24 |
+
if param.requires_grad:
|
25 |
+
num_params += np.prod(param.detach().cpu().numpy().shape)
|
26 |
+
return num_params
|
27 |
+
|
28 |
+
def relocate_input(self, x: list):
|
29 |
+
"""
|
30 |
+
Relocates provided tensors to the same device set for the module.
|
31 |
+
"""
|
32 |
+
device = next(self.parameters()).device
|
33 |
+
for i in range(len(x)):
|
34 |
+
if isinstance(x[i], torch.Tensor) and x[i].device != device:
|
35 |
+
x[i] = x[i].to(device)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
class LayerNorm(BaseModule):
|
40 |
+
def __init__(self, channels, eps=1e-4):
|
41 |
+
super(LayerNorm, self).__init__()
|
42 |
+
self.channels = channels
|
43 |
+
self.eps = eps
|
44 |
+
|
45 |
+
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
46 |
+
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
n_dims = len(x.shape)
|
50 |
+
mean = torch.mean(x, 1, keepdim=True)
|
51 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
52 |
+
|
53 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
54 |
+
|
55 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
56 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class ConvReluNorm(BaseModule):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
in_channels,
|
64 |
+
hidden_channels,
|
65 |
+
out_channels,
|
66 |
+
kernel_size,
|
67 |
+
n_layers,
|
68 |
+
p_dropout,
|
69 |
+
eps=1e-5,
|
70 |
+
):
|
71 |
+
super(ConvReluNorm, self).__init__()
|
72 |
+
self.in_channels = in_channels
|
73 |
+
self.hidden_channels = hidden_channels
|
74 |
+
self.out_channels = out_channels
|
75 |
+
self.kernel_size = kernel_size
|
76 |
+
self.n_layers = n_layers
|
77 |
+
self.p_dropout = p_dropout
|
78 |
+
self.eps = eps
|
79 |
+
|
80 |
+
self.conv_layers = torch.nn.ModuleList()
|
81 |
+
self.conv_layers.append(
|
82 |
+
torch.nn.Conv1d(
|
83 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
84 |
+
)
|
85 |
+
)
|
86 |
+
self.relu_drop = torch.nn.Sequential(
|
87 |
+
torch.nn.ReLU(), torch.nn.Dropout(p_dropout)
|
88 |
+
)
|
89 |
+
for _ in range(n_layers - 1):
|
90 |
+
self.conv_layers.append(
|
91 |
+
torch.nn.Conv1d(
|
92 |
+
hidden_channels,
|
93 |
+
hidden_channels,
|
94 |
+
kernel_size,
|
95 |
+
padding=kernel_size // 2,
|
96 |
+
)
|
97 |
+
)
|
98 |
+
self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1)
|
99 |
+
self.proj.weight.data.zero_()
|
100 |
+
self.proj.bias.data.zero_()
|
101 |
+
|
102 |
+
def forward(self, x, x_mask):
|
103 |
+
for i in range(self.n_layers):
|
104 |
+
x = self.conv_layers[i](x * x_mask)
|
105 |
+
x = self.instance_norm(x, x_mask)
|
106 |
+
x = self.relu_drop(x)
|
107 |
+
x = self.proj(x)
|
108 |
+
return x * x_mask
|
109 |
+
|
110 |
+
def instance_norm(self, x, mask, return_mean_std=False):
|
111 |
+
mean, std = self.calc_mean_std(x, mask)
|
112 |
+
x = (x - mean) / std
|
113 |
+
if return_mean_std:
|
114 |
+
return x, mean, std
|
115 |
+
else:
|
116 |
+
return x
|
117 |
+
|
118 |
+
def calc_mean_std(self, x, mask=None):
|
119 |
+
x = x * mask
|
120 |
+
B, C = x.shape[:2]
|
121 |
+
mn = x.view(B, C, -1).mean(-1)
|
122 |
+
sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt()
|
123 |
+
mn = mn.view(B, C, *((len(x.shape) - 2) * [1]))
|
124 |
+
sd = sd.view(B, C, *((len(x.shape) - 2) * [1]))
|
125 |
+
return mn, sd
|
126 |
+
|
127 |
+
|
128 |
+
class MultiHeadAttention(BaseModule):
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
channels,
|
132 |
+
out_channels,
|
133 |
+
n_heads,
|
134 |
+
window_size=None,
|
135 |
+
heads_share=True,
|
136 |
+
p_dropout=0.0,
|
137 |
+
proximal_bias=False,
|
138 |
+
proximal_init=False,
|
139 |
+
):
|
140 |
+
super(MultiHeadAttention, self).__init__()
|
141 |
+
assert channels % n_heads == 0
|
142 |
+
|
143 |
+
self.channels = channels
|
144 |
+
self.out_channels = out_channels
|
145 |
+
self.n_heads = n_heads
|
146 |
+
self.window_size = window_size
|
147 |
+
self.heads_share = heads_share
|
148 |
+
self.proximal_bias = proximal_bias
|
149 |
+
self.p_dropout = p_dropout
|
150 |
+
self.attn = None
|
151 |
+
|
152 |
+
self.k_channels = channels // n_heads
|
153 |
+
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
154 |
+
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
155 |
+
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
156 |
+
if window_size is not None:
|
157 |
+
n_heads_rel = 1 if heads_share else n_heads
|
158 |
+
rel_stddev = self.k_channels**-0.5
|
159 |
+
self.emb_rel_k = torch.nn.Parameter(
|
160 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
161 |
+
* rel_stddev
|
162 |
+
)
|
163 |
+
self.emb_rel_v = torch.nn.Parameter(
|
164 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
165 |
+
* rel_stddev
|
166 |
+
)
|
167 |
+
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
168 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
169 |
+
|
170 |
+
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
171 |
+
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
172 |
+
if proximal_init:
|
173 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
174 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
175 |
+
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
176 |
+
|
177 |
+
def forward(self, x, c, attn_mask=None):
|
178 |
+
q = self.conv_q(x)
|
179 |
+
k = self.conv_k(c)
|
180 |
+
v = self.conv_v(c)
|
181 |
+
|
182 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
183 |
+
|
184 |
+
x = self.conv_o(x)
|
185 |
+
return x
|
186 |
+
|
187 |
+
def attention(self, query, key, value, mask=None):
|
188 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
189 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
190 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
191 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
192 |
+
|
193 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
194 |
+
if self.window_size is not None:
|
195 |
+
assert (
|
196 |
+
t_s == t_t
|
197 |
+
), "Relative attention is only available for self-attention."
|
198 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
199 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
200 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
201 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
202 |
+
scores = scores + scores_local
|
203 |
+
if self.proximal_bias:
|
204 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
205 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
206 |
+
device=scores.device, dtype=scores.dtype
|
207 |
+
)
|
208 |
+
if mask is not None:
|
209 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
210 |
+
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
211 |
+
p_attn = self.drop(p_attn)
|
212 |
+
output = torch.matmul(p_attn, value)
|
213 |
+
if self.window_size is not None:
|
214 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
215 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
216 |
+
self.emb_rel_v, t_s
|
217 |
+
)
|
218 |
+
output = output + self._matmul_with_relative_values(
|
219 |
+
relative_weights, value_relative_embeddings
|
220 |
+
)
|
221 |
+
output = output.transpose(2, 3).contiguous().view(b, d, t_t)
|
222 |
+
return output, p_attn
|
223 |
+
|
224 |
+
def _matmul_with_relative_values(self, x, y):
|
225 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
226 |
+
return ret
|
227 |
+
|
228 |
+
def _matmul_with_relative_keys(self, x, y):
|
229 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
230 |
+
return ret
|
231 |
+
|
232 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
233 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
234 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
235 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
236 |
+
if pad_length > 0:
|
237 |
+
padded_relative_embeddings = torch.nn.functional.pad(
|
238 |
+
relative_embeddings,
|
239 |
+
convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
240 |
+
)
|
241 |
+
else:
|
242 |
+
padded_relative_embeddings = relative_embeddings
|
243 |
+
used_relative_embeddings = padded_relative_embeddings[
|
244 |
+
:, slice_start_position:slice_end_position
|
245 |
+
]
|
246 |
+
return used_relative_embeddings
|
247 |
+
|
248 |
+
def _relative_position_to_absolute_position(self, x):
|
249 |
+
batch, heads, length, _ = x.size()
|
250 |
+
x = torch.nn.functional.pad(
|
251 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
|
252 |
+
)
|
253 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
254 |
+
x_flat = torch.nn.functional.pad(
|
255 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
256 |
+
)
|
257 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
258 |
+
:, :, :length, length - 1 :
|
259 |
+
]
|
260 |
+
return x_final
|
261 |
+
|
262 |
+
def _absolute_position_to_relative_position(self, x):
|
263 |
+
batch, heads, length, _ = x.size()
|
264 |
+
x = torch.nn.functional.pad(
|
265 |
+
x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
266 |
+
)
|
267 |
+
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
268 |
+
x_flat = torch.nn.functional.pad(
|
269 |
+
x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]])
|
270 |
+
)
|
271 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
272 |
+
return x_final
|
273 |
+
|
274 |
+
def _attention_bias_proximal(self, length):
|
275 |
+
r = torch.arange(length, dtype=torch.float32)
|
276 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
277 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
278 |
+
|
279 |
+
|
280 |
+
class FFN(BaseModule):
|
281 |
+
def __init__(
|
282 |
+
self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0
|
283 |
+
):
|
284 |
+
super(FFN, self).__init__()
|
285 |
+
self.in_channels = in_channels
|
286 |
+
self.out_channels = out_channels
|
287 |
+
self.filter_channels = filter_channels
|
288 |
+
self.kernel_size = kernel_size
|
289 |
+
self.p_dropout = p_dropout
|
290 |
+
|
291 |
+
self.conv_1 = torch.nn.Conv1d(
|
292 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
293 |
+
)
|
294 |
+
self.conv_2 = torch.nn.Conv1d(
|
295 |
+
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
296 |
+
)
|
297 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
298 |
+
|
299 |
+
def forward(self, x, x_mask):
|
300 |
+
x = self.conv_1(x * x_mask)
|
301 |
+
x = torch.relu(x)
|
302 |
+
x = self.drop(x)
|
303 |
+
x = self.conv_2(x * x_mask)
|
304 |
+
return x * x_mask
|
305 |
+
|
306 |
+
|
307 |
+
class Encoder(BaseModule):
|
308 |
+
def __init__(
|
309 |
+
self,
|
310 |
+
hidden_channels,
|
311 |
+
filter_channels,
|
312 |
+
n_heads=2,
|
313 |
+
n_layers=6,
|
314 |
+
kernel_size=3,
|
315 |
+
p_dropout=0.1,
|
316 |
+
window_size=4,
|
317 |
+
**kwargs
|
318 |
+
):
|
319 |
+
super(Encoder, self).__init__()
|
320 |
+
self.hidden_channels = hidden_channels
|
321 |
+
self.filter_channels = filter_channels
|
322 |
+
self.n_heads = n_heads
|
323 |
+
self.n_layers = n_layers
|
324 |
+
self.kernel_size = kernel_size
|
325 |
+
self.p_dropout = p_dropout
|
326 |
+
self.window_size = window_size
|
327 |
+
|
328 |
+
self.drop = torch.nn.Dropout(p_dropout)
|
329 |
+
self.attn_layers = torch.nn.ModuleList()
|
330 |
+
self.norm_layers_1 = torch.nn.ModuleList()
|
331 |
+
self.ffn_layers = torch.nn.ModuleList()
|
332 |
+
self.norm_layers_2 = torch.nn.ModuleList()
|
333 |
+
for _ in range(self.n_layers):
|
334 |
+
self.attn_layers.append(
|
335 |
+
MultiHeadAttention(
|
336 |
+
hidden_channels,
|
337 |
+
hidden_channels,
|
338 |
+
n_heads,
|
339 |
+
window_size=window_size,
|
340 |
+
p_dropout=p_dropout,
|
341 |
+
)
|
342 |
+
)
|
343 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
344 |
+
self.ffn_layers.append(
|
345 |
+
FFN(
|
346 |
+
hidden_channels,
|
347 |
+
hidden_channels,
|
348 |
+
filter_channels,
|
349 |
+
kernel_size,
|
350 |
+
p_dropout=p_dropout,
|
351 |
+
)
|
352 |
+
)
|
353 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
354 |
+
|
355 |
+
def forward(self, x, x_mask):
|
356 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
357 |
+
for i in range(self.n_layers):
|
358 |
+
x = x * x_mask
|
359 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
360 |
+
y = self.drop(y)
|
361 |
+
x = self.norm_layers_1[i](x + y)
|
362 |
+
y = self.ffn_layers[i](x, x_mask)
|
363 |
+
y = self.drop(y)
|
364 |
+
x = self.norm_layers_2[i](x + y)
|
365 |
+
x = x * x_mask
|
366 |
+
return x
|
367 |
+
|
368 |
+
|
369 |
+
class Conformer(BaseModule):
|
370 |
+
def __init__(self, cfg):
|
371 |
+
super().__init__()
|
372 |
+
self.cfg = cfg
|
373 |
+
self.n_heads = self.cfg.n_heads
|
374 |
+
self.n_layers = self.cfg.n_layers
|
375 |
+
self.hidden_channels = self.cfg.input_dim
|
376 |
+
self.filter_channels = self.cfg.filter_channels
|
377 |
+
self.output_dim = self.cfg.output_dim
|
378 |
+
self.dropout = self.cfg.dropout
|
379 |
+
|
380 |
+
self.conformer_encoder = Encoder(
|
381 |
+
self.hidden_channels,
|
382 |
+
self.filter_channels,
|
383 |
+
n_heads=self.n_heads,
|
384 |
+
n_layers=self.n_layers,
|
385 |
+
kernel_size=3,
|
386 |
+
p_dropout=self.dropout,
|
387 |
+
window_size=4,
|
388 |
+
)
|
389 |
+
self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1)
|
390 |
+
|
391 |
+
def forward(self, x, x_mask):
|
392 |
+
"""
|
393 |
+
Args:
|
394 |
+
x: (N, seq_len, input_dim)
|
395 |
+
Returns:
|
396 |
+
output: (N, seq_len, output_dim)
|
397 |
+
"""
|
398 |
+
# (N, seq_len, d_model)
|
399 |
+
x = x.transpose(1, 2)
|
400 |
+
x_mask = x_mask.transpose(1, 2)
|
401 |
+
output = self.conformer_encoder(x, x_mask)
|
402 |
+
# (N, seq_len, output_dim)
|
403 |
+
output = self.projection(output)
|
404 |
+
output = output.transpose(1, 2)
|
405 |
+
return output
|
models/svc/transformer/transformer.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
10 |
+
|
11 |
+
|
12 |
+
class Transformer(nn.Module):
|
13 |
+
def __init__(self, cfg):
|
14 |
+
super().__init__()
|
15 |
+
self.cfg = cfg
|
16 |
+
|
17 |
+
dropout = self.cfg.dropout
|
18 |
+
nhead = self.cfg.n_heads
|
19 |
+
nlayers = self.cfg.n_layers
|
20 |
+
input_dim = self.cfg.input_dim
|
21 |
+
output_dim = self.cfg.output_dim
|
22 |
+
|
23 |
+
d_model = input_dim
|
24 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout)
|
25 |
+
encoder_layers = TransformerEncoderLayer(
|
26 |
+
d_model, nhead, dropout=dropout, batch_first=True
|
27 |
+
)
|
28 |
+
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
29 |
+
|
30 |
+
self.output_mlp = nn.Linear(d_model, output_dim)
|
31 |
+
|
32 |
+
def forward(self, x, mask=None):
|
33 |
+
"""
|
34 |
+
Args:
|
35 |
+
x: (N, seq_len, input_dim)
|
36 |
+
Returns:
|
37 |
+
output: (N, seq_len, output_dim)
|
38 |
+
"""
|
39 |
+
# (N, seq_len, d_model)
|
40 |
+
src = self.pos_encoder(x)
|
41 |
+
# model_stats["pos_embedding"] = x
|
42 |
+
# (N, seq_len, d_model)
|
43 |
+
output = self.transformer_encoder(src)
|
44 |
+
# (N, seq_len, output_dim)
|
45 |
+
output = self.output_mlp(output)
|
46 |
+
return output
|
47 |
+
|
48 |
+
|
49 |
+
class PositionalEncoding(nn.Module):
|
50 |
+
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
|
51 |
+
super().__init__()
|
52 |
+
self.dropout = nn.Dropout(p=dropout)
|
53 |
+
|
54 |
+
position = torch.arange(max_len).unsqueeze(1)
|
55 |
+
div_term = torch.exp(
|
56 |
+
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
|
57 |
+
)
|
58 |
+
|
59 |
+
# Assume that x is (seq_len, N, d)
|
60 |
+
# pe = torch.zeros(max_len, 1, d_model)
|
61 |
+
# pe[:, 0, 0::2] = torch.sin(position * div_term)
|
62 |
+
# pe[:, 0, 1::2] = torch.cos(position * div_term)
|
63 |
+
|
64 |
+
# Assume that x in (N, seq_len, d)
|
65 |
+
pe = torch.zeros(1, max_len, d_model)
|
66 |
+
pe[0, :, 0::2] = torch.sin(position * div_term)
|
67 |
+
pe[0, :, 1::2] = torch.cos(position * div_term)
|
68 |
+
|
69 |
+
self.register_buffer("pe", pe)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
"""
|
73 |
+
Args:
|
74 |
+
x: Tensor, shape [N, seq_len, d]
|
75 |
+
"""
|
76 |
+
# Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model)
|
77 |
+
# x = x + self.pe[: x.size(0)]
|
78 |
+
|
79 |
+
# Now: self.pe is (1, max_len, d)
|
80 |
+
x = x + self.pe[:, : x.size(1), :]
|
81 |
+
|
82 |
+
return self.dropout(x)
|
models/svc/transformer/transformer_inference.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
import time
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from tqdm import tqdm
|
11 |
+
import torch.nn as nn
|
12 |
+
from collections import OrderedDict
|
13 |
+
|
14 |
+
from models.svc.base import SVCInference
|
15 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
16 |
+
from models.svc.transformer.transformer import Transformer
|
17 |
+
from models.svc.transformer.conformer import Conformer
|
18 |
+
|
19 |
+
|
20 |
+
class TransformerInference(SVCInference):
|
21 |
+
def __init__(self, args=None, cfg=None, infer_type="from_dataset"):
|
22 |
+
SVCInference.__init__(self, args, cfg, infer_type)
|
23 |
+
|
24 |
+
def _build_model(self):
|
25 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
26 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
27 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
28 |
+
if self.cfg.model.transformer.type == "transformer":
|
29 |
+
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
|
30 |
+
elif self.cfg.model.transformer.type == "conformer":
|
31 |
+
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
|
32 |
+
else:
|
33 |
+
raise NotImplementedError
|
34 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
35 |
+
return model
|
36 |
+
|
37 |
+
def _inference_each_batch(self, batch_data):
|
38 |
+
device = self.accelerator.device
|
39 |
+
for k, v in batch_data.items():
|
40 |
+
batch_data[k] = v.to(device)
|
41 |
+
|
42 |
+
condition = self.condition_encoder(batch_data)
|
43 |
+
y_pred = self.acoustic_mapper(condition, batch_data["mask"])
|
44 |
+
|
45 |
+
return y_pred
|
models/svc/transformer/transformer_trainer.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2023 Amphion.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the MIT license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from models.svc.base import SVCTrainer
|
9 |
+
from modules.encoder.condition_encoder import ConditionEncoder
|
10 |
+
from models.svc.transformer.transformer import Transformer
|
11 |
+
from models.svc.transformer.conformer import Conformer
|
12 |
+
from utils.ssim import SSIM
|
13 |
+
|
14 |
+
|
15 |
+
class TransformerTrainer(SVCTrainer):
|
16 |
+
def __init__(self, args, cfg):
|
17 |
+
SVCTrainer.__init__(self, args, cfg)
|
18 |
+
self.ssim_loss = SSIM()
|
19 |
+
|
20 |
+
def _build_model(self):
|
21 |
+
self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min
|
22 |
+
self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max
|
23 |
+
self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder)
|
24 |
+
if self.cfg.model.transformer.type == "transformer":
|
25 |
+
self.acoustic_mapper = Transformer(self.cfg.model.transformer)
|
26 |
+
elif self.cfg.model.transformer.type == "conformer":
|
27 |
+
self.acoustic_mapper = Conformer(self.cfg.model.transformer)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper])
|
31 |
+
return model
|
32 |
+
|
33 |
+
def _forward_step(self, batch):
|
34 |
+
total_loss = 0
|
35 |
+
device = self.accelerator.device
|
36 |
+
mel = batch["mel"]
|
37 |
+
mask = batch["mask"]
|
38 |
+
|
39 |
+
condition = self.condition_encoder(batch)
|
40 |
+
mel_pred = self.acoustic_mapper(condition, mask)
|
41 |
+
|
42 |
+
l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum(
|
43 |
+
batch["mask"]
|
44 |
+
)
|
45 |
+
self._check_nan(l1_loss, mel_pred, mel)
|
46 |
+
total_loss += l1_loss
|
47 |
+
ssim_loss = self.ssim_loss(mel_pred, mel)
|
48 |
+
ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"])
|
49 |
+
self._check_nan(ssim_loss, mel_pred, mel)
|
50 |
+
total_loss += ssim_loss
|
51 |
+
|
52 |
+
return total_loss
|