diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..39e7ae7fd0fdd2d8e5bc370225bb1f3eb8648ac8 100644 --- a/.gitattributes +++ b/.gitattributes @@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.xz filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..b174072d5944c7044abfdde417bb3ec9eb33521e --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +__pycache__ +flagged +result + +# Developing mode +_*.sh +_*.json +*.lst +yard* +*.out +evaluation/evalset_selection +mfa +egs/svc/*wavmark +egs/svc/custom +egs/svc/*/dev* +egs/svc/dev_exp_config.json +bins/svc/demo* +bins/svc/preprocess_custom.py \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..49d1a413ec53285bb9a4ec813702c8fab99d9d1d --- /dev/null +++ b/app.py @@ -0,0 +1,78 @@ +import gradio as gr + + +SUPPORTED_TARGET_SINGERS = { + "Adele": "vocalist_l1_Adele", + "Beyonce": "vocalist_l1_Beyonce", + "Bruno Mars": "vocalist_l1_BrunoMars", + "John Mayer": "vocalist_l1_JohnMayer", + "Michael Jackson": "vocalist_l1_MichaelJackson", + "Taylor Swift": "vocalist_l1_TaylorSwift", + "Jacky Cheung 张学友": "vocalist_l1_张学友", + "Jian Li 李健": "vocalist_l1_李健", + "Feng Wang 汪峰": "vocalist_l1_汪峰", + "Faye Wong 王菲": "vocalist_l1_王菲", + "Yijie Shi 石倚洁": "vocalist_l1_石倚洁", + "Tsai Chin 蔡琴": "vocalist_l1_蔡琴", + "Ying Na 那英": "vocalist_l1_那英", + "Eason Chan 陈奕迅": "vocalist_l1_陈奕迅", + "David Tao 陶喆": "vocalist_l1_陶喆", +} + + +def svc_inference( + source_audio, + target_singer, + diffusion_steps=1000, + key_shift_mode="auto", + key_shift_num=0, +): + pass + + +demo_inputs = [ + gr.Audio( + sources=["upload", "microphone"], + label="Upload (or record) a song you want to listen", + ), + gr.Radio( + choices=list(SUPPORTED_TARGET_SINGERS.keys()), + label="Target Singer", + value="Jian Li 李健", + ), + gr.Slider( + 1, + 1000, + value=1000, + step=1, + label="Diffusion Inference Steps", + info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", + ), + gr.Radio( + choices=["Auto Shift", "Key Shift"], + value="Auto Shift", + label="Pitch Shift Control", + info='If you want to control the specific pitch shift value, you need to choose "Key Shift"', + ), + gr.Slider( + -6, + 6, + value=0, + step=1, + label="Key Shift Values", + info='How many semitones you want to transpose. This parameter will work only if you choose "Key Shift"', + ), +] + +demo_outputs = gr.Audio(label="") + + +demo = gr.Interface( + fn=svc_inference, + inputs=demo_inputs, + outputs=demo_outputs, + title="Amphion Singing Voice Conversion", +) + +if __name__ == "__main__": + demo.launch(show_api=False) diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/args.json b/ckpts/svc/vocalist_l1_contentvec+whisper/args.json new file mode 100755 index 0000000000000000000000000000000000000000..836d5e81420921d4ec096d3445c0ff5964e13b73 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/args.json @@ -0,0 +1,256 @@ +{ + "base_config": "config/diffusion.json", + "dataset": [ + "vocalist_l1", + ], + "exp_name": "vocalist_l1_contentvec+whisper", + "inference": { + "diffusion": { + "scheduler": "pndm", + "scheduler_settings": { + "num_inference_timesteps": 1000, + }, + }, + }, + "model": { + "condition_encoder": { + "content_encoder_dim": 384, + "contentvec_dim": 256, + "f0_max": 1100, + "f0_min": 50, + "input_loudness_dim": 1, + "input_melody_dim": 1, + "merge_mode": "add", + "mert_dim": 256, + "n_bins_loudness": 256, + "n_bins_melody": 256, + "output_content_dim": 384, + "output_loudness_dim": 384, + "output_melody_dim": 384, + "output_singer_dim": 384, + "pitch_max": 1100, + "pitch_min": 50, + "singer_table_size": 512, + "use_conformer_for_content_features": false, + "use_contentvec": true, + "use_log_f0": true, + "use_log_loudness": true, + "use_mert": false, + "use_singer_encoder": true, + "use_spkid": true, + "use_wenet": false, + "use_whisper": true, + "wenet_dim": 512, + "whisper_dim": 1024, + }, + "diffusion": { + "bidilconv": { + "base_channel": 384, + "conditioner_size": 384, + "conv_kernel_size": 3, + "dilation_cycle_length": 4, + "n_res_block": 20, + }, + "model_type": "bidilconv", + "scheduler": "ddpm", + "scheduler_settings": { + "beta_end": 0.02, + "beta_schedule": "linear", + "beta_start": 0.0001, + "num_train_timesteps": 1000, + }, + "step_encoder": { + "activation": "SiLU", + "dim_hidden_layer": 512, + "dim_raw_embedding": 128, + "max_period": 10000, + "num_layer": 2, + }, + "unet2d": { + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ], + "in_channels": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "only_cross_attention": false, + "out_channels": 1, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ], + }, + }, + }, + "model_type": "DiffWaveNetSVC", + "preprocess": { + "audio_dir": "audios", + "bits": 8, + "content_feature_batch_size": 16, + "contentvec_batch_size": 1, + "contentvec_dir": "contentvec", + "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt", + "contentvec_frameshift": 0.02, + "contentvec_sample_rate": 16000, + "dur_dir": "durs", + "duration_dir": "duration", + "emo2id": "emo2id.json", + "energy_dir": "energys", + "extract_audio": false, + "extract_contentvec_feature": true, + "extract_energy": true, + "extract_label": false, + "extract_mcep": false, + "extract_mel": true, + "extract_mert_feature": false, + "extract_pitch": true, + "extract_uv": true, + "extract_wenet_feature": false, + "extract_whisper_feature": true, + "f0_max": 1100, + "f0_min": 50, + "file_lst": "file.lst", + "fmax": 12000, + "fmin": 0, + "hop_size": 256, + "is_label": true, + "is_mu_law": true, + "lab_dir": "labs", + "label_dir": "labels", + "mcep_dir": "mcep", + "mel_dir": "mels", + "mel_min_max_norm": true, + "mel_min_max_stats_dir": "mel_min_max_stats", + "mert_dir": "mert", + "mert_feature_layer": -1, + "mert_frameshit": 0.01333, + "mert_hop_size": 320, + "mert_model": "m-a-p/MERT-v1-330M", + "min_level_db": -115, + "mu_law_norm": false, + "n_fft": 1024, + "n_mel": 100, + "num_silent_frames": 8, + "num_workers": 8, + "phone_seq_file": "phone_seq_file", + "pin_memory": true, + "pitch_bin": 256, + "pitch_dir": "pitches", + "pitch_extractor": "parselmouth", + "pitch_max": 1100.0, + "pitch_min": 50.0, + "processed_dir": "ckpts/svc/vocalist_l1_contentvec+whisper/data", + "ref_level_db": 20, + "sample_rate": 24000, + "spk2id": "singers.json", + "train_file": "train.json", + "trim_fft_size": 512, + "trim_hop_size": 128, + "trim_silence": false, + "trim_top_db": 30, + "trimmed_wav_dir": "trimmed_wavs", + "use_audio": false, + "use_contentvec": true, + "use_dur": false, + "use_emoid": false, + "use_frame_duration": false, + "use_frame_energy": true, + "use_frame_pitch": true, + "use_lab": false, + "use_label": false, + "use_log_scale_energy": false, + "use_log_scale_pitch": false, + "use_mel": true, + "use_mert": false, + "use_min_max_norm_mel": true, + "use_one_hot": false, + "use_phn_seq": false, + "use_phone_duration": false, + "use_phone_energy": false, + "use_phone_pitch": false, + "use_spkid": true, + "use_uv": true, + "use_wav": false, + "use_wenet": false, + "use_whisper": true, + "utt2emo": "utt2emo", + "utt2spk": "utt2singer", + "uv_dir": "uvs", + "valid_file": "test.json", + "wav_dir": "wavs", + "wenet_batch_size": 1, + "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml", + "wenet_dir": "wenet", + "wenet_downsample_rate": 4, + "wenet_frameshift": 0.01, + "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt", + "wenet_sample_rate": 16000, + "whisper_batch_size": 30, + "whisper_dir": "whisper", + "whisper_downsample_rate": 2, + "whisper_frameshift": 0.01, + "whisper_model": "medium", + "whisper_model_path": "pretrained/whisper/medium.pt", + "win_size": 1024, + }, + "supported_model_type": [ + "Fastspeech2", + "DiffSVC", + "Transformer", + "EDM", + "CD", + ], + "train": { + "adamw": { + "lr": 0.0004, + }, + "batch_size": 32, + "dataloader": { + "num_worker": 8, + "pin_memory": true, + }, + "ddp": true, + "epochs": 50000, + "gradient_accumulation_step": 1, + "keep_checkpoint_max": 5, + "keep_last": [ + 5, + -1, + ], + "max_epoch": -1, + "max_steps": 1000000, + "multi_speaker_training": false, + "optimizer": "AdamW", + "random_seed": 10086, + "reducelronplateau": { + "factor": 0.8, + "min_lr": 0.0001, + "patience": 10, + }, + "run_eval": [ + false, + true, + ], + "sampler": { + "drop_last": true, + "holistic_shuffle": false, + }, + "save_checkpoint_stride": [ + 3, + 10, + ], + "save_checkpoints_steps": 10000, + "save_summary_steps": 500, + "scheduler": "ReduceLROnPlateau", + "total_training_steps": 50000, + "tracker": [ + "tensorboard", + ], + "valid_interval": 10000, + }, + "use_custom_dataset": true, +} \ No newline at end of file diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin new file mode 100755 index 0000000000000000000000000000000000000000..6b5604e3770d0c8de693930332f32ef2e0b16fe0 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/optimizer.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:836af10b834c7aec9209eb19ce43559e6ef1e3a59bd6468e90cadbc9a18749ef +size 249512389 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin new file mode 100755 index 0000000000000000000000000000000000000000..a11911352aa92208e5246cea59e52b3de1f0d704 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/pytorch_model.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d54eed12bef331095fc367f196d07c5061d5cb72dd6fe0e1e4453b997bf1d68d +size 124755137 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl new file mode 100755 index 0000000000000000000000000000000000000000..be96aac1818d9f8fc4dedfcc530ee1e8ea9f78f7 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/random_states_0.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6798ddffadcd7d5405a77e667c674c474e4fef0cba817fdd300c7c985c1e82fe +size 14599 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json new file mode 100755 index 0000000000000000000000000000000000000000..cd56250fa8be439b4ac6d2afe15fed300a69c973 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/checkpoint/epoch-6852_step-0678447_loss-1.946773/singers.json @@ -0,0 +1,17 @@ +{ + "vocalist_l1_Adele": 0, + "vocalist_l1_Beyonce": 1, + "vocalist_l1_BrunoMars": 2, + "vocalist_l1_JohnMayer": 3, + "vocalist_l1_MichaelJackson": 4, + "vocalist_l1_TaylorSwift": 5, + "vocalist_l1_张学友": 6, + "vocalist_l1_李健": 7, + "vocalist_l1_汪峰": 8, + "vocalist_l1_王菲": 9, + "vocalist_l1_石倚洁": 10, + "vocalist_l1_蔡琴": 11, + "vocalist_l1_那英": 12, + "vocalist_l1_陈奕迅": 13, + "vocalist_l1_陶喆": 14 +} \ No newline at end of file diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy new file mode 100755 index 0000000000000000000000000000000000000000..f74cf6fe3127f22eb07c931f1f9ece4c07ed00ed --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_max.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:04131849378aa4f525a701909f743c303f8d56571682572b888046ead9f3e2ab +size 528 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy new file mode 100755 index 0000000000000000000000000000000000000000..20326231f2c3925360e7b102eb98e22bb9a238f5 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/mel_min_max_stats/mel_min.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef4895ebef0e9949a6e623315bdc8a68490ba95d2f81b2be9f5146f904203016 +size 528 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json new file mode 100755 index 0000000000000000000000000000000000000000..5d4cf31177f9dd2bac9538df9eb649ac522fcd69 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/meta_info.json @@ -0,0 +1,31 @@ +{ + "dataset": "vocalist_l1", + "train": { + "size": 3180, + "hours": 6.1643 + }, + "test": { + "size": 114, + "hours": 0.2224 + }, + "singers": { + "size": 15, + "training_minutes": { + "vocalist_l1_陶喆": 45.51, + "vocalist_l1_陈奕迅": 43.36, + "vocalist_l1_汪峰": 41.08, + "vocalist_l1_李健": 38.9, + "vocalist_l1_JohnMayer": 30.83, + "vocalist_l1_Adele": 27.23, + "vocalist_l1_那英": 27.02, + "vocalist_l1_石倚洁": 24.93, + "vocalist_l1_张学友": 18.31, + "vocalist_l1_TaylorSwift": 18.31, + "vocalist_l1_王菲": 16.78, + "vocalist_l1_MichaelJackson": 15.13, + "vocalist_l1_蔡琴": 10.12, + "vocalist_l1_BrunoMars": 6.29, + "vocalist_l1_Beyonce": 6.06 + } + } +} \ No newline at end of file diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json new file mode 100755 index 0000000000000000000000000000000000000000..472551c609057a7eb7ba05b1eceba3ebc0461ed4 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/data/vocalist_l1/pitches/statistics.json @@ -0,0 +1,242 @@ +{ + "vocalist_l1_Adele": { + "voiced_positions": { + "mean": 336.5038018286193, + "std": 100.2148774476881, + "median": 332.98363792619296, + "min": 59.99838412340723, + "max": 1099.849325287837 + }, + "total_positions": { + "mean": 231.79366581704338, + "std": 176.6042850107386, + "median": 273.2844263775394, + "min": 0.0, + "max": 1099.849325287837 + } + }, + "vocalist_l1_Beyonce": { + "voiced_positions": { + "mean": 357.5678927636881, + "std": 130.1132620135807, + "median": 318.2981879228934, + "min": 70.29719673914867, + "max": 1050.354470112099 + }, + "total_positions": { + "mean": 267.5248026267327, + "std": 191.71600807951046, + "median": 261.91981963774066, + "min": 0.0, + "max": 1050.354470112099 + } + }, + "vocalist_l1_BrunoMars": { + "voiced_positions": { + "mean": 330.92612740814315, + "std": 86.51034158515388, + "median": 324.65585832605217, + "min": 58.74277302450286, + "max": 999.2818302992808 + }, + "total_positions": { + "mean": 237.26076288057826, + "std": 166.09898203490803, + "median": 286.3097386522132, + "min": 0.0, + "max": 999.2818302992808 + } + }, + "vocalist_l1_JohnMayer": { + "voiced_positions": { + "mean": 218.3531239166661, + "std": 77.89887175223768, + "median": 200.19060542586652, + "min": 53.371912740674716, + "max": 1098.1986774161685 + }, + "total_positions": { + "mean": 112.95331907131244, + "std": 122.65534824070893, + "median": 124.71389285965317, + "min": 0.0, + "max": 1098.1986774161685 + } + }, + "vocalist_l1_MichaelJackson": { + "voiced_positions": { + "mean": 293.4663654519906, + "std": 89.02211325650234, + "median": 284.4323483619402, + "min": 61.14507754070825, + "max": 1096.4247902272325 + }, + "total_positions": { + "mean": 172.1013565770682, + "std": 159.79551912957191, + "median": 212.82938711725973, + "min": 0.0, + "max": 1096.4247902272325 + } + }, + "vocalist_l1_TaylorSwift": { + "voiced_positions": { + "mean": 302.5346928039029, + "std": 87.1724728626562, + "median": 286.91670244246586, + "min": 51.31173137207717, + "max": 1098.9374311806605 + }, + "total_positions": { + "mean": 169.90968097339214, + "std": 163.7133164876362, + "median": 220.90943653386546, + "min": 0.0, + "max": 1098.9374311806605 + } + }, + "vocalist_l1_张学友": { + "voiced_positions": { + "mean": 233.6845479691867, + "std": 66.47140810463938, + "median": 228.28695118043396, + "min": 51.65338480121057, + "max": 1094.4381927885959 + }, + "total_positions": { + "mean": 167.79543637603194, + "std": 119.28338415844308, + "median": 194.81504136428546, + "min": 0.0, + "max": 1094.4381927885959 + } + }, + "vocalist_l1_李健": { + "voiced_positions": { + "mean": 234.98401896504657, + "std": 71.3955175177514, + "median": 221.86415264367847, + "min": 54.070687769392585, + "max": 1096.3342286660531 + }, + "total_positions": { + "mean": 148.74760079412246, + "std": 126.70486473504008, + "median": 180.21374566147688, + "min": 0.0, + "max": 1096.3342286660531 + } + }, + "vocalist_l1_汪峰": { + "voiced_positions": { + "mean": 284.27752567207864, + "std": 78.51774150654873, + "median": 278.26186808969493, + "min": 54.30945929095861, + "max": 1053.6870553733015 + }, + "total_positions": { + "mean": 172.41584497486713, + "std": 151.74272125914902, + "median": 216.27534661524862, + "min": 0.0, + "max": 1053.6870553733015 + } + }, + "vocalist_l1_王菲": { + "voiced_positions": { + "mean": 339.1661679865587, + "std": 86.86768172635271, + "median": 327.4151031268507, + "min": 51.21299842481366, + "max": 1096.7044574066776 + }, + "total_positions": { + "mean": 217.726880186, + "std": 176.8748978138034, + "median": 277.8608050501477, + "min": 0.0, + "max": 1096.7044574066776 + } + }, + "vocalist_l1_石倚洁": { + "voiced_positions": { + "mean": 279.67710779262256, + "std": 87.82306577322389, + "median": 271.13024912248443, + "min": 59.604772357481075, + "max": 1098.0574674417153 + }, + "total_positions": { + "mean": 205.49634806008135, + "std": 144.6064344590865, + "median": 234.19454400899718, + "min": 0.0, + "max": 1098.0574674417153 + } + }, + "vocalist_l1_蔡琴": { + "voiced_positions": { + "mean": 258.9105806499278, + "std": 67.4079737418162, + "median": 250.29778287949176, + "min": 54.81875790199644, + "max": 930.3733192171918 + }, + "total_positions": { + "mean": 197.64675891035662, + "std": 124.80889987119957, + "median": 228.14775033720753, + "min": 0.0, + "max": 930.3733192171918 + } + }, + "vocalist_l1_那英": { + "voiced_positions": { + "mean": 358.98655838013195, + "std": 91.30591323348871, + "median": 346.95185476261275, + "min": 71.62879029165369, + "max": 1085.4349856526985 + }, + "total_positions": { + "mean": 243.83317702162077, + "std": 183.68660712060583, + "median": 294.9745603259994, + "min": 0.0, + "max": 1085.4349856526985 + } + }, + "vocalist_l1_陈奕迅": { + "voiced_positions": { + "mean": 222.0124146654594, + "std": 68.65002654904572, + "median": 218.9200565540147, + "min": 50.48503062529368, + "max": 1084.6336454006018 + }, + "total_positions": { + "mean": 154.2275169157727, + "std": 117.16740631313343, + "median": 176.89315636838086, + "min": 0.0, + "max": 1084.6336454006018 + } + }, + "vocalist_l1_陶喆": { + "voiced_positions": { + "mean": 242.58206762395713, + "std": 69.61805791083957, + "median": 227.5222796096177, + "min": 50.44809060945403, + "max": 1098.4942623171203 + }, + "total_positions": { + "mean": 171.59040988406485, + "std": 124.93911390018495, + "median": 204.4328861811408, + "min": 0.0, + "max": 1098.4942623171203 + } + } +} \ No newline at end of file diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 new file mode 100755 index 0000000000000000000000000000000000000000..df0c6ae73d3d4df3a0c2856e4ddd75bfc4cc520b --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.0 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d7f490fd0c97876e24bfc44413365ded7ff5d22c1c79f0dac0b754f3b32df76f +size 88 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 new file mode 100755 index 0000000000000000000000000000000000000000..4ee06f708e74717bc23b3130ddcb6f82e5cf84ee --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/log/vocalist_l1_contentvec+whisper/events.out.tfevents.1696052302.mmnewyardnodesz63219.120.1 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e01bcf2fa621ba563b70568c18fe0742d0f48cafae83a6e8beb0bb6d1f6d146d +size 77413046 diff --git a/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json b/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json new file mode 100755 index 0000000000000000000000000000000000000000..cd56250fa8be439b4ac6d2afe15fed300a69c973 --- /dev/null +++ b/ckpts/svc/vocalist_l1_contentvec+whisper/singers.json @@ -0,0 +1,17 @@ +{ + "vocalist_l1_Adele": 0, + "vocalist_l1_Beyonce": 1, + "vocalist_l1_BrunoMars": 2, + "vocalist_l1_JohnMayer": 3, + "vocalist_l1_MichaelJackson": 4, + "vocalist_l1_TaylorSwift": 5, + "vocalist_l1_张学友": 6, + "vocalist_l1_李健": 7, + "vocalist_l1_汪峰": 8, + "vocalist_l1_王菲": 9, + "vocalist_l1_石倚洁": 10, + "vocalist_l1_蔡琴": 11, + "vocalist_l1_那英": 12, + "vocalist_l1_陈奕迅": 13, + "vocalist_l1_陶喆": 14 +} \ No newline at end of file diff --git a/egs/svc/MultipleContentsSVC/README.md b/egs/svc/MultipleContentsSVC/README.md new file mode 100755 index 0000000000000000000000000000000000000000..ac999e6253076f79ca59ed05fac168e5679feaea --- /dev/null +++ b/egs/svc/MultipleContentsSVC/README.md @@ -0,0 +1,153 @@ +# Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion + +[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2310.11160) +[![demo](https://img.shields.io/badge/SVC-Demo-red)](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html) + +
+
+ +
+
+ +This is the official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Specially, + +- The muptile content features are from [Whipser](https://github.com/wenet-e2e/wenet) and [ContentVec](https://github.com/auspicious3000/contentvec). +- The acoustic model is based on Bidirectional Non-Causal Dilated CNN (called `DiffWaveNetSVC` in Amphion), which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219). +- The vocoder is [BigVGAN](https://github.com/NVIDIA/BigVGAN) architecture and we fine-tuned it in over 120 hours singing voice data. + +There are four stages in total: + +1. Data preparation +2. Features extraction +3. Training +4. Inference/conversion + +> **NOTE:** You need to run every command of this recipe in the `Amphion` root path: +> ```bash +> cd Amphion +> ``` + +## 1. Data Preparation + +### Dataset Download + +By default, we utilize the five datasets for training: M4Singer, Opencpop, OpenSinger, SVCC, and VCTK. How to download them is detailed [here](../../datasets/README.md). + +### Configuration + +Specify the dataset paths in `exp_config.json`. Note that you can change the `dataset` list to use your preferred datasets. + +```json + "dataset": [ + "m4singer", + "opencpop", + "opensinger", + "svcc", + "vctk" + ], + "dataset_path": { + // TODO: Fill in your dataset path + "m4singer": "[M4Singer dataset path]", + "opencpop": "[Opencpop dataset path]", + "opensinger": "[OpenSinger dataset path]", + "svcc": "[SVCC dataset path]", + "vctk": "[VCTK dataset path]" + }, +``` + +## 2. Features Extraction + +### Content-based Pretrained Models Download + +By default, we utilize the Whisper and ContentVec to extract content features. How to download them is detailed [here](../../../pretrained/README.md). + +### Configuration + +Specify the dataset path and the output path for saving the processed data and the training model in `exp_config.json`: + +```json + // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc" + "log_dir": "ckpts/svc", + "preprocess": { + // TODO: Fill in the output data path. The default value is "Amphion/data" + "processed_dir": "data", + ... + }, +``` + +### Run + +Run the `run.sh` as the preproces stage (set `--stage 1`). + +```bash +sh egs/svc/MultipleContentsSVC/run.sh --stage 1 +``` + +> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "1"`. + +## 3. Training + +### Configuration + +We provide the default hyparameters in the `exp_config.json`. They can work on single NVIDIA-24g GPU. You can adjust them based on you GPU machines. + +```json +"train": { + "batch_size": 32, + ... + "adamw": { + "lr": 2.0e-4 + }, + ... + } +``` + +### Run + +Run the `run.sh` as the training stage (set `--stage 2`). Specify a experimental name to run the following command. The tensorboard logs and checkpoints will be saved in `Amphion/ckpts/svc/[YourExptName]`. + +```bash +sh egs/svc/MultipleContentsSVC/run.sh --stage 2 --name [YourExptName] +``` + +> **NOTE:** The `CUDA_VISIBLE_DEVICES` is set as `"0"` in default. You can change it when running `run.sh` by specifying such as `--gpu "0,1,2,3"`. + +## 4. Inference/Conversion + +### Pretrained Vocoder Download + +We fine-tune the official BigVGAN pretrained model with over 120 hours singing voice data. The benifits of fine-tuning has been investigated in our paper (see this [demo page](https://www.zhangxueyao.com/data/MultipleContentsSVC/vocoder.html)). The final pretrained singing voice vocoder is released [here](../../../pretrained/README.md#amphion-singing-bigvgan) (called `Amphion Singing BigVGAN`). + +### Run + +For inference/conversion, you need to specify the following configurations when running `run.sh`: + +| Parameters | Description | Example | +| --------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `--infer_expt_dir` | The experimental directory which contains `checkpoint` | `Amphion/ckpts/svc/[YourExptName]` | +| `--infer_output_dir` | The output directory to save inferred audios. | `Amphion/ckpts/svc/[YourExptName]/result` | +| `--infer_source_file` or `--infer_source_audio_dir` | The inference source (can be a json file or a dir). | The `infer_source_file` could be `Amphion/data/[YourDataset]/test.json`, and the `infer_source_audio_dir` is a folder which includes several audio files (*.wav, *.mp3 or *.flac). | +| `--infer_target_speaker` | The target speaker you want to convert into. You can refer to `Amphion/ckpts/svc/[YourExptName]/singers.json` to choose a trained speaker. | For opencpop dataset, the speaker name would be `opencpop_female1`. | +| `--infer_key_shift` | How many semitones you want to transpose. | `"autoshfit"` (by default), `3`, `-3`, etc. | + +For example, if you want to make `opencpop_female1` sing the songs in the `[Your Audios Folder]`, just run: + +```bash +sh egs/svc/MultipleContentsSVC/run.sh --stage 3 --gpu "0" \ + --infer_expt_dir Amphion/ckpts/svc/[YourExptName] \ + --infer_output_dir Amphion/ckpts/svc/[YourExptName]/result \ + --infer_source_audio_dir [Your Audios Folder] \ + --infer_target_speaker "opencpop_female1" \ + --infer_key_shift "autoshift" +``` + +## Citations + +```bibtex +@article{zhang2023leveraging, + title={Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion}, + author={Zhang, Xueyao and Gu, Yicheng and Chen, Haopeng and Fang, Zihao and Zou, Lexiao and Xue, Liumeng and Wu, Zhizheng}, + journal={Machine Learning for Audio Worshop, NeurIPS 2023}, + year={2023} +} +``` diff --git a/egs/svc/MultipleContentsSVC/exp_config.json b/egs/svc/MultipleContentsSVC/exp_config.json new file mode 100755 index 0000000000000000000000000000000000000000..7047855abd18c25760fcdd46ec63da5c4b7ad8ba --- /dev/null +++ b/egs/svc/MultipleContentsSVC/exp_config.json @@ -0,0 +1,126 @@ +{ + "base_config": "config/diffusion.json", + "model_type": "DiffWaveNetSVC", + "dataset": [ + "m4singer", + "opencpop", + "opensinger", + "svcc", + "vctk" + ], + "dataset_path": { + // TODO: Fill in your dataset path + "m4singer": "[M4Singer dataset path]", + "opencpop": "[Opencpop dataset path]", + "opensinger": "[OpenSinger dataset path]", + "svcc": "[SVCC dataset path]", + "vctk": "[VCTK dataset path]" + }, + // TODO: Fill in the output log path. The default value is "Amphion/ckpts/svc" + "log_dir": "ckpts/svc", + "preprocess": { + // TODO: Fill in the output data path. The default value is "Amphion/data" + "processed_dir": "data", + // Config for features extraction + "extract_mel": true, + "extract_pitch": true, + "extract_energy": true, + "extract_whisper_feature": true, + "extract_contentvec_feature": true, + "extract_wenet_feature": false, + "whisper_batch_size": 30, // decrease it if your GPU is out of memory + "contentvec_batch_size": 1, + // Fill in the content-based pretrained model's path + "contentvec_file": "pretrained/contentvec/checkpoint_best_legacy_500.pt", + "wenet_model_path": "pretrained/wenet/20220506_u2pp_conformer_exp/final.pt", + "wenet_config": "pretrained/wenet/20220506_u2pp_conformer_exp/train.yaml", + "whisper_model": "medium", + "whisper_model_path": "pretrained/whisper/medium.pt", + // Config for features usage + "use_mel": true, + "use_min_max_norm_mel": true, + "use_frame_pitch": true, + "use_frame_energy": true, + "use_spkid": true, + "use_whisper": true, + "use_contentvec": true, + "use_wenet": false, + "n_mel": 100, + "sample_rate": 24000 + }, + "model": { + "condition_encoder": { + // Config for features usage + "use_whisper": true, + "use_contentvec": true, + "use_wenet": false, + "whisper_dim": 1024, + "contentvec_dim": 256, + "wenet_dim": 512, + "use_singer_encoder": false, + "pitch_min": 50, + "pitch_max": 1100 + }, + "diffusion": { + "scheduler": "ddpm", + "scheduler_settings": { + "num_train_timesteps": 1000, + "beta_start": 1.0e-4, + "beta_end": 0.02, + "beta_schedule": "linear" + }, + // Diffusion steps encoder + "step_encoder": { + "dim_raw_embedding": 128, + "dim_hidden_layer": 512, + "activation": "SiLU", + "num_layer": 2, + "max_period": 10000 + }, + // Diffusion decoder + "model_type": "bidilconv", + // bidilconv, unet2d, TODO: unet1d + "bidilconv": { + "base_channel": 512, + "n_res_block": 40, + "conv_kernel_size": 3, + "dilation_cycle_length": 4, + // specially, 1 means no dilation + "conditioner_size": 384 + } + } + }, + "train": { + "batch_size": 32, + "gradient_accumulation_step": 1, + "max_epoch": -1, // -1 means no limit + "save_checkpoint_stride": [ + 3, + 50 + ], + "keep_last": [ + 3, + 2 + ], + "run_eval": [ + true, + true + ], + "adamw": { + "lr": 2.0e-4 + }, + "reducelronplateau": { + "factor": 0.8, + "patience": 30, + "min_lr": 1.0e-4 + }, + "dataloader": { + "num_worker": 8, + "pin_memory": true + }, + "sampler": { + "holistic_shuffle": false, + "drop_last": true + } + } +} \ No newline at end of file diff --git a/egs/svc/MultipleContentsSVC/run.sh b/egs/svc/MultipleContentsSVC/run.sh new file mode 120000 index 0000000000000000000000000000000000000000..f8daac3da463c177e36cdf041342566cc4243257 --- /dev/null +++ b/egs/svc/MultipleContentsSVC/run.sh @@ -0,0 +1 @@ +../_template/run.sh \ No newline at end of file diff --git a/egs/svc/README.md b/egs/svc/README.md new file mode 100755 index 0000000000000000000000000000000000000000..5961eaab3782ff96ddbb65a246527ab768498fa5 --- /dev/null +++ b/egs/svc/README.md @@ -0,0 +1,34 @@ +# Amphion Singing Voice Conversion (SVC) Recipe + +## Quick Start + +We provide a **[beginner recipe](MultipleContentsSVC)** to demonstrate how to train a cutting edge SVC model. Specifically, it is also an official implementation of the paper "[Leveraging Content-based Features from Multiple Acoustic Models for Singing Voice Conversion](https://arxiv.org/abs/2310.11160)" (NeurIPS 2023 Workshop on Machine Learning for Audio). Some demos can be seen [here](https://www.zhangxueyao.com/data/MultipleContentsSVC/index.html). + +## Supported Model Architectures + +The main idea of SVC is to first disentangle the speaker-agnostic representations from the source audio, and then inject the desired speaker information to synthesize the target, which usually utilizes an acoustic decoder and a subsequent waveform synthesizer (vocoder): + +
+
+ +
+
+ +Until now, Amphion SVC has supported the following features and models: + +- **Speaker-agnostic Representations**: + - Content Features: Sourcing from [WeNet](https://github.com/wenet-e2e/wenet), [Whisper](https://github.com/openai/whisper), and [ContentVec](https://github.com/auspicious3000/contentvec). + - Prosody Features: F0 and energy. +- **Speaker Embeddings**: + - Speaker Look-Up Table. + - Reference Encoder (👨‍💻 developing): It can be used for zero-shot SVC. +- **Acoustic Decoders**: + - Diffusion-based models: + - **[DiffWaveNetSVC](MultipleContentsSVC)**: The encoder is based on Bidirectional Non-Causal Dilated CNN, which is similar to [WaveNet](https://arxiv.org/pdf/1609.03499.pdf), [DiffWave](https://openreview.net/forum?id=a-xFK8Ymz5J), and [DiffSVC](https://ieeexplore.ieee.org/document/9688219). + - **[DiffComoSVC](DiffComoSVC)** (👨‍💻 developing): The diffusion framework is based on [Consistency Model](https://proceedings.mlr.press/v202/song23a.html). It can significantly accelerate the inference process of the diffusion model. + - Transformer-based models: + - **[TransformerSVC](TransformerSVC)**: Encoder-only and Non-autoregressive Transformer Architecture. + - VAE- and Flow-based models: + - **[VitsSVC]()** (👨‍💻 developing): It is designed as a [VITS](https://arxiv.org/abs/2106.06103)-like model whose textual input is replaced by the content features, which is similar to [so-vits-svc](https://github.com/svc-develop-team/so-vits-svc). +- **Waveform Synthesizers (Vocoders)**: + - The supported vocoders can be seen in [Amphion Vocoder Recipe](../vocoder/README.md). diff --git a/egs/svc/_template/run.sh b/egs/svc/_template/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..8dc870fdef8b1464000021def5627f91d1676bbe --- /dev/null +++ b/egs/svc/_template/run.sh @@ -0,0 +1,150 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +######## Build Experiment Environment ########### +exp_dir=$(cd `dirname $0`; pwd) +work_dir=$(dirname $(dirname $(dirname $exp_dir))) + +export WORK_DIR=$work_dir +export PYTHONPATH=$work_dir +export PYTHONIOENCODING=UTF-8 + +######## Parse the Given Parameters from the Commond ########### +options=$(getopt -o c:n:s --long gpu:,config:,name:,stage:,resume:,resume_from_ckpt_path:,resume_type:,infer_expt_dir:,infer_output_dir:,infer_source_file:,infer_source_audio_dir:,infer_target_speaker:,infer_key_shift:,infer_vocoder_dir: -- "$@") +eval set -- "$options" + +while true; do + case $1 in + # Experimental Configuration File + -c | --config) shift; exp_config=$1 ; shift ;; + # Experimental Name + -n | --name) shift; exp_name=$1 ; shift ;; + # Running Stage + -s | --stage) shift; running_stage=$1 ; shift ;; + # Visible GPU machines. The default value is "0". + --gpu) shift; gpu=$1 ; shift ;; + + # [Only for Training] Resume configuration + --resume) shift; resume=$1 ; shift ;; + # [Only for Training] The specific checkpoint path that you want to resume from. + --resume_from_ckpt_path) shift; resume_from_ckpt_path=$1 ; shift ;; + # [Only for Training] `resume` for loading all the things (including model weights, optimizer, scheduler, and random states). `finetune` for loading only the model weights. + --resume_type) shift; resume_type=$1 ; shift ;; + + # [Only for Inference] The experiment dir. The value is like "[Your path to save logs and checkpoints]/[YourExptName]" + --infer_expt_dir) shift; infer_expt_dir=$1 ; shift ;; + # [Only for Inference] The output dir to save inferred audios. Its default value is "$expt_dir/result" + --infer_output_dir) shift; infer_output_dir=$1 ; shift ;; + # [Only for Inference] The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir can be "$work_dir/source_audio" which includes several audio files (*.wav, *.mp3 or *.flac). + --infer_source_file) shift; infer_source_file=$1 ; shift ;; + --infer_source_audio_dir) shift; infer_source_audio_dir=$1 ; shift ;; + # [Only for Inference] Specify the target speaker you want to convert into. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1". + --infer_target_speaker) shift; infer_target_speaker=$1 ; shift ;; + # [Only for Inference] For advanced users, you can modify the trans_key parameters into an integer (which means the semitones you want to transpose). Its default value is "autoshift". + --infer_key_shift) shift; infer_key_shift=$1 ; shift ;; + # [Only for Inference] The vocoder dir. Its default value is Amphion/pretrained/bigvgan. See Amphion/pretrained/README.md to download the pretrained BigVGAN vocoders. + --infer_vocoder_dir) shift; infer_vocoder_dir=$1 ; shift ;; + + --) shift ; break ;; + *) echo "Invalid option: $1" exit 1 ;; + esac +done + + +### Value check ### +if [ -z "$running_stage" ]; then + echo "[Error] Please specify the running stage" + exit 1 +fi + +if [ -z "$exp_config" ]; then + exp_config="${exp_dir}"/exp_config.json +fi +echo "Exprimental Configuration File: $exp_config" + +if [ -z "$gpu" ]; then + gpu="0" +fi + +######## Features Extraction ########### +if [ $running_stage -eq 1 ]; then + CUDA_VISIBLE_DEVICES=$gpu python "${work_dir}"/bins/svc/preprocess.py \ + --config $exp_config \ + --num_workers 4 +fi + +######## Training ########### +if [ $running_stage -eq 2 ]; then + if [ -z "$exp_name" ]; then + echo "[Error] Please specify the experiments name" + exit 1 + fi + echo "Exprimental Name: $exp_name" + + if [ "$resume" = true ]; then + echo "Automatically resume from the experimental dir..." + CUDA_VISIBLE_DEVICES="$gpu" accelerate launch "${work_dir}"/bins/svc/train.py \ + --config "$exp_config" \ + --exp_name "$exp_name" \ + --log_level info \ + --resume + else + CUDA_VISIBLE_DEVICES=$gpu accelerate launch "${work_dir}"/bins/svc/train.py \ + --config "$exp_config" \ + --exp_name "$exp_name" \ + --log_level info \ + --resume_from_ckpt_path "$resume_from_ckpt_path" \ + --resume_type "$resume_type" + fi +fi + +######## Inference/Conversion ########### +if [ $running_stage -eq 3 ]; then + if [ -z "$infer_expt_dir" ]; then + echo "[Error] Please specify the experimental directionary. The value is like [Your path to save logs and checkpoints]/[YourExptName]" + exit 1 + fi + + if [ -z "$infer_output_dir" ]; then + infer_output_dir="$expt_dir/result" + fi + + if [ -z "$infer_source_file" ] && [ -z "$infer_source_audio_dir" ]; then + echo "[Error] Please specify the source file/dir. The inference source (can be a json file or a dir). For example, the source_file can be "[Your path to save processed data]/[YourDataset]/test.json", and the source_audio_dir should include several audio files (*.wav, *.mp3 or *.flac)." + exit 1 + fi + + if [ -z "$infer_source_file" ]; then + infer_source=$infer_source_audio_dir + fi + + if [ -z "$infer_source_audio_dir" ]; then + infer_source=$infer_source_file + fi + + if [ -z "$infer_target_speaker" ]; then + echo "[Error] Please specify the target speaker. You can refer to "[Your path to save logs and checkpoints]/[Your Expt Name]/singers.json". In this singer look-up table, you can see the usable speaker names (all the keys of the dictionary). For example, for opencpop dataset, the speaker name would be "opencpop_female1"" + exit 1 + fi + + if [ -z "$infer_key_shift" ]; then + infer_key_shift="autoshift" + fi + + if [ -z "$infer_vocoder_dir" ]; then + infer_vocoder_dir="$work_dir"/pretrained/bigvgan + echo "[Warning] You don't specify the infer_vocoder_dir. It is set $infer_vocoder_dir by default. Make sure that you have followed Amphoion/pretrained/README.md to download the pretrained BigVGAN vocoder checkpoint." + fi + + CUDA_VISIBLE_DEVICES=$gpu accelerate launch "$work_dir"/bins/svc/inference.py \ + --config $exp_config \ + --acoustics_dir $infer_expt_dir \ + --vocoder_dir $infer_vocoder_dir \ + --target_singer $infer_target_speaker \ + --trans_key $infer_key_shift \ + --source $infer_source \ + --output_dir $infer_output_dir \ + --log_level debug +fi \ No newline at end of file diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..729bbd0058b1ed399c793231ef645db106a071cf --- /dev/null +++ b/inference.py @@ -0,0 +1,258 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import glob +from tqdm import tqdm +import json +import torch +import time + +from models.svc.diffusion.diffusion_inference import DiffusionInference +from models.svc.comosvc.comosvc_inference import ComoSVCInference +from models.svc.transformer.transformer_inference import TransformerInference +from utils.util import load_config +from utils.audio_slicer import split_audio, merge_segments_encodec +from processors import acoustic_extractor, content_extractor + + +def build_inference(args, cfg, infer_type="from_dataset"): + supported_inference = { + "DiffWaveNetSVC": DiffusionInference, + "DiffComoSVC": ComoSVCInference, + "TransformerSVC": TransformerInference, + } + + inference_class = supported_inference[cfg.model_type] + return inference_class(args, cfg, infer_type) + + +def prepare_for_audio_file(args, cfg, num_workers=1): + preprocess_path = cfg.preprocess.processed_dir + audio_name = cfg.inference.source_audio_name + temp_audio_dir = os.path.join(preprocess_path, audio_name) + + ### eval file + t = time.time() + eval_file = prepare_source_eval_file(cfg, temp_audio_dir, audio_name) + args.source = eval_file + with open(eval_file, "r") as f: + metadata = json.load(f) + print("Prepare for meta eval data: {:.1f}s".format(time.time() - t)) + + ### acoustic features + t = time.time() + acoustic_extractor.extract_utt_acoustic_features_serial( + metadata, temp_audio_dir, cfg + ) + acoustic_extractor.cal_mel_min_max( + dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata + ) + acoustic_extractor.cal_pitch_statistics_svc( + dataset=audio_name, output_path=preprocess_path, cfg=cfg, metadata=metadata + ) + print("Prepare for acoustic features: {:.1f}s".format(time.time() - t)) + + ### content features + t = time.time() + content_extractor.extract_utt_content_features_dataloader( + cfg, metadata, num_workers + ) + print("Prepare for content features: {:.1f}s".format(time.time() - t)) + return args, cfg, temp_audio_dir + + +def merge_for_audio_segments(audio_files, args, cfg): + audio_name = cfg.inference.source_audio_name + target_singer_name = args.target_singer + + merge_segments_encodec( + wav_files=audio_files, + fs=cfg.preprocess.sample_rate, + output_path=os.path.join( + args.output_dir, "{}_{}.wav".format(audio_name, target_singer_name) + ), + overlap_duration=cfg.inference.segments_overlap_duration, + ) + + for tmp_file in audio_files: + os.remove(tmp_file) + + +def prepare_source_eval_file(cfg, temp_audio_dir, audio_name): + """ + Prepare the eval file (json) for an audio + """ + + audio_chunks_results = split_audio( + wav_file=cfg.inference.source_audio_path, + target_sr=cfg.preprocess.sample_rate, + output_dir=os.path.join(temp_audio_dir, "wavs"), + max_duration_of_segment=cfg.inference.segments_max_duration, + overlap_duration=cfg.inference.segments_overlap_duration, + ) + + metadata = [] + for i, res in enumerate(audio_chunks_results): + res["index"] = i + res["Dataset"] = audio_name + res["Singer"] = audio_name + res["Uid"] = "{}_{}".format(audio_name, res["Uid"]) + metadata.append(res) + + eval_file = os.path.join(temp_audio_dir, "eval.json") + with open(eval_file, "w") as f: + json.dump(metadata, f, indent=4, ensure_ascii=False, sort_keys=True) + + return eval_file + + +def cuda_relevant(deterministic=False): + torch.cuda.empty_cache() + # TF32 on Ampere and above + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.enabled = True + torch.backends.cudnn.allow_tf32 = True + # Deterministic + torch.backends.cudnn.deterministic = deterministic + torch.backends.cudnn.benchmark = not deterministic + torch.use_deterministic_algorithms(deterministic) + + +def infer(args, cfg, infer_type): + # Build inference + t = time.time() + trainer = build_inference(args, cfg, infer_type) + print("Model Init: {:.1f}s".format(time.time() - t)) + + # Run inference + t = time.time() + output_audio_files = trainer.inference() + print("Model inference: {:.1f}s".format(time.time() - t)) + return output_audio_files + + +def build_parser(): + r"""Build argument parser for inference.py. + Anything else should be put in an extra config YAML file. + """ + + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=str, + required=True, + help="JSON/YAML file for configurations.", + ) + parser.add_argument( + "--acoustics_dir", + type=str, + help="Acoustics model checkpoint directory. If a directory is given, " + "search for the latest checkpoint dir in the directory. If a specific " + "checkpoint dir is given, directly load the checkpoint.", + ) + parser.add_argument( + "--vocoder_dir", + type=str, + required=True, + help="Vocoder checkpoint directory. Searching behavior is the same as " + "the acoustics one.", + ) + parser.add_argument( + "--target_singer", + type=str, + required=True, + help="convert to a specific singer (e.g. --target_singers singer_id).", + ) + parser.add_argument( + "--trans_key", + default=0, + help="0: no pitch shift; autoshift: pitch shift; int: key shift.", + ) + parser.add_argument( + "--source", + type=str, + default="source_audio", + help="Source audio file or directory. If a JSON file is given, " + "inference from dataset is applied. If a directory is given, " + "inference from all wav/flac/mp3 audio files in the directory is applied. " + "Default: inference from all wav/flac/mp3 audio files in ./source_audio", + ) + parser.add_argument( + "--output_dir", + type=str, + default="conversion_results", + help="Output directory. Default: ./conversion_results", + ) + parser.add_argument( + "--log_level", + type=str, + default="warning", + help="Logging level. Default: warning", + ) + parser.add_argument( + "--keep_cache", + action="store_true", + default=True, + help="Keep cache files. Only applicable to inference from files.", + ) + parser.add_argument( + "--diffusion_inference_steps", + type=int, + default=1000, + help="Number of inference steps. Only applicable to diffusion inference.", + ) + return parser + + +def main(): + ### Parse arguments and config + args = build_parser().parse_args() + cfg = load_config(args.config) + + # CUDA settings + cuda_relevant() + + if os.path.isdir(args.source): + ### Infer from file + + # Get all the source audio files (.wav, .flac, .mp3) + source_audio_dir = args.source + audio_list = [] + for suffix in ["wav", "flac", "mp3"]: + audio_list += glob.glob( + os.path.join(source_audio_dir, "**/*.{}".format(suffix)), recursive=True + ) + print("There are {} source audios: ".format(len(audio_list))) + + # Infer for every file as dataset + output_root_path = args.output_dir + for audio_path in tqdm(audio_list): + audio_name = audio_path.split("/")[-1].split(".")[0] + args.output_dir = os.path.join(output_root_path, audio_name) + print("\n{}\nConversion for {}...\n".format("*" * 10, audio_name)) + + cfg.inference.source_audio_path = audio_path + cfg.inference.source_audio_name = audio_name + cfg.inference.segments_max_duration = 10.0 + cfg.inference.segments_overlap_duration = 1.0 + + # Prepare metadata and features + args, cfg, cache_dir = prepare_for_audio_file(args, cfg) + + # Infer from file + output_audio_files = infer(args, cfg, infer_type="from_file") + + # Merge the split segments + merge_for_audio_segments(output_audio_files, args, cfg) + + # Keep or remove caches + if not args.keep_cache: + os.removedirs(cache_dir) + + else: + ### Infer from dataset + infer(args, cfg, infer_type="from_dataset") diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/base/__init__.py b/models/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe0221047a62e0b9b3ddd112c79a700c48834fd1 --- /dev/null +++ b/models/base/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .new_trainer import BaseTrainer +from .new_inference import BaseInference diff --git a/models/base/base_dataset.py b/models/base/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..e7d9848bb08f29669a51f3dde200d31bafe1d8da --- /dev/null +++ b/models/base/base_dataset.py @@ -0,0 +1,350 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +import torch.utils.data +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from processors.acoustic_extractor import cal_normalized_mel +from text import text_to_sequence +from text.text_token_collation import phoneIDCollation + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + + assert isinstance(dataset, str) + + # self.data_root = processed_data_dir + self.cfg = cfg + + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) + meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file + self.metafile_path = os.path.join(processed_data_dir, meta_file) + self.metadata = self.get_metadata() + + + + ''' + load spk2id and utt2spk from json file + spk2id: {spk1: 0, spk2: 1, ...} + utt2spk: {dataset_uid: spk1, ...} + ''' + if cfg.preprocess.use_spkid: + spk2id_path = os.path.join(processed_data_dir, cfg.preprocess.spk2id) + with open(spk2id_path, "r") as f: + self.spk2id = json.load(f) + + utt2spk_path = os.path.join(processed_data_dir, cfg.preprocess.utt2spk) + self.utt2spk = dict() + with open(utt2spk_path, "r") as f: + for line in f.readlines(): + utt, spk = line.strip().split('\t') + self.utt2spk[utt] = spk + + + if cfg.preprocess.use_uv: + self.utt2uv_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2uv_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.uv_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_pitch_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.pitch_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_energy: + self.utt2frame_energy_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_energy_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.energy_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_mel: + self.utt2mel_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2mel_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.mel_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_linear: + self.utt2linear_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2linear_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.linear_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_audio: + self.utt2audio_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2audio_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.audio_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_label: + self.utt2label_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2label_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.label_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_one_hot: + self.utt2one_hot_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2one_hot_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.one_hot_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_text or cfg.preprocess.use_phone: + self.utt2seq = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if cfg.preprocess.use_text: + text = utt_info["Text"] + sequence = text_to_sequence(text, cfg.preprocess.text_cleaners) + elif cfg.preprocess.use_phone: + # load phoneme squence from phone file + phone_path = os.path.join(processed_data_dir, + cfg.preprocess.phone_dir, + uid+'.phone' + ) + with open(phone_path, 'r') as fin: + phones = fin.readlines() + assert len(phones) == 1 + phones = phones[0].strip() + phones_seq = phones.split(' ') + + phon_id_collator = phoneIDCollation(cfg, dataset=dataset) + sequence = phon_id_collator.get_phone_id_sequence(cfg, phones_seq) + + self.utt2seq[utt] = sequence + + + def get_metadata(self): + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + + return metadata + + def get_dataset_name(self): + return self.metadata[0]["Dataset"] + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_spkid: + single_feature["spk_id"] = np.array( + [self.spk2id[self.utt2spk[utt]]], dtype=np.int32 + ) + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + if self.cfg.preprocess.use_min_max_norm_mel: + # do mel norm + mel = cal_normalized_mel(mel, utt_info["Dataset"], self.cfg.preprocess) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + single_feature["mel"] = mel.T # [T, n_mels] + + if self.cfg.preprocess.use_linear: + linear = np.load(self.utt2linear_path[utt]) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = linear.shape[1] + single_feature["linear"] = linear.T # [T, n_linear] + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch_path = self.utt2frame_pitch_path[utt] + frame_pitch = np.load(frame_pitch_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_uv: + frame_uv_path = self.utt2uv_path[utt] + frame_uv = np.load(frame_uv_path) + aligned_frame_uv = align_length(frame_uv, single_feature["target_len"]) + aligned_frame_uv = [ + 0 if frame_uv else 1 for frame_uv in aligned_frame_uv + ] + aligned_frame_uv = np.array(aligned_frame_uv) + single_feature["frame_uv"] = aligned_frame_uv + + if self.cfg.preprocess.use_frame_energy: + frame_energy_path = self.utt2frame_energy_path[utt] + frame_energy = np.load(frame_energy_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_energy) + aligned_frame_energy = align_length( + frame_energy, single_feature["target_len"] + ) + single_feature["frame_energy"] = aligned_frame_energy + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + single_feature["audio"] = audio + single_feature["audio_len"] = audio.shape[0] + + if self.cfg.preprocess.use_phone or self.cfg.preprocess.use_text: + single_feature["phone_seq"] = np.array(self.utt2seq[utt]) + single_feature["phone_len"] = len(self.utt2seq[utt]) + + return single_feature + + def __len__(self): + return len(self.metadata) + + +class BaseCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, T, n_mels] + # frame_pitch, frame_energy: [1, T] + # target_len: [1] + # spk_id: [b, 1] + # mask: [b, T, 1] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "phone_len": + packed_batch_features["phone_len"] = torch.LongTensor( + [b["phone_len"] for b in batch] + ) + masks = [ + torch.ones((b["phone_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["phn_mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "audio_len": + packed_batch_features["audio_len"] = torch.LongTensor( + [b["audio_len"] for b in batch] + ) + masks = [ + torch.ones((b["audio_len"], 1), dtype=torch.long) for b in batch + ] + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + return packed_batch_features + + +class BaseTestDataset(torch.utils.data.Dataset): + def __init__(self, cfg, args): + raise NotImplementedError + + + def get_metadata(self): + raise NotImplementedError + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + return len(self.metadata) + + +class BaseTestCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + raise NotImplementedError + + def __call__(self, batch): + raise NotImplementedError diff --git a/models/base/base_inference.py b/models/base/base_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2713f19a0d61f06bca1f01de5ccd8a3b4d2cc02f --- /dev/null +++ b/models/base/base_inference.py @@ -0,0 +1,220 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import re +import time +from pathlib import Path + +import torch +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.vocoders.vocoder_inference import synthesis +from torch.utils.data import DataLoader +from utils.util import set_all_random_seed +from utils.util import load_config + + +def parse_vocoder(vocoder_dir): + r"""Parse vocoder config""" + vocoder_dir = os.path.abspath(vocoder_dir) + ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] + ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) + ckpt_path = str(ckpt_list[0]) + vocoder_cfg = load_config(os.path.join(vocoder_dir, "args.json"), lowercase=True) + vocoder_cfg.model.bigvgan = vocoder_cfg.vocoder + return vocoder_cfg, ckpt_path + + +class BaseInference(object): + def __init__(self, cfg, args): + self.cfg = cfg + self.args = args + self.model_type = cfg.model_type + self.avg_rtf = list() + set_all_random_seed(10086) + os.makedirs(args.output_dir, exist_ok=True) + + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + torch.set_num_threads(10) # inference on 1 core cpu. + + # Load acoustic model + self.model = self.create_model().to(self.device) + state_dict = self.load_state_dict() + self.load_model(state_dict) + self.model.eval() + + # Load vocoder model if necessary + if self.args.checkpoint_dir_vocoder is not None: + self.get_vocoder_info() + + def create_model(self): + raise NotImplementedError + + def load_state_dict(self): + self.checkpoint_file = self.args.checkpoint_file + if self.checkpoint_file is None: + assert self.args.checkpoint_dir is not None + checkpoint_path = os.path.join(self.args.checkpoint_dir, "checkpoint") + checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() + self.checkpoint_file = os.path.join( + self.args.checkpoint_dir, checkpoint_filename + ) + + self.checkpoint_dir = os.path.split(self.checkpoint_file)[0] + + print("Restore acoustic model from {}".format(self.checkpoint_file)) + raw_state_dict = torch.load(self.checkpoint_file, map_location=self.device) + self.am_restore_step = re.findall(r"step-(.+?)_loss", self.checkpoint_file)[0] + + return raw_state_dict + + def load_model(self, model): + raise NotImplementedError + + def get_vocoder_info(self): + self.checkpoint_dir_vocoder = self.args.checkpoint_dir_vocoder + self.vocoder_cfg = os.path.join( + os.path.dirname(self.checkpoint_dir_vocoder), "args.json" + ) + self.cfg.vocoder = load_config(self.vocoder_cfg, lowercase=True) + self.vocoder_tag = self.checkpoint_dir_vocoder.split("/")[-2].split(":")[-1] + self.vocoder_steps = self.checkpoint_dir_vocoder.split("/")[-1].split(".")[0] + + def build_test_utt_data(self): + raise NotImplementedError + + def build_testdata_loader(self, args, target_speaker=None): + datasets, collate = self.build_test_dataset() + self.test_dataset = datasets(self.cfg, args, target_speaker) + self.test_collate = collate(self.cfg) + self.test_batch_size = min( + self.cfg.train.batch_size, len(self.test_dataset.metadata) + ) + test_loader = DataLoader( + self.test_dataset, + collate_fn=self.test_collate, + num_workers=self.args.num_workers, + batch_size=self.test_batch_size, + shuffle=False, + ) + return test_loader + + def inference_each_batch(self, batch_data): + raise NotImplementedError + + def inference_for_batches(self, args, target_speaker=None): + ###### Construct test_batch ###### + loader = self.build_testdata_loader(args, target_speaker) + + n_batch = len(loader) + now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) + print( + "Model eval time: {}, batch_size = {}, n_batch = {}".format( + now, self.test_batch_size, n_batch + ) + ) + self.model.eval() + + ###### Inference for each batch ###### + pred_res = [] + with torch.no_grad(): + for i, batch_data in enumerate(loader if n_batch == 1 else tqdm(loader)): + # Put the data to device + for k, v in batch_data.items(): + batch_data[k] = batch_data[k].to(self.device) + + y_pred, stats = self.inference_each_batch(batch_data) + + pred_res += y_pred + + return pred_res + + def inference(self, feature): + raise NotImplementedError + + def synthesis_by_vocoder(self, pred): + audios_pred = synthesis( + self.vocoder_cfg, + self.checkpoint_dir_vocoder, + len(pred), + pred, + ) + return audios_pred + + def __call__(self, utt): + feature = self.build_test_utt_data(utt) + start_time = time.time() + with torch.no_grad(): + outputs = self.inference(feature)[0] + time_used = time.time() - start_time + rtf = time_used / ( + outputs.shape[1] + * self.cfg.preprocess.hop_size + / self.cfg.preprocess.sample_rate + ) + print("Time used: {:.3f}, RTF: {:.4f}".format(time_used, rtf)) + self.avg_rtf.append(rtf) + audios = outputs.cpu().squeeze().numpy().reshape(-1, 1) + return audios + + +def base_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", default="config.json", help="json files for configurations." + ) + parser.add_argument("--use_ddp_inference", default=False) + parser.add_argument("--n_workers", default=1, type=int) + parser.add_argument("--local_rank", default=-1, type=int) + parser.add_argument( + "--batch_size", default=1, type=int, help="Batch size for inference" + ) + parser.add_argument( + "--num_workers", + default=1, + type=int, + help="Worker number for inference dataloader", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + default=None, + help="Checkpoint dir including model file and configuration", + ) + parser.add_argument( + "--checkpoint_file", help="checkpoint file", type=str, default=None + ) + parser.add_argument( + "--test_list", help="test utterance list for testing", type=str, default=None + ) + parser.add_argument( + "--checkpoint_dir_vocoder", + help="Vocoder's checkpoint dir including model file and configuration", + type=str, + default=None, + ) + parser.add_argument( + "--output_dir", + type=str, + default=None, + help="Output dir for saving generated results", + ) + return parser + + +if __name__ == "__main__": + parser = base_parser() + args = parser.parse_args() + cfg = load_config(args.config) + + # Build inference + inference = BaseInference(cfg, args) + inference() diff --git a/models/base/base_sampler.py b/models/base/base_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..149d1437eb1d3a00ca8c9895b150b39b2a3635fa --- /dev/null +++ b/models/base/base_sampler.py @@ -0,0 +1,136 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random + +from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +class ScheduledSampler(Sampler): + """A sampler that samples data from a given concat-dataset. + + Args: + concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets + batch_size (int): batch size + holistic_shuffle (bool): whether to shuffle the whole dataset or not + logger (logging.Logger): logger to print warning message + + Usage: + For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: + >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) + [3, 4, 5, 0, 1, 2, 6, 7, 8] + """ + + def __init__( + self, + concat_dataset, + batch_size, + holistic_shuffle, + logger=None, + loader_type="train", + ): + if not isinstance(concat_dataset, ConcatDataset): + raise ValueError( + "concat_dataset must be an instance of ConcatDataset, but got {}".format( + type(concat_dataset) + ) + ) + if not isinstance(batch_size, int): + raise ValueError( + "batch_size must be an integer, but got {}".format(type(batch_size)) + ) + if not isinstance(holistic_shuffle, bool): + raise ValueError( + "holistic_shuffle must be a boolean, but got {}".format( + type(holistic_shuffle) + ) + ) + + self.concat_dataset = concat_dataset + self.batch_size = batch_size + self.holistic_shuffle = holistic_shuffle + + affected_dataset_name = [] + affected_dataset_len = [] + for dataset in concat_dataset.datasets: + dataset_len = len(dataset) + dataset_name = dataset.get_dataset_name() + if dataset_len < batch_size: + affected_dataset_name.append(dataset_name) + affected_dataset_len.append(dataset_len) + + self.type = loader_type + for dataset_name, dataset_len in zip( + affected_dataset_name, affected_dataset_len + ): + if not loader_type == "valid": + logger.warning( + "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( + loader_type, dataset_name, dataset_len, batch_size + ) + ) + + def __len__(self): + # the number of batches with drop last + num_of_batches = sum( + [ + math.floor(len(dataset) / self.batch_size) + for dataset in self.concat_dataset.datasets + ] + ) + # if samples are not enough for one batch, we don't drop last + if self.type == "valid" and num_of_batches < 1: + return len(self.concat_dataset) + return num_of_batches * self.batch_size + + def __iter__(self): + iters = [] + for dataset in self.concat_dataset.datasets: + iters.append( + SequentialSampler(dataset).__iter__() + if not self.holistic_shuffle + else RandomSampler(dataset).__iter__() + ) + # e.g. [0, 200, 400] + init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] + output_batches = [] + for dataset_idx in range(len(self.concat_dataset.datasets)): + cur_batch = [] + for idx in iters[dataset_idx]: + cur_batch.append(idx + init_indices[dataset_idx]) + if len(cur_batch) == self.batch_size: + output_batches.append(cur_batch) + cur_batch = [] + # if loader_type is valid, we don't need to drop last + if self.type == "valid" and len(cur_batch) > 0: + output_batches.append(cur_batch) + + # force drop last in training + random.shuffle(output_batches) + output_indices = [item for sublist in output_batches for item in sublist] + return iter(output_indices) + + +def build_samplers(concat_dataset: Dataset, cfg, logger, loader_type): + sampler = ScheduledSampler( + concat_dataset, + cfg.train.batch_size, + cfg.train.sampler.holistic_shuffle, + logger, + loader_type, + ) + batch_sampler = BatchSampler( + sampler, + cfg.train.batch_size, + cfg.train.sampler.drop_last if not loader_type == "valid" else False, + ) + return sampler, batch_sampler diff --git a/models/base/base_trainer.py b/models/base/base_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..8782216dc13ce5d9de05ae790faeb82cf7cfd501 --- /dev/null +++ b/models/base/base_trainer.py @@ -0,0 +1,348 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import collections +import json +import os +import sys +import time + +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import ConcatDataset, DataLoader +from torch.utils.tensorboard import SummaryWriter + +from models.base.base_sampler import BatchSampler +from utils.util import ( + Logger, + remove_older_ckpt, + save_config, + set_all_random_seed, + ValueWindow, +) + + +class BaseTrainer(object): + def __init__(self, args, cfg): + self.args = args + self.log_dir = args.log_dir + self.cfg = cfg + + self.checkpoint_dir = os.path.join(args.log_dir, "checkpoints") + os.makedirs(self.checkpoint_dir, exist_ok=True) + if not cfg.train.ddp or args.local_rank == 0: + self.sw = SummaryWriter(os.path.join(args.log_dir, "events")) + self.logger = self.build_logger() + self.time_window = ValueWindow(50) + + self.step = 0 + self.epoch = -1 + self.max_epochs = self.cfg.train.epochs + self.max_steps = self.cfg.train.max_steps + + # set random seed & init distributed training + set_all_random_seed(self.cfg.train.random_seed) + if cfg.train.ddp: + dist.init_process_group(backend="nccl") + + if cfg.model_type not in ["AutoencoderKL", "AudioLDM"]: + self.singers = self.build_singers_lut() + + # setup data_loader + self.data_loader = self.build_data_loader() + + # setup model & enable distributed training + self.model = self.build_model() + print(self.model) + + if isinstance(self.model, dict): + for key, value in self.model.items(): + value.cuda(self.args.local_rank) + if key == "PQMF": + continue + if cfg.train.ddp: + self.model[key] = DistributedDataParallel( + value, device_ids=[self.args.local_rank] + ) + else: + self.model.cuda(self.args.local_rank) + if cfg.train.ddp: + self.model = DistributedDataParallel( + self.model, device_ids=[self.args.local_rank] + ) + + # create criterion + self.criterion = self.build_criterion() + if isinstance(self.criterion, dict): + for key, value in self.criterion.items(): + self.criterion[key].cuda(args.local_rank) + else: + self.criterion.cuda(self.args.local_rank) + + # optimizer + self.optimizer = self.build_optimizer() + self.scheduler = self.build_scheduler() + + # save config file + self.config_save_path = os.path.join(self.checkpoint_dir, "args.json") + + def build_logger(self): + log_file = os.path.join(self.checkpoint_dir, "train.log") + logger = Logger(log_file, level=self.args.log_level).logger + + return logger + + def build_dataset(self): + raise NotImplementedError + + def build_data_loader(self): + Dataset, Collator = self.build_dataset() + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + + train_collate = Collator(self.cfg) + # TODO: multi-GPU training + if self.cfg.train.ddp: + raise NotImplementedError("DDP is not supported yet.") + + # sampler will provide indices to batch_sampler, which will perform batching and yield batch indices + batch_sampler = BatchSampler( + cfg=self.cfg, concat_dataset=train_dataset, dataset_list=datasets_list + ) + + # use batch_sampler argument instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + num_workers=self.args.num_workers, + batch_sampler=batch_sampler, + pin_memory=False, + ) + if not self.cfg.train.ddp or self.args.local_rank == 0: + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + batch_sampler = BatchSampler( + cfg=self.cfg, concat_dataset=valid_dataset, dataset_list=datasets_list + ) + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + num_workers=1, + batch_sampler=batch_sampler, + ) + else: + raise NotImplementedError("DDP is not supported yet.") + # valid_loader = None + data_loader = {"train": train_loader, "valid": valid_loader} + return data_loader + + def build_singers_lut(self): + # combine singers + if not os.path.exists(os.path.join(self.log_dir, self.cfg.preprocess.spk2id)): + singers = collections.OrderedDict() + else: + with open( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "r" + ) as singer_file: + singers = json.load(singer_file) + singer_count = len(singers) + for dataset in self.cfg.dataset: + singer_lut_path = os.path.join( + self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id + ) + with open(singer_lut_path, "r") as singer_lut_path: + singer_lut = json.load(singer_lut_path) + for singer in singer_lut.keys(): + if singer not in singers: + singers[singer] = singer_count + singer_count += 1 + with open( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id), "w" + ) as singer_file: + json.dump(singers, singer_file, indent=4, ensure_ascii=False) + print( + "singers have been dumped to {}".format( + os.path.join(self.log_dir, self.cfg.preprocess.spk2id) + ) + ) + return singers + + def build_model(self): + raise NotImplementedError() + + def build_optimizer(self): + raise NotImplementedError + + def build_scheduler(self): + raise NotImplementedError() + + def build_criterion(self): + raise NotImplementedError + + def get_state_dict(self): + raise NotImplementedError + + def save_config_file(self): + save_config(self.config_save_path, self.cfg) + + # TODO, save without module. + def save_checkpoint(self, state_dict, saved_model_path): + torch.save(state_dict, saved_model_path) + + def load_checkpoint(self): + checkpoint_path = os.path.join(self.checkpoint_dir, "checkpoint") + assert os.path.exists(checkpoint_path) + checkpoint_filename = open(checkpoint_path).readlines()[-1].strip() + model_path = os.path.join(self.checkpoint_dir, checkpoint_filename) + assert os.path.exists(model_path) + if not self.cfg.train.ddp or self.args.local_rank == 0: + self.logger.info(f"Re(store) from {model_path}") + checkpoint = torch.load(model_path, map_location="cpu") + return checkpoint + + def load_model(self, checkpoint): + raise NotImplementedError + + def restore(self): + checkpoint = self.load_checkpoint() + self.load_model(checkpoint) + + def train_step(self, data): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + @torch.no_grad() + def eval_step(self): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def write_summary(self, losses, stats): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def write_valid_summary(self, losses, stats): + raise NotImplementedError( + f"Need to implement function {sys._getframe().f_code.co_name} in " + f"your sub-class of {self.__class__.__name__}. " + ) + + def echo_log(self, losses, mode="Training"): + message = [ + "{} - Epoch {} Step {}: [{:.3f} s/step]".format( + mode, self.epoch + 1, self.step, self.time_window.average + ) + ] + + for key in sorted(losses.keys()): + if isinstance(losses[key], dict): + for k, v in losses[key].items(): + message.append( + str(k).split("/")[-1] + "=" + str(round(float(v), 5)) + ) + else: + message.append( + str(key).split("/")[-1] + "=" + str(round(float(losses[key]), 5)) + ) + self.logger.info(", ".join(message)) + + def eval_epoch(self): + self.logger.info("Validation...") + valid_losses = {} + for i, batch_data in enumerate(self.data_loader["valid"]): + for k, v in batch_data.items(): + if isinstance(v, torch.Tensor): + batch_data[k] = v.cuda() + valid_loss, valid_stats, total_valid_loss = self.eval_step(batch_data, i) + for key in valid_loss: + if key not in valid_losses: + valid_losses[key] = 0 + valid_losses[key] += valid_loss[key] + + # Add mel and audio to the Tensorboard + # Average loss + for key in valid_losses: + valid_losses[key] /= i + 1 + self.echo_log(valid_losses, "Valid") + return valid_losses, valid_stats + + def train_epoch(self): + for i, batch_data in enumerate(self.data_loader["train"]): + start_time = time.time() + # Put the data to cuda device + for k, v in batch_data.items(): + if isinstance(v, torch.Tensor): + batch_data[k] = v.cuda(self.args.local_rank) + + # Training step + train_losses, train_stats, total_loss = self.train_step(batch_data) + self.time_window.append(time.time() - start_time) + + if self.args.local_rank == 0 or not self.cfg.train.ddp: + if self.step % self.args.stdout_interval == 0: + self.echo_log(train_losses, "Training") + + if self.step % self.cfg.train.save_summary_steps == 0: + self.logger.info(f"Save summary as step {self.step}") + self.write_summary(train_losses, train_stats) + + if ( + self.step % self.cfg.train.save_checkpoints_steps == 0 + and self.step != 0 + ): + saved_model_name = "step-{:07d}_loss-{:.4f}.pt".format( + self.step, total_loss + ) + saved_model_path = os.path.join( + self.checkpoint_dir, saved_model_name + ) + saved_state_dict = self.get_state_dict() + self.save_checkpoint(saved_state_dict, saved_model_path) + self.save_config_file() + # keep max n models + remove_older_ckpt( + saved_model_name, + self.checkpoint_dir, + max_to_keep=self.cfg.train.keep_checkpoint_max, + ) + + if self.step != 0 and self.step % self.cfg.train.valid_interval == 0: + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].eval() + else: + self.model.eval() + # Evaluate one epoch and get average loss + valid_losses, valid_stats = self.eval_epoch() + if isinstance(self.model, dict): + for key in self.model.keys(): + self.model[key].train() + else: + self.model.train() + # Write validation losses to summary. + self.write_valid_summary(valid_losses, valid_stats) + self.step += 1 + + def train(self): + for epoch in range(max(0, self.epoch), self.max_epochs): + self.train_epoch() + self.epoch += 1 + if self.step > self.max_steps: + self.logger.info("Training finished!") + break diff --git a/models/base/new_dataset.py b/models/base/new_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2201bb4132ab86d1110092d7ab9e509296367a22 --- /dev/null +++ b/models/base/new_dataset.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +from abc import abstractmethod +from pathlib import Path + +import json5 +import torch +import yaml + + +# TODO: for training and validating +class BaseDataset(torch.utils.data.Dataset): + r"""Base dataset for training and validating.""" + + def __init__(self, args, cfg, is_valid=False): + pass + + +class BaseTestDataset(torch.utils.data.Dataset): + r"""Test dataset for inference.""" + + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + assert infer_type in ["from_dataset", "from_file"] + + self.args = args + self.cfg = cfg + self.infer_type = infer_type + + @abstractmethod + def __getitem__(self, index): + pass + + def __len__(self): + return len(self.metadata) + + def get_metadata(self): + path = Path(self.args.source) + if path.suffix == ".json" or path.suffix == ".jsonc": + metadata = json5.load(open(self.args.source, "r")) + elif path.suffix == ".yaml" or path.suffix == ".yml": + metadata = yaml.full_load(open(self.args.source, "r")) + else: + raise ValueError(f"Unsupported file type: {path.suffix}") + + return metadata diff --git a/models/base/new_inference.py b/models/base/new_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..4813fca4aba192fb8737dd74f37f6d430e1909a4 --- /dev/null +++ b/models/base/new_inference.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import re +import time +from abc import abstractmethod +from pathlib import Path + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.logging import get_logger +from torch.utils.data import DataLoader + +from models.vocoders.vocoder_inference import synthesis +from utils.io import save_audio +from utils.util import load_config +from utils.audio_slicer import is_silence + +EPS = 1.0e-12 + + +class BaseInference(object): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + super().__init__() + + start = time.monotonic_ns() + self.args = args + self.cfg = cfg + + assert infer_type in ["from_dataset", "from_file"] + self.infer_type = infer_type + + # init with accelerate + self.accelerator = accelerate.Accelerator() + self.accelerator.wait_for_everyone() + + # Use accelerate logger for distributed inference + with self.accelerator.main_process_first(): + self.logger = get_logger("inference", log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New inference process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + + self.acoustics_dir = args.acoustics_dir + self.logger.debug(f"Acoustic dir: {args.acoustics_dir}") + self.vocoder_dir = args.vocoder_dir + self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") + # should be in svc inferencer + # self.target_singer = args.target_singer + # self.logger.info(f"Target singers: {args.target_singer}") + # self.trans_key = args.trans_key + # self.logger.info(f"Trans key: {args.trans_key}") + + os.makedirs(args.output_dir, exist_ok=True) + + # set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.test_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + # self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") + + # init with accelerate + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self.accelerator = accelerate.Accelerator() + self.model = self.accelerator.prepare(self.model) + end = time.monotonic_ns() + self.accelerator.wait_for_everyone() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") + + with self.accelerator.main_process_first(): + self.logger.info("Loading checkpoint...") + start = time.monotonic_ns() + # TODO: Also, suppose only use latest one yet + self.__load_model(os.path.join(args.acoustics_dir, "checkpoint")) + end = time.monotonic_ns() + self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") + + self.model.eval() + self.accelerator.wait_for_everyone() + + ### Abstract methods ### + @abstractmethod + def _build_test_dataset(self): + pass + + @abstractmethod + def _build_model(self): + pass + + @abstractmethod + @torch.inference_mode() + def _inference_each_batch(self, batch_data): + pass + + ### Abstract methods end ### + + @torch.inference_mode() + def inference(self): + for i, batch in enumerate(self.test_dataloader): + y_pred = self._inference_each_batch(batch).cpu() + mel_min, mel_max = self.test_dataset.target_mel_extrema + y_pred = (y_pred + 1.0) / 2.0 * (mel_max - mel_min + EPS) + mel_min + y_ls = y_pred.chunk(self.test_batch_size) + tgt_ls = batch["target_len"].cpu().chunk(self.test_batch_size) + j = 0 + for it, l in zip(y_ls, tgt_ls): + l = l.item() + it = it.squeeze(0)[:l] + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] + torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) + j += 1 + + vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) + + res = synthesis( + cfg=vocoder_cfg, + vocoder_weight_file=vocoder_ckpt, + n_samples=None, + pred=[ + torch.load( + os.path.join(self.args.output_dir, "{}.pt".format(i["Uid"])) + ).numpy(force=True) + for i in self.test_dataset.metadata + ], + ) + + output_audio_files = [] + for it, wav in zip(self.test_dataset.metadata, res): + uid = it["Uid"] + file = os.path.join(self.args.output_dir, f"{uid}.wav") + output_audio_files.append(file) + + wav = wav.numpy(force=True) + save_audio( + file, + wav, + self.cfg.preprocess.sample_rate, + add_silence=False, + turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), + ) + os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) + + return sorted(output_audio_files) + + # TODO: LEGACY CODE + def _build_dataloader(self): + datasets, collate = self._build_test_dataset() + self.test_dataset = datasets(self.args, self.cfg, self.infer_type) + self.test_collate = collate(self.cfg) + self.test_batch_size = min( + self.cfg.train.batch_size, len(self.test_dataset.metadata) + ) + test_dataloader = DataLoader( + self.test_dataset, + collate_fn=self.test_collate, + num_workers=1, + batch_size=self.test_batch_size, + shuffle=False, + ) + return test_dataloader + + def __load_model(self, checkpoint_dir: str = None, checkpoint_path: str = None): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [] + for i in Path(checkpoint_dir).iterdir(): + if re.match(r"epoch-\d+_step-\d+_loss-[\d.]+", str(i.stem)): + ls.append(i) + ls.sort( + key=lambda x: int(x.stem.split("_")[-3].split("-")[-1]), reverse=True + ) + checkpoint_path = ls[0] + else: + checkpoint_path = Path(checkpoint_path) + self.accelerator.load_state(str(checkpoint_path)) + # set epoch and step + self.epoch = int(checkpoint_path.stem.split("_")[-3].split("-")[-1]) + self.step = int(checkpoint_path.stem.split("_")[-2].split("-")[-1]) + return str(checkpoint_path) + + @staticmethod + def _set_random_seed(seed): + r"""Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + @staticmethod + def _parse_vocoder(vocoder_dir): + r"""Parse vocoder config""" + vocoder_dir = os.path.abspath(vocoder_dir) + ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] + ckpt_list.sort(key=lambda x: int(x.stem), reverse=True) + ckpt_path = str(ckpt_list[0]) + vocoder_cfg = load_config( + os.path.join(vocoder_dir, "args.json"), lowercase=True + ) + return vocoder_cfg, ckpt_path + + @staticmethod + def __count_parameters(model): + return sum(p.numel() for p in model.parameters()) + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) diff --git a/models/base/new_trainer.py b/models/base/new_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d013d2bc2f2e47e5c7646cac5c63cc88c04486b --- /dev/null +++ b/models/base/new_trainer.py @@ -0,0 +1,722 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os +import random +import shutil +import time +from abc import abstractmethod +from pathlib import Path + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration +from torch.utils.data import ConcatDataset, DataLoader +from tqdm import tqdm + +from models.base.base_sampler import build_samplers +from optimizer.optimizers import NoamLR + + +class BaseTrainer(object): + r"""The base trainer for all tasks. Any trainer should inherit from this class.""" + + def __init__(self, args=None, cfg=None): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # init with accelerate + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Use accelerate logger for distributed training + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # init counts + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check values + if self.accelerator.is_main_process: + self.__check_basic_configs() + # Set runtime configs + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.keep_last = [ + i if i > 0 else float("inf") for i in self.cfg.train.keep_last + ] + self.run_eval = self.cfg.train.run_eval + + # set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # setup data_loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # setup model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.model) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info( + f"Model parameters: {self.__count_parameters(self.model)/1e6:.2f}M" + ) + # optimizer & scheduler + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + self.optimizer = self.__build_optimizer() + self.scheduler = self.__build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # accelerate prepare + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + ( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.model, + self.optimizer, + self.scheduler, + ) + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # create criterion + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterion = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume or Finetune + with self.accelerator.main_process_first(): + if args.resume: + ## Automatically resume according to the current exprimental name + self.logger.info("Resuming from {}...".format(self.checkpoint_dir)) + start = time.monotonic_ns() + ckpt_path = self.__load_model( + checkpoint_dir=self.checkpoint_dir, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + self.checkpoints_path = json.load( + open(os.path.join(ckpt_path, "ckpts.json"), "r") + ) + elif args.resume_from_ckpt_path and args.resume_from_ckpt_path != "": + ## Resume from the given checkpoint path + if not os.path.exists(args.resume_from_ckpt_path): + raise ValueError( + "[Error] The resumed checkpoint path {} don't exist.".format( + args.resume_from_ckpt_path + ) + ) + + self.logger.info( + "Resuming from {}...".format(args.resume_from_ckpt_path) + ) + start = time.monotonic_ns() + ckpt_path = self.__load_model( + checkpoint_path=args.resume_from_ckpt_path, + resume_type=args.resume_type, + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + + # save config file path + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + ### Following are abstract methods that should be implemented in child classes ### + @abstractmethod + def _build_dataset(self): + r"""Build dataset for model training/validating/evaluating.""" + pass + + @staticmethod + @abstractmethod + def _build_criterion(): + r"""Build criterion function for model loss calculation.""" + pass + + @abstractmethod + def _build_model(self): + r"""Build model for training/validating/evaluating.""" + pass + + @abstractmethod + def _forward_step(self, batch): + r"""One forward step of the neural network. This abstract method is trying to + unify ``_train_step`` and ``_valid_step`` and avoid redundant implementation. + However, for special case that using different forward step pattern for + training and validating, you could just override this method with ``pass`` and + implement ``_train_step`` and ``_valid_step`` separately. + """ + pass + + @abstractmethod + def _save_auxiliary_states(self): + r"""To save some auxiliary states when saving model's ckpt""" + pass + + ### Abstract methods end ### + + ### THIS IS MAIN ENTRY ### + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + self.model.train() + self.optimizer.zero_grad() + # Wait to ensure good to go + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) + ### It's inconvenient for the model with multiple losses + # Do training & validating epoch + train_loss = self._train_epoch() + self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) + valid_loss = self._valid_epoch() + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) + self.accelerator.log( + {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + # TODO: what is scheduler? + self.scheduler.step(valid_loss) # FIXME: use epoch track correct? + + # Check if hit save_checkpoint_stride and run_eval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + hit_dix = [] + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + hit_dix.append(i) + run_eval |= self.run_eval[i] + + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_loss + ), + ) + self.tmp_checkpoint_save_path = path + self.accelerator.save_state(path) + print(f"save checkpoint in {path}") + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + self._save_auxiliary_states() + + # Remove old checkpoints + to_remove = [] + for idx in hit_dix: + self.checkpoints_path[idx].append(path) + while len(self.checkpoints_path[idx]) > self.keep_last[idx]: + to_remove.append((idx, self.checkpoints_path[idx].pop(0))) + + # Search conflicts + total = set() + for i in self.checkpoints_path: + total |= set(i) + do_remove = set() + for idx, path in to_remove[::-1]: + if path in total: + self.checkpoints_path[idx].insert(0, path) + else: + do_remove.add(path) + + # Remove old checkpoints + for path in do_remove: + shutil.rmtree(path, ignore_errors=True) + self.logger.debug(f"Remove old checkpoint: {path}") + + self.accelerator.wait_for_everyone() + if run_eval: + # TODO: run evaluation + pass + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_loss + ), + ) + ) + self._save_auxiliary_states() + + self.accelerator.end_training() + + ### Following are methods that can be used directly in child classes ### + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.train() + epoch_sum_loss: float = 0.0 + epoch_step: int = 0 + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Do training step and BP + with self.accelerator.accumulate(self.model): + loss = self._train_step(batch) + self.accelerator.backward(loss) + self.optimizer.step() + self.optimizer.zero_grad() + self.batch_count += 1 + + # Update info for each step + # TODO: step means BP counts or batch counts? + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += loss + self.accelerator.log( + { + "Step/Train Loss": loss, + "Step/Learning Rate": self.optimizer.param_groups[0]["lr"], + }, + step=self.step, + ) + self.step += 1 + epoch_step += 1 + + self.accelerator.wait_for_everyone() + return ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.eval() + epoch_sum_loss = 0.0 + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + batch_loss = self._valid_step(batch) + epoch_sum_loss += batch_loss.item() + + self.accelerator.wait_for_everyone() + return epoch_sum_loss / len(self.valid_dataloader) + + def _train_step(self, batch): + r"""Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + return self._forward_step(batch) + + @torch.inference_mode() + def _valid_step(self, batch): + r"""Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + return self._forward_step(batch) + + def __load_model( + self, + checkpoint_dir: str = None, + checkpoint_path: str = None, + resume_type: str = "", + ): + r"""Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + self.logger.info("Resume from {}...".format(checkpoint_path)) + + if resume_type in ["resume", ""]: + # Load all the things, including model weights, optimizer, scheduler, and random states. + self.accelerator.load_state(input_dir=checkpoint_path) + + # set epoch and step + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + + elif resume_type == "finetune": + # Load only the model weights + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune...") + + else: + raise ValueError("Resume_type must be `resume` or `finetune`.") + + return checkpoint_path + + # TODO: LEGACY CODE + def _build_dataloader(self): + Dataset, Collator = self._build_dataset() + + # build dataset instance for each dataset and combine them by ConcatDataset + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = ConcatDataset(datasets_list) + train_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train") + self.logger.debug(f"train batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {train_dataset.cumulative_sizes}") + # TODO: use config instead of (sampler, shuffle, drop_last, batch_size) + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + + # Build valid dataloader + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = ConcatDataset(datasets_list) + valid_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "valid") + self.logger.debug(f"valid batch_sampler: {list(batch_sampler)}") + self.logger.debug(f"length: {valid_dataset.cumulative_sizes}") + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, valid_loader + + @staticmethod + def _set_random_seed(seed): + r"""Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss, y_pred, y_gt): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: Training is down since loss has Nan!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + if torch.any(torch.isnan(y_pred)): + self.logger.error( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + else: + self.logger.debug( + f"y_pred has Nan: {torch.any(torch.isnan(y_pred))}", in_order=True + ) + if torch.any(torch.isnan(y_gt)): + self.logger.error( + f"y_gt has Nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + else: + self.logger.debug( + f"y_gt has nan: {torch.any(torch.isnan(y_gt))}", in_order=True + ) + if torch.any(torch.isnan(y_pred)): + self.logger.error(f"y_pred: {y_pred}", in_order=True) + else: + self.logger.debug(f"y_pred: {y_pred}", in_order=True) + if torch.any(torch.isnan(y_gt)): + self.logger.error(f"y_gt: {y_gt}", in_order=True) + else: + self.logger.debug(f"y_gt: {y_gt}", in_order=True) + + # TODO: still OK to save tracking? + self.accelerator.end_training() + raise RuntimeError("Loss has Nan! See log for more info.") + + ### Protected methods end ### + + ## Following are private methods ## + ## !!! These are inconvenient for GAN-based model training. It'd be better to move these to svc_trainer.py if needed. + def __build_optimizer(self): + r"""Build optimizer for model.""" + # Make case-insensitive matching + if self.cfg.train.optimizer.lower() == "adadelta": + optimizer = torch.optim.Adadelta( + self.model.parameters(), **self.cfg.train.adadelta + ) + self.logger.info("Using Adadelta optimizer.") + elif self.cfg.train.optimizer.lower() == "adagrad": + optimizer = torch.optim.Adagrad( + self.model.parameters(), **self.cfg.train.adagrad + ) + self.logger.info("Using Adagrad optimizer.") + elif self.cfg.train.optimizer.lower() == "adam": + optimizer = torch.optim.Adam(self.model.parameters(), **self.cfg.train.adam) + self.logger.info("Using Adam optimizer.") + elif self.cfg.train.optimizer.lower() == "adamw": + optimizer = torch.optim.AdamW( + self.model.parameters(), **self.cfg.train.adamw + ) + elif self.cfg.train.optimizer.lower() == "sparseadam": + optimizer = torch.optim.SparseAdam( + self.model.parameters(), **self.cfg.train.sparseadam + ) + elif self.cfg.train.optimizer.lower() == "adamax": + optimizer = torch.optim.Adamax( + self.model.parameters(), **self.cfg.train.adamax + ) + elif self.cfg.train.optimizer.lower() == "asgd": + optimizer = torch.optim.ASGD(self.model.parameters(), **self.cfg.train.asgd) + elif self.cfg.train.optimizer.lower() == "lbfgs": + optimizer = torch.optim.LBFGS( + self.model.parameters(), **self.cfg.train.lbfgs + ) + elif self.cfg.train.optimizer.lower() == "nadam": + optimizer = torch.optim.NAdam( + self.model.parameters(), **self.cfg.train.nadam + ) + elif self.cfg.train.optimizer.lower() == "radam": + optimizer = torch.optim.RAdam( + self.model.parameters(), **self.cfg.train.radam + ) + elif self.cfg.train.optimizer.lower() == "rmsprop": + optimizer = torch.optim.RMSprop( + self.model.parameters(), **self.cfg.train.rmsprop + ) + elif self.cfg.train.optimizer.lower() == "rprop": + optimizer = torch.optim.Rprop( + self.model.parameters(), **self.cfg.train.rprop + ) + elif self.cfg.train.optimizer.lower() == "sgd": + optimizer = torch.optim.SGD(self.model.parameters(), **self.cfg.train.sgd) + else: + raise NotImplementedError( + f"Optimizer {self.cfg.train.optimizer} not supported yet!" + ) + return optimizer + + def __build_scheduler(self): + r"""Build scheduler for optimizer.""" + # Make case-insensitive matching + if self.cfg.train.scheduler.lower() == "lambdalr": + scheduler = torch.optim.lr_scheduler.LambdaLR( + self.optimizer, **self.cfg.train.lambdalr + ) + elif self.cfg.train.scheduler.lower() == "multiplicativelr": + scheduler = torch.optim.lr_scheduler.MultiplicativeLR( + self.optimizer, **self.cfg.train.multiplicativelr + ) + elif self.cfg.train.scheduler.lower() == "steplr": + scheduler = torch.optim.lr_scheduler.StepLR( + self.optimizer, **self.cfg.train.steplr + ) + elif self.cfg.train.scheduler.lower() == "multisteplr": + scheduler = torch.optim.lr_scheduler.MultiStepLR( + self.optimizer, **self.cfg.train.multisteplr + ) + elif self.cfg.train.scheduler.lower() == "constantlr": + scheduler = torch.optim.lr_scheduler.ConstantLR( + self.optimizer, **self.cfg.train.constantlr + ) + elif self.cfg.train.scheduler.lower() == "linearlr": + scheduler = torch.optim.lr_scheduler.LinearLR( + self.optimizer, **self.cfg.train.linearlr + ) + elif self.cfg.train.scheduler.lower() == "exponentiallr": + scheduler = torch.optim.lr_scheduler.ExponentialLR( + self.optimizer, **self.cfg.train.exponentiallr + ) + elif self.cfg.train.scheduler.lower() == "polynomiallr": + scheduler = torch.optim.lr_scheduler.PolynomialLR( + self.optimizer, **self.cfg.train.polynomiallr + ) + elif self.cfg.train.scheduler.lower() == "cosineannealinglr": + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + self.optimizer, **self.cfg.train.cosineannealinglr + ) + elif self.cfg.train.scheduler.lower() == "sequentiallr": + scheduler = torch.optim.lr_scheduler.SequentialLR( + self.optimizer, **self.cfg.train.sequentiallr + ) + elif self.cfg.train.scheduler.lower() == "reducelronplateau": + scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( + self.optimizer, **self.cfg.train.reducelronplateau + ) + elif self.cfg.train.scheduler.lower() == "cycliclr": + scheduler = torch.optim.lr_scheduler.CyclicLR( + self.optimizer, **self.cfg.train.cycliclr + ) + elif self.cfg.train.scheduler.lower() == "onecyclelr": + scheduler = torch.optim.lr_scheduler.OneCycleLR( + self.optimizer, **self.cfg.train.onecyclelr + ) + elif self.cfg.train.scheduler.lower() == "cosineannearingwarmrestarts": + scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( + self.optimizer, **self.cfg.train.cosineannearingwarmrestarts + ) + elif self.cfg.train.scheduler.lower() == "noamlr": + scheduler = NoamLR(self.optimizer, **self.cfg.train.lr_scheduler) + else: + raise NotImplementedError( + f"Scheduler {self.cfg.train.scheduler} not supported yet!" + ) + return scheduler + + def _init_accelerator(self): + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, + logging_dir=os.path.join(self.exp_dir, "log"), + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def __check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + # TODO: check other values + + @staticmethod + def __count_parameters(model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + ### Private methods end ### diff --git a/models/svc/__init__.py b/models/svc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/base/__init__.py b/models/svc/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c2b1686db550b3b9892b8bc6e594cd847aafd1 --- /dev/null +++ b/models/svc/base/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .svc_inference import SVCInference +from .svc_trainer import SVCTrainer diff --git a/models/svc/base/svc_dataset.py b/models/svc/base/svc_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cad356bd97b92b9917db6c973b31a42552f5fa76 --- /dev/null +++ b/models/svc/base/svc_dataset.py @@ -0,0 +1,425 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import random +import torch +from torch.nn.utils.rnn import pad_sequence +import json +import os +import numpy as np +from utils.data_utils import * +from processors.acoustic_extractor import cal_normalized_mel, load_mel_extrema +from processors.content_extractor import ( + ContentvecExtractor, + WhisperExtractor, + WenetExtractor, +) +from models.base.base_dataset import ( + BaseCollator, + BaseDataset, +) +from models.base.new_dataset import BaseTestDataset + +EPS = 1.0e-12 + + +class SVCDataset(BaseDataset): + def __init__(self, cfg, dataset, is_valid=False): + BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) + + cfg = self.cfg + + if cfg.model.condition_encoder.use_whisper: + self.whisper_aligner = WhisperExtractor(self.cfg) + self.utt2whisper_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir + ) + + if cfg.model.condition_encoder.use_contentvec: + self.contentvec_aligner = ContentvecExtractor(self.cfg) + self.utt2contentVec_path = load_content_feature_path( + self.metadata, + cfg.preprocess.processed_dir, + cfg.preprocess.contentvec_dir, + ) + + if cfg.model.condition_encoder.use_mert: + self.utt2mert_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir + ) + if cfg.model.condition_encoder.use_wenet: + self.wenet_aligner = WenetExtractor(self.cfg) + self.utt2wenet_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir + ) + + def __getitem__(self, index): + single_feature = BaseDataset.__getitem__(self, index) + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + if self.cfg.model.condition_encoder.use_whisper: + assert "target_len" in single_feature.keys() + aligned_whisper_feat = self.whisper_aligner.offline_align( + np.load(self.utt2whisper_path[utt]), single_feature["target_len"] + ) + single_feature["whisper_feat"] = aligned_whisper_feat + + if self.cfg.model.condition_encoder.use_contentvec: + assert "target_len" in single_feature.keys() + aligned_contentvec = self.contentvec_aligner.offline_align( + np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] + ) + single_feature["contentvec_feat"] = aligned_contentvec + + if self.cfg.model.condition_encoder.use_mert: + assert "target_len" in single_feature.keys() + aligned_mert_feat = align_content_feature_length( + np.load(self.utt2mert_path[utt]), + single_feature["target_len"], + source_hop=self.cfg.preprocess.mert_hop_size, + ) + single_feature["mert_feat"] = aligned_mert_feat + + if self.cfg.model.condition_encoder.use_wenet: + assert "target_len" in single_feature.keys() + aligned_wenet_feat = self.wenet_aligner.offline_align( + np.load(self.utt2wenet_path[utt]), single_feature["target_len"] + ) + single_feature["wenet_feat"] = aligned_wenet_feat + + # print(single_feature.keys()) + # for k, v in single_feature.items(): + # if type(v) in [torch.Tensor, np.ndarray]: + # print(k, v.shape) + # else: + # print(k, v) + # exit() + + return self.clip_if_too_long(single_feature) + + def __len__(self): + return len(self.metadata) + + def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): + """ + ending_ts: to avoid invalid whisper features for over 30s audios + 2812 = 30 * 24000 // 256 + """ + ts = max(feature_seq_len - max_seq_len, 0) + ts = min(ts, ending_ts - max_seq_len) + + start = random.randint(0, ts) + end = start + max_seq_len + return start, end + + def clip_if_too_long(self, sample, max_seq_len=512): + """ + sample : + { + 'spk_id': (1,), + 'target_len': int + 'mel': (seq_len, dim), + 'frame_pitch': (seq_len,) + 'frame_energy': (seq_len,) + 'content_vector_feat': (seq_len, dim) + } + """ + if sample["target_len"] <= max_seq_len: + return sample + + start, end = self.random_select(sample["target_len"], max_seq_len) + sample["target_len"] = end - start + + for k in sample.keys(): + if k not in ["spk_id", "target_len"]: + sample[k] = sample[k][start:end] + + return sample + + +class SVCCollator(BaseCollator): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + BaseCollator.__init__(self, cfg) + + def __call__(self, batch): + parsed_batch_features = BaseCollator.__call__(self, batch) + return parsed_batch_features + + +class SVCTestDataset(BaseTestDataset): + def __init__(self, args, cfg, infer_type): + BaseTestDataset.__init__(self, args, cfg, infer_type) + self.metadata = self.get_metadata() + + target_singer = args.target_singer + self.cfg = cfg + self.trans_key = args.trans_key + assert type(target_singer) == str + + self.target_singer = target_singer.split("_")[-1] + self.target_dataset = target_singer.replace( + "_{}".format(self.target_singer), "" + ) + + self.target_mel_extrema = load_mel_extrema(cfg.preprocess, self.target_dataset) + self.target_mel_extrema = torch.as_tensor( + self.target_mel_extrema[0] + ), torch.as_tensor(self.target_mel_extrema[1]) + + ######### Load source acoustic features ######### + if cfg.preprocess.use_spkid: + spk2id_path = os.path.join(args.acoustics_dir, cfg.preprocess.spk2id) + # utt2sp_path = os.path.join(self.data_root, cfg.preprocess.utt2spk) + + with open(spk2id_path, "r") as f: + self.spk2id = json.load(f) + # print("self.spk2id", self.spk2id) + + if cfg.preprocess.use_uv: + self.utt2uv_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.uv_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.pitch_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + # Target F0 median + target_f0_statistics_path = os.path.join( + cfg.preprocess.processed_dir, + self.target_dataset, + cfg.preprocess.pitch_dir, + "statistics.json", + ) + self.target_pitch_median = json.load(open(target_f0_statistics_path, "r"))[ + f"{self.target_dataset}_{self.target_singer}" + ]["voiced_positions"]["median"] + + # Source F0 median (if infer from file) + if infer_type == "from_file": + source_audio_name = cfg.inference.source_audio_name + source_f0_statistics_path = os.path.join( + cfg.preprocess.processed_dir, + source_audio_name, + cfg.preprocess.pitch_dir, + "statistics.json", + ) + self.source_pitch_median = json.load( + open(source_f0_statistics_path, "r") + )[f"{source_audio_name}_{source_audio_name}"]["voiced_positions"][ + "median" + ] + else: + self.source_pitch_median = None + + if cfg.preprocess.use_frame_energy: + self.utt2frame_energy_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.energy_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + if cfg.preprocess.use_mel: + self.utt2mel_path = { + f'{utt_info["Dataset"]}_{utt_info["Uid"]}': os.path.join( + cfg.preprocess.processed_dir, + utt_info["Dataset"], + cfg.preprocess.mel_dir, + utt_info["Uid"] + ".npy", + ) + for utt_info in self.metadata + } + + ######### Load source content features' path ######### + if cfg.model.condition_encoder.use_whisper: + self.whisper_aligner = WhisperExtractor(cfg) + self.utt2whisper_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.whisper_dir + ) + + if cfg.model.condition_encoder.use_contentvec: + self.contentvec_aligner = ContentvecExtractor(cfg) + self.utt2contentVec_path = load_content_feature_path( + self.metadata, + cfg.preprocess.processed_dir, + cfg.preprocess.contentvec_dir, + ) + + if cfg.model.condition_encoder.use_mert: + self.utt2mert_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.mert_dir + ) + if cfg.model.condition_encoder.use_wenet: + self.wenet_aligner = WenetExtractor(cfg) + self.utt2wenet_path = load_content_feature_path( + self.metadata, cfg.preprocess.processed_dir, cfg.preprocess.wenet_dir + ) + + def __getitem__(self, index): + single_feature = {} + + utt_info = self.metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + source_dataset = self.metadata[index]["Dataset"] + + if self.cfg.preprocess.use_spkid: + single_feature["spk_id"] = np.array( + [self.spk2id[f"{self.target_dataset}_{self.target_singer}"]], + dtype=np.int32, + ) + + ######### Get Acoustic Features Item ######### + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + if self.cfg.preprocess.use_min_max_norm_mel: + # mel norm + mel = cal_normalized_mel(mel, source_dataset, self.cfg.preprocess) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + single_feature["mel"] = mel.T # [T, n_mels] + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch_path = self.utt2frame_pitch_path[utt] + frame_pitch = np.load(frame_pitch_path) + + if self.trans_key: + try: + self.trans_key = int(self.trans_key) + except: + pass + if type(self.trans_key) == int: + frame_pitch = transpose_key(frame_pitch, self.trans_key) + elif self.trans_key: + assert self.target_singer + + frame_pitch = pitch_shift_to_target( + frame_pitch, self.target_pitch_median, self.source_pitch_median + ) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_uv: + frame_uv_path = self.utt2uv_path[utt] + frame_uv = np.load(frame_uv_path) + aligned_frame_uv = align_length(frame_uv, single_feature["target_len"]) + aligned_frame_uv = [ + 0 if frame_uv else 1 for frame_uv in aligned_frame_uv + ] + aligned_frame_uv = np.array(aligned_frame_uv) + single_feature["frame_uv"] = aligned_frame_uv + + if self.cfg.preprocess.use_frame_energy: + frame_energy_path = self.utt2frame_energy_path[utt] + frame_energy = np.load(frame_energy_path) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_energy) + aligned_frame_energy = align_length( + frame_energy, single_feature["target_len"] + ) + single_feature["frame_energy"] = aligned_frame_energy + + ######### Get Content Features Item ######### + if self.cfg.model.condition_encoder.use_whisper: + assert "target_len" in single_feature.keys() + aligned_whisper_feat = self.whisper_aligner.offline_align( + np.load(self.utt2whisper_path[utt]), single_feature["target_len"] + ) + single_feature["whisper_feat"] = aligned_whisper_feat + + if self.cfg.model.condition_encoder.use_contentvec: + assert "target_len" in single_feature.keys() + aligned_contentvec = self.contentvec_aligner.offline_align( + np.load(self.utt2contentVec_path[utt]), single_feature["target_len"] + ) + single_feature["contentvec_feat"] = aligned_contentvec + + if self.cfg.model.condition_encoder.use_mert: + assert "target_len" in single_feature.keys() + aligned_mert_feat = align_content_feature_length( + np.load(self.utt2mert_path[utt]), + single_feature["target_len"], + source_hop=self.cfg.preprocess.mert_hop_size, + ) + single_feature["mert_feat"] = aligned_mert_feat + + if self.cfg.model.condition_encoder.use_wenet: + assert "target_len" in single_feature.keys() + aligned_wenet_feat = self.wenet_aligner.offline_align( + np.load(self.utt2wenet_path[utt]), single_feature["target_len"] + ) + single_feature["wenet_feat"] = aligned_wenet_feat + + return single_feature + + def __len__(self): + return len(self.metadata) + + +class SVCTestCollator: + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, T, n_mels] + # frame_pitch, frame_energy: [1, T] + # target_len: [1] + # spk_id: [b, 1] + # mask: [b, T, 1] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/svc/base/svc_inference.py b/models/svc/base/svc_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..52f88d5d915e1616292c03927b4f51557351f58b --- /dev/null +++ b/models/svc/base/svc_inference.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from models.base.new_inference import BaseInference +from models.svc.base.svc_dataset import SVCTestCollator, SVCTestDataset + + +class SVCInference(BaseInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + BaseInference.__init__(self, args, cfg, infer_type) + + def _build_test_dataset(self): + return SVCTestDataset, SVCTestCollator diff --git a/models/svc/base/svc_trainer.py b/models/svc/base/svc_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a093a86712bb7ccfa786a6c18dd1683ffc013c --- /dev/null +++ b/models/svc/base/svc_trainer.py @@ -0,0 +1,111 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +import torch +import torch.nn as nn + +from models.base.new_trainer import BaseTrainer +from models.svc.base.svc_dataset import SVCCollator, SVCDataset + + +class SVCTrainer(BaseTrainer): + r"""The base trainer for all SVC models. It inherits from BaseTrainer and implements + ``build_criterion``, ``_build_dataset`` and ``_build_singer_lut`` methods. You can inherit from this + class, and implement ``_build_model``, ``_forward_step``. + """ + + def __init__(self, args=None, cfg=None): + self.args = args + self.cfg = cfg + + self._init_accelerator() + + # Only for SVC tasks + with self.accelerator.main_process_first(): + self.singers = self._build_singer_lut() + + # Super init + BaseTrainer.__init__(self, args, cfg) + + # Only for SVC tasks + self.task_type = "SVC" + self.logger.info("Task type: {}".format(self.task_type)) + + ### Following are methods only for SVC tasks ### + # TODO: LEGACY CODE, NEED TO BE REFACTORED + def _build_dataset(self): + return SVCDataset, SVCCollator + + @staticmethod + def _build_criterion(): + criterion = nn.MSELoss(reduction="none") + return criterion + + @staticmethod + def _compute_loss(criterion, y_pred, y_gt, loss_mask): + """ + Args: + criterion: MSELoss(reduction='none') + y_pred, y_gt: (bs, seq_len, D) + loss_mask: (bs, seq_len, 1) + Returns: + loss: Tensor of shape [] + """ + + # (bs, seq_len, D) + loss = criterion(y_pred, y_gt) + # expand loss_mask to (bs, seq_len, D) + loss_mask = loss_mask.repeat(1, 1, loss.shape[-1]) + + loss = torch.sum(loss * loss_mask) / torch.sum(loss_mask) + return loss + + def _save_auxiliary_states(self): + """ + To save the singer's look-up table in the checkpoint saving path + """ + with open( + os.path.join(self.tmp_checkpoint_save_path, self.cfg.preprocess.spk2id), "w" + ) as f: + json.dump(self.singers, f, indent=4, ensure_ascii=False) + + def _build_singer_lut(self): + resumed_singer_path = None + if self.args.resume_from_ckpt_path and self.args.resume_from_ckpt_path != "": + resumed_singer_path = os.path.join( + self.args.resume_from_ckpt_path, self.cfg.preprocess.spk2id + ) + if os.path.exists(os.path.join(self.exp_dir, self.cfg.preprocess.spk2id)): + resumed_singer_path = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + + if resumed_singer_path: + with open(resumed_singer_path, "r") as f: + singers = json.load(f) + else: + singers = dict() + + for dataset in self.cfg.dataset: + singer_lut_path = os.path.join( + self.cfg.preprocess.processed_dir, dataset, self.cfg.preprocess.spk2id + ) + with open(singer_lut_path, "r") as singer_lut_path: + singer_lut = json.load(singer_lut_path) + for singer in singer_lut.keys(): + if singer not in singers: + singers[singer] = len(singers) + + with open( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id), "w" + ) as singer_file: + json.dump(singers, singer_file, indent=4, ensure_ascii=False) + print( + "singers have been dumped to {}".format( + os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) + ) + ) + return singers diff --git a/models/svc/comosvc/__init__.py b/models/svc/comosvc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..19f1cb162e95d8a992002beaa0c0d8bada9cddd5 --- /dev/null +++ b/models/svc/comosvc/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. diff --git a/models/svc/comosvc/comosvc.py b/models/svc/comosvc/comosvc.py new file mode 100644 index 0000000000000000000000000000000000000000..6cecd7a3f40f3a78f0df06ef2340159d321d6117 --- /dev/null +++ b/models/svc/comosvc/comosvc.py @@ -0,0 +1,377 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +"""Adapted from https://github.com/zhenye234/CoMoSpeech""" + +import torch +import torch.nn as nn +import copy +import numpy as np +import math +from tqdm.auto import tqdm + +from utils.ssim import SSIM + +from models.svc.transformer.conformer import Conformer, BaseModule +from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper +from models.svc.comosvc.utils import slice_segments, rand_ids_segments + + +class Consistency(nn.Module): + def __init__(self, cfg, distill=False): + super().__init__() + self.cfg = cfg + # self.denoise_fn = GradLogPEstimator2d(96) + self.denoise_fn = DiffusionWrapper(self.cfg) + self.cfg = cfg.model.comosvc + self.teacher = not distill + self.P_mean = self.cfg.P_mean + self.P_std = self.cfg.P_std + self.sigma_data = self.cfg.sigma_data + self.sigma_min = self.cfg.sigma_min + self.sigma_max = self.cfg.sigma_max + self.rho = self.cfg.rho + self.N = self.cfg.n_timesteps + self.ssim_loss = SSIM() + + # Time step discretization + step_indices = torch.arange(self.N) + # karras boundaries formula + t_steps = ( + self.sigma_min ** (1 / self.rho) + + step_indices + / (self.N - 1) + * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) + ) ** self.rho + self.t_steps = torch.cat( + [torch.zeros_like(t_steps[:1]), self.round_sigma(t_steps)] + ) + + def init_consistency_training(self): + self.denoise_fn_ema = copy.deepcopy(self.denoise_fn) + self.denoise_fn_pretrained = copy.deepcopy(self.denoise_fn) + + def EDMPrecond(self, x, sigma, cond, denoise_fn, mask, spk=None): + """ + karras diffusion reverse process + + Args: + x: noisy mel-spectrogram [B x n_mel x L] + sigma: noise level [B x 1 x 1] + cond: output of conformer encoder [B x n_mel x L] + denoise_fn: denoiser neural network e.g. DilatedCNN + mask: mask of padded frames [B x n_mel x L] + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + sigma = sigma.reshape(-1, 1, 1) + + c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) + c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt() + c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt() + c_noise = sigma.log() / 4 + + x_in = c_in * x + x_in = x_in.transpose(1, 2) + x = x.transpose(1, 2) + cond = cond.transpose(1, 2) + F_x = denoise_fn(x_in, c_noise.squeeze(), cond) + # F_x = denoise_fn((c_in * x), mask, cond, c_noise.flatten()) + D_x = c_skip * x + c_out * (F_x) + D_x = D_x.transpose(1, 2) + return D_x + + def EDMLoss(self, x_start, cond, mask): + """ + compute loss for EDM model + + Args: + x_start: ground truth mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + mask: mask of padded frames [B x n_mel x L] + """ + rnd_normal = torch.randn([x_start.shape[0], 1, 1], device=x_start.device) + sigma = (rnd_normal * self.P_std + self.P_mean).exp() + weight = (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2 + + # follow Grad-TTS, start from Gaussian noise with mean cond and std I + noise = (torch.randn_like(x_start) + cond) * sigma + D_yn = self.EDMPrecond(x_start + noise, sigma, cond, self.denoise_fn, mask) + loss = weight * ((D_yn - x_start) ** 2) + loss = torch.sum(loss * mask) / torch.sum(mask) + return loss + + def round_sigma(self, sigma): + return torch.as_tensor(sigma) + + def edm_sampler( + self, + latents, + cond, + nonpadding, + num_steps=50, + sigma_min=0.002, + sigma_max=80, + rho=7, + S_churn=0, + S_min=0, + S_max=float("inf"), + S_noise=1, + # S_churn=40 ,S_min=0.05,S_max=50,S_noise=1.003,# S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, + # S_churn=30 ,S_min=0.01,S_max=30,S_noise=1.007, + # S_churn=30 ,S_min=0.01,S_max=1,S_noise=1.007, + # S_churn=80 ,S_min=0.05,S_max=50,S_noise=1.003, + ): + """ + karras diffusion sampler + + Args: + latents: noisy mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + nonpadding: mask of padded frames [B x n_mel x L] + num_steps: number of steps for diffusion inference + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + # Time step discretization. + step_indices = torch.arange(num_steps, device=latents.device) + + num_steps = num_steps + 1 + t_steps = ( + sigma_max ** (1 / rho) + + step_indices + / (num_steps - 1) + * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) + ) ** rho + t_steps = torch.cat([self.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) + + # Main sampling loop. + x_next = latents * t_steps[0] + # wrap in tqdm for progress bar + bar = tqdm(enumerate(zip(t_steps[:-1], t_steps[1:]))) + for i, (t_cur, t_next) in bar: + x_cur = x_next + # Increase noise temporarily. + gamma = ( + min(S_churn / num_steps, np.sqrt(2) - 1) + if S_min <= t_cur <= S_max + else 0 + ) + t_hat = self.round_sigma(t_cur + gamma * t_cur) + t = torch.zeros((x_cur.shape[0], 1, 1), device=x_cur.device) + t[:, 0, 0] = t_hat + t_hat = t + x_hat = x_cur + ( + t_hat**2 - t_cur**2 + ).sqrt() * S_noise * torch.randn_like(x_cur) + # Euler step. + denoised = self.EDMPrecond(x_hat, t_hat, cond, self.denoise_fn, nonpadding) + d_cur = (x_hat - denoised) / t_hat + x_next = x_hat + (t_next - t_hat) * d_cur + + return x_next + + def CTLoss_D(self, y, cond, mask): + """ + compute loss for consistency distillation + + Args: + y: ground truth mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + mask: mask of padded frames [B x n_mel x L] + """ + with torch.no_grad(): + mu = 0.95 + for p, ema_p in zip( + self.denoise_fn.parameters(), self.denoise_fn_ema.parameters() + ): + ema_p.mul_(mu).add_(p, alpha=1 - mu) + + n = torch.randint(1, self.N, (y.shape[0],)) + z = torch.randn_like(y) + cond + + tn_1 = self.t_steps[n + 1].reshape(-1, 1, 1).to(y.device) + f_theta = self.EDMPrecond(y + tn_1 * z, tn_1, cond, self.denoise_fn, mask) + + with torch.no_grad(): + tn = self.t_steps[n].reshape(-1, 1, 1).to(y.device) + + # euler step + x_hat = y + tn_1 * z + denoised = self.EDMPrecond( + x_hat, tn_1, cond, self.denoise_fn_pretrained, mask + ) + d_cur = (x_hat - denoised) / tn_1 + y_tn = x_hat + (tn - tn_1) * d_cur + + f_theta_ema = self.EDMPrecond(y_tn, tn, cond, self.denoise_fn_ema, mask) + + # loss = (f_theta - f_theta_ema.detach()) ** 2 + # loss = torch.sum(loss * mask) / torch.sum(mask) + loss = self.ssim_loss(f_theta, f_theta_ema.detach()) + loss = torch.sum(loss * mask) / torch.sum(mask) + + return loss + + def get_t_steps(self, N): + N = N + 1 + step_indices = torch.arange(N) # , device=latents.device) + t_steps = ( + self.sigma_min ** (1 / self.rho) + + step_indices + / (N - 1) + * (self.sigma_max ** (1 / self.rho) - self.sigma_min ** (1 / self.rho)) + ) ** self.rho + + return t_steps.flip(0) + + def CT_sampler(self, latents, cond, nonpadding, t_steps=1): + """ + consistency distillation sampler + + Args: + latents: noisy mel-spectrogram [B x n_mel x L] + cond: output of conformer encoder [B x n_mel x L] + nonpadding: mask of padded frames [B x n_mel x L] + t_steps: number of steps for diffusion inference + + Returns: + denoised mel-spectrogram [B x n_mel x L] + """ + # one-step + if t_steps == 1: + t_steps = [80] + # multi-step + else: + t_steps = self.get_t_steps(t_steps) + + t_steps = torch.as_tensor(t_steps).to(latents.device) + latents = latents * t_steps[0] + _t = torch.zeros((latents.shape[0], 1, 1), device=latents.device) + _t[:, 0, 0] = t_steps + x = self.EDMPrecond(latents, _t, cond, self.denoise_fn_ema, nonpadding) + + for t in t_steps[1:-1]: + z = torch.randn_like(x) + cond + x_tn = x + (t**2 - self.sigma_min**2).sqrt() * z + _t = torch.zeros((x.shape[0], 1, 1), device=x.device) + _t[:, 0, 0] = t + t = _t + print(t) + x = self.EDMPrecond(x_tn, t, cond, self.denoise_fn_ema, nonpadding) + return x + + def forward(self, x, nonpadding, cond, t_steps=1, infer=False): + """ + calculate loss or sample mel-spectrogram + + Args: + x: + training: ground truth mel-spectrogram [B x n_mel x L] + inference: output of encoder [B x n_mel x L] + """ + if self.teacher: # teacher model -- karras diffusion + if not infer: + loss = self.EDMLoss(x, cond, nonpadding) + return loss + else: + shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) + x = torch.randn(shape, device=x.device) + cond + x = self.edm_sampler(x, cond, nonpadding, t_steps) + + return x + else: # Consistency distillation + if not infer: + loss = self.CTLoss_D(x, cond, nonpadding) + return loss + + else: + shape = (cond.shape[0], self.cfg.n_mel, cond.shape[2]) + x = torch.randn(shape, device=x.device) + cond + x = self.CT_sampler(x, cond, nonpadding, t_steps=1) + + return x + + +class ComoSVC(BaseModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.cfg.model.comosvc.n_mel = self.cfg.preprocess.n_mel + self.distill = self.cfg.model.comosvc.distill + self.encoder = Conformer(self.cfg.model.comosvc) + self.decoder = Consistency(self.cfg, distill=self.distill) + self.ssim_loss = SSIM() + + @torch.no_grad() + def forward(self, x_mask, x, n_timesteps, temperature=1.0): + """ + Generates mel-spectrogram from pitch, content vector, energy. Returns: + 1. encoder outputs (from conformer) + 2. decoder outputs (from diffusion-based decoder) + + Args: + x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] + x : output of encoder framework. [B x L x d_condition] + n_timesteps : number of steps to use for reverse diffusion in decoder. + temperature : controls variance of terminal distribution. + """ + + # Get encoder_outputs `mu_x` + mu_x = self.encoder(x, x_mask) + encoder_outputs = mu_x + + mu_x = mu_x.transpose(1, 2) + x_mask = x_mask.transpose(1, 2) + + # Generate sample by performing reverse dynamics + decoder_outputs = self.decoder( + mu_x, x_mask, mu_x, t_steps=n_timesteps, infer=True + ) + decoder_outputs = decoder_outputs.transpose(1, 2) + return encoder_outputs, decoder_outputs + + def compute_loss(self, x_mask, x, mel, out_size=None, skip_diff=False): + """ + Computes 2 losses: + 1. prior loss: loss between mel-spectrogram and encoder outputs. + 2. diffusion loss: loss between gaussian noise and its reconstruction by diffusion-based decoder. + + Args: + x_mask : mask of padded frames in mel-spectrogram. [B x L x n_mel] + x : output of encoder framework. [B x L x d_condition] + mel : ground truth mel-spectrogram. [B x L x n_mel] + """ + + mu_x = self.encoder(x, x_mask) + # prior loss + prior_loss = torch.sum( + 0.5 * ((mel - mu_x) ** 2 + math.log(2 * math.pi)) * x_mask + ) + prior_loss = prior_loss / (torch.sum(x_mask) * self.cfg.model.comosvc.n_mel) + # ssim loss + ssim_loss = self.ssim_loss(mu_x, mel) + ssim_loss = torch.sum(ssim_loss * x_mask) / torch.sum(x_mask) + + x_mask = x_mask.transpose(1, 2) + mu_x = mu_x.transpose(1, 2) + mel = mel.transpose(1, 2) + if not self.distill and skip_diff: + diff_loss = prior_loss.clone() + diff_loss.fill_(0) + + # Cut a small segment of mel-spectrogram in order to increase batch size + else: + if self.distill: + mu_y = mu_x.detach() + else: + mu_y = mu_x + mask_y = x_mask + + diff_loss = self.decoder(mel, mask_y, mu_y, infer=False) + + return ssim_loss, prior_loss, diff_loss diff --git a/models/svc/comosvc/comosvc_inference.py b/models/svc/comosvc/comosvc_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..2783ec7e468c367c7d2f5f8988ed1f7e272d4cb7 --- /dev/null +++ b/models/svc/comosvc/comosvc_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from models.svc.base import SVCInference +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.comosvc.comosvc import ComoSVC + + +class ComoSVCInference(SVCInference): + def __init__(self, args, cfg, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + def _build_model(self): + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = ComoSVC(self.cfg) + if self.cfg.model.comosvc.distill: + self.acoustic_mapper.decoder.init_consistency_training() + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + cond = self.condition_encoder(batch_data) + mask = batch_data["mask"] + encoder_pred, decoder_pred = self.acoustic_mapper( + mask, cond, self.cfg.inference.comosvc.inference_steps + ) + + return decoder_pred diff --git a/models/svc/comosvc/comosvc_trainer.py b/models/svc/comosvc/comosvc_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5ba49fd4539b8ae351137a85595ff9cfba1f4677 --- /dev/null +++ b/models/svc/comosvc/comosvc_trainer.py @@ -0,0 +1,295 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os +import json5 +from collections import OrderedDict +from tqdm import tqdm +import json +import shutil + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.comosvc.comosvc import ComoSVC + + +class ComoSVCTrainer(SVCTrainer): + r"""The base trainer for all diffusion models. It inherits from SVCTrainer and + implements ``_build_model`` and ``_forward_step`` methods. + """ + + def __init__(self, args=None, cfg=None): + SVCTrainer.__init__(self, args, cfg) + self.distill = cfg.model.comosvc.distill + self.skip_diff = True + if self.distill: # and args.resume is None: + self.teacher_model_path = cfg.model.teacher_model_path + self.teacher_state_dict = self._load_teacher_state_dict() + self._load_teacher_model(self.teacher_state_dict) + self.acoustic_mapper.decoder.init_consistency_training() + + ### Following are methods only for comoSVC models ### + def _load_teacher_state_dict(self): + self.checkpoint_file = self.teacher_model_path + print("Load teacher acoustic model from {}".format(self.checkpoint_file)) + raw_state_dict = torch.load(self.checkpoint_file) # , map_location=self.device) + return raw_state_dict + + def _load_teacher_model(self, state_dict): + raw_dict = state_dict + clean_dict = OrderedDict() + for k, v in raw_dict.items(): + if k.startswith("module."): + clean_dict[k[7:]] = v + else: + clean_dict[k] = v + self.model.load_state_dict(clean_dict) + + def _build_model(self): + r"""Build the model for training. This function is called in ``__init__`` function.""" + + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = ComoSVC(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _forward_step(self, batch): + r"""Forward step for training and inference. This function is called + in ``_train_step`` & ``_test_step`` function. + """ + loss = {} + mask = batch["mask"] + mel_input = batch["mel"] + cond = self.condition_encoder(batch) + if self.distill: + cond = cond.detach() + self.skip_diff = True if self.step < self.cfg.train.fast_steps else False + ssim_loss, prior_loss, diff_loss = self.acoustic_mapper.compute_loss( + mask, cond, mel_input, skip_diff=self.skip_diff + ) + if self.distill: + loss["distil_loss"] = diff_loss + else: + loss["ssim_loss_encoder"] = ssim_loss + loss["prior_loss_encoder"] = prior_loss + loss["diffusion_loss_decoder"] = diff_loss + + return loss + + def _train_epoch(self): + r"""Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.train() + epoch_sum_loss: float = 0.0 + epoch_step: int = 0 + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Do training step and BP + with self.accelerator.accumulate(self.model): + loss = self._train_step(batch) + total_loss = 0 + for k, v in loss.items(): + total_loss += v + self.accelerator.backward(total_loss) + enc_grad_norm = torch.nn.utils.clip_grad_norm_( + self.acoustic_mapper.encoder.parameters(), max_norm=1 + ) + dec_grad_norm = torch.nn.utils.clip_grad_norm_( + self.acoustic_mapper.decoder.parameters(), max_norm=1 + ) + self.optimizer.step() + self.optimizer.zero_grad() + self.batch_count += 1 + + # Update info for each step + # TODO: step means BP counts or batch counts? + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + epoch_sum_loss += total_loss + log_info = {} + for k, v in loss.items(): + key = "Step/Train Loss/{}".format(k) + log_info[key] = v + log_info["Step/Learning Rate"]: self.optimizer.param_groups[0]["lr"] + self.accelerator.log( + log_info, + step=self.step, + ) + self.step += 1 + epoch_step += 1 + + self.accelerator.wait_for_everyone() + return ( + epoch_sum_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step, + loss, + ) + + def train_loop(self): + r"""Training loop. The public entry of training process.""" + # Wait everyone to prepare before we move on + self.accelerator.wait_for_everyone() + # dump config file + if self.accelerator.is_main_process: + self.__dump_cfg(self.config_save_path) + self.model.train() + self.optimizer.zero_grad() + # Wait to ensure good to go + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + ### TODO: change the return values of _train_epoch() to a loss dict, or (total_loss, loss_dict) + ### It's inconvenient for the model with multiple losses + # Do training & validating epoch + train_loss, loss = self._train_epoch() + self.logger.info(" |- Train/Loss: {:.6f}".format(train_loss)) + for k, v in loss.items(): + self.logger.info(" |- Train/Loss/{}: {:.6f}".format(k, v)) + valid_loss = self._valid_epoch() + self.logger.info(" |- Valid/Loss: {:.6f}".format(valid_loss)) + self.accelerator.log( + {"Epoch/Train Loss": train_loss, "Epoch/Valid Loss": valid_loss}, + step=self.epoch, + ) + + self.accelerator.wait_for_everyone() + # TODO: what is scheduler? + self.scheduler.step(valid_loss) # FIXME: use epoch track correct? + + # Check if hit save_checkpoint_stride and run_eval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + hit_dix = [] + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + hit_dix.append(i) + run_eval |= self.run_eval[i] + + self.accelerator.wait_for_everyone() + if ( + self.accelerator.is_main_process + and save_checkpoint + and (self.distill or not self.skip_diff) + ): + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, train_loss + ), + ) + self.accelerator.save_state(path) + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + + # Remove old checkpoints + to_remove = [] + for idx in hit_dix: + self.checkpoints_path[idx].append(path) + while len(self.checkpoints_path[idx]) > self.keep_last[idx]: + to_remove.append((idx, self.checkpoints_path[idx].pop(0))) + + # Search conflicts + total = set() + for i in self.checkpoints_path: + total |= set(i) + do_remove = set() + for idx, path in to_remove[::-1]: + if path in total: + self.checkpoints_path[idx].insert(0, path) + else: + do_remove.add(path) + + # Remove old checkpoints + for path in do_remove: + shutil.rmtree(path, ignore_errors=True) + self.logger.debug(f"Remove old checkpoint: {path}") + + self.accelerator.wait_for_everyone() + if run_eval: + # TODO: run evaluation + pass + + # Update info for each epoch + self.epoch += 1 + + # Finish training and save final checkpoint + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process: + self.accelerator.save_state( + os.path.join( + self.checkpoint_dir, + "final_epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_loss + ), + ) + ) + self.accelerator.end_training() + + @torch.inference_mode() + def _valid_epoch(self): + r"""Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.model.eval() + epoch_sum_loss = 0.0 + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + batch_loss = self._valid_step(batch) + for k, v in batch_loss.items(): + epoch_sum_loss += v + + self.accelerator.wait_for_everyone() + return epoch_sum_loss / len(self.valid_dataloader) + + @staticmethod + def __count_parameters(model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def __dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) diff --git a/models/svc/comosvc/utils.py b/models/svc/comosvc/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f576f9a237d0a22ddfdb160122b906da9bcf889 --- /dev/null +++ b/models/svc/comosvc/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def slice_segments(x, ids_str, segment_size=200): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_ids_segments(lengths, segment_size=200): + b = lengths.shape[0] + ids_str_max = lengths - segment_size + ids_str = (torch.rand([b]).to(device=lengths.device) * ids_str_max).to( + dtype=torch.long + ) + return ids_str + + +def fix_len_compatibility(length, num_downsamplings_in_unet=2): + while True: + if length % (2**num_downsamplings_in_unet) == 0: + return length + length += 1 diff --git a/models/svc/diffusion/__init__.py b/models/svc/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/diffusion/diffusion_inference.py b/models/svc/diffusion/diffusion_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..a752ef8f195b59d0ac0ad402dc35ce5840626ab9 --- /dev/null +++ b/models/svc/diffusion/diffusion_inference.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler + +from models.svc.base import SVCInference +from models.svc.diffusion.diffusion_inference_pipeline import DiffusionInferencePipeline +from models.svc.diffusion.diffusion_wrapper import DiffusionWrapper +from modules.encoder.condition_encoder import ConditionEncoder + + +class DiffusionInference(SVCInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + settings = { + **cfg.model.diffusion.scheduler_settings, + **cfg.inference.diffusion.scheduler_settings, + } + settings.pop("num_inference_timesteps") + + if cfg.inference.diffusion.scheduler.lower() == "ddpm": + self.scheduler = DDPMScheduler(**settings) + self.logger.info("Using DDPM scheduler.") + elif cfg.inference.diffusion.scheduler.lower() == "ddim": + self.scheduler = DDIMScheduler(**settings) + self.logger.info("Using DDIM scheduler.") + elif cfg.inference.diffusion.scheduler.lower() == "pndm": + self.scheduler = PNDMScheduler(**settings) + self.logger.info("Using PNDM scheduler.") + else: + raise NotImplementedError( + "Unsupported scheduler type: {}".format( + cfg.inference.diffusion.scheduler.lower() + ) + ) + + self.pipeline = DiffusionInferencePipeline( + self.model[1], + self.scheduler, + args.diffusion_inference_steps, + ) + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = DiffusionWrapper(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + conditioner = self.model[0](batch_data) + noise = torch.randn_like(batch_data["mel"], device=device) + y_pred = self.pipeline(noise, conditioner) + return y_pred diff --git a/models/svc/diffusion/diffusion_inference_pipeline.py b/models/svc/diffusion/diffusion_inference_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..e2461aada99179ac17a2aaffebdb24864af1f5ee --- /dev/null +++ b/models/svc/diffusion/diffusion_inference_pipeline.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DiffusionPipeline + + +class DiffusionInferencePipeline(DiffusionPipeline): + def __init__(self, network, scheduler, num_inference_timesteps=1000): + super().__init__() + + self.register_modules(network=network, scheduler=scheduler) + self.num_inference_timesteps = num_inference_timesteps + + @torch.inference_mode() + def __call__( + self, + initial_noise: torch.Tensor, + conditioner: torch.Tensor = None, + ): + r""" + Args: + initial_noise: The initial noise to be denoised. + conditioner:The conditioner. + n_inference_steps: The number of denoising steps. More denoising steps + usually lead to a higher quality at the expense of slower inference. + """ + + mel = initial_noise + batch_size = mel.size(0) + self.scheduler.set_timesteps(self.num_inference_timesteps) + + for t in self.progress_bar(self.scheduler.timesteps): + timestep = torch.full((batch_size,), t, device=mel.device, dtype=torch.long) + + # 1. predict noise model_output + model_output = self.network(mel, timestep, conditioner) + + # 2. denoise, compute previous step: x_t -> x_t-1 + mel = self.scheduler.step(model_output, t, mel).prev_sample + + # 3. clamp + mel = mel.clamp(-1.0, 1.0) + + return mel diff --git a/models/svc/diffusion/diffusion_trainer.py b/models/svc/diffusion/diffusion_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5aeb56a825f84c57bb1d2ba9a5ff5a32d5f486 --- /dev/null +++ b/models/svc/diffusion/diffusion_trainer.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from diffusers import DDPMScheduler + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from .diffusion_wrapper import DiffusionWrapper + + +class DiffusionTrainer(SVCTrainer): + r"""The base trainer for all diffusion models. It inherits from SVCTrainer and + implements ``_build_model`` and ``_forward_step`` methods. + """ + + def __init__(self, args=None, cfg=None): + SVCTrainer.__init__(self, args, cfg) + + # Only for SVC tasks using diffusion + self.noise_scheduler = DDPMScheduler( + **self.cfg.model.diffusion.scheduler_settings, + ) + self.diffusion_timesteps = ( + self.cfg.model.diffusion.scheduler_settings.num_train_timesteps + ) + + ### Following are methods only for diffusion models ### + def _build_model(self): + r"""Build the model for training. This function is called in ``__init__`` function.""" + + # TODO: sort out the config + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + self.acoustic_mapper = DiffusionWrapper(self.cfg) + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + + num_of_params_encoder = self.count_parameters(self.condition_encoder) + num_of_params_am = self.count_parameters(self.acoustic_mapper) + num_of_params = num_of_params_encoder + num_of_params_am + log = "Diffusion Model's Parameters: #Encoder is {:.2f}M, #Diffusion is {:.2f}M. The total is {:.2f}M".format( + num_of_params_encoder / 1e6, num_of_params_am / 1e6, num_of_params / 1e6 + ) + self.logger.info(log) + + return model + + def count_parameters(self, model): + model_param = 0.0 + if isinstance(model, dict): + for key, value in model.items(): + model_param += sum(p.numel() for p in model[key].parameters()) + else: + model_param = sum(p.numel() for p in model.parameters()) + return model_param + + def _forward_step(self, batch): + r"""Forward step for training and inference. This function is called + in ``_train_step`` & ``_test_step`` function. + """ + + device = self.accelerator.device + + mel_input = batch["mel"] + noise = torch.randn_like(mel_input, device=device, dtype=torch.float32) + batch_size = mel_input.size(0) + timesteps = torch.randint( + 0, + self.diffusion_timesteps, + (batch_size,), + device=device, + dtype=torch.long, + ) + + noisy_mel = self.noise_scheduler.add_noise(mel_input, noise, timesteps) + conditioner = self.condition_encoder(batch) + + y_pred = self.acoustic_mapper(noisy_mel, timesteps, conditioner) + + # TODO: Predict noise or gt should be configurable + loss = self._compute_loss(self.criterion, y_pred, noise, batch["mask"]) + self._check_nan(loss, y_pred, noise) + + # FIXME: Clarify that we should not divide it with batch size here + return loss diff --git a/models/svc/diffusion/diffusion_wrapper.py b/models/svc/diffusion/diffusion_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ef66c2b6b85ceb8fe7a2cf9b53c62edc6b3ef6bc --- /dev/null +++ b/models/svc/diffusion/diffusion_wrapper.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch.nn as nn + +from modules.diffusion import BiDilConv +from modules.encoder.position_encoder import PositionEncoder + + +class DiffusionWrapper(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.diff_cfg = cfg.model.diffusion + + self.diff_encoder = PositionEncoder( + d_raw_emb=self.diff_cfg.step_encoder.dim_raw_embedding, + d_out=self.diff_cfg.bidilconv.base_channel, + d_mlp=self.diff_cfg.step_encoder.dim_hidden_layer, + activation_function=self.diff_cfg.step_encoder.activation, + n_layer=self.diff_cfg.step_encoder.num_layer, + max_period=self.diff_cfg.step_encoder.max_period, + ) + + # FIXME: Only support BiDilConv now for debug + if self.diff_cfg.model_type.lower() == "bidilconv": + self.neural_network = BiDilConv( + input_channel=self.cfg.preprocess.n_mel, **self.diff_cfg.bidilconv + ) + else: + raise ValueError( + f"Unsupported diffusion model type: {self.diff_cfg.model_type}" + ) + + def forward(self, x, t, c): + """ + Args: + x: [N, T, mel_band] of mel spectrogram + t: Diffusion time step with shape of [N] + c: [N, T, conditioner_size] of conditioner + + Returns: + [N, T, mel_band] of mel spectrogram + """ + + assert ( + x.size()[:-1] == c.size()[:-1] + ), "x mismatch with c, got \n x: {} \n c: {}".format(x.size(), c.size()) + assert x.size(0) == t.size( + 0 + ), "x mismatch with t, got \n x: {} \n t: {}".format(x.size(), t.size()) + assert t.dim() == 1, "t must be 1D tensor, got {}".format(t.dim()) + + N, T, mel_band = x.size() + + x = x.transpose(1, 2).contiguous() # [N, mel_band, T] + c = c.transpose(1, 2).contiguous() # [N, conditioner_size, T] + t = self.diff_encoder(t).contiguous() # [N, base_channel] + + h = self.neural_network(x, t, c) + h = h.transpose(1, 2).contiguous() # [N, T, mel_band] + + assert h.size() == ( + N, + T, + mel_band, + ), "h mismatch with input x, got \n h: {} \n x: {}".format( + h.size(), (N, T, mel_band) + ) + return h diff --git a/models/svc/transformer/__init__.py b/models/svc/transformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/svc/transformer/conformer.py b/models/svc/transformer/conformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e48019cfc17d5f3825ce989f4852cec55fe1daa --- /dev/null +++ b/models/svc/transformer/conformer.py @@ -0,0 +1,405 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import numpy as np +import torch.nn as nn +from utils.util import convert_pad_shape + + +class BaseModule(torch.nn.Module): + def __init__(self): + super(BaseModule, self).__init__() + + @property + def nparams(self): + """ + Returns number of trainable parameters of the module. + """ + num_params = 0 + for name, param in self.named_parameters(): + if param.requires_grad: + num_params += np.prod(param.detach().cpu().numpy().shape) + return num_params + + def relocate_input(self, x: list): + """ + Relocates provided tensors to the same device set for the module. + """ + device = next(self.parameters()).device + for i in range(len(x)): + if isinstance(x[i], torch.Tensor) and x[i].device != device: + x[i] = x[i].to(device) + return x + + +class LayerNorm(BaseModule): + def __init__(self, channels, eps=1e-4): + super(LayerNorm, self).__init__() + self.channels = channels + self.eps = eps + + self.gamma = torch.nn.Parameter(torch.ones(channels)) + self.beta = torch.nn.Parameter(torch.zeros(channels)) + + def forward(self, x): + n_dims = len(x.shape) + mean = torch.mean(x, 1, keepdim=True) + variance = torch.mean((x - mean) ** 2, 1, keepdim=True) + + x = (x - mean) * torch.rsqrt(variance + self.eps) + + shape = [1, -1] + [1] * (n_dims - 2) + x = x * self.gamma.view(*shape) + self.beta.view(*shape) + return x + + +class ConvReluNorm(BaseModule): + def __init__( + self, + in_channels, + hidden_channels, + out_channels, + kernel_size, + n_layers, + p_dropout, + eps=1e-5, + ): + super(ConvReluNorm, self).__init__() + self.in_channels = in_channels + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.n_layers = n_layers + self.p_dropout = p_dropout + self.eps = eps + + self.conv_layers = torch.nn.ModuleList() + self.conv_layers.append( + torch.nn.Conv1d( + in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 + ) + ) + self.relu_drop = torch.nn.Sequential( + torch.nn.ReLU(), torch.nn.Dropout(p_dropout) + ) + for _ in range(n_layers - 1): + self.conv_layers.append( + torch.nn.Conv1d( + hidden_channels, + hidden_channels, + kernel_size, + padding=kernel_size // 2, + ) + ) + self.proj = torch.nn.Conv1d(hidden_channels, out_channels, 1) + self.proj.weight.data.zero_() + self.proj.bias.data.zero_() + + def forward(self, x, x_mask): + for i in range(self.n_layers): + x = self.conv_layers[i](x * x_mask) + x = self.instance_norm(x, x_mask) + x = self.relu_drop(x) + x = self.proj(x) + return x * x_mask + + def instance_norm(self, x, mask, return_mean_std=False): + mean, std = self.calc_mean_std(x, mask) + x = (x - mean) / std + if return_mean_std: + return x, mean, std + else: + return x + + def calc_mean_std(self, x, mask=None): + x = x * mask + B, C = x.shape[:2] + mn = x.view(B, C, -1).mean(-1) + sd = (x.view(B, C, -1).var(-1) + self.eps).sqrt() + mn = mn.view(B, C, *((len(x.shape) - 2) * [1])) + sd = sd.view(B, C, *((len(x.shape) - 2) * [1])) + return mn, sd + + +class MultiHeadAttention(BaseModule): + def __init__( + self, + channels, + out_channels, + n_heads, + window_size=None, + heads_share=True, + p_dropout=0.0, + proximal_bias=False, + proximal_init=False, + ): + super(MultiHeadAttention, self).__init__() + assert channels % n_heads == 0 + + self.channels = channels + self.out_channels = out_channels + self.n_heads = n_heads + self.window_size = window_size + self.heads_share = heads_share + self.proximal_bias = proximal_bias + self.p_dropout = p_dropout + self.attn = None + + self.k_channels = channels // n_heads + self.conv_q = torch.nn.Conv1d(channels, channels, 1) + self.conv_k = torch.nn.Conv1d(channels, channels, 1) + self.conv_v = torch.nn.Conv1d(channels, channels, 1) + if window_size is not None: + n_heads_rel = 1 if heads_share else n_heads + rel_stddev = self.k_channels**-0.5 + self.emb_rel_k = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.emb_rel_v = torch.nn.Parameter( + torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) + * rel_stddev + ) + self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) + self.drop = torch.nn.Dropout(p_dropout) + + torch.nn.init.xavier_uniform_(self.conv_q.weight) + torch.nn.init.xavier_uniform_(self.conv_k.weight) + if proximal_init: + self.conv_k.weight.data.copy_(self.conv_q.weight.data) + self.conv_k.bias.data.copy_(self.conv_q.bias.data) + torch.nn.init.xavier_uniform_(self.conv_v.weight) + + def forward(self, x, c, attn_mask=None): + q = self.conv_q(x) + k = self.conv_k(c) + v = self.conv_v(c) + + x, self.attn = self.attention(q, k, v, mask=attn_mask) + + x = self.conv_o(x) + return x + + def attention(self, query, key, value, mask=None): + b, d, t_s, t_t = (*key.size(), query.size(2)) + query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) + key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) + + scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels) + if self.window_size is not None: + assert ( + t_s == t_t + ), "Relative attention is only available for self-attention." + key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s) + rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings) + rel_logits = self._relative_position_to_absolute_position(rel_logits) + scores_local = rel_logits / math.sqrt(self.k_channels) + scores = scores + scores_local + if self.proximal_bias: + assert t_s == t_t, "Proximal bias is only available for self-attention." + scores = scores + self._attention_bias_proximal(t_s).to( + device=scores.device, dtype=scores.dtype + ) + if mask is not None: + scores = scores.masked_fill(mask == 0, -1e4) + p_attn = torch.nn.functional.softmax(scores, dim=-1) + p_attn = self.drop(p_attn) + output = torch.matmul(p_attn, value) + if self.window_size is not None: + relative_weights = self._absolute_position_to_relative_position(p_attn) + value_relative_embeddings = self._get_relative_embeddings( + self.emb_rel_v, t_s + ) + output = output + self._matmul_with_relative_values( + relative_weights, value_relative_embeddings + ) + output = output.transpose(2, 3).contiguous().view(b, d, t_t) + return output, p_attn + + def _matmul_with_relative_values(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0)) + return ret + + def _matmul_with_relative_keys(self, x, y): + ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1)) + return ret + + def _get_relative_embeddings(self, relative_embeddings, length): + pad_length = max(length - (self.window_size + 1), 0) + slice_start_position = max((self.window_size + 1) - length, 0) + slice_end_position = slice_start_position + 2 * length - 1 + if pad_length > 0: + padded_relative_embeddings = torch.nn.functional.pad( + relative_embeddings, + convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]), + ) + else: + padded_relative_embeddings = relative_embeddings + used_relative_embeddings = padded_relative_embeddings[ + :, slice_start_position:slice_end_position + ] + return used_relative_embeddings + + def _relative_position_to_absolute_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad( + x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]) + ) + x_flat = x.view([batch, heads, length * 2 * length]) + x_flat = torch.nn.functional.pad( + x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]) + ) + x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[ + :, :, :length, length - 1 : + ] + return x_final + + def _absolute_position_to_relative_position(self, x): + batch, heads, length, _ = x.size() + x = torch.nn.functional.pad( + x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]) + ) + x_flat = x.view([batch, heads, length**2 + length * (length - 1)]) + x_flat = torch.nn.functional.pad( + x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]) + ) + x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:] + return x_final + + def _attention_bias_proximal(self, length): + r = torch.arange(length, dtype=torch.float32) + diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1) + return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0) + + +class FFN(BaseModule): + def __init__( + self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0 + ): + super(FFN, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.filter_channels = filter_channels + self.kernel_size = kernel_size + self.p_dropout = p_dropout + + self.conv_1 = torch.nn.Conv1d( + in_channels, filter_channels, kernel_size, padding=kernel_size // 2 + ) + self.conv_2 = torch.nn.Conv1d( + filter_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.drop = torch.nn.Dropout(p_dropout) + + def forward(self, x, x_mask): + x = self.conv_1(x * x_mask) + x = torch.relu(x) + x = self.drop(x) + x = self.conv_2(x * x_mask) + return x * x_mask + + +class Encoder(BaseModule): + def __init__( + self, + hidden_channels, + filter_channels, + n_heads=2, + n_layers=6, + kernel_size=3, + p_dropout=0.1, + window_size=4, + **kwargs + ): + super(Encoder, self).__init__() + self.hidden_channels = hidden_channels + self.filter_channels = filter_channels + self.n_heads = n_heads + self.n_layers = n_layers + self.kernel_size = kernel_size + self.p_dropout = p_dropout + self.window_size = window_size + + self.drop = torch.nn.Dropout(p_dropout) + self.attn_layers = torch.nn.ModuleList() + self.norm_layers_1 = torch.nn.ModuleList() + self.ffn_layers = torch.nn.ModuleList() + self.norm_layers_2 = torch.nn.ModuleList() + for _ in range(self.n_layers): + self.attn_layers.append( + MultiHeadAttention( + hidden_channels, + hidden_channels, + n_heads, + window_size=window_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_1.append(LayerNorm(hidden_channels)) + self.ffn_layers.append( + FFN( + hidden_channels, + hidden_channels, + filter_channels, + kernel_size, + p_dropout=p_dropout, + ) + ) + self.norm_layers_2.append(LayerNorm(hidden_channels)) + + def forward(self, x, x_mask): + attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1) + for i in range(self.n_layers): + x = x * x_mask + y = self.attn_layers[i](x, x, attn_mask) + y = self.drop(y) + x = self.norm_layers_1[i](x + y) + y = self.ffn_layers[i](x, x_mask) + y = self.drop(y) + x = self.norm_layers_2[i](x + y) + x = x * x_mask + return x + + +class Conformer(BaseModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.n_heads = self.cfg.n_heads + self.n_layers = self.cfg.n_layers + self.hidden_channels = self.cfg.input_dim + self.filter_channels = self.cfg.filter_channels + self.output_dim = self.cfg.output_dim + self.dropout = self.cfg.dropout + + self.conformer_encoder = Encoder( + self.hidden_channels, + self.filter_channels, + n_heads=self.n_heads, + n_layers=self.n_layers, + kernel_size=3, + p_dropout=self.dropout, + window_size=4, + ) + self.projection = nn.Conv1d(self.hidden_channels, self.output_dim, 1) + + def forward(self, x, x_mask): + """ + Args: + x: (N, seq_len, input_dim) + Returns: + output: (N, seq_len, output_dim) + """ + # (N, seq_len, d_model) + x = x.transpose(1, 2) + x_mask = x_mask.transpose(1, 2) + output = self.conformer_encoder(x, x_mask) + # (N, seq_len, output_dim) + output = self.projection(output) + output = output.transpose(1, 2) + return output diff --git a/models/svc/transformer/transformer.py b/models/svc/transformer/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3cdb6c2d0fc93534d005b9f67a3058c9185c60 --- /dev/null +++ b/models/svc/transformer/transformer.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import torch +import torch.nn as nn +from torch.nn import TransformerEncoder, TransformerEncoderLayer + + +class Transformer(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + + dropout = self.cfg.dropout + nhead = self.cfg.n_heads + nlayers = self.cfg.n_layers + input_dim = self.cfg.input_dim + output_dim = self.cfg.output_dim + + d_model = input_dim + self.pos_encoder = PositionalEncoding(d_model, dropout) + encoder_layers = TransformerEncoderLayer( + d_model, nhead, dropout=dropout, batch_first=True + ) + self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers) + + self.output_mlp = nn.Linear(d_model, output_dim) + + def forward(self, x, mask=None): + """ + Args: + x: (N, seq_len, input_dim) + Returns: + output: (N, seq_len, output_dim) + """ + # (N, seq_len, d_model) + src = self.pos_encoder(x) + # model_stats["pos_embedding"] = x + # (N, seq_len, d_model) + output = self.transformer_encoder(src) + # (N, seq_len, output_dim) + output = self.output_mlp(output) + return output + + +class PositionalEncoding(nn.Module): + def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) + ) + + # Assume that x is (seq_len, N, d) + # pe = torch.zeros(max_len, 1, d_model) + # pe[:, 0, 0::2] = torch.sin(position * div_term) + # pe[:, 0, 1::2] = torch.cos(position * div_term) + + # Assume that x in (N, seq_len, d) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe) + + def forward(self, x): + """ + Args: + x: Tensor, shape [N, seq_len, d] + """ + # Old: Assume that x is (seq_len, N, d), and self.pe is (max_len, 1, d_model) + # x = x + self.pe[: x.size(0)] + + # Now: self.pe is (1, max_len, d) + x = x + self.pe[:, : x.size(1), :] + + return self.dropout(x) diff --git a/models/svc/transformer/transformer_inference.py b/models/svc/transformer/transformer_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..f6299c532aec6cb9283ee87ee9f0142f0b5c981b --- /dev/null +++ b/models/svc/transformer/transformer_inference.py @@ -0,0 +1,45 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import time +import numpy as np +import torch +from tqdm import tqdm +import torch.nn as nn +from collections import OrderedDict + +from models.svc.base import SVCInference +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.transformer.transformer import Transformer +from models.svc.transformer.conformer import Conformer + + +class TransformerInference(SVCInference): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + SVCInference.__init__(self, args, cfg, infer_type) + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + if self.cfg.model.transformer.type == "transformer": + self.acoustic_mapper = Transformer(self.cfg.model.transformer) + elif self.cfg.model.transformer.type == "conformer": + self.acoustic_mapper = Conformer(self.cfg.model.transformer) + else: + raise NotImplementedError + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _inference_each_batch(self, batch_data): + device = self.accelerator.device + for k, v in batch_data.items(): + batch_data[k] = v.to(device) + + condition = self.condition_encoder(batch_data) + y_pred = self.acoustic_mapper(condition, batch_data["mask"]) + + return y_pred diff --git a/models/svc/transformer/transformer_trainer.py b/models/svc/transformer/transformer_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3633078475d26e708280bc354f091bb9ab01ae45 --- /dev/null +++ b/models/svc/transformer/transformer_trainer.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from models.svc.base import SVCTrainer +from modules.encoder.condition_encoder import ConditionEncoder +from models.svc.transformer.transformer import Transformer +from models.svc.transformer.conformer import Conformer +from utils.ssim import SSIM + + +class TransformerTrainer(SVCTrainer): + def __init__(self, args, cfg): + SVCTrainer.__init__(self, args, cfg) + self.ssim_loss = SSIM() + + def _build_model(self): + self.cfg.model.condition_encoder.f0_min = self.cfg.preprocess.f0_min + self.cfg.model.condition_encoder.f0_max = self.cfg.preprocess.f0_max + self.condition_encoder = ConditionEncoder(self.cfg.model.condition_encoder) + if self.cfg.model.transformer.type == "transformer": + self.acoustic_mapper = Transformer(self.cfg.model.transformer) + elif self.cfg.model.transformer.type == "conformer": + self.acoustic_mapper = Conformer(self.cfg.model.transformer) + else: + raise NotImplementedError + model = torch.nn.ModuleList([self.condition_encoder, self.acoustic_mapper]) + return model + + def _forward_step(self, batch): + total_loss = 0 + device = self.accelerator.device + mel = batch["mel"] + mask = batch["mask"] + + condition = self.condition_encoder(batch) + mel_pred = self.acoustic_mapper(condition, mask) + + l1_loss = torch.sum(torch.abs(mel_pred - mel) * batch["mask"]) / torch.sum( + batch["mask"] + ) + self._check_nan(l1_loss, mel_pred, mel) + total_loss += l1_loss + ssim_loss = self.ssim_loss(mel_pred, mel) + ssim_loss = torch.sum(ssim_loss * batch["mask"]) / torch.sum(batch["mask"]) + self._check_nan(ssim_loss, mel_pred, mel) + total_loss += ssim_loss + + return total_loss diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py b/models/vocoders/autoregressive/autoregressive_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_inference.py b/models/vocoders/autoregressive/autoregressive_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py b/models/vocoders/autoregressive/autoregressive_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/autoregressive/wavenet/conv.py b/models/vocoders/autoregressive/wavenet/conv.py new file mode 100644 index 0000000000000000000000000000000000000000..a095aad5d7203f6e5fb5a4d585b894e34dbe63c7 --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/conv.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from torch import nn +from torch.nn import functional as F + + +class Conv1d(nn.Conv1d): + """Extended nn.Conv1d for incremental dilated convolutions""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.clear_buffer() + self._linearized_weight = None + self.register_backward_hook(self._clear_linearized_weight) + + def incremental_forward(self, input): + # input (B, T, C) + # run forward pre hooks + for hook in self._forward_pre_hooks.values(): + hook(self, input) + + # reshape weight + weight = self._get_linearized_weight() + kw = self.kernel_size[0] + dilation = self.dilation[0] + + bsz = input.size(0) + if kw > 1: + input = input.data + if self.input_buffer is None: + self.input_buffer = input.new( + bsz, kw + (kw - 1) * (dilation - 1), input.size(2) + ) + self.input_buffer.zero_() + else: + # shift buffer + self.input_buffer[:, :-1, :] = self.input_buffer[:, 1:, :].clone() + # append next input + self.input_buffer[:, -1, :] = input[:, -1, :] + input = self.input_buffer + if dilation > 1: + input = input[:, 0::dilation, :].contiguous() + output = F.linear(input.view(bsz, -1), weight, self.bias) + return output.view(bsz, 1, -1) + + def clear_buffer(self): + self.input_buffer = None + + def _get_linearized_weight(self): + if self._linearized_weight is None: + kw = self.kernel_size[0] + # nn.Conv1d + if self.weight.size() == (self.out_channels, self.in_channels, kw): + weight = self.weight.transpose(1, 2).contiguous() + else: + # fairseq.modules.conv_tbc.ConvTBC + weight = self.weight.transpose(2, 1).transpose(1, 0).contiguous() + assert weight.size() == (self.out_channels, kw, self.in_channels) + self._linearized_weight = weight.view(self.out_channels, -1) + return self._linearized_weight + + def _clear_linearized_weight(self, *args): + self._linearized_weight = None diff --git a/models/vocoders/autoregressive/wavenet/modules.py b/models/vocoders/autoregressive/wavenet/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..13d51e52a50af3bc1f7fe9627aeae8d2b1b28b7d --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/modules.py @@ -0,0 +1,152 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import math + +from torch import nn +from torch.nn import functional as F + +from .conv import Conv1d as conv_Conv1d + + +def Conv1d(in_channels, out_channels, kernel_size, dropout=0, **kwargs): + m = conv_Conv1d(in_channels, out_channels, kernel_size, **kwargs) + nn.init.kaiming_normal_(m.weight, nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + return nn.utils.weight_norm(m) + + +def Conv1d1x1(in_channels, out_channels, bias=True): + return Conv1d( + in_channels, out_channels, kernel_size=1, padding=0, dilation=1, bias=bias + ) + + +def _conv1x1_forward(conv, x, is_incremental): + if is_incremental: + x = conv.incremental_forward(x) + else: + x = conv(x) + return x + + +class ResidualConv1dGLU(nn.Module): + """Residual dilated conv1d + Gated linear unit + + Args: + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + kernel_size (int): Kernel size of convolution layers. + skip_out_channels (int): Skip connection channels. If None, set to same + as ``residual_channels``. + cin_channels (int): Local conditioning channels. If negative value is + set, local conditioning is disabled. + dropout (float): Dropout probability. + padding (int): Padding for convolution layers. If None, proper padding + is computed depends on dilation and kernel_size. + dilation (int): Dilation factor. + """ + + def __init__( + self, + residual_channels, + gate_channels, + kernel_size, + skip_out_channels=None, + cin_channels=-1, + dropout=1 - 0.95, + padding=None, + dilation=1, + causal=True, + bias=True, + *args, + **kwargs, + ): + super(ResidualConv1dGLU, self).__init__() + self.dropout = dropout + + if skip_out_channels is None: + skip_out_channels = residual_channels + if padding is None: + # no future time stamps available + if causal: + padding = (kernel_size - 1) * dilation + else: + padding = (kernel_size - 1) // 2 * dilation + self.causal = causal + + self.conv = Conv1d( + residual_channels, + gate_channels, + kernel_size, + padding=padding, + dilation=dilation, + bias=bias, + *args, + **kwargs, + ) + + # mel conditioning + self.conv1x1c = Conv1d1x1(cin_channels, gate_channels, bias=False) + + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_out_channels, bias=bias) + + def forward(self, x, c=None): + return self._forward(x, c, False) + + def incremental_forward(self, x, c=None): + return self._forward(x, c, True) + + def clear_buffer(self): + for c in [ + self.conv, + self.conv1x1_out, + self.conv1x1_skip, + self.conv1x1c, + ]: + if c is not None: + c.clear_buffer() + + def _forward(self, x, c, is_incremental): + """Forward + + Args: + x (Tensor): B x C x T + c (Tensor): B x C x T, Mel conditioning features + Returns: + Tensor: output + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + if is_incremental: + splitdim = -1 + x = self.conv.incremental_forward(x) + else: + splitdim = 1 + x = self.conv(x) + # remove future time steps + x = x[:, :, : residual.size(-1)] if self.causal else x + + a, b = x.split(x.size(splitdim) // 2, dim=splitdim) + + assert self.conv1x1c is not None + c = _conv1x1_forward(self.conv1x1c, c, is_incremental) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + a, b = a + ca, b + cb + + x = torch.tanh(a) * torch.sigmoid(b) + + # For skip connection + s = _conv1x1_forward(self.conv1x1_skip, x, is_incremental) + + # For residual connection + x = _conv1x1_forward(self.conv1x1_out, x, is_incremental) + + x = (x + residual) * math.sqrt(0.5) + return x, s diff --git a/models/vocoders/autoregressive/wavenet/upsample.py b/models/vocoders/autoregressive/wavenet/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..b664302cd56545f1709a4f1874ebadd8e9375a9c --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/upsample.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import numpy as np + +from torch import nn +from torch.nn import functional as F + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale, mode="nearest"): + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode + ) + + +def _get_activation(upsample_activation): + nonlinear = getattr(nn, upsample_activation) + return nonlinear + + +class UpsampleNetwork(nn.Module): + def __init__( + self, + upsample_scales, + upsample_activation="none", + upsample_activation_params={}, + mode="nearest", + freq_axis_kernel_size=1, + cin_pad=0, + cin_channels=128, + ): + super(UpsampleNetwork, self).__init__() + self.up_layers = nn.ModuleList() + total_scale = np.prod(upsample_scales) + self.indent = cin_pad * total_scale + for scale in upsample_scales: + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + k_size = (freq_axis_kernel_size, scale * 2 + 1) + padding = (freq_axis_padding, scale) + stretch = Stretch2d(scale, 1, mode) + conv = nn.Conv2d(1, 1, kernel_size=k_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / np.prod(k_size)) + conv = nn.utils.weight_norm(conv) + self.up_layers.append(stretch) + self.up_layers.append(conv) + if upsample_activation != "none": + nonlinear = _get_activation(upsample_activation) + self.up_layers.append(nonlinear(**upsample_activation_params)) + + def forward(self, c): + """ + Args: + c : B x C x T + """ + + # B x 1 x C x T + c = c.unsqueeze(1) + for f in self.up_layers: + c = f(c) + # B x C x T + c = c.squeeze(1) + + if self.indent > 0: + c = c[:, :, self.indent : -self.indent] + return c + + +class ConvInUpsampleNetwork(nn.Module): + def __init__( + self, + upsample_scales, + upsample_activation="none", + upsample_activation_params={}, + mode="nearest", + freq_axis_kernel_size=1, + cin_pad=0, + cin_channels=128, + ): + super(ConvInUpsampleNetwork, self).__init__() + # To capture wide-context information in conditional features + # meaningless if cin_pad == 0 + ks = 2 * cin_pad + 1 + self.conv_in = nn.Conv1d( + cin_channels, cin_channels, kernel_size=ks, padding=cin_pad, bias=False + ) + self.upsample = UpsampleNetwork( + upsample_scales, + upsample_activation, + upsample_activation_params, + mode, + freq_axis_kernel_size, + cin_pad=cin_pad, + cin_channels=cin_channels, + ) + + def forward(self, c): + c_up = self.upsample(self.conv_in(c)) + return c_up diff --git a/models/vocoders/autoregressive/wavenet/wavenet.py b/models/vocoders/autoregressive/wavenet/wavenet.py new file mode 100644 index 0000000000000000000000000000000000000000..d63f22c2600fd0f83e5bdf339ebb121b3d2f35e6 --- /dev/null +++ b/models/vocoders/autoregressive/wavenet/wavenet.py @@ -0,0 +1,170 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math + +from torch import nn +from torch.nn import functional as F + +from .modules import Conv1d1x1, ResidualConv1dGLU +from .upsample import ConvInUpsampleNetwork + + +def receptive_field_size( + total_layers, num_cycles, kernel_size, dilation=lambda x: 2**x +): + """Compute receptive field size + + Args: + total_layers (int): total layers + num_cycles (int): cycles + kernel_size (int): kernel size + dilation (lambda): lambda to compute dilation factor. ``lambda x : 1`` + to disable dilated convolution. + + Returns: + int: receptive field size in sample + + """ + assert total_layers % num_cycles == 0 + + layers_per_cycle = total_layers // num_cycles + dilations = [dilation(i % layers_per_cycle) for i in range(total_layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + +class WaveNet(nn.Module): + """The WaveNet model that supports local and global conditioning. + + Args: + out_channels (int): Output channels. If input_type is mu-law quantized + one-hot vecror. this must equal to the quantize channels. Other wise + num_mixtures x 3 (pi, mu, log_scale). + layers (int): Number of total layers + stacks (int): Number of dilation cycles + residual_channels (int): Residual input / output channels + gate_channels (int): Gated activation channels. + skip_out_channels (int): Skip connection channels. + kernel_size (int): Kernel size of convolution layers. + dropout (float): Dropout probability. + input_dim (int): Number of mel-spec dimension. + upsample_scales (list): List of upsample scale. + ``np.prod(upsample_scales)`` must equal to hop size. Used only if + upsample_conditional_features is enabled. + freq_axis_kernel_size (int): Freq-axis kernel_size for transposed + convolution layers for upsampling. If you only care about time-axis + upsampling, set this to 1. + scalar_input (Bool): If True, scalar input ([-1, 1]) is expected, otherwise + quantized one-hot vector is expected.. + """ + + def __init__(self, cfg): + super(WaveNet, self).__init__() + self.cfg = cfg + self.scalar_input = self.cfg.VOCODER.SCALAR_INPUT + self.out_channels = self.cfg.VOCODER.OUT_CHANNELS + self.cin_channels = self.cfg.VOCODER.INPUT_DIM + self.residual_channels = self.cfg.VOCODER.RESIDUAL_CHANNELS + self.layers = self.cfg.VOCODER.LAYERS + self.stacks = self.cfg.VOCODER.STACKS + self.gate_channels = self.cfg.VOCODER.GATE_CHANNELS + self.kernel_size = self.cfg.VOCODER.KERNEL_SIZE + self.skip_out_channels = self.cfg.VOCODER.SKIP_OUT_CHANNELS + self.dropout = self.cfg.VOCODER.DROPOUT + self.upsample_scales = self.cfg.VOCODER.UPSAMPLE_SCALES + self.mel_frame_pad = self.cfg.VOCODER.MEL_FRAME_PAD + + assert self.layers % self.stacks == 0 + + layers_per_stack = self.layers // self.stacks + if self.scalar_input: + self.first_conv = Conv1d1x1(1, self.residual_channels) + else: + self.first_conv = Conv1d1x1(self.out_channels, self.residual_channels) + + self.conv_layers = nn.ModuleList() + for layer in range(self.layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualConv1dGLU( + self.residual_channels, + self.gate_channels, + kernel_size=self.kernel_size, + skip_out_channels=self.skip_out_channels, + bias=True, + dilation=dilation, + dropout=self.dropout, + cin_channels=self.cin_channels, + ) + self.conv_layers.append(conv) + + self.last_conv_layers = nn.ModuleList( + [ + nn.ReLU(inplace=True), + Conv1d1x1(self.skip_out_channels, self.skip_out_channels), + nn.ReLU(inplace=True), + Conv1d1x1(self.skip_out_channels, self.out_channels), + ] + ) + + self.upsample_net = ConvInUpsampleNetwork( + upsample_scales=self.upsample_scales, + cin_pad=self.mel_frame_pad, + cin_channels=self.cin_channels, + ) + + self.receptive_field = receptive_field_size( + self.layers, self.stacks, self.kernel_size + ) + + def forward(self, x, mel, softmax=False): + """Forward step + + Args: + x (Tensor): One-hot encoded audio signal, shape (B x C x T) + mel (Tensor): Local conditioning features, + shape (B x cin_channels x T) + softmax (bool): Whether applies softmax or not. + + Returns: + Tensor: output, shape B x out_channels x T + """ + B, _, T = x.size() + + mel = self.upsample_net(mel) + assert mel.shape[-1] == x.shape[-1] + + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, mel) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + x = skips + for f in self.last_conv_layers: + x = f(x) + + x = F.softmax(x, dim=1) if softmax else x + + return x + + def clear_buffer(self): + self.first_conv.clear_buffer() + for f in self.conv_layers: + f.clear_buffer() + for f in self.last_conv_layers: + try: + f.clear_buffer() + except AttributeError: + pass + + def make_generation_fast_(self): + def remove_weight_norm(m): + try: + nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(remove_weight_norm) diff --git a/models/vocoders/autoregressive/wavernn/wavernn.py b/models/vocoders/autoregressive/wavernn/wavernn.py new file mode 100644 index 0000000000000000000000000000000000000000..c7475fa8fe8b4575bf714e615349582ff98bbc27 --- /dev/null +++ b/models/vocoders/autoregressive/wavernn/wavernn.py @@ -0,0 +1,188 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +import numpy as np + + +class ResBlock(nn.Module): + def __init__(self, dims): + super().__init__() + self.conv1 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.conv2 = nn.Conv1d(dims, dims, kernel_size=1, bias=False) + self.batch_norm1 = nn.BatchNorm1d(dims) + self.batch_norm2 = nn.BatchNorm1d(dims) + + def forward(self, x): + residual = x + x = self.conv1(x) + x = self.batch_norm1(x) + x = F.relu(x) + x = self.conv2(x) + x = self.batch_norm2(x) + x = x + residual + return x + + +class MelResNet(nn.Module): + def __init__(self, res_blocks, in_dims, compute_dims, res_out_dims, pad): + super().__init__() + kernel_size = pad * 2 + 1 + self.conv_in = nn.Conv1d( + in_dims, compute_dims, kernel_size=kernel_size, bias=False + ) + self.batch_norm = nn.BatchNorm1d(compute_dims) + self.layers = nn.ModuleList() + for i in range(res_blocks): + self.layers.append(ResBlock(compute_dims)) + self.conv_out = nn.Conv1d(compute_dims, res_out_dims, kernel_size=1) + + def forward(self, x): + x = self.conv_in(x) + x = self.batch_norm(x) + x = F.relu(x) + for f in self.layers: + x = f(x) + x = self.conv_out(x) + return x + + +class Stretch2d(nn.Module): + def __init__(self, x_scale, y_scale): + super().__init__() + self.x_scale = x_scale + self.y_scale = y_scale + + def forward(self, x): + b, c, h, w = x.size() + x = x.unsqueeze(-1).unsqueeze(3) + x = x.repeat(1, 1, 1, self.y_scale, 1, self.x_scale) + return x.view(b, c, h * self.y_scale, w * self.x_scale) + + +class UpsampleNetwork(nn.Module): + def __init__( + self, feat_dims, upsample_scales, compute_dims, res_blocks, res_out_dims, pad + ): + super().__init__() + total_scale = np.cumproduct(upsample_scales)[-1] + self.indent = pad * total_scale + self.resnet = MelResNet(res_blocks, feat_dims, compute_dims, res_out_dims, pad) + self.resnet_stretch = Stretch2d(total_scale, 1) + self.up_layers = nn.ModuleList() + for scale in upsample_scales: + kernel_size = (1, scale * 2 + 1) + padding = (0, scale) + stretch = Stretch2d(scale, 1) + conv = nn.Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + conv.weight.data.fill_(1.0 / kernel_size[1]) + self.up_layers.append(stretch) + self.up_layers.append(conv) + + def forward(self, m): + aux = self.resnet(m).unsqueeze(1) + aux = self.resnet_stretch(aux) + aux = aux.squeeze(1) + m = m.unsqueeze(1) + for f in self.up_layers: + m = f(m) + m = m.squeeze(1)[:, :, self.indent : -self.indent] + return m.transpose(1, 2), aux.transpose(1, 2) + + +class WaveRNN(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + self.pad = self.cfg.VOCODER.MEL_FRAME_PAD + + if self.cfg.VOCODER.MODE == "mu_law_quantize": + self.n_classes = 2**self.cfg.VOCODER.BITS + elif self.cfg.VOCODER.MODE == "mu_law" or self.cfg.VOCODER: + self.n_classes = 30 + + self._to_flatten = [] + + self.rnn_dims = self.cfg.VOCODER.RNN_DIMS + self.aux_dims = self.cfg.VOCODER.RES_OUT_DIMS // 4 + self.hop_length = self.cfg.VOCODER.HOP_LENGTH + self.fc_dims = self.cfg.VOCODER.FC_DIMS + self.upsample_factors = self.cfg.VOCODER.UPSAMPLE_FACTORS + self.feat_dims = self.cfg.VOCODER.INPUT_DIM + self.compute_dims = self.cfg.VOCODER.COMPUTE_DIMS + self.res_out_dims = self.cfg.VOCODER.RES_OUT_DIMS + self.res_blocks = self.cfg.VOCODER.RES_BLOCKS + + self.upsample = UpsampleNetwork( + self.feat_dims, + self.upsample_factors, + self.compute_dims, + self.res_blocks, + self.res_out_dims, + self.pad, + ) + self.I = nn.Linear(self.feat_dims + self.aux_dims + 1, self.rnn_dims) + + self.rnn1 = nn.GRU(self.rnn_dims, self.rnn_dims, batch_first=True) + self.rnn2 = nn.GRU( + self.rnn_dims + self.aux_dims, self.rnn_dims, batch_first=True + ) + self._to_flatten += [self.rnn1, self.rnn2] + + self.fc1 = nn.Linear(self.rnn_dims + self.aux_dims, self.fc_dims) + self.fc2 = nn.Linear(self.fc_dims + self.aux_dims, self.fc_dims) + self.fc3 = nn.Linear(self.fc_dims, self.n_classes) + + self.num_params() + + self._flatten_parameters() + + def forward(self, x, mels): + device = next(self.parameters()).device + + self._flatten_parameters() + + batch_size = x.size(0) + h1 = torch.zeros(1, batch_size, self.rnn_dims, device=device) + h2 = torch.zeros(1, batch_size, self.rnn_dims, device=device) + mels, aux = self.upsample(mels) + + aux_idx = [self.aux_dims * i for i in range(5)] + a1 = aux[:, :, aux_idx[0] : aux_idx[1]] + a2 = aux[:, :, aux_idx[1] : aux_idx[2]] + a3 = aux[:, :, aux_idx[2] : aux_idx[3]] + a4 = aux[:, :, aux_idx[3] : aux_idx[4]] + + x = torch.cat([x.unsqueeze(-1), mels, a1], dim=2) + x = self.I(x) + res = x + x, _ = self.rnn1(x, h1) + + x = x + res + res = x + x = torch.cat([x, a2], dim=2) + x, _ = self.rnn2(x, h2) + + x = x + res + x = torch.cat([x, a3], dim=2) + x = F.relu(self.fc1(x)) + + x = torch.cat([x, a4], dim=2) + x = F.relu(self.fc2(x)) + return self.fc3(x) + + def num_params(self, print_out=True): + parameters = filter(lambda p: p.requires_grad, self.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print("Trainable Parameters: %.3fM" % parameters) + return parameters + + def _flatten_parameters(self): + [m.flatten_parameters() for m in self._to_flatten] diff --git a/models/vocoders/diffusion/diffusion_vocoder_dataset.py b/models/vocoders/diffusion/diffusion_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffusion_vocoder_inference.py b/models/vocoders/diffusion/diffusion_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffusion_vocoder_trainer.py b/models/vocoders/diffusion/diffusion_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/diffusion/diffwave/diffwave.py b/models/vocoders/diffusion/diffwave/diffwave.py new file mode 100644 index 0000000000000000000000000000000000000000..c9379b0b622c6da8a754f2cc87fd7723eacfa995 --- /dev/null +++ b/models/vocoders/diffusion/diffwave/diffwave.py @@ -0,0 +1,173 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from math import sqrt + + +Linear = nn.Linear +ConvTranspose2d = nn.ConvTranspose2d + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +@torch.jit.script +def silu(x): + return x * torch.sigmoid(x) + + +class DiffusionEmbedding(nn.Module): + def __init__(self, max_steps): + super().__init__() + self.register_buffer( + "embedding", self._build_embedding(max_steps), persistent=False + ) + self.projection1 = Linear(128, 512) + self.projection2 = Linear(512, 512) + + def forward(self, diffusion_step): + if diffusion_step.dtype in [torch.int32, torch.int64]: + x = self.embedding[diffusion_step] + else: + x = self._lerp_embedding(diffusion_step) + x = self.projection1(x) + x = silu(x) + x = self.projection2(x) + x = silu(x) + return x + + def _lerp_embedding(self, t): + low_idx = torch.floor(t).long() + high_idx = torch.ceil(t).long() + low = self.embedding[low_idx] + high = self.embedding[high_idx] + return low + (high - low) * (t - low_idx) + + def _build_embedding(self, max_steps): + steps = torch.arange(max_steps).unsqueeze(1) # [T,1] + dims = torch.arange(64).unsqueeze(0) # [1,64] + table = steps * 10.0 ** (dims * 4.0 / 63.0) # [T,64] + table = torch.cat([torch.sin(table), torch.cos(table)], dim=1) + return table + + +class SpectrogramUpsampler(nn.Module): + def __init__(self, upsample_factors): + super().__init__() + self.conv1 = ConvTranspose2d( + 1, + 1, + [3, upsample_factors[0] * 2], + stride=[1, upsample_factors[0]], + padding=[1, upsample_factors[0] // 2], + ) + self.conv2 = ConvTranspose2d( + 1, + 1, + [3, upsample_factors[1] * 2], + stride=[1, upsample_factors[1]], + padding=[1, upsample_factors[1] // 2], + ) + + def forward(self, x): + x = torch.unsqueeze(x, 1) + x = self.conv1(x) + x = F.leaky_relu(x, 0.4) + x = self.conv2(x) + x = F.leaky_relu(x, 0.4) + x = torch.squeeze(x, 1) + return x + + +class ResidualBlock(nn.Module): + def __init__(self, n_mels, residual_channels, dilation): + super().__init__() + self.dilated_conv = Conv1d( + residual_channels, + 2 * residual_channels, + 3, + padding=dilation, + dilation=dilation, + ) + self.diffusion_projection = Linear(512, residual_channels) + + self.conditioner_projection = Conv1d(n_mels, 2 * residual_channels, 1) + + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, diffusion_step, conditioner): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + y = x + diffusion_step + + conditioner = self.conditioner_projection(conditioner) + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / sqrt(2.0), skip + + +class DiffWave(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.cfg.VOCODER.NOISE_SCHEDULE = np.linspace( + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[0], + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[1], + self.cfg.VOCODER.NOISE_SCHEDULE_FACTORS[2], + ).tolist() + self.input_projection = Conv1d(1, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1) + self.diffusion_embedding = DiffusionEmbedding( + len(self.cfg.VOCODER.NOISE_SCHEDULE) + ) + self.spectrogram_upsampler = SpectrogramUpsampler( + self.cfg.VOCODER.UPSAMPLE_FACTORS + ) + + self.residual_layers = nn.ModuleList( + [ + ResidualBlock( + self.cfg.VOCODER.INPUT_DIM, + self.cfg.VOCODER.RESIDUAL_CHANNELS, + 2 ** (i % self.cfg.VOCODER.DILATION_CYCLE_LENGTH), + ) + for i in range(self.cfg.VOCODER.RESIDUAL_LAYERS) + ] + ) + self.skip_projection = Conv1d( + self.cfg.VOCODER.RESIDUAL_CHANNELS, self.cfg.VOCODER.RESIDUAL_CHANNELS, 1 + ) + self.output_projection = Conv1d(self.cfg.VOCODER.RESIDUAL_CHANNELS, 1, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, audio, diffusion_step, spectrogram): + x = audio.unsqueeze(1) + x = self.input_projection(x) + x = F.relu(x) + + diffusion_step = self.diffusion_embedding(diffusion_step) + spectrogram = self.spectrogram_upsampler(spectrogram) + + skip = None + for layer in self.residual_layers: + x, skip_connection = layer(x, diffusion_step, spectrogram) + skip = skip_connection if skip is None else skip_connection + skip + + x = skip / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) + return x diff --git a/models/vocoders/dsp/world/world.py b/models/vocoders/dsp/world/world.py new file mode 100644 index 0000000000000000000000000000000000000000..59f28e8e896f883fe6ce243dfb7f254e78fd09c6 --- /dev/null +++ b/models/vocoders/dsp/world/world.py @@ -0,0 +1,183 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# 1. Extract WORLD features including F0, AP, SP +# 2. Transform between SP and MCEP +import torchaudio +import pyworld as pw +import numpy as np +import torch +import diffsptk +import os +from tqdm import tqdm +import pickle +import json +import re +import torchaudio + +from cuhkszsvc.configs.config_parse import get_wav_path, get_wav_file_path +from utils.io import has_existed + + +def get_mcep_params(fs): + """Hyperparameters of transformation between SP and MCEP + + Reference: + https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh + + """ + if fs in [44100, 48000]: + fft_size = 2048 + alpha = 0.77 + if fs in [16000]: + fft_size = 1024 + alpha = 0.58 + return fft_size, alpha + + +def extract_world_features(wave_file, fs, frameshift): + # waveform: (1, seq) + waveform, sample_rate = torchaudio.load(wave_file) + if sample_rate != fs: + waveform = torchaudio.functional.resample( + waveform, orig_freq=sample_rate, new_freq=fs + ) + # x: (seq,) + x = np.array(torch.clamp(waveform[0], -1.0, 1.0), dtype=np.double) + + _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor + f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement + sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram + ap = pw.d4c(x, f0, t, fs) # extract aperiodicity + + return f0, sp, ap, fs + + +def sp2mcep(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.ScalarOperation("SquareRoot")(x) + tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp) + mgc = diffsptk.MelCepstralAnalysis( + cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1 + )(tmp) + return mgc.numpy() + + +def mcep2sp(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.MelGeneralizedCepstrumToSpectrum( + alpha=alpha, + cep_order=mcsize - 1, + fft_length=fft_size, + )(x) + tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp) + sp = diffsptk.ScalarOperation("Power", 2)(tmp) + return sp.double().numpy() + + +def extract_mcep_features_of_dataset( + output_path, dataset_path, dataset, mcsize, fs, frameshift, splits=None +): + output_dir = os.path.join(output_path, dataset, "mcep/{}".format(fs)) + + if not splits: + splits = ["train", "test"] if dataset != "m4singer" else ["test"] + + for dataset_type in splits: + print("-" * 20) + print("Dataset: {}, {}".format(dataset, dataset_type)) + + output_file = os.path.join(output_dir, "{}.pkl".format(dataset_type)) + if has_existed(output_file): + continue + + # Extract SP features + print("\nExtracting SP featuers...") + sp_features = get_world_features_of_dataset( + output_path, dataset_path, dataset, dataset_type, fs, frameshift + ) + + # SP to MCEP + print("\nTransform SP to MCEP...") + mcep_features = [sp2mcep(sp, mcsize=mcsize, fs=fs) for sp in tqdm(sp_features)] + + # Save + os.makedirs(output_dir, exist_ok=True) + with open(output_file, "wb") as f: + pickle.dump(mcep_features, f) + + +def get_world_features_of_dataset( + output_path, + dataset_path, + dataset, + dataset_type, + fs, + frameshift, + save_sp_feature=False, +): + data_dir = os.path.join(output_path, dataset) + wave_dir = get_wav_path(dataset_path, dataset) + + # Dataset + dataset_file = os.path.join(data_dir, "{}.json".format(dataset_type)) + if not os.path.exists(dataset_file): + print("File {} has not existed.".format(dataset_file)) + return None + + with open(dataset_file, "r") as f: + datasets = json.load(f) + + # Save dir + f0_dir = os.path.join(output_path, dataset, "f0") + os.makedirs(f0_dir, exist_ok=True) + + # Extract + f0_features = [] + sp_features = [] + for utt in tqdm(datasets): + wave_file = get_wav_file_path(dataset, wave_dir, utt) + f0, sp, _, _ = extract_world_features(wave_file, fs, frameshift) + + sp_features.append(sp) + f0_features.append(f0) + + # Save sp + if save_sp_feature: + sp_dir = os.path.join(output_path, dataset, "sp") + os.makedirs(sp_dir, exist_ok=True) + with open(os.path.join(sp_dir, "{}.pkl".format(dataset_type)), "wb") as f: + pickle.dump(sp_features, f) + + # F0 statistics + f0_statistics_file = os.path.join(f0_dir, "{}_f0.pkl".format(dataset_type)) + f0_statistics(f0_features, f0_statistics_file) + + return sp_features + + +def f0_statistics(f0_features, path): + print("\nF0 statistics...") + + total_f0 = [] + for f0 in tqdm(f0_features): + total_f0 += [f for f in f0 if f != 0] + + mean = sum(total_f0) / len(total_f0) + print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean)) + + with open(path, "wb") as f: + pickle.dump([mean, total_f0], f) + + +def world_synthesis(f0, sp, ap, fs, frameshift): + y = pw.synthesize( + f0, sp, ap, fs, frame_period=frameshift + ) # synthesize an utterance using the parameters + return y diff --git a/models/vocoders/flow/flow_vocoder_dataset.py b/models/vocoders/flow/flow_vocoder_dataset.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/flow_vocoder_inference.py b/models/vocoders/flow/flow_vocoder_inference.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/flow_vocoder_trainer.py b/models/vocoders/flow/flow_vocoder_trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/flow/waveglow/waveglow.py b/models/vocoders/flow/waveglow/waveglow.py new file mode 100644 index 0000000000000000000000000000000000000000..13e2a1bf8f5e3c3d47a031ceec87e4ff111cd5fe --- /dev/null +++ b/models/vocoders/flow/waveglow/waveglow.py @@ -0,0 +1,249 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.autograd import Variable +import torch.nn.functional as F + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +class Invertible1x1Conv(torch.nn.Module): + """ + The layer outputs both the convolution, and the log determinant + of its weight matrix. If reverse=True it does convolution with + inverse + """ + + def __init__(self, c): + super(Invertible1x1Conv, self).__init__() + self.conv = torch.nn.Conv1d( + c, c, kernel_size=1, stride=1, padding=0, bias=False + ) + + # Sample a random orthonormal matrix to initialize weights + W = torch.linalg.qr(torch.FloatTensor(c, c).normal_())[0] + + # Ensure determinant is 1.0 not -1.0 + if torch.det(W) < 0: + W[:, 0] = -1 * W[:, 0] + W = W.view(c, c, 1) + self.conv.weight.data = W + + def forward(self, z, reverse=False): + # shape + batch_size, group_size, n_of_groups = z.size() + + W = self.conv.weight.squeeze() + + if reverse: + if not hasattr(self, "W_inverse"): + # Reverse computation + W_inverse = W.float().inverse() + W_inverse = Variable(W_inverse[..., None]) + if z.type() == "torch.cuda.HalfTensor": + W_inverse = W_inverse.half() + self.W_inverse = W_inverse + z = F.conv1d(z, self.W_inverse, bias=None, stride=1, padding=0) + return z + else: + # Forward computation + log_det_W = batch_size * n_of_groups * torch.logdet(W) + z = self.conv(z) + return z, log_det_W + + +class WN(torch.nn.Module): + """ + This is the WaveNet like layer for the affine coupling. The primary difference + from WaveNet is the convolutions need not be causal. There is also no dilation + size reset. The dilation only doubles on each layer + """ + + def __init__( + self, n_in_channels, n_mel_channels, n_layers, n_channels, kernel_size + ): + super(WN, self).__init__() + assert kernel_size % 2 == 1 + assert n_channels % 2 == 0 + self.n_layers = n_layers + self.n_channels = n_channels + self.in_layers = torch.nn.ModuleList() + self.res_skip_layers = torch.nn.ModuleList() + + start = torch.nn.Conv1d(n_in_channels, n_channels, 1) + start = torch.nn.utils.weight_norm(start, name="weight") + self.start = start + + # Initializing last layer to 0 makes the affine coupling layers + # do nothing at first. This helps with training stability + end = torch.nn.Conv1d(n_channels, 2 * n_in_channels, 1) + end.weight.data.zero_() + end.bias.data.zero_() + self.end = end + + cond_layer = torch.nn.Conv1d(n_mel_channels, 2 * n_channels * n_layers, 1) + self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight") + + for i in range(n_layers): + dilation = 2**i + padding = int((kernel_size * dilation - dilation) / 2) + in_layer = torch.nn.Conv1d( + n_channels, + 2 * n_channels, + kernel_size, + dilation=dilation, + padding=padding, + ) + in_layer = torch.nn.utils.weight_norm(in_layer, name="weight") + self.in_layers.append(in_layer) + + # last one is not necessary + if i < n_layers - 1: + res_skip_channels = 2 * n_channels + else: + res_skip_channels = n_channels + res_skip_layer = torch.nn.Conv1d(n_channels, res_skip_channels, 1) + res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight") + self.res_skip_layers.append(res_skip_layer) + + def forward(self, forward_input): + audio, spect = forward_input + audio = self.start(audio) + output = torch.zeros_like(audio) + n_channels_tensor = torch.IntTensor([self.n_channels]) + + spect = self.cond_layer(spect) + + for i in range(self.n_layers): + spect_offset = i * 2 * self.n_channels + acts = fused_add_tanh_sigmoid_multiply( + self.in_layers[i](audio), + spect[:, spect_offset : spect_offset + 2 * self.n_channels, :], + n_channels_tensor, + ) + + res_skip_acts = self.res_skip_layers[i](acts) + if i < self.n_layers - 1: + audio = audio + res_skip_acts[:, : self.n_channels, :] + output = output + res_skip_acts[:, self.n_channels :, :] + else: + output = output + res_skip_acts + + return self.end(output) + + +class WaveGlow(torch.nn.Module): + def __init__(self, cfg): + super(WaveGlow, self).__init__() + + self.cfg = cfg + + self.upsample = torch.nn.ConvTranspose1d( + self.cfg.VOCODER.INPUT_DIM, + self.cfg.VOCODER.INPUT_DIM, + 1024, + stride=256, + ) + assert self.cfg.VOCODER.N_GROUP % 2 == 0 + self.n_flows = self.cfg.VOCODER.N_FLOWS + self.n_group = self.cfg.VOCODER.N_GROUP + self.n_early_every = self.cfg.VOCODER.N_EARLY_EVERY + self.n_early_size = self.cfg.VOCODER.N_EARLY_SIZE + self.WN = torch.nn.ModuleList() + self.convinv = torch.nn.ModuleList() + + n_half = int(self.cfg.VOCODER.N_GROUP / 2) + + # Set up layers with the right sizes based on how many dimensions + # have been output already + n_remaining_channels = self.cfg.VOCODER.N_GROUP + for k in range(self.cfg.VOCODER.N_FLOWS): + if k % self.n_early_every == 0 and k > 0: + n_half = n_half - int(self.n_early_size / 2) + n_remaining_channels = n_remaining_channels - self.n_early_size + self.convinv.append(Invertible1x1Conv(n_remaining_channels)) + self.WN.append( + WN( + n_half, + self.cfg.VOCODER.INPUT_DIM * self.cfg.VOCODER.N_GROUP, + self.cfg.VOCODER.N_LAYERS, + self.cfg.VOCODER.N_CHANNELS, + self.cfg.VOCODER.KERNEL_SIZE, + ) + ) + self.n_remaining_channels = n_remaining_channels # Useful during inference + + def forward(self, forward_input): + """ + forward_input[0] = mel_spectrogram: batch x n_mel_channels x frames + forward_input[1] = audio: batch x time + """ + spect, audio = forward_input + + # Upsample spectrogram to size of audio + spect = self.upsample(spect) + assert spect.size(2) >= audio.size(1) + if spect.size(2) > audio.size(1): + spect = spect[:, :, : audio.size(1)] + + spect = spect.unfold(2, self.n_group, self.n_group).permute(0, 2, 1, 3) + spect = ( + spect.contiguous().view(spect.size(0), spect.size(1), -1).permute(0, 2, 1) + ) + + audio = audio.unfold(1, self.n_group, self.n_group).permute(0, 2, 1) + output_audio = [] + log_s_list = [] + log_det_W_list = [] + + for k in range(self.n_flows): + if k % self.n_early_every == 0 and k > 0: + output_audio.append(audio[:, : self.n_early_size, :]) + audio = audio[:, self.n_early_size :, :] + + audio, log_det_W = self.convinv[k](audio) + log_det_W_list.append(log_det_W) + + n_half = int(audio.size(1) / 2) + audio_0 = audio[:, :n_half, :] + audio_1 = audio[:, n_half:, :] + + output = self.WN[k]((audio_0, spect)) + log_s = output[:, n_half:, :] + b = output[:, :n_half, :] + audio_1 = torch.exp(log_s) * audio_1 + b + log_s_list.append(log_s) + + audio = torch.cat([audio_0, audio_1], 1) + + output_audio.append(audio) + return torch.cat(output_audio, 1), log_s_list, log_det_W_list + + @staticmethod + def remove_weightnorm(model): + waveglow = model + for WN in waveglow.WN: + WN.start = torch.nn.utils.remove_weight_norm(WN.start) + WN.in_layers = remove(WN.in_layers) + WN.cond_layer = torch.nn.utils.remove_weight_norm(WN.cond_layer) + WN.res_skip_layers = remove(WN.res_skip_layers) + return waveglow + + +def remove(conv_list): + new_conv_list = torch.nn.ModuleList() + for old_conv in conv_list: + old_conv = torch.nn.utils.remove_weight_norm(old_conv) + new_conv_list.append(old_conv) + return new_conv_list diff --git a/models/vocoders/gan/discriminator/mpd.py b/models/vocoders/gan/discriminator/mpd.py new file mode 100644 index 0000000000000000000000000000000000000000..f28711d18847a106a998cab90871fe6303a4fd08 --- /dev/null +++ b/models/vocoders/gan/discriminator/mpd.py @@ -0,0 +1,269 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv2d, Conv1d +from torch.nn.utils import weight_norm, spectral_norm +from torch import nn +from modules.vocoder_blocks import * + +LRELU_SLOPE = 0.1 + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, cfg, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + self.d_mult = cfg.model.mpd.discriminator_channel_mult_factor + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + int(32 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(32 * self.d_mult), + int(128 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(128 * self.d_mult), + int(512 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(512 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(get_padding(5, 1), 0), + ) + ), + norm_f( + Conv2d( + int(1024 * self.d_mult), + int(1024 * self.d_mult), + (kernel_size, 1), + (stride, 1), + padding=(2, 0), + ) + ), + ] + ) + self.conv_post = norm_f( + Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)) + ) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, cfg): + super(MultiPeriodDiscriminator, self).__init__() + self.mpd_reshapes = cfg.model.mpd.mpd_reshapes + print("mpd_reshapes: {}".format(self.mpd_reshapes)) + discriminators = [ + DiscriminatorP(cfg, rs, use_spectral_norm=cfg.model.mpd.use_spectral_norm) + for rs in self.mpd_reshapes + ] + self.discriminators = nn.ModuleList(discriminators) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +# TODO: merge with DiscriminatorP (lmxue, yicheng) +class DiscriminatorP_vits(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False): + super(DiscriminatorP_vits, self).__init__() + self.period = period + self.use_spectral_norm = use_spectral_norm + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f( + Conv2d( + 1, + 32, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 32, + 128, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 128, + 512, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 512, + 1024, + (kernel_size, 1), + (stride, 1), + padding=(get_padding(kernel_size, 1), 0), + ) + ), + norm_f( + Conv2d( + 1024, + 1024, + (kernel_size, 1), + 1, + padding=(get_padding(kernel_size, 1), 0), + ) + ), + ] + ) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 16, 15, 1, padding=7)), + norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)), + norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +# TODO: merge with MultiPeriodDiscriminator (lmxue, yicheng) +class MultiPeriodDiscriminator_vits(torch.nn.Module): + def __init__(self, use_spectral_norm=False): + super(MultiPeriodDiscriminator_vits, self).__init__() + periods = [2, 3, 5, 7, 11] + + discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)] + discs = discs + [ + DiscriminatorP_vits(i, use_spectral_norm=use_spectral_norm) for i in periods + ] + self.discriminators = nn.ModuleList(discs) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + y_d_gs.append(y_d_g) + fmap_rs.append(fmap_r) + fmap_gs.append(fmap_g) + + outputs = { + "y_d_hat_r": y_d_rs, + "y_d_hat_g": y_d_gs, + "fmap_rs": fmap_rs, + "fmap_gs": fmap_gs, + } + + return outputs diff --git a/models/vocoders/gan/discriminator/mrd.py b/models/vocoders/gan/discriminator/mrd.py new file mode 100644 index 0000000000000000000000000000000000000000..38ee80bfbf82b6aa63c80dbc2c6ffed8cb50a924 --- /dev/null +++ b/models/vocoders/gan/discriminator/mrd.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm +from torch import nn + +LRELU_SLOPE = 0.1 + + +# This code is a refined MRD adopted from BigVGAN under the MIT License +# https://github.com/NVIDIA/BigVGAN + + +class DiscriminatorR(nn.Module): + def __init__(self, cfg, resolution): + super().__init__() + + self.resolution = resolution + assert ( + len(self.resolution) == 3 + ), "MRD layer requires list with len=3, got {}".format(self.resolution) + self.lrelu_slope = LRELU_SLOPE + + norm_f = ( + weight_norm if cfg.model.mrd.use_spectral_norm == False else spectral_norm + ) + if cfg.model.mrd.mrd_override: + print( + "INFO: overriding MRD use_spectral_norm as {}".format( + cfg.model.mrd.mrd_use_spectral_norm + ) + ) + norm_f = ( + weight_norm + if cfg.model.mrd.mrd_use_spectral_norm == False + else spectral_norm + ) + self.d_mult = cfg.model.mrd.discriminator_channel_mult_factor + if cfg.model.mrd.mrd_override: + print( + "INFO: overriding mrd channel multiplier as {}".format( + cfg.model.mrd.mrd_channel_mult + ) + ) + self.d_mult = cfg.model.mrd.mrd_channel_mult + + self.convs = nn.ModuleList( + [ + norm_f(nn.Conv2d(1, int(32 * self.d_mult), (3, 9), padding=(1, 4))), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 9), + stride=(1, 2), + padding=(1, 4), + ) + ), + norm_f( + nn.Conv2d( + int(32 * self.d_mult), + int(32 * self.d_mult), + (3, 3), + padding=(1, 1), + ) + ), + ] + ) + self.conv_post = norm_f( + nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)) + ) + + def forward(self, x): + fmap = [] + + x = self.spectrogram(x) + x = x.unsqueeze(1) + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, self.lrelu_slope) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + def spectrogram(self, x): + n_fft, hop_length, win_length = self.resolution + x = F.pad( + x, + (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), + mode="reflect", + ) + x = x.squeeze(1) + x = torch.stft( + x, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + center=False, + return_complex=True, + ) + x = torch.view_as_real(x) # [B, F, TT, 2] + mag = torch.norm(x, p=2, dim=-1) # [B, F, TT] + + return mag + + +class MultiResolutionDiscriminator(nn.Module): + def __init__(self, cfg, debug=False): + super().__init__() + self.resolutions = cfg.model.mrd.resolutions + assert ( + len(self.resolutions) == 3 + ), "MRD requires list of list with len=3, each element having a list with len=3. got {}".format( + self.resolutions + ) + self.discriminators = nn.ModuleList( + [DiscriminatorR(cfg, resolution) for resolution in self.resolutions] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(x=y) + y_d_g, fmap_g = d(x=y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/msd.py b/models/vocoders/gan/discriminator/msd.py new file mode 100644 index 0000000000000000000000000000000000000000..4c1556aea581878dcbe10f7a3bdebc33a4972e2c --- /dev/null +++ b/models/vocoders/gan/discriminator/msd.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, AvgPool1d +from torch.nn.utils import weight_norm, spectral_norm +from torch import nn +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class DiscriminatorS(nn.Module): + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + + self.convs = nn.ModuleList( + [ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ] + ) + + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(nn.Module): + def __init__(self, cfg): + super(MultiScaleDiscriminator, self).__init__() + + self.cfg = cfg + + self.discriminators = nn.ModuleList( + [ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ] + ) + + self.meanpools = nn.ModuleList( + [AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/mssbcqtd.py b/models/vocoders/gan/discriminator/mssbcqtd.py new file mode 100644 index 0000000000000000000000000000000000000000..213de5441754944a360707e99a3734ad035d9077 --- /dev/null +++ b/models/vocoders/gan/discriminator/mssbcqtd.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch import nn +from modules.vocoder_blocks import * + +from einops import rearrange +import torchaudio.transforms as T + +from nnAudio import features + +LRELU_SLOPE = 0.1 + + +class DiscriminatorCQT(nn.Module): + def __init__(self, cfg, hop_length, n_octaves, bins_per_octave): + super(DiscriminatorCQT, self).__init__() + self.cfg = cfg + + self.filters = cfg.model.mssbcqtd.filters + self.max_filters = cfg.model.mssbcqtd.max_filters + self.filters_scale = cfg.model.mssbcqtd.filters_scale + self.kernel_size = (3, 9) + self.dilations = cfg.model.mssbcqtd.dilations + self.stride = (1, 2) + + self.in_channels = cfg.model.mssbcqtd.in_channels + self.out_channels = cfg.model.mssbcqtd.out_channels + self.fs = cfg.preprocess.sample_rate + self.hop_length = hop_length + self.n_octaves = n_octaves + self.bins_per_octave = bins_per_octave + + self.cqt_transform = features.cqt.CQT2010v2( + sr=self.fs * 2, + hop_length=self.hop_length, + n_bins=self.bins_per_octave * self.n_octaves, + bins_per_octave=self.bins_per_octave, + output_format="Complex", + pad_mode="constant", + ) + + self.conv_pres = nn.ModuleList() + for i in range(self.n_octaves): + self.conv_pres.append( + NormConv2d( + self.in_channels * 2, + self.in_channels * 2, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + ) + ) + + self.convs = nn.ModuleList() + + self.convs.append( + NormConv2d( + self.in_channels * 2, + self.filters, + kernel_size=self.kernel_size, + padding=get_2d_padding(self.kernel_size), + ) + ) + + in_chs = min(self.filters_scale * self.filters, self.max_filters) + for i, dilation in enumerate(self.dilations): + out_chs = min( + (self.filters_scale ** (i + 1)) * self.filters, self.max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=self.kernel_size, + stride=self.stride, + dilation=(dilation, 1), + padding=get_2d_padding(self.kernel_size, (dilation, 1)), + norm="weight_norm", + ) + ) + in_chs = out_chs + out_chs = min( + (self.filters_scale ** (len(self.dilations) + 1)) * self.filters, + self.max_filters, + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + ) + + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(self.kernel_size[0], self.kernel_size[0]), + padding=get_2d_padding((self.kernel_size[0], self.kernel_size[0])), + norm="weight_norm", + ) + + self.activation = torch.nn.LeakyReLU(negative_slope=LRELU_SLOPE) + self.resample = T.Resample(orig_freq=self.fs, new_freq=self.fs * 2) + + def forward(self, x): + fmap = [] + + x = self.resample(x) + + z = self.cqt_transform(x) + + z_amplitude = z[:, :, :, 0].unsqueeze(1) + z_phase = z[:, :, :, 1].unsqueeze(1) + + z = torch.cat([z_amplitude, z_phase], dim=1) + z = rearrange(z, "b c w t -> b c t w") + + latent_z = [] + for i in range(self.n_octaves): + latent_z.append( + self.conv_pres[i]( + z[ + :, + :, + :, + i * self.bins_per_octave : (i + 1) * self.bins_per_octave, + ] + ) + ) + latent_z = torch.cat(latent_z, dim=-1) + + for i, l in enumerate(self.convs): + latent_z = l(latent_z) + + latent_z = self.activation(latent_z) + fmap.append(latent_z) + + latent_z = self.conv_post(latent_z) + + return latent_z, fmap + + +class MultiScaleSubbandCQTDiscriminator(nn.Module): + def __init__(self, cfg): + super(MultiScaleSubbandCQTDiscriminator, self).__init__() + + self.cfg = cfg + + self.discriminators = nn.ModuleList( + [ + DiscriminatorCQT( + cfg, + hop_length=cfg.model.mssbcqtd.hop_lengths[i], + n_octaves=cfg.model.mssbcqtd.n_octaves[i], + bins_per_octave=cfg.model.mssbcqtd.bins_per_octaves[i], + ) + for i in range(len(cfg.model.mssbcqtd.hop_lengths)) + ] + ) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/discriminator/msstftd.py b/models/vocoders/gan/discriminator/msstftd.py new file mode 100644 index 0000000000000000000000000000000000000000..83dedb78848d2d73ac667e7a191f05de1ed7bf21 --- /dev/null +++ b/models/vocoders/gan/discriminator/msstftd.py @@ -0,0 +1,226 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is adopted from META's Encodec under MIT License +# https://github.com/facebookresearch/encodec + +"""MS-STFT discriminator, provided here for reference.""" + +import typing as tp + +import torchaudio +import torch +from torch import nn +from einops import rearrange + +from modules.vocoder_blocks import * + + +FeatureMapType = tp.List[torch.Tensor] +LogitsType = torch.Tensor +DiscriminatorOutput = tp.Tuple[tp.List[LogitsType], tp.List[FeatureMapType]] + + +def get_2d_padding( + kernel_size: tp.Tuple[int, int], dilation: tp.Tuple[int, int] = (1, 1) +): + return ( + ((kernel_size[0] - 1) * dilation[0]) // 2, + ((kernel_size[1] - 1) * dilation[1]) // 2, + ) + + +class DiscriminatorSTFT(nn.Module): + """STFT sub-discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_fft (int): Size of FFT for each scale. Default: 1024 + hop_length (int): Length of hop between STFT windows for each scale. Default: 256 + kernel_size (tuple of int): Inner Conv2d kernel sizes. Default: ``(3, 9)`` + stride (tuple of int): Inner Conv2d strides. Default: ``(1, 2)`` + dilations (list of int): Inner Conv2d dilation on the time dimension. Default: ``[1, 2, 4]`` + win_length (int): Window size for each scale. Default: 1024 + normalized (bool): Whether to normalize by magnitude after stft. Default: True + norm (str): Normalization method. Default: `'weight_norm'` + activation (str): Activation function. Default: `'LeakyReLU'` + activation_params (dict): Parameters to provide to the activation function. + growth (int): Growth factor for the filters. Default: 1 + """ + + def __init__( + self, + filters: int, + in_channels: int = 1, + out_channels: int = 1, + n_fft: int = 1024, + hop_length: int = 256, + win_length: int = 1024, + max_filters: int = 1024, + filters_scale: int = 1, + kernel_size: tp.Tuple[int, int] = (3, 9), + dilations: tp.List = [1, 2, 4], + stride: tp.Tuple[int, int] = (1, 2), + normalized: bool = True, + norm: str = "weight_norm", + activation: str = "LeakyReLU", + activation_params: dict = {"negative_slope": 0.2}, + ): + super().__init__() + assert len(kernel_size) == 2 + assert len(stride) == 2 + self.filters = filters + self.in_channels = in_channels + self.out_channels = out_channels + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + self.normalized = normalized + self.activation = getattr(torch.nn, activation)(**activation_params) + self.spec_transform = torchaudio.transforms.Spectrogram( + n_fft=self.n_fft, + hop_length=self.hop_length, + win_length=self.win_length, + window_fn=torch.hann_window, + normalized=self.normalized, + center=False, + pad_mode=None, + power=None, + ) + spec_channels = 2 * self.in_channels + self.convs = nn.ModuleList() + self.convs.append( + NormConv2d( + spec_channels, + self.filters, + kernel_size=kernel_size, + padding=get_2d_padding(kernel_size), + ) + ) + in_chs = min(filters_scale * self.filters, max_filters) + for i, dilation in enumerate(dilations): + out_chs = min((filters_scale ** (i + 1)) * self.filters, max_filters) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=(dilation, 1), + padding=get_2d_padding(kernel_size, (dilation, 1)), + norm=norm, + ) + ) + in_chs = out_chs + out_chs = min( + (filters_scale ** (len(dilations) + 1)) * self.filters, max_filters + ) + self.convs.append( + NormConv2d( + in_chs, + out_chs, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + ) + self.conv_post = NormConv2d( + out_chs, + self.out_channels, + kernel_size=(kernel_size[0], kernel_size[0]), + padding=get_2d_padding((kernel_size[0], kernel_size[0])), + norm=norm, + ) + + def forward(self, x: torch.Tensor): + """Discriminator STFT Module is the sub module of MultiScaleSTFTDiscriminator. + + Args: + x (torch.Tensor): input tensor of shape [B, 1, Time] + + Returns: + z: z is the output of the last convolutional layer of shape + fmap: fmap is the list of feature maps of every convolutional layer of shape + """ + fmap = [] + z = self.spec_transform(x) # [B, 2, Freq, Frames, 2] + z = torch.cat([z.real, z.imag], dim=1) + z = rearrange(z, "b c w t -> b c t w") + for i, layer in enumerate(self.convs): + z = layer(z) + + z = self.activation(z) + fmap.append(z) + z = self.conv_post(z) + return z, fmap + + +class MultiScaleSTFTDiscriminator(nn.Module): + """Multi-Scale STFT (MS-STFT) discriminator. + Args: + filters (int): Number of filters in convolutions + in_channels (int): Number of input channels. Default: 1 + out_channels (int): Number of output channels. Default: 1 + n_ffts (Sequence[int]): Size of FFT for each scale + hop_lengths (Sequence[int]): Length of hop between STFT windows for each scale + win_lengths (Sequence[int]): Window size for each scale + **kwargs: additional args for STFTDiscriminator + """ + + def __init__( + self, + cfg, + in_channels: int = 1, + out_channels: int = 1, + n_ffts: tp.List[int] = [1024, 2048, 512], + hop_lengths: tp.List[int] = [256, 512, 256], + win_lengths: tp.List[int] = [1024, 2048, 512], + **kwargs, + ): + self.cfg = cfg + super().__init__() + assert len(n_ffts) == len(hop_lengths) == len(win_lengths) + self.discriminators = nn.ModuleList( + [ + DiscriminatorSTFT( + filters=self.cfg.model.msstftd.filters, + in_channels=in_channels, + out_channels=out_channels, + n_fft=n_ffts[i], + win_length=win_lengths[i], + hop_length=hop_lengths[i], + **kwargs, + ) + for i in range(len(n_ffts)) + ] + ) + self.num_discriminators = len(self.discriminators) + + def forward(self, y, y_hat) -> DiscriminatorOutput: + """Multi-Scale STFT (MS-STFT) discriminator. + + Args: + x (torch.Tensor): input waveform + + Returns: + logits: list of every discriminator's output + fmaps: list of every discriminator's feature maps, + each feature maps is a list of Discriminator STFT's every layer + """ + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + + for disc in self.discriminators: + y_d_r, fmap_r = disc(y) + y_d_g, fmap_g = disc(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs diff --git a/models/vocoders/gan/gan_vocoder_dataset.py b/models/vocoders/gan/gan_vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5bf87c371647a44fb5bcae33701eda65616e5fd7 --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_dataset.py @@ -0,0 +1,205 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import random + +import numpy as np + +from torch.nn import functional as F + +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from models.vocoders.vocoder_dataset import VocoderDataset + + +class GANVocoderDataset(VocoderDataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + super().__init__(cfg, dataset, is_valid) + + eval_index = random.randint(0, len(self.metadata) - 1) + eval_utt_info = self.metadata[eval_index] + eval_utt = "{}_{}".format(eval_utt_info["Dataset"], eval_utt_info["Uid"]) + self.eval_audio = np.load(self.utt2audio_path[eval_utt]) + if cfg.preprocess.use_mel: + self.eval_mel = np.load(self.utt2mel_path[eval_utt]) + if cfg.preprocess.use_frame_pitch: + self.eval_pitch = np.load(self.utt2frame_pitch_path[eval_utt]) + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + mel = np.pad( + mel, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + start = random.randint( + 0, mel.shape[-1] - self.cfg.preprocess.cut_mel_frame + ) + end = start + self.cfg.preprocess.cut_mel_frame + single_feature["start"] = start + single_feature["end"] = end + mel = mel[:, single_feature["start"] : single_feature["end"]] + single_feature["mel"] = mel + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch = np.load(self.utt2frame_pitch_path[utt]) + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + aligned_frame_pitch = np.pad( + aligned_frame_pitch, + ( + ( + 0, + self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + - audio.shape[-1], + ) + ), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + start = random.randint( + 0, + aligned_frame_pitch.shape[-1] + - self.cfg.preprocess.cut_mel_frame, + ) + end = start + self.cfg.preprocess.cut_mel_frame + single_feature["start"] = start + single_feature["end"] = end + aligned_frame_pitch = aligned_frame_pitch[ + single_feature["start"] : single_feature["end"] + ] + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + + assert "target_len" in single_feature.keys() + + if ( + audio.shape[-1] + <= self.cfg.preprocess.cut_mel_frame * self.cfg.preprocess.hop_size + ): + audio = np.pad( + audio, + ( + ( + 0, + self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + - audio.shape[-1], + ) + ), + mode="constant", + ) + else: + if "start" not in single_feature.keys(): + audio = audio[ + 0 : self.cfg.preprocess.cut_mel_frame + * self.cfg.preprocess.hop_size + ] + else: + audio = audio[ + single_feature["start"] + * self.cfg.preprocess.hop_size : single_feature["end"] + * self.cfg.preprocess.hop_size, + ] + single_feature["audio"] = audio + + if self.cfg.preprocess.use_amplitude_phase: + logamp = np.load(self.utt2logamp_path[utt]) + pha = np.load(self.utt2pha_path[utt]) + rea = np.load(self.utt2rea_path[utt]) + imag = np.load(self.utt2imag_path[utt]) + + assert "target_len" in single_feature.keys() + + if single_feature["target_len"] <= self.cfg.preprocess.cut_mel_frame: + logamp = np.pad( + logamp, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + pha = np.pad( + pha, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + rea = np.pad( + rea, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + imag = np.pad( + imag, + ((0, 0), (0, self.cfg.preprocess.cut_mel_frame - mel.shape[-1])), + mode="constant", + ) + else: + logamp = logamp[:, single_feature["start"] : single_feature["end"]] + pha = pha[:, single_feature["start"] : single_feature["end"]] + rea = rea[:, single_feature["start"] : single_feature["end"]] + imag = imag[:, single_feature["start"] : single_feature["end"]] + single_feature["logamp"] = logamp + single_feature["pha"] = pha + single_feature["rea"] = rea + single_feature["imag"] = imag + + return single_feature + + +class GANVocoderCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, n_mels, frame] + # frame_pitch: [b, frame] + # audios: [b, frame * hop_size] + + for key in batch[0].keys(): + if key in ["target_len", "start", "end"]: + continue + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/vocoders/gan/gan_vocoder_inference.py b/models/vocoders/gan/gan_vocoder_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..cb69631662dedf4fc73a29f675f0a4bc361b03ec --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_inference.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from utils.util import pad_mels_to_tensors, pad_f0_to_tensors + + +def vocoder_inference(cfg, model, mels, f0s=None, device=None, fast_inference=False): + """Inference the vocoder + Args: + mels: A tensor of mel-specs with the shape (batch_size, num_mels, frames) + Returns: + audios: A tensor of audios with the shape (batch_size, seq_len) + """ + model.eval() + + with torch.no_grad(): + mels = mels.to(device) + if f0s != None: + f0s = f0s.to(device) + + if f0s == None and not cfg.preprocess.extract_amplitude_phase: + output = model.forward(mels) + elif cfg.preprocess.extract_amplitude_phase: + ( + _, + _, + _, + _, + output, + ) = model.forward(mels) + else: + output = model.forward(mels, f0s) + + return output.squeeze(1).detach().cpu() + + +def synthesis_audios(cfg, model, mels, f0s=None, batch_size=None, fast_inference=False): + """Inference the vocoder + Args: + mels: A list of mel-specs + Returns: + audios: A list of audios + """ + # Get the device + device = next(model.parameters()).device + + audios = [] + + # Pad the given list into tensors + mel_batches, mel_frames = pad_mels_to_tensors(mels, batch_size) + if f0s != None: + f0_batches = pad_f0_to_tensors(f0s, batch_size) + + if f0s == None: + for mel_batch, mel_frame in zip(mel_batches, mel_frames): + for i in range(mel_batch.shape[0]): + mel = mel_batch[i] + frame = mel_frame[i] + audio = vocoder_inference( + cfg, + model, + mel.unsqueeze(0), + device=device, + fast_inference=fast_inference, + ).squeeze(0) + + # # Apply fade_out to make the sound more natural + # fade_out = torch.linspace( + # 1, 0, steps=20 * model.cfg.preprocess.hop_size + # ).cpu() + + # calculate the audio length + audio_length = frame * model.cfg.preprocess.hop_size + audio = audio[:audio_length] + + # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out + + audios.append(audio) + else: + for mel_batch, f0_batch, mel_frame in zip(mel_batches, f0_batches, mel_frames): + for i in range(mel_batch.shape[0]): + mel = mel_batch[i] + f0 = f0_batch[i] + frame = mel_frame[i] + audio = vocoder_inference( + cfg, + model, + mel.unsqueeze(0), + f0s=f0.unsqueeze(0), + device=device, + fast_inference=fast_inference, + ).squeeze(0) + + # # Apply fade_out to make the sound more natural + # fade_out = torch.linspace( + # 1, 0, steps=20 * model.cfg.preprocess.hop_size + # ).cpu() + + # calculate the audio length + audio_length = frame * model.cfg.preprocess.hop_length + audio = audio[:audio_length] + + # audio[-20 * model.cfg.preprocess.hop_size :] *= fade_out + + audios.append(audio) + return audios diff --git a/models/vocoders/gan/gan_vocoder_trainer.py b/models/vocoders/gan/gan_vocoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..2fb9c8f03a7de14d0162bfd671d33b76890293a5 --- /dev/null +++ b/models/vocoders/gan/gan_vocoder_trainer.py @@ -0,0 +1,1112 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import sys +import time +import torch +import json +import itertools +import accelerate +import torch.distributed as dist +import torch.nn.functional as F +from tqdm import tqdm +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.tensorboard import SummaryWriter + +from torch.optim import AdamW +from torch.optim.lr_scheduler import ExponentialLR + +from librosa.filters import mel as librosa_mel_fn + +from accelerate.logging import get_logger +from pathlib import Path + +from utils.io import save_audio +from utils.data_utils import * +from utils.util import ( + Logger, + ValueWindow, + remove_older_ckpt, + set_all_random_seed, + save_config, +) +from utils.mel import extract_mel_features +from models.vocoders.vocoder_trainer import VocoderTrainer +from models.vocoders.gan.gan_vocoder_dataset import ( + GANVocoderDataset, + GANVocoderCollator, +) + +from models.vocoders.gan.generator.bigvgan import BigVGAN +from models.vocoders.gan.generator.hifigan import HiFiGAN +from models.vocoders.gan.generator.melgan import MelGAN +from models.vocoders.gan.generator.nsfhifigan import NSFHiFiGAN +from models.vocoders.gan.generator.apnet import APNet + +from models.vocoders.gan.discriminator.mpd import MultiPeriodDiscriminator +from models.vocoders.gan.discriminator.mrd import MultiResolutionDiscriminator +from models.vocoders.gan.discriminator.mssbcqtd import MultiScaleSubbandCQTDiscriminator +from models.vocoders.gan.discriminator.msd import MultiScaleDiscriminator +from models.vocoders.gan.discriminator.msstftd import MultiScaleSTFTDiscriminator + +from models.vocoders.gan.gan_vocoder_inference import vocoder_inference + +supported_generators = { + "bigvgan": BigVGAN, + "hifigan": HiFiGAN, + "melgan": MelGAN, + "nsfhifigan": NSFHiFiGAN, + "apnet": APNet, +} + +supported_discriminators = { + "mpd": MultiPeriodDiscriminator, + "msd": MultiScaleDiscriminator, + "mrd": MultiResolutionDiscriminator, + "msstftd": MultiScaleSTFTDiscriminator, + "mssbcqtd": MultiScaleSubbandCQTDiscriminator, +} + + +class GANVocoderTrainer(VocoderTrainer): + def __init__(self, args, cfg): + super().__init__() + + self.args = args + self.cfg = cfg + + cfg.exp_name = args.exp_name + + # Init accelerator + self._init_accelerator() + self.accelerator.wait_for_everyone() + + # Init logger + with self.accelerator.main_process_first(): + self.logger = get_logger(args.exp_name, log_level=args.log_level) + + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New training process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + self.logger.debug(f"Using {args.log_level.upper()} logging level.") + self.logger.info(f"Experiment name: {args.exp_name}") + self.logger.info(f"Experiment directory: {self.exp_dir}") + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Init training status + self.batch_count: int = 0 + self.step: int = 0 + self.epoch: int = 0 + + self.max_epoch = ( + self.cfg.train.max_epoch if self.cfg.train.max_epoch > 0 else float("inf") + ) + self.logger.info( + "Max epoch: {}".format( + self.max_epoch if self.max_epoch < float("inf") else "Unlimited" + ) + ) + + # Check potential erorrs + if self.accelerator.is_main_process: + self._check_basic_configs() + self.save_checkpoint_stride = self.cfg.train.save_checkpoint_stride + self.checkpoints_path = [ + [] for _ in range(len(self.save_checkpoint_stride)) + ] + self.run_eval = self.cfg.train.run_eval + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Build dataloader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.train_dataloader, self.valid_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.generator, self.discriminators = self._build_model() + end = time.monotonic_ns() + self.logger.debug(self.generator) + for _, discriminator in self.discriminators.items(): + self.logger.debug(discriminator) + self.logger.info(f"Building model done in {(end - start) / 1e6:.2f}ms") + self.logger.info(f"Model parameters: {self._count_parameters()/1e6:.2f}M") + + # Build optimizers and schedulers + with self.accelerator.main_process_first(): + self.logger.info("Building optimizer and scheduler...") + start = time.monotonic_ns() + ( + self.generator_optimizer, + self.discriminator_optimizer, + ) = self._build_optimizer() + ( + self.generator_scheduler, + self.discriminator_scheduler, + ) = self._build_scheduler() + end = time.monotonic_ns() + self.logger.info( + f"Building optimizer and scheduler done in {(end - start) / 1e6:.2f}ms" + ) + + # Accelerator preparing + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + ( + self.train_dataloader, + self.valid_dataloader, + self.generator, + self.generator_optimizer, + self.discriminator_optimizer, + self.generator_scheduler, + self.discriminator_scheduler, + ) = self.accelerator.prepare( + self.train_dataloader, + self.valid_dataloader, + self.generator, + self.generator_optimizer, + self.discriminator_optimizer, + self.generator_scheduler, + self.discriminator_scheduler, + ) + for key, discriminator in self.discriminators.items(): + self.discriminators[key] = self.accelerator.prepare_model(discriminator) + end = time.monotonic_ns() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.2f}ms") + + # Build criterions + with self.accelerator.main_process_first(): + self.logger.info("Building criterion...") + start = time.monotonic_ns() + self.criterions = self._build_criterion() + end = time.monotonic_ns() + self.logger.info(f"Building criterion done in {(end - start) / 1e6:.2f}ms") + + # Resume checkpoints + with self.accelerator.main_process_first(): + if args.resume_type: + self.logger.info("Resuming from checkpoint...") + start = time.monotonic_ns() + ckpt_path = Path(args.checkpoint) + if self._is_valid_pattern(ckpt_path.parts[-1]): + ckpt_path = self._load_model( + None, args.checkpoint, args.resume_type + ) + else: + ckpt_path = self._load_model( + args.checkpoint, resume_type=args.resume_type + ) + end = time.monotonic_ns() + self.logger.info( + f"Resuming from checkpoint done in {(end - start) / 1e6:.2f}ms" + ) + self.checkpoints_path = json.load( + open(os.path.join(ckpt_path, "ckpts.json"), "r") + ) + + self.checkpoint_dir = os.path.join(self.exp_dir, "checkpoint") + if self.accelerator.is_main_process: + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.logger.debug(f"Checkpoint directory: {self.checkpoint_dir}") + + # Save config + self.config_save_path = os.path.join(self.exp_dir, "args.json") + + def _build_dataset(self): + return GANVocoderDataset, GANVocoderCollator + + def _build_criterion(self): + class feature_criterion(torch.nn.Module): + def __init__(self, cfg): + super(feature_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, fmap_r, fmap_g): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + loss = loss * 2 + elif self.cfg.model.generator in ["melgan"]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += self.l1Loss(rl, gl) + + loss = loss * 10 + elif self.cfg.model.generator in ["codec"]: + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss = loss + self.l1Loss(rl, gl) / torch.mean( + torch.abs(rl) + ) + + KL_scale = len(fmap_r) * len(fmap_r[0]) + + loss = 3 * loss / KL_scale + else: + raise NotImplementedError + + return loss + + class discriminator_criterion(torch.nn.Module): + def __init__(self, cfg): + super(discriminator_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg**2) + loss += r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + elif self.cfg.model.generator in ["melgan"]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(self.relu(1 - dr)) + g_loss = torch.mean(self.relu(1 + dg)) + loss = loss + r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + elif self.cfg.model.generator in ["codec"]: + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean(self.relu(1 - dr)) + g_loss = torch.mean(self.relu(1 + dg)) + loss = loss + r_loss + g_loss + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + loss = loss / len(disc_real_outputs) + else: + raise NotImplementedError + + return loss, r_losses, g_losses + + class generator_criterion(torch.nn.Module): + def __init__(self, cfg): + super(generator_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, disc_outputs): + loss = 0 + gen_losses = [] + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + gen_losses.append(l) + loss += l + elif self.cfg.model.generator in ["melgan"]: + for dg in disc_outputs: + l = -torch.mean(dg) + gen_losses.append(l) + loss += l + elif self.cfg.model.generator in ["codec"]: + for dg in disc_outputs: + l = torch.mean(self.relu(1 - dg)) / len(disc_outputs) + gen_losses.append(l) + loss += l + else: + raise NotImplementedError + + return loss, gen_losses + + class mel_criterion(torch.nn.Module): + def __init__(self, cfg): + super(mel_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, y_gt, y_pred): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "melgan", + "codec", + "apnet", + ]: + y_gt_mel = extract_mel_features(y_gt, self.cfg.preprocess) + y_pred_mel = extract_mel_features( + y_pred.squeeze(1), self.cfg.preprocess + ) + + loss = self.l1Loss(y_gt_mel, y_pred_mel) * 45 + else: + raise NotImplementedError + + return loss + + class wav_criterion(torch.nn.Module): + def __init__(self, cfg): + super(wav_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, y_gt, y_pred): + loss = 0 + + if self.cfg.model.generator in [ + "hifigan", + "nsfhifigan", + "bigvgan", + "apnet", + ]: + loss = self.l2Loss(y_gt, y_pred.squeeze(1)) * 100 + elif self.cfg.model.generator in ["melgan"]: + loss = self.l1Loss(y_gt, y_pred.squeeze(1)) / 10 + elif self.cfg.model.generator in ["codec"]: + loss = self.l1Loss(y_gt, y_pred.squeeze(1)) + self.l2Loss( + y_gt, y_pred.squeeze(1) + ) + loss /= 10 + else: + raise NotImplementedError + + return loss + + class phase_criterion(torch.nn.Module): + def __init__(self, cfg): + super(phase_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, phase_gt, phase_pred): + n_fft = self.cfg.preprocess.n_fft + frames = phase_gt.size()[-1] + + GD_matrix = ( + torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=1) + - torch.triu(torch.ones(n_fft // 2 + 1, n_fft // 2 + 1), diagonal=2) + - torch.eye(n_fft // 2 + 1) + ) + GD_matrix = GD_matrix.to(phase_pred.device) + + GD_r = torch.matmul(phase_gt.permute(0, 2, 1), GD_matrix) + GD_g = torch.matmul(phase_pred.permute(0, 2, 1), GD_matrix) + + PTD_matrix = ( + torch.triu(torch.ones(frames, frames), diagonal=1) + - torch.triu(torch.ones(frames, frames), diagonal=2) + - torch.eye(frames) + ) + PTD_matrix = PTD_matrix.to(phase_pred.device) + + PTD_r = torch.matmul(phase_gt, PTD_matrix) + PTD_g = torch.matmul(phase_pred, PTD_matrix) + + IP_loss = torch.mean(-torch.cos(phase_gt - phase_pred)) + GD_loss = torch.mean(-torch.cos(GD_r - GD_g)) + PTD_loss = torch.mean(-torch.cos(PTD_r - PTD_g)) + + return 100 * (IP_loss + GD_loss + PTD_loss) + + class amplitude_criterion(torch.nn.Module): + def __init__(self, cfg): + super(amplitude_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__(self, log_amplitude_gt, log_amplitude_pred): + amplitude_loss = self.l2Loss(log_amplitude_gt, log_amplitude_pred) + + return 45 * amplitude_loss + + class consistency_criterion(torch.nn.Module): + def __init__(self, cfg): + super(consistency_criterion, self).__init__() + self.cfg = cfg + self.l1Loss = torch.nn.L1Loss(reduction="mean") + self.l2Loss = torch.nn.MSELoss(reduction="mean") + self.relu = torch.nn.ReLU() + + def __call__( + self, + rea_gt, + rea_pred, + rea_pred_final, + imag_gt, + imag_pred, + imag_pred_final, + ): + C_loss = torch.mean( + torch.mean( + (rea_pred - rea_pred_final) ** 2 + + (imag_pred - imag_pred_final) ** 2, + (1, 2), + ) + ) + + L_R = self.l1Loss(rea_gt, rea_pred) + L_I = self.l1Loss(imag_gt, imag_pred) + + return 20 * (C_loss + 2.25 * (L_R + L_I)) + + criterions = dict() + for key in self.cfg.train.criterions: + if key == "feature": + criterions["feature"] = feature_criterion(self.cfg) + elif key == "discriminator": + criterions["discriminator"] = discriminator_criterion(self.cfg) + elif key == "generator": + criterions["generator"] = generator_criterion(self.cfg) + elif key == "mel": + criterions["mel"] = mel_criterion(self.cfg) + elif key == "wav": + criterions["wav"] = wav_criterion(self.cfg) + elif key == "phase": + criterions["phase"] = phase_criterion(self.cfg) + elif key == "amplitude": + criterions["amplitude"] = amplitude_criterion(self.cfg) + elif key == "consistency": + criterions["consistency"] = consistency_criterion(self.cfg) + else: + raise NotImplementedError + + return criterions + + def _build_model(self): + generator = supported_generators[self.cfg.model.generator](self.cfg) + discriminators = dict() + for key in self.cfg.model.discriminators: + discriminators[key] = supported_discriminators[key](self.cfg) + + return generator, discriminators + + def _build_optimizer(self): + optimizer_params_generator = [dict(params=self.generator.parameters())] + generator_optimizer = AdamW( + optimizer_params_generator, + lr=self.cfg.train.adamw.lr, + betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), + ) + + optimizer_params_discriminator = [] + for discriminator in self.discriminators.keys(): + optimizer_params_discriminator.append( + dict(params=self.discriminators[discriminator].parameters()) + ) + discriminator_optimizer = AdamW( + optimizer_params_discriminator, + lr=self.cfg.train.adamw.lr, + betas=(self.cfg.train.adamw.adam_b1, self.cfg.train.adamw.adam_b2), + ) + + return generator_optimizer, discriminator_optimizer + + def _build_scheduler(self): + discriminator_scheduler = ExponentialLR( + self.discriminator_optimizer, + gamma=self.cfg.train.exponential_lr.lr_decay, + last_epoch=self.epoch - 1, + ) + + generator_scheduler = ExponentialLR( + self.generator_optimizer, + gamma=self.cfg.train.exponential_lr.lr_decay, + last_epoch=self.epoch - 1, + ) + + return generator_scheduler, discriminator_scheduler + + def train_loop(self): + """Training process""" + self.accelerator.wait_for_everyone() + + # Dump config + if self.accelerator.is_main_process: + self._dump_cfg(self.config_save_path) + self.generator.train() + for key in self.discriminators.keys(): + self.discriminators[key].train() + self.generator_optimizer.zero_grad() + self.discriminator_optimizer.zero_grad() + + # Sync and start training + self.accelerator.wait_for_everyone() + while self.epoch < self.max_epoch: + self.logger.info("\n") + self.logger.info("-" * 32) + self.logger.info("Epoch {}: ".format(self.epoch)) + + # Train and Validate + train_total_loss, train_losses = self._train_epoch() + for key, loss in train_losses.items(): + self.logger.info(" |- Train/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Train {} Loss".format(key): loss}, + step=self.epoch, + ) + valid_total_loss, valid_losses = self._valid_epoch() + for key, loss in valid_losses.items(): + self.logger.info(" |- Valid/{} Loss: {:.6f}".format(key, loss)) + self.accelerator.log( + {"Epoch/Valid {} Loss".format(key): loss}, + step=self.epoch, + ) + self.accelerator.log( + { + "Epoch/Train Total Loss": train_total_loss, + "Epoch/Valid Total Loss": valid_total_loss, + }, + step=self.epoch, + ) + + # Update scheduler + self.accelerator.wait_for_everyone() + self.generator_scheduler.step() + self.discriminator_scheduler.step() + + # Check save checkpoint interval + run_eval = False + if self.accelerator.is_main_process: + save_checkpoint = False + for i, num in enumerate(self.save_checkpoint_stride): + if self.epoch % num == 0: + save_checkpoint = True + run_eval |= self.run_eval[i] + + # Save checkpoints + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and save_checkpoint: + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + self.accelerator.save_state(path) + json.dump( + self.checkpoints_path, + open(os.path.join(path, "ckpts.json"), "w"), + ensure_ascii=False, + indent=4, + ) + + # Save eval audios + self.accelerator.wait_for_everyone() + if self.accelerator.is_main_process and run_eval: + for i in range(len(self.valid_dataloader.dataset.eval_audios)): + if self.cfg.preprocess.use_frame_pitch: + eval_audio = self._inference( + self.valid_dataloader.dataset.eval_mels[i], + eval_pitch=self.valid_dataloader.dataset.eval_pitchs[i], + use_pitch=True, + ) + else: + eval_audio = self._inference( + self.valid_dataloader.dataset.eval_mels[i] + ) + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}.wav".format( + self.epoch, + self.step, + valid_total_loss, + self.valid_dataloader.dataset.eval_dataset_names[i], + ), + ) + path_gt = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}_eval_audio_{}_gt.wav".format( + self.epoch, + self.step, + valid_total_loss, + self.valid_dataloader.dataset.eval_dataset_names[i], + ), + ) + save_audio(path, eval_audio, self.cfg.preprocess.sample_rate) + save_audio( + path_gt, + self.valid_dataloader.dataset.eval_audios[i], + self.cfg.preprocess.sample_rate, + ) + + self.accelerator.wait_for_everyone() + + self.epoch += 1 + + # Finish training + self.accelerator.wait_for_everyone() + path = os.path.join( + self.checkpoint_dir, + "epoch-{:04d}_step-{:07d}_loss-{:.6f}".format( + self.epoch, self.step, valid_total_loss + ), + ) + self.accelerator.save_state(path) + + def _train_epoch(self): + """Training epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.generator.train() + for key, _ in self.discriminators.items(): + self.discriminators[key].train() + + epoch_losses: dict = {} + epoch_total_loss: int = 0 + + for batch in tqdm( + self.train_dataloader, + desc=f"Training Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Get losses + total_loss, losses = self._train_step(batch) + self.batch_count += 1 + + # Log info + if self.batch_count % self.cfg.train.gradient_accumulation_step == 0: + self.accelerator.log( + { + "Step/Generator Learning Rate": self.generator_optimizer.param_groups[ + 0 + ][ + "lr" + ], + "Step/Discriminator Learning Rate": self.discriminator_optimizer.param_groups[ + 0 + ][ + "lr" + ], + }, + step=self.step, + ) + for key, _ in losses.items(): + self.accelerator.log( + { + "Step/Train {} Loss".format(key): losses[key], + }, + step=self.step, + ) + + if not epoch_losses: + epoch_losses = losses + else: + for key, value in losses.items(): + epoch_losses[key] += value + epoch_total_loss += total_loss + self.step += 1 + + # Get and log total losses + self.accelerator.wait_for_everyone() + epoch_total_loss = ( + epoch_total_loss + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + for key in epoch_losses.keys(): + epoch_losses[key] = ( + epoch_losses[key] + / len(self.train_dataloader) + * self.cfg.train.gradient_accumulation_step + ) + return epoch_total_loss, epoch_losses + + def _train_step(self, data): + """Training forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_train_epoch`` for usage. + """ + # Init losses + train_losses = {} + total_loss = 0 + + generator_losses = {} + generator_total_loss = 0 + discriminator_losses = {} + discriminator_total_loss = 0 + + # Use input feature to get predictions + mel_input = data["mel"] + audio_gt = data["audio"] + + if self.cfg.preprocess.extract_amplitude_phase: + logamp_gt = data["logamp"] + pha_gt = data["pha"] + rea_gt = data["rea"] + imag_gt = data["imag"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = data["frame_pitch"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = pitch_input.float() + audio_pred = self.generator.forward(mel_input, pitch_input) + elif self.cfg.preprocess.extract_amplitude_phase: + ( + logamp_pred, + pha_pred, + rea_pred, + imag_pred, + audio_pred, + ) = self.generator.forward(mel_input) + from utils.mel import amplitude_phase_spectrum + + _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( + audio_pred.squeeze(1), self.cfg.preprocess + ) + else: + audio_pred = self.generator.forward(mel_input) + + # Calculate and BP Discriminator losses + self.discriminator_optimizer.zero_grad() + for key, _ in self.discriminators.items(): + y_r, y_g, _, _ = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred.detach() + ) + ( + discriminator_losses["{}_discriminator".format(key)], + _, + _, + ) = self.criterions["discriminator"](y_r, y_g) + discriminator_total_loss += discriminator_losses[ + "{}_discriminator".format(key) + ] + + self.accelerator.backward(discriminator_total_loss) + self.discriminator_optimizer.step() + + # Calculate and BP Generator losses + self.generator_optimizer.zero_grad() + for key, _ in self.discriminators.items(): + y_r, y_g, f_r, f_g = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + generator_losses["{}_feature".format(key)] = self.criterions["feature"]( + f_r, f_g + ) + generator_losses["{}_generator".format(key)], _ = self.criterions[ + "generator" + ](y_g) + generator_total_loss += generator_losses["{}_feature".format(key)] + generator_total_loss += generator_losses["{}_generator".format(key)] + + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + + if "amplitude" in self.criterions.keys(): + generator_losses["amplitude"] = self.criterions["amplitude"]( + logamp_gt, logamp_pred + ) + generator_total_loss += generator_losses["amplitude"] + + if "phase" in self.criterions.keys(): + generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) + generator_total_loss += generator_losses["phase"] + + if "consistency" in self.criterions.keys(): + generator_losses["consistency"] = self.criterions["consistency"]( + rea_gt, rea_pred, rea_pred_final, imag_gt, imag_pred, imag_pred_final + ) + generator_total_loss += generator_losses["consistency"] + + self.accelerator.backward(generator_total_loss) + self.generator_optimizer.step() + + # Get the total losses + total_loss = discriminator_total_loss + generator_total_loss + train_losses.update(discriminator_losses) + train_losses.update(generator_losses) + + for key, _ in train_losses.items(): + train_losses[key] = train_losses[key].item() + + return total_loss.item(), train_losses + + def _valid_epoch(self): + """Testing epoch. Should return average loss of a batch (sample) over + one epoch. See ``train_loop`` for usage. + """ + self.generator.eval() + for key, _ in self.discriminators.items(): + self.discriminators[key].eval() + + epoch_losses: dict = {} + epoch_total_loss: int = 0 + + for batch in tqdm( + self.valid_dataloader, + desc=f"Validating Epoch {self.epoch}", + unit="batch", + colour="GREEN", + leave=False, + dynamic_ncols=True, + smoothing=0.04, + disable=not self.accelerator.is_main_process, + ): + # Get losses + total_loss, losses = self._valid_step(batch) + + # Log info + for key, _ in losses.items(): + self.accelerator.log( + { + "Step/Valid {} Loss".format(key): losses[key], + }, + step=self.step, + ) + + if not epoch_losses: + epoch_losses = losses + else: + for key, value in losses.items(): + epoch_losses[key] += value + epoch_total_loss += total_loss + + # Get and log total losses + self.accelerator.wait_for_everyone() + epoch_total_loss = epoch_total_loss / len(self.valid_dataloader) + for key in epoch_losses.keys(): + epoch_losses[key] = epoch_losses[key] / len(self.valid_dataloader) + return epoch_total_loss, epoch_losses + + def _valid_step(self, data): + """Testing forward step. Should return average loss of a sample over + one batch. Provoke ``_forward_step`` is recommended except for special case. + See ``_test_epoch`` for usage. + """ + # Init losses + valid_losses = {} + total_loss = 0 + + generator_losses = {} + generator_total_loss = 0 + discriminator_losses = {} + discriminator_total_loss = 0 + + # Use feature inputs to get the predicted audio + mel_input = data["mel"] + audio_gt = data["audio"] + + if self.cfg.preprocess.extract_amplitude_phase: + logamp_gt = data["logamp"] + pha_gt = data["pha"] + rea_gt = data["rea"] + imag_gt = data["imag"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = data["frame_pitch"] + + if self.cfg.preprocess.use_frame_pitch: + pitch_input = pitch_input.float() + audio_pred = self.generator.forward(mel_input, pitch_input) + elif self.cfg.preprocess.extract_amplitude_phase: + ( + logamp_pred, + pha_pred, + rea_pred, + imag_pred, + audio_pred, + ) = self.generator.forward(mel_input) + from utils.mel import amplitude_phase_spectrum + + _, _, rea_pred_final, imag_pred_final = amplitude_phase_spectrum( + audio_pred.squeeze(1), self.cfg.preprocess + ) + else: + audio_pred = self.generator.forward(mel_input) + + # Get Discriminator losses + for key, _ in self.discriminators.items(): + y_r, y_g, _, _ = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + ( + discriminator_losses["{}_discriminator".format(key)], + _, + _, + ) = self.criterions["discriminator"](y_r, y_g) + discriminator_total_loss += discriminator_losses[ + "{}_discriminator".format(key) + ] + + for key, _ in self.discriminators.items(): + y_r, y_g, f_r, f_g = self.discriminators[key].forward( + audio_gt.unsqueeze(1), audio_pred + ) + generator_losses["{}_feature".format(key)] = self.criterions["feature"]( + f_r, f_g + ) + generator_losses["{}_generator".format(key)], _ = self.criterions[ + "generator" + ](y_g) + generator_total_loss += generator_losses["{}_feature".format(key)] + generator_total_loss += generator_losses["{}_generator".format(key)] + + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + if "mel" in self.criterions.keys(): + generator_losses["mel"] = self.criterions["mel"](audio_gt, audio_pred) + generator_total_loss += generator_losses["mel"] + + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + if "wav" in self.criterions.keys(): + generator_losses["wav"] = self.criterions["wav"](audio_gt, audio_pred) + generator_total_loss += generator_losses["wav"] + + if "amplitude" in self.criterions.keys(): + generator_losses["amplitude"] = self.criterions["amplitude"]( + logamp_gt, logamp_pred + ) + generator_total_loss += generator_losses["amplitude"] + + if "phase" in self.criterions.keys(): + generator_losses["phase"] = self.criterions["phase"](pha_gt, pha_pred) + generator_total_loss += generator_losses["phase"] + + if "consistency" in self.criterions.keys(): + generator_losses["consistency"] = self.criterions["consistency"]( + rea_gt, + rea_pred, + rea_pred_final, + imag_gt, + imag_pred, + imag_pred_final, + ) + generator_total_loss += generator_losses["consistency"] + + total_loss = discriminator_total_loss + generator_total_loss + valid_losses.update(discriminator_losses) + valid_losses.update(generator_losses) + + for item in valid_losses: + valid_losses[item] = valid_losses[item].item() + for item in valid_losses: + valid_losses[item] = valid_losses[item].item() + + return total_loss.item(), valid_losses + return total_loss.item(), valid_losses + + def _inference(self, eval_mel, eval_pitch=None, use_pitch=False): + """Inference during training for test audios.""" + if use_pitch: + eval_pitch = align_length(eval_pitch, eval_mel.shape[1]) + eval_audio = vocoder_inference( + self.cfg, + self.generator, + torch.from_numpy(eval_mel).unsqueeze(0), + f0s=torch.from_numpy(eval_pitch).unsqueeze(0).float(), + device=next(self.generator.parameters()).device, + ).squeeze(0) + else: + eval_audio = vocoder_inference( + self.cfg, + self.generator, + torch.from_numpy(eval_mel).unsqueeze(0), + device=next(self.generator.parameters()).device, + ).squeeze(0) + return eval_audio + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If checkpoint_path is None, it will + load the latest checkpoint in checkpoint_dir. If checkpoint_path is not + None, it will load the checkpoint specified by checkpoint_path. **Only use this + method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + if resume_type == "resume": + self.accelerator.load_state(checkpoint_path) + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + elif resume_type == "finetune": + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.generator), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + for key, _ in self.discriminators.items(): + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.discriminators[key]), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + return checkpoint_path + + def _count_parameters(self): + result = sum(p.numel() for p in self.generator.parameters()) + for _, discriminator in self.discriminators.items(): + result += sum(p.numel() for p in discriminator.parameters()) + return result diff --git a/models/vocoders/gan/generator/apnet.py b/models/vocoders/gan/generator/apnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f39dd6b01f6be5a6bdd2e04ca822a4a9c9b4c9b4 --- /dev/null +++ b/models/vocoders/gan/generator/apnet.py @@ -0,0 +1,395 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from modules.vocoder_blocks import * + +LRELU_SLOPE = 0.1 + + +class ISTFT(nn.Module): + """ + Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with + windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges. + See issue: https://github.com/pytorch/pytorch/issues/62323 + Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs. + The NOLA constraint is met as we trim padded samples anyway. + + Args: + n_fft (int): Size of Fourier transform. + hop_length (int): The distance between neighboring sliding window frames. + win_length (int): The size of window frame and STFT filter. + padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same". + """ + + def __init__( + self, + n_fft: int, + hop_length: int, + win_length: int, + padding: str = "same", + ): + super().__init__() + if padding not in ["center", "same"]: + raise ValueError("Padding must be 'center' or 'same'.") + self.padding = padding + self.n_fft = n_fft + self.hop_length = hop_length + self.win_length = win_length + + def forward(self, spec: torch.Tensor, window) -> torch.Tensor: + """ + Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram. + + Args: + spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size, + N is the number of frequency bins, and T is the number of time frames. + + Returns: + Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal. + """ + if self.padding == "center": + # Fallback to pytorch native implementation + return torch.istft( + spec, + self.n_fft, + self.hop_length, + self.win_length, + window, + center=True, + ) + elif self.padding == "same": + pad = (self.win_length - self.hop_length) // 2 + else: + raise ValueError("Padding must be 'center' or 'same'.") + + assert spec.dim() == 3, "Expected a 3D tensor as input" + B, N, T = spec.shape + + # Inverse FFT + ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") + ifft = ifft * window[None, :, None] + + # Overlap and Add + output_size = (T - 1) * self.hop_length + self.win_length + y = torch.nn.functional.fold( + ifft, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + )[:, 0, 0, pad:-pad] + + # Window envelope + window_sq = window.square().expand(1, T, -1).transpose(1, 2) + window_envelope = torch.nn.functional.fold( + window_sq, + output_size=(1, output_size), + kernel_size=(1, self.win_length), + stride=(1, self.hop_length), + ).squeeze()[pad:-pad] + + # Normalize + assert (window_envelope > 1e-11).all() + y = y / window_envelope + + return y + +# The ASP and PSP Module are adopted from APNet under the MIT License +# https://github.com/YangAi520/APNet/blob/main/models.py + +class ASPResBlock(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ASPResBlock, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class PSPResBlock(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(PSPResBlock, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + +class APNet(torch.nn.Module): + def __init__(self, cfg): + super(APNet, self).__init__() + self.cfg = cfg + self.ASP_num_kernels = len(cfg.model.apnet.ASP_resblock_kernel_sizes) + self.PSP_num_kernels = len(cfg.model.apnet.PSP_resblock_kernel_sizes) + + self.ASP_input_conv = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.apnet.ASP_channel, + cfg.model.apnet.ASP_input_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.ASP_input_conv_kernel_size, 1), + ) + ) + self.PSP_input_conv = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.apnet.PSP_channel, + cfg.model.apnet.PSP_input_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_input_conv_kernel_size, 1), + ) + ) + + self.ASP_ResNet = nn.ModuleList() + for j, (k, d) in enumerate( + zip( + cfg.model.apnet.ASP_resblock_kernel_sizes, + cfg.model.apnet.ASP_resblock_dilation_sizes, + ) + ): + self.ASP_ResNet.append(ASPResBlock(cfg, cfg.model.apnet.ASP_channel, k, d)) + + self.PSP_ResNet = nn.ModuleList() + for j, (k, d) in enumerate( + zip( + cfg.model.apnet.PSP_resblock_kernel_sizes, + cfg.model.apnet.PSP_resblock_dilation_sizes, + ) + ): + self.PSP_ResNet.append(PSPResBlock(cfg, cfg.model.apnet.PSP_channel, k, d)) + + self.ASP_output_conv = weight_norm( + Conv1d( + cfg.model.apnet.ASP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.ASP_output_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.ASP_output_conv_kernel_size, 1), + ) + ) + self.PSP_output_R_conv = weight_norm( + Conv1d( + cfg.model.apnet.PSP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.PSP_output_R_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_output_R_conv_kernel_size, 1), + ) + ) + self.PSP_output_I_conv = weight_norm( + Conv1d( + cfg.model.apnet.PSP_channel, + cfg.preprocess.n_fft // 2 + 1, + cfg.model.apnet.PSP_output_I_conv_kernel_size, + 1, + padding=get_padding(cfg.model.apnet.PSP_output_I_conv_kernel_size, 1), + ) + ) + + self.iSTFT = ISTFT( + self.cfg.preprocess.n_fft, + hop_length=self.cfg.preprocess.hop_size, + win_length=self.cfg.preprocess.win_size, + ) + + self.ASP_output_conv.apply(init_weights) + self.PSP_output_R_conv.apply(init_weights) + self.PSP_output_I_conv.apply(init_weights) + + def forward(self, mel): + logamp = self.ASP_input_conv(mel) + logamps = None + for j in range(self.ASP_num_kernels): + if logamps is None: + logamps = self.ASP_ResNet[j](logamp) + else: + logamps += self.ASP_ResNet[j](logamp) + logamp = logamps / self.ASP_num_kernels + logamp = F.leaky_relu(logamp) + logamp = self.ASP_output_conv(logamp) + + pha = self.PSP_input_conv(mel) + phas = None + for j in range(self.PSP_num_kernels): + if phas is None: + phas = self.PSP_ResNet[j](pha) + else: + phas += self.PSP_ResNet[j](pha) + pha = phas / self.PSP_num_kernels + pha = F.leaky_relu(pha) + R = self.PSP_output_R_conv(pha) + I = self.PSP_output_I_conv(pha) + + pha = torch.atan2(I, R) + + rea = torch.exp(logamp) * torch.cos(pha) + imag = torch.exp(logamp) * torch.sin(pha) + + spec = torch.cat((rea.unsqueeze(-1), imag.unsqueeze(-1)), -1) + + spec = torch.view_as_complex(spec) + + audio = self.iSTFT.forward( + spec, torch.hann_window(self.cfg.preprocess.win_size).to(mel.device) + ) + + return logamp, pha, rea, imag, audio.unsqueeze(1) diff --git a/models/vocoders/gan/generator/bigvgan.py b/models/vocoders/gan/generator/bigvgan.py new file mode 100644 index 0000000000000000000000000000000000000000..205f8d697b9a49aa4ac3f4c46015fef861d4fe1b --- /dev/null +++ b/models/vocoders/gan/generator/bigvgan.py @@ -0,0 +1,344 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import torch.nn as nn + +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +from modules.vocoder_blocks import * +from modules.activation_functions import * +from modules.anti_aliasing import * + +LRELU_SLOPE = 0.1 + +# The AMPBlock Module is adopted from BigVGAN under the MIT License +# https://github.com/NVIDIA/BigVGAN + +class AMPBlock1(torch.nn.Module): + def __init__( + self, cfg, channels, kernel_size=3, dilation=(1, 3, 5), activation=None + ): + super(AMPBlock1, self).__init__() + self.cfg = cfg + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + self.num_layers = len(self.convs1) + len( + self.convs2 + ) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + acts1, acts2 = self.activations[::2], self.activations[1::2] + for c1, c2, a1, a2 in zip(self.convs1, self.convs2, acts1, acts2): + xt = a1(x) + xt = c1(xt) + xt = a2(xt) + xt = c2(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class AMPBlock2(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3), activation=None): + super(AMPBlock2, self).__init__() + self.cfg = cfg + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + self.num_layers = len(self.convs) # total number of conv layers + + if ( + activation == "snake" + ): # periodic nonlinearity with snake function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=Snake( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + elif ( + activation == "snakebeta" + ): # periodic nonlinearity with snakebeta function and anti-aliasing + self.activations = nn.ModuleList( + [ + Activation1d( + activation=SnakeBeta( + channels, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + ) + for _ in range(self.num_layers) + ] + ) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + def forward(self, x): + for c, a in zip(self.convs, self.activations): + xt = a(x) + xt = c(xt) + x = xt + x + + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class BigVGAN(torch.nn.Module): + def __init__(self, cfg): + super(BigVGAN, self).__init__() + self.cfg = cfg + + self.num_kernels = len(cfg.model.bigvgan.resblock_kernel_sizes) + self.num_upsamples = len(cfg.model.bigvgan.upsample_rates) + + # Conv pre to boost channels + self.conv_pre = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + cfg.model.bigvgan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + + resblock = AMPBlock1 if cfg.model.bigvgan.resblock == "1" else AMPBlock2 + + # Upsamplers + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + cfg.model.bigvgan.upsample_rates, + cfg.model.bigvgan.upsample_kernel_sizes, + ) + ): + self.ups.append( + nn.ModuleList( + [ + weight_norm( + ConvTranspose1d( + cfg.model.bigvgan.upsample_initial_channel // (2**i), + cfg.model.bigvgan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ] + ) + ) + + # Res Blocks with AMP and Anti-aliasing + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = cfg.model.bigvgan.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip( + cfg.model.bigvgan.resblock_kernel_sizes, + cfg.model.bigvgan.resblock_dilation_sizes, + ) + ): + self.resblocks.append( + resblock(cfg, ch, k, d, activation=cfg.model.bigvgan.activation) + ) + + # Conv post for result + if ( + cfg.model.bigvgan.activation == "snake" + ): + activation_post = Snake(ch, alpha_logscale=cfg.model.bigvgan.snake_logscale) + self.activation_post = Activation1d(activation=activation_post) + elif ( + cfg.model.bigvgan.activation == "snakebeta" + ): + activation_post = SnakeBeta( + ch, alpha_logscale=cfg.model.bigvgan.snake_logscale + ) + self.activation_post = Activation1d(activation=activation_post) + else: + raise NotImplementedError( + "activation incorrectly specified. check the config file and look for 'activation'." + ) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + # Weight Norm + for i in range(len(self.ups)): + self.ups[i].apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + + for i in range(self.num_upsamples): + for i_up in range(len(self.ups[i])): + x = self.ups[i][i_up](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + + x = self.activation_post(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + for l_i in l: + remove_weight_norm(l_i) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/models/vocoders/gan/generator/hifigan.py b/models/vocoders/gan/generator/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..9f5f32498f5eb6441db787b0ae204a1eeff36aa3 --- /dev/null +++ b/models/vocoders/gan/generator/hifigan.py @@ -0,0 +1,449 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.cfg = cfg + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.cfg = cfg + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class HiFiGAN(torch.nn.Module): + def __init__(self, cfg): + super(HiFiGAN, self).__init__() + self.cfg = cfg + self.num_kernels = len(self.cfg.model.hifigan.resblock_kernel_sizes) + self.num_upsamples = len(self.cfg.model.hifigan.upsample_rates) + self.conv_pre = weight_norm( + Conv1d( + cfg.preprocess.n_mel, + self.cfg.model.hifigan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + resblock = ResBlock1 if self.cfg.model.hifigan.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + self.cfg.model.hifigan.upsample_rates, + self.cfg.model.hifigan.upsample_kernel_sizes, + ) + ): + self.ups.append( + weight_norm( + ConvTranspose1d( + self.cfg.model.hifigan.upsample_initial_channel // (2**i), + self.cfg.model.hifigan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = self.cfg.model.hifigan.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip( + self.cfg.model.hifigan.resblock_kernel_sizes, + self.cfg.model.hifigan.resblock_dilation_sizes, + ) + ): + self.resblocks.append(resblock(self.cfg, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +# todo: merge with ResBlock1 (lmxue, yicheng) +class ResBlock1_vits(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1_vits, self).__init__() + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x, x_mask=None): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c2(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +# todo: merge with ResBlock2 (lmxue, yicheng) +class ResBlock2_vits(torch.nn.Module): + def __init__(self, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2_vits, self).__init__() + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x, x_mask=None): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + if x_mask is not None: + xt = xt * x_mask + xt = c(xt) + x = xt + x + if x_mask is not None: + x = x * x_mask + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +# todo: merge with HiFiGAN (lmxue, yicheng) +class HiFiGAN_vits(torch.nn.Module): + def __init__( + self, + initial_channel, + resblock, + resblock_kernel_sizes, + resblock_dilation_sizes, + upsample_rates, + upsample_initial_channel, + upsample_kernel_sizes, + gin_channels=0, + ): + super(HiFiGAN_vits, self).__init__() + self.num_kernels = len(resblock_kernel_sizes) + self.num_upsamples = len(upsample_rates) + self.conv_pre = Conv1d( + initial_channel, upsample_initial_channel, 7, 1, padding=3 + ) + resblock = ResBlock1_vits if resblock == "1" else ResBlock2_vits + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + upsample_initial_channel // (2**i), + upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(resblock_kernel_sizes, resblock_dilation_sizes) + ): + self.resblocks.append(resblock(ch, k, d)) + + self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False) + self.ups.apply(init_weights) + + if gin_channels != 0: + self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1) + + def forward(self, x, g=None): + x = self.conv_pre(x) + if g is not None: + x = x + self.cond(g) + + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() diff --git a/models/vocoders/gan/generator/melgan.py b/models/vocoders/gan/generator/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..ca90d684ca1f5a0a813db4192540adac0cee2558 --- /dev/null +++ b/models/vocoders/gan/generator/melgan.py @@ -0,0 +1,103 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from torch.nn.utils import weight_norm + +# This code is adopted from MelGAN under the MIT License +# https://github.com/descriptinc/melgan-neurips + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(0.0, 0.02) + elif classname.find("BatchNorm2d") != -1: + m.weight.data.normal_(1.0, 0.02) + m.bias.data.fill_(0) + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dilation=1): + super().__init__() + self.block = nn.Sequential( + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(dilation), + WNConv1d(dim, dim, kernel_size=3, dilation=dilation), + nn.LeakyReLU(0.2), + WNConv1d(dim, dim, kernel_size=1), + ) + self.shortcut = WNConv1d(dim, dim, kernel_size=1) + + def forward(self, x): + return self.shortcut(x) + self.block(x) + + +class MelGAN(nn.Module): + def __init__(self, cfg): + super().__init__() + + self.cfg = cfg + + self.hop_length = np.prod(self.cfg.model.melgan.ratios) + mult = int(2 ** len(self.cfg.model.melgan.ratios)) + + model = [ + nn.ReflectionPad1d(3), + WNConv1d( + self.cfg.preprocess.n_mel, + mult * self.cfg.model.melgan.ngf, + kernel_size=7, + padding=0, + ), + ] + + # Upsample to raw audio scale + for i, r in enumerate(self.cfg.model.melgan.ratios): + model += [ + nn.LeakyReLU(0.2), + WNConvTranspose1d( + mult * self.cfg.model.melgan.ngf, + mult * self.cfg.model.melgan.ngf // 2, + kernel_size=r * 2, + stride=r, + padding=r // 2 + r % 2, + output_padding=r % 2, + ), + ] + + for j in range(self.cfg.model.melgan.n_residual_layers): + model += [ + ResnetBlock(mult * self.cfg.model.melgan.ngf // 2, dilation=3**j) + ] + + mult //= 2 + + model += [ + nn.LeakyReLU(0.2), + nn.ReflectionPad1d(3), + WNConv1d(self.cfg.model.melgan.ngf, 1, kernel_size=7, padding=0), + nn.Tanh(), + ] + + self.model = nn.Sequential(*model) + self.apply(weights_init) + + def forward(self, x): + return self.model(x) diff --git a/models/vocoders/gan/generator/nsfhifigan.py b/models/vocoders/gan/generator/nsfhifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..8deb4ce7a9348f4a03bfef55fd21c634b3f25a78 --- /dev/null +++ b/models/vocoders/gan/generator/nsfhifigan.py @@ -0,0 +1,281 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from modules.neural_source_filter import * +from modules.vocoder_blocks import * + + +LRELU_SLOPE = 0.1 + + +class ResBlock1(nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.cfg = cfg + + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(nn.Module): + def __init__(self, cfg, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock1, self).__init__() + self.cfg = cfg + + self.convs = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + ] + ) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + +# This NSF Module is adopted from Xin Wang's NSF under the MIT License +# https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts + +class SourceModuleHnNSF(nn.Module): + def __init__( + self, fs, harmonic_num=0, amp=0.1, noise_std=0.003, voiced_threshold=0 + ): + super(SourceModuleHnNSF, self).__init__() + + self.amp = amp + self.noise_std = noise_std + self.l_sin_gen = SineGen(fs, harmonic_num, amp, noise_std, voiced_threshold) + + self.l_linear = nn.Linear(harmonic_num + 1, 1) + self.l_tanh = nn.Tanh() + + def forward(self, x, upp): + sine_wavs, uv, _ = self.l_sin_gen(x, upp) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + return sine_merge + + +class NSFHiFiGAN(nn.Module): + def __init__(self, cfg): + super(NSFHiFiGAN, self).__init__() + + self.cfg = cfg + self.num_kernels = len(self.cfg.model.nsfhifigan.resblock_kernel_sizes) + self.num_upsamples = len(self.cfg.model.nsfhifigan.upsample_rates) + self.m_source = SourceModuleHnNSF( + fs=self.cfg.preprocess.sample_rate, + harmonic_num=self.cfg.model.nsfhifigan.harmonic_num, + ) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm( + Conv1d( + self.cfg.preprocess.n_mel, + self.cfg.model.nsfhifigan.upsample_initial_channel, + 7, + 1, + padding=3, + ) + ) + + resblock = ResBlock1 if self.cfg.model.nsfhifigan.resblock == "1" else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip( + self.cfg.model.nsfhifigan.upsample_rates, + self.cfg.model.nsfhifigan.upsample_kernel_sizes, + ) + ): + c_cur = self.cfg.model.nsfhifigan.upsample_initial_channel // (2 ** (i + 1)) + self.ups.append( + weight_norm( + ConvTranspose1d( + self.cfg.model.nsfhifigan.upsample_initial_channel // (2**i), + self.cfg.model.nsfhifigan.upsample_initial_channel + // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + if i + 1 < len(self.cfg.model.nsfhifigan.upsample_rates): + stride_f0 = int( + np.prod(self.cfg.model.nsfhifigan.upsample_rates[i + 1 :]) + ) + self.noise_convs.append( + Conv1d( + 1, + c_cur, + kernel_size=stride_f0 * 2, + stride=stride_f0, + padding=stride_f0 // 2, + ) + ) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + ch = self.cfg.model.nsfhifigan.upsample_initial_channel + for i in range(len(self.ups)): + ch //= 2 + for j, (k, d) in enumerate( + zip( + self.cfg.model.nsfhifigan.resblock_kernel_sizes, + self.cfg.model.nsfhifigan.resblock_dilation_sizes, + ) + ): + self.resblocks.append(resblock(cfg, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + self.upp = int(np.prod(self.cfg.model.nsfhifigan.upsample_rates)) + + def forward(self, x, f0): + har_source = self.m_source(f0, self.upp).transpose(1, 2) + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + x_source = self.noise_convs[i](har_source) + + length = min(x.shape[-1], x_source.shape[-1]) + x = x[:, :, :length] + x_source = x[:, :, :length] + + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x diff --git a/models/vocoders/gan/generator/sifigan.py b/models/vocoders/gan/generator/sifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/vocoders/vocoder_dataset.py b/models/vocoders/vocoder_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7df17b97ba7a4f770f01971324126eca4a2db272 --- /dev/null +++ b/models/vocoders/vocoder_dataset.py @@ -0,0 +1,264 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Iterable +import torch +import numpy as np +import torch.utils.data +from torch.nn.utils.rnn import pad_sequence +from utils.data_utils import * +from torch.utils.data import ConcatDataset, Dataset + + +class VocoderDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, is_valid=False): + """ + Args: + cfg: config + dataset: dataset name + is_valid: whether to use train or valid dataset + """ + assert isinstance(dataset, str) + + processed_data_dir = os.path.join(cfg.preprocess.processed_dir, dataset) + + meta_file = cfg.preprocess.valid_file if is_valid else cfg.preprocess.train_file + self.metafile_path = os.path.join(processed_data_dir, meta_file) + self.metadata = self.get_metadata() + + self.data_root = processed_data_dir + self.cfg = cfg + + if cfg.preprocess.use_audio: + self.utt2audio_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2audio_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.audio_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_label: + self.utt2label_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2label_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.label_dir, + uid + ".npy", + ) + elif cfg.preprocess.use_one_hot: + self.utt2one_hot_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2one_hot_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.one_hot_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_mel: + self.utt2mel_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2mel_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.mel_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_frame_pitch: + self.utt2frame_pitch_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + self.utt2frame_pitch_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.pitch_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_uv: + self.utt2uv_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2uv_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.uv_dir, + uid + ".npy", + ) + + if cfg.preprocess.use_amplitude_phase: + self.utt2logamp_path = {} + self.utt2pha_path = {} + self.utt2rea_path = {} + self.utt2imag_path = {} + for utt_info in self.metadata: + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + self.utt2logamp_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.log_amplitude_dir, + uid + ".npy", + ) + self.utt2pha_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.phase_dir, + uid + ".npy", + ) + self.utt2rea_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.real_dir, + uid + ".npy", + ) + self.utt2imag_path[utt] = os.path.join( + cfg.preprocess.processed_dir, + dataset, + cfg.preprocess.imaginary_dir, + uid + ".npy", + ) + + def __getitem__(self, index): + utt_info = self.metadata[index] + + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + single_feature = dict() + + if self.cfg.preprocess.use_mel: + mel = np.load(self.utt2mel_path[utt]) + assert mel.shape[0] == self.cfg.preprocess.n_mel # [n_mels, T] + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = mel.shape[1] + + single_feature["mel"] = mel + + if self.cfg.preprocess.use_frame_pitch: + frame_pitch = np.load(self.utt2frame_pitch_path[utt]) + + if "target_len" not in single_feature.keys(): + single_feature["target_len"] = len(frame_pitch) + + aligned_frame_pitch = align_length( + frame_pitch, single_feature["target_len"] + ) + + single_feature["frame_pitch"] = aligned_frame_pitch + + if self.cfg.preprocess.use_audio: + audio = np.load(self.utt2audio_path[utt]) + + single_feature["audio"] = audio + + return single_feature + + def get_metadata(self): + with open(self.metafile_path, "r", encoding="utf-8") as f: + metadata = json.load(f) + + return metadata + + def get_dataset_name(self): + return self.metadata[0]["Dataset"] + + def __len__(self): + return len(self.metadata) + + +class VocoderConcatDataset(ConcatDataset): + def __init__(self, datasets: Iterable[Dataset], full_audio_inference=False): + """Concatenate a series of datasets with their random inference audio merged.""" + super().__init__(datasets) + + self.cfg = self.datasets[0].cfg + + self.metadata = [] + + # Merge metadata + for dataset in self.datasets: + self.metadata += dataset.metadata + + # Merge random inference features + if full_audio_inference: + self.eval_audios = [] + self.eval_dataset_names = [] + if self.cfg.preprocess.use_mel: + self.eval_mels = [] + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs = [] + for dataset in self.datasets: + self.eval_audios.append(dataset.eval_audio) + self.eval_dataset_names.append(dataset.get_dataset_name()) + if self.cfg.preprocess.use_mel: + self.eval_mels.append(dataset.eval_mel) + if self.cfg.preprocess.use_frame_pitch: + self.eval_pitchs.append(dataset.eval_pitch) + + +class VocoderCollator(object): + """Zero-pads model inputs and targets based on number of frames per step""" + + def __init__(self, cfg): + self.cfg = cfg + + def __call__(self, batch): + packed_batch_features = dict() + + # mel: [b, n_mels, frame] + # frame_pitch: [b, frame] + # audios: [b, frame * hop_size] + + for key in batch[0].keys(): + if key == "target_len": + packed_batch_features["target_len"] = torch.LongTensor( + [b["target_len"] for b in batch] + ) + masks = [ + torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch + ] + packed_batch_features["mask"] = pad_sequence( + masks, batch_first=True, padding_value=0 + ) + elif key == "mel": + values = [torch.from_numpy(b[key]).T for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + else: + values = [torch.from_numpy(b[key]) for b in batch] + packed_batch_features[key] = pad_sequence( + values, batch_first=True, padding_value=0 + ) + + return packed_batch_features diff --git a/models/vocoders/vocoder_inference.py b/models/vocoders/vocoder_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..dfd09ee6aa44c544c51a62c0e014cca4260cc6a8 --- /dev/null +++ b/models/vocoders/vocoder_inference.py @@ -0,0 +1,488 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import json +import json5 +import time +import accelerate +import random +import numpy as np +import shutil + +from pathlib import Path +from tqdm import tqdm +from glob import glob +from accelerate.logging import get_logger +from torch.utils.data import DataLoader + +from models.vocoders.vocoder_dataset import ( + VocoderDataset, + VocoderCollator, + VocoderConcatDataset, +) + +from models.vocoders.gan.generator import bigvgan, hifigan, melgan, nsfhifigan, apnet +from models.vocoders.flow.waveglow import waveglow +from models.vocoders.diffusion.diffwave import diffwave +from models.vocoders.autoregressive.wavenet import wavenet +from models.vocoders.autoregressive.wavernn import wavernn +from models.vocoders.gan import gan_vocoder_inference +from utils.io import save_audio + +_vocoders = { + "diffwave": diffwave.DiffWave, + "wavernn": wavernn.WaveRNN, + "wavenet": wavenet.WaveNet, + "waveglow": waveglow.WaveGlow, + "nsfhifigan": nsfhifigan.NSFHiFiGAN, + "bigvgan": bigvgan.BigVGAN, + "hifigan": hifigan.HiFiGAN, + "melgan": melgan.MelGAN, + "apnet": apnet.APNet, +} + +_vocoder_infer_funcs = { + # "world": world_inference.synthesis_audios, + # "wavernn": wavernn_inference.synthesis_audios, + # "wavenet": wavenet_inference.synthesis_audios, + # "diffwave": diffwave_inference.synthesis_audios, + "nsfhifigan": gan_vocoder_inference.synthesis_audios, + "bigvgan": gan_vocoder_inference.synthesis_audios, + "melgan": gan_vocoder_inference.synthesis_audios, + "hifigan": gan_vocoder_inference.synthesis_audios, + "apnet": gan_vocoder_inference.synthesis_audios, +} + + +class VocoderInference(object): + def __init__(self, args=None, cfg=None, infer_type="from_dataset"): + super().__init__() + + start = time.monotonic_ns() + self.args = args + self.cfg = cfg + self.infer_type = infer_type + + # Init accelerator + self.accelerator = accelerate.Accelerator() + self.accelerator.wait_for_everyone() + + # Get logger + with self.accelerator.main_process_first(): + self.logger = get_logger("inference", log_level=args.log_level) + + # Log some info + self.logger.info("=" * 56) + self.logger.info("||\t\t" + "New inference process started." + "\t\t||") + self.logger.info("=" * 56) + self.logger.info("\n") + + self.vocoder_dir = args.vocoder_dir + self.logger.debug(f"Vocoder dir: {args.vocoder_dir}") + + os.makedirs(args.output_dir, exist_ok=True) + if os.path.exists(os.path.join(args.output_dir, "pred")): + shutil.rmtree(os.path.join(args.output_dir, "pred")) + if os.path.exists(os.path.join(args.output_dir, "gt")): + shutil.rmtree(os.path.join(args.output_dir, "gt")) + os.makedirs(os.path.join(args.output_dir, "pred"), exist_ok=True) + os.makedirs(os.path.join(args.output_dir, "gt"), exist_ok=True) + + # Set random seed + with self.accelerator.main_process_first(): + start = time.monotonic_ns() + self._set_random_seed(self.cfg.train.random_seed) + end = time.monotonic_ns() + self.logger.debug( + f"Setting random seed done in {(end - start) / 1e6:.2f}ms" + ) + self.logger.debug(f"Random seed: {self.cfg.train.random_seed}") + + # Setup inference mode + if self.infer_type == "infer_from_dataset": + self.cfg.dataset = self.args.infer_datasets + elif self.infer_type == "infer_from_feature": + self._build_tmp_dataset_from_feature() + self.cfg.dataset = ["tmp"] + elif self.infer_type == "infer_from_audio": + self._build_tmp_dataset_from_audio() + self.cfg.dataset = ["tmp"] + + # Setup data loader + with self.accelerator.main_process_first(): + self.logger.info("Building dataset...") + start = time.monotonic_ns() + self.test_dataloader = self._build_dataloader() + end = time.monotonic_ns() + self.logger.info(f"Building dataset done in {(end - start) / 1e6:.2f}ms") + + # Build model + with self.accelerator.main_process_first(): + self.logger.info("Building model...") + start = time.monotonic_ns() + self.model = self._build_model() + end = time.monotonic_ns() + self.logger.info(f"Building model done in {(end - start) / 1e6:.3f}ms") + + # Init with accelerate + self.logger.info("Initializing accelerate...") + start = time.monotonic_ns() + self.accelerator = accelerate.Accelerator() + (self.model, self.test_dataloader) = self.accelerator.prepare( + self.model, self.test_dataloader + ) + end = time.monotonic_ns() + self.accelerator.wait_for_everyone() + self.logger.info(f"Initializing accelerate done in {(end - start) / 1e6:.3f}ms") + + with self.accelerator.main_process_first(): + self.logger.info("Loading checkpoint...") + start = time.monotonic_ns() + if os.path.isdir(args.vocoder_dir): + if os.path.isdir(os.path.join(args.vocoder_dir, "checkpoint")): + self._load_model(os.path.join(args.vocoder_dir, "checkpoint")) + else: + self._load_model(os.path.join(args.vocoder_dir)) + else: + self._load_model(os.path.join(args.vocoder_dir)) + end = time.monotonic_ns() + self.logger.info(f"Loading checkpoint done in {(end - start) / 1e6:.3f}ms") + + self.model.eval() + self.accelerator.wait_for_everyone() + + def _build_tmp_dataset_from_feature(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + mels = glob(os.path.join(self.args.feature_folder, "mels", "*.npy")) + for i, mel in enumerate(mels): + uid = mel.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + features = glob(os.path.join(self.args.feature_folder, "*")) + for feature in features: + feature_name = feature.split("/")[-1] + if os.path.isfile(feature): + continue + shutil.copytree( + os.path.join(self.args.feature_folder, feature_name), + os.path.join(self.cfg.preprocess.processed_dir, "tmp", feature_name), + ) + + def _build_tmp_dataset_from_audio(self): + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + utts = [] + audios = glob(os.path.join(self.args.audio_folder, "*")) + for i, audio in enumerate(audios): + uid = audio.split("/")[-1].split(".")[0] + utt = {"Dataset": "tmp", "Uid": uid, "index": i, "Path": audio} + utts.append(utt) + + os.makedirs(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "test.json"), "w" + ) as f: + json.dump(utts, f) + + meta_info = {"dataset": "tmp", "test": {"size": len(utts)}} + + with open( + os.path.join(self.cfg.preprocess.processed_dir, "tmp", "meta_info.json"), + "w", + ) as f: + json.dump(meta_info, f) + + from processors import acoustic_extractor + + acoustic_extractor.extract_utt_acoustic_features_serial( + utts, os.path.join(self.cfg.preprocess.processed_dir, "tmp"), self.cfg + ) + + def _build_test_dataset(self): + return VocoderDataset, VocoderCollator + + def _build_model(self): + model = _vocoders[self.cfg.model.generator](self.cfg) + return model + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + Dataset, Collator = self._build_test_dataset() + + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + test_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=False) + test_collate = Collator(self.cfg) + test_batch_size = min(self.cfg.inference.batch_size, len(test_dataset)) + test_dataloader = DataLoader( + test_dataset, + collate_fn=test_collate, + num_workers=1, + batch_size=test_batch_size, + shuffle=False, + ) + self.test_batch_size = test_batch_size + self.test_dataset = test_dataset + return test_dataloader + + def _load_model(self, checkpoint_dir, from_multi_gpu=False): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if os.path.isdir(checkpoint_dir): + if "epoch" in checkpoint_dir and "step" in checkpoint_dir: + checkpoint_path = checkpoint_dir + else: + # Load the latest accelerator state dicts + ls = [ + str(i) + for i in Path(checkpoint_dir).glob("*") + if not "audio" in str(i) + ] + ls.sort( + key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True + ) + checkpoint_path = ls[0] + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + return str(checkpoint_path) + else: + # Load old .pt checkpoints + if self.cfg.model.generator in [ + "bigvgan", + "hifigan", + "melgan", + "nsfhifigan", + ]: + ckpt = torch.load( + checkpoint_dir, + map_location=torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu"), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = self.model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + self.model.load_state_dict(generator_dict) + else: + self.model.load_state_dict(ckpt["generator_state_dict"]) + else: + self.model.load_state_dict(torch.load(checkpoint_dir)["state_dict"]) + return str(checkpoint_dir) + + def inference(self): + """Inference via batches""" + for i, batch in tqdm(enumerate(self.test_dataloader)): + if self.cfg.preprocess.use_frame_pitch: + audio_pred = self.model.forward( + batch["mel"].transpose(-1, -2), batch["frame_pitch"].float() + ).cpu() + elif self.cfg.preprocess.extract_amplitude_phase: + audio_pred = self.model.forward(batch["mel"].transpose(-1, -2))[-1] + else: + audio_pred = ( + self.model.forward(batch["mel"].transpose(-1, -2)).detach().cpu() + ) + audio_ls = audio_pred.chunk(self.test_batch_size) + audio_gt_ls = batch["audio"].cpu().chunk(self.test_batch_size) + length_ls = batch["target_len"].cpu().chunk(self.test_batch_size) + j = 0 + for it, it_gt, l in zip(audio_ls, audio_gt_ls, length_ls): + l = l.item() + it = it.squeeze(0).squeeze(0)[: l * self.cfg.preprocess.hop_size] + it_gt = it_gt.squeeze(0)[: l * self.cfg.preprocess.hop_size] + uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] + save_audio( + os.path.join(self.args.output_dir, "pred", "{}.wav").format(uid), + it, + self.cfg.preprocess.sample_rate, + ) + save_audio( + os.path.join(self.args.output_dir, "gt", "{}.wav").format(uid), + it_gt, + self.cfg.preprocess.sample_rate, + ) + j += 1 + + if os.path.exists(os.path.join(self.cfg.preprocess.processed_dir, "tmp")): + shutil.rmtree(os.path.join(self.cfg.preprocess.processed_dir, "tmp")) + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _count_parameters(self, model): + return sum(p.numel() for p in model.parameters()) + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + +def load_nnvocoder( + cfg, + vocoder_name, + weights_file, + from_multi_gpu=False, +): + """Load the specified vocoder. + cfg: the vocoder config filer. + weights_file: a folder or a .pt path. + from_multi_gpu: automatically remove the "module" string in state dicts if "True". + """ + print("Loading Vocoder from Weights file: {}".format(weights_file)) + + # Build model + model = _vocoders[vocoder_name](cfg) + if not os.path.isdir(weights_file): + # Load from .pt file + if vocoder_name in ["bigvgan", "hifigan", "melgan", "nsfhifigan"]: + ckpt = torch.load( + weights_file, + map_location=torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu"), + ) + if from_multi_gpu: + pretrained_generator_dict = ckpt["generator_state_dict"] + generator_dict = model.state_dict() + + new_generator_dict = { + k.split("module.")[-1]: v + for k, v in pretrained_generator_dict.items() + if ( + k.split("module.")[-1] in generator_dict + and v.shape == generator_dict[k.split("module.")[-1]].shape + ) + } + + generator_dict.update(new_generator_dict) + + model.load_state_dict(generator_dict) + else: + model.load_state_dict(ckpt["generator_state_dict"]) + else: + model.load_state_dict(torch.load(weights_file)["state_dict"]) + else: + # Load from accelerator state dict + weights_file = os.path.join(weights_file, "checkpoint") + ls = [str(i) for i in Path(weights_file).glob("*") if not "audio" in str(i)] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + accelerator = accelerate.Accelerator() + model = accelerator.prepare(model) + accelerator.load_state(checkpoint_path) + + if torch.cuda.is_available(): + model = model.cuda() + + model = model.eval() + return model + + +def tensorize(data, device, n_samples): + """ + data: a list of numpy array + """ + assert type(data) == list + if n_samples: + data = data[:n_samples] + data = [torch.as_tensor(x, device=device) for x in data] + return data + + +def synthesis( + cfg, + vocoder_weight_file, + n_samples, + pred, + f0s=None, + batch_size=64, + fast_inference=False, +): + """Synthesis audios from a given vocoder and series of given features. + cfg: vocoder config. + vocoder_weight_file: a folder of accelerator state dict or a path to the .pt file. + pred: a list of numpy arrays. [(seq_len1, acoustic_features_dim), (seq_len2, acoustic_features_dim), ...] + """ + + vocoder_name = cfg.model.generator + + print("Synthesis audios using {} vocoder...".format(vocoder_name)) + + ###### TODO: World Vocoder Refactor ###### + # if vocoder_name == "world": + # world_inference.synthesis_audios( + # cfg, dataset_name, split, n_samples, pred, save_dir, tag + # ) + # return + + # ====== Loading neural vocoder model ====== + vocoder = load_nnvocoder( + cfg, vocoder_name, weights_file=vocoder_weight_file, from_multi_gpu=True + ) + device = next(vocoder.parameters()).device + + # ====== Inference for predicted acoustic features ====== + # pred: (frame_len, n_mels) -> (n_mels, frame_len) + mels_pred = tensorize([p.T for p in pred], device, n_samples) + print("For predicted mels, #sample = {}...".format(len(mels_pred))) + audios_pred = _vocoder_infer_funcs[vocoder_name]( + cfg, + vocoder, + mels_pred, + f0s=f0s, + batch_size=batch_size, + fast_inference=fast_inference, + ) + return audios_pred diff --git a/models/vocoders/vocoder_sampler.py b/models/vocoders/vocoder_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..9d29f88a291dcf7386cadaeae0d990c8e76ebf98 --- /dev/null +++ b/models/vocoders/vocoder_sampler.py @@ -0,0 +1,126 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random + +from torch.utils.data import ConcatDataset, Dataset +from torch.utils.data.sampler import ( + BatchSampler, + RandomSampler, + Sampler, + SequentialSampler, +) + + +class ScheduledSampler(Sampler): + """A sampler that samples data from a given concat-dataset. + + Args: + concat_dataset (ConcatDataset): a concatenated dataset consisting of all datasets + batch_size (int): batch size + holistic_shuffle (bool): whether to shuffle the whole dataset or not + logger (logging.Logger): logger to print warning message + + Usage: + For cfg.train.batch_size = 3, cfg.train.holistic_shuffle = False, cfg.train.drop_last = True: + >>> list(ScheduledSampler(ConcatDataset([0, 1, 2], [3, 4, 5], [6, 7, 8]]))) + [3, 4, 5, 0, 1, 2, 6, 7, 8] + """ + + def __init__( + self, concat_dataset, batch_size, holistic_shuffle, logger=None, type="train" + ): + if not isinstance(concat_dataset, ConcatDataset): + raise ValueError( + "concat_dataset must be an instance of ConcatDataset, but got {}".format( + type(concat_dataset) + ) + ) + if not isinstance(batch_size, int): + raise ValueError( + "batch_size must be an integer, but got {}".format(type(batch_size)) + ) + if not isinstance(holistic_shuffle, bool): + raise ValueError( + "holistic_shuffle must be a boolean, but got {}".format( + type(holistic_shuffle) + ) + ) + + self.concat_dataset = concat_dataset + self.batch_size = batch_size + self.holistic_shuffle = holistic_shuffle + + affected_dataset_name = [] + affected_dataset_len = [] + for dataset in concat_dataset.datasets: + dataset_len = len(dataset) + dataset_name = dataset.get_dataset_name() + if dataset_len < batch_size: + affected_dataset_name.append(dataset_name) + affected_dataset_len.append(dataset_len) + + self.type = type + for dataset_name, dataset_len in zip( + affected_dataset_name, affected_dataset_len + ): + if not type == "valid": + logger.warning( + "The {} dataset {} has a length of {}, which is smaller than the batch size {}. This may cause unexpected behavior.".format( + type, dataset_name, dataset_len, batch_size + ) + ) + + def __len__(self): + # the number of batches with drop last + num_of_batches = sum( + [ + math.floor(len(dataset) / self.batch_size) + for dataset in self.concat_dataset.datasets + ] + ) + return num_of_batches * self.batch_size + + def __iter__(self): + iters = [] + for dataset in self.concat_dataset.datasets: + iters.append( + SequentialSampler(dataset).__iter__() + if self.holistic_shuffle + else RandomSampler(dataset).__iter__() + ) + init_indices = [0] + self.concat_dataset.cumulative_sizes[:-1] + output_batches = [] + for dataset_idx in range(len(self.concat_dataset.datasets)): + cur_batch = [] + for idx in iters[dataset_idx]: + cur_batch.append(idx + init_indices[dataset_idx]) + if len(cur_batch) == self.batch_size: + output_batches.append(cur_batch) + cur_batch = [] + if self.type == "valid" and len(cur_batch) > 0: + output_batches.append(cur_batch) + cur_batch = [] + # force drop last in training + random.shuffle(output_batches) + output_indices = [item for sublist in output_batches for item in sublist] + return iter(output_indices) + + +def build_samplers(concat_dataset: Dataset, cfg, logger, type): + sampler = ScheduledSampler( + concat_dataset, + cfg.train.batch_size, + cfg.train.sampler.holistic_shuffle, + logger, + type, + ) + batch_sampler = BatchSampler( + sampler, + cfg.train.batch_size, + cfg.train.sampler.drop_last if not type == "valid" else False, + ) + return sampler, batch_sampler diff --git a/models/vocoders/vocoder_trainer.py b/models/vocoders/vocoder_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5821e735a64f07fcf9c782712670e24ce6a91c04 --- /dev/null +++ b/models/vocoders/vocoder_trainer.py @@ -0,0 +1,180 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +from pathlib import Path +import re + +import accelerate +import json5 +import numpy as np +import torch +from accelerate.utils import ProjectConfiguration +from torch.utils.data import DataLoader +from tqdm import tqdm + +from models.vocoders.vocoder_dataset import VocoderConcatDataset +from models.vocoders.vocoder_sampler import build_samplers + + +class VocoderTrainer: + def __init__(self): + super().__init__() + + def _init_accelerator(self): + """Initialize the accelerator components.""" + self.exp_dir = os.path.join( + os.path.abspath(self.cfg.log_dir), self.args.exp_name + ) + project_config = ProjectConfiguration( + project_dir=self.exp_dir, logging_dir=os.path.join(self.exp_dir, "log") + ) + self.accelerator = accelerate.Accelerator( + gradient_accumulation_steps=self.cfg.train.gradient_accumulation_step, + log_with=self.cfg.train.tracker, + project_config=project_config, + ) + if self.accelerator.is_main_process: + os.makedirs(project_config.project_dir, exist_ok=True) + os.makedirs(project_config.logging_dir, exist_ok=True) + with self.accelerator.main_process_first(): + self.accelerator.init_trackers(self.args.exp_name) + + def _build_dataset(self): + pass + + def _build_criterion(self): + pass + + def _build_model(self): + pass + + def _build_dataloader(self): + """Build dataloader which merges a series of datasets.""" + # Build dataset instance for each dataset and combine them by ConcatDataset + Dataset, Collator = self._build_dataset() + + # Build train set + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=False) + datasets_list.append(subdataset) + train_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True) + train_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(train_dataset, self.cfg, self.logger, "train") + train_loader = DataLoader( + train_dataset, + collate_fn=train_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + + # Build test set + datasets_list = [] + for dataset in self.cfg.dataset: + subdataset = Dataset(self.cfg, dataset, is_valid=True) + datasets_list.append(subdataset) + valid_dataset = VocoderConcatDataset(datasets_list, full_audio_inference=True) + valid_collate = Collator(self.cfg) + _, batch_sampler = build_samplers(valid_dataset, self.cfg, self.logger, "train") + valid_loader = DataLoader( + valid_dataset, + collate_fn=valid_collate, + batch_sampler=batch_sampler, + num_workers=self.cfg.train.dataloader.num_worker, + pin_memory=self.cfg.train.dataloader.pin_memory, + ) + return train_loader, valid_loader + + def _build_optimizer(self): + pass + + def _build_scheduler(self): + pass + + def _load_model(self, checkpoint_dir, checkpoint_path=None, resume_type="resume"): + """Load model from checkpoint. If a folder is given, it will + load the latest checkpoint in checkpoint_dir. If a path is given + it will load the checkpoint specified by checkpoint_path. + **Only use this method after** ``accelerator.prepare()``. + """ + if checkpoint_path is None: + ls = [str(i) for i in Path(checkpoint_dir).glob("*")] + ls.sort(key=lambda x: int(x.split("_")[-3].split("-")[-1]), reverse=True) + checkpoint_path = ls[0] + if resume_type == "resume": + self.accelerator.load_state(checkpoint_path) + elif resume_type == "finetune": + accelerate.load_checkpoint_and_dispatch( + self.accelerator.unwrap_model(self.model), + os.path.join(checkpoint_path, "pytorch_model.bin"), + ) + self.logger.info("Load model weights for finetune SUCCESS!") + else: + raise ValueError("Unsupported resume type: {}".format(resume_type)) + self.epoch = int(checkpoint_path.split("_")[-3].split("-")[-1]) + 1 + self.step = int(checkpoint_path.split("_")[-2].split("-")[-1]) + 1 + return checkpoint_path + + def train_loop(self): + pass + + def _train_epoch(self): + pass + + def _valid_epoch(self): + pass + + def _train_step(self): + pass + + def _valid_step(self): + pass + + def _inference(self): + pass + + def _set_random_seed(self, seed): + """Set random seed for all possible random modules.""" + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + def _check_nan(self, loss): + if torch.any(torch.isnan(loss)): + self.logger.fatal("Fatal Error: NaN!") + self.logger.error("loss = {:.6f}".format(loss.item()), in_order=True) + + def _check_basic_configs(self): + if self.cfg.train.gradient_accumulation_step <= 0: + self.logger.fatal("Invalid gradient_accumulation_step value!") + self.logger.error( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + self.accelerator.end_training() + raise ValueError( + f"Invalid gradient_accumulation_step value: {self.cfg.train.gradient_accumulation_step}. It should be positive." + ) + + def _count_parameters(self): + pass + + def _dump_cfg(self, path): + os.makedirs(os.path.dirname(path), exist_ok=True) + json5.dump( + self.cfg, + open(path, "w"), + indent=4, + sort_keys=True, + ensure_ascii=False, + quote_keys=True, + ) + + def _is_valid_pattern(self, directory_name): + directory_name = str(directory_name) + pattern = r"^epoch-\d{4}_step-\d{7}_loss-\d{1}\.\d{6}" + return re.match(pattern, directory_name) is not None diff --git a/pretrained/bigvgan/400000.pt b/pretrained/bigvgan/400000.pt new file mode 100755 index 0000000000000000000000000000000000000000..a36c956e753aef0862753c496ade62d23ab5c906 --- /dev/null +++ b/pretrained/bigvgan/400000.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:989df2350b502e1175cdb1d204d9f81c27ddf97fe1919db4fa2605631e4cab1d +size 1846939571 diff --git a/pretrained/bigvgan/args.json b/pretrained/bigvgan/args.json new file mode 100644 index 0000000000000000000000000000000000000000..06878c8cc6e7c667b107836a51c1d577c15fa7b1 --- /dev/null +++ b/pretrained/bigvgan/args.json @@ -0,0 +1,235 @@ +{ + "base_config": "egs/vocoder/gan/exp_config_base.json", + "exp_name": "bigvgan_large", + "inference": { + "batch_size": 1, + }, + "model": { + "bigvgan": { + "activation": "snakebeta", + "resblock": "1", + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5, + ], + [ + 1, + 3, + 5, + ], + [ + 1, + 3, + 5, + ], + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11, + ], + "snake_logscale": true, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 8, + 8, + 4, + 4, + 4, + 4, + ], + "upsample_rates": [ + 4, + 4, + 2, + 2, + 2, + 2, + ], + }, + "discriminators": [ + "mpd", + "msstftd", + ], + "generator": "bigvgan", + "mpd": { + "discriminator_channel_multi": 1, + "mpd_reshapes": [ + 2, + 3, + 5, + 7, + 11, + ], + "use_spectral_norm": false, + }, + "mrd": { + "discriminator_channel_multi": 1, + "mrd_override": false, + "resolutions": [ + [ + 1024, + 120, + 600, + ], + [ + 2048, + 240, + 1200, + ], + [ + 512, + 50, + 240, + ], + ], + "use_spectral_norm": false, + }, + "msstftd": { + "filters": 32, + }, + }, + "model_type": "GANVocoder", + "preprocess": { + "audio_dir": "audios", + "bits": 8, + "contentvec_dir": "contentvec", + "cut_mel_frame": 32, + "data_augment": false, + "dur_dir": "durs", + "duration_dir": "duration", + "emo2id": "emo2id.json", + "energy_dir": "energys", + "energy_extract_mode": "from_mel", + "energy_norm": false, + "extract_audio": true, + "extract_contentvec_feature": false, + "extract_duration": false, + "extract_energy": false, + "extract_label": false, + "extract_mcep": false, + "extract_mel": true, + "extract_mert_feature": false, + "extract_one_hot": false, + "extract_pitch": false, + "extract_uv": false, + "extract_wenet_feature": false, + "extract_whisper_feature": false, + "f0_max": 1100, + "f0_min": 50, + "file_lst": "file.lst", + "fmax": 12000, + "fmin": 0, + "hop_size": 256, + "is_mu_law": false, + "lab_dir": "labs", + "label_dir": "labels", + "mcep_dir": "mcep", + "mel_dir": "mels", + "mel_min_max_norm": false, + "min_level_db": -115, + "n_fft": 1024, + "n_mel": 100, + "num_silent_frames": 8, + "phone_seq_file": "phone_seq_file", + "pitch_bin": 256, + "pitch_dir": "pitches", + "pitch_extractor": "parselmouth", + "pitch_max": 1100.0, + "pitch_min": 50.0, + "pitch_norm": false, + "processed_dir": "processed_data", + "ref_level_db": 20, + "sample_rate": 24000, + "spk2id": "singers.json", + "train_file": "train.json", + "trim_fft_size": 512, + "trim_hop_size": 128, + "trim_silence": false, + "trim_top_db": 30, + "trimmed_wav_dir": "trimmed_wavs", + "use_audio": true, + "use_dur": false, + "use_emoid": false, + "use_frame_duration": false, + "use_frame_energy": false, + "use_frame_pitch": false, + "use_lab": false, + "use_label": false, + "use_log_scale_energy": false, + "use_log_scale_pitch": false, + "use_mel": true, + "use_one_hot": false, + "use_phn_seq": false, + "use_phone_duration": false, + "use_phone_energy": false, + "use_phone_pitch": false, + "use_spkid": false, + "use_uv": false, + "use_wav": false, + "use_wenet": false, + "utt2emo": "utt2emo", + "utt2spk": "utt2spk", + "uv_dir": "uvs", + "valid_file": "test.json", + "wav_dir": "wavs", + "wenet_dir": "wenet", + "win_size": 1024, + }, + "supported_model_type": [ + "GANVocoder", + "Fastspeech2", + "DiffSVC", + "Transformer", + "EDM", + "CD", + ], + "train": { + "adamw": { + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr": 0.0002, + }, + "batch_size": 4, + "criterions": [ + "feature", + "discriminator", + "generator", + "mel", + ], + "dataloader": { + "num_worker": 4, + "pin_memory": true, + }, + "ddp": true, + "epochs": 50000, + "exponential_lr": { + "lr_decay": 0.999, + }, + "gradient_accumulation_step": 1, + "keep_checkpoint_max": 5, + "max_epoch": 1000000, + "max_steps": 1000000, + "multi_speaker_training": false, + "random_seed": 114514, + "run_eval": [ + true, + ], + "sampler": { + "drop_last": true, + "holistic_shuffle": true, + }, + "save_checkpoint_stride": [ + 200, + ], + "save_checkpoints_steps": 10000, + "save_summary_steps": 500, + "total_training_steps": 50000, + "tracker": [ + "tensorboard", + ], + "valid_interval": 10000, + }, +} \ No newline at end of file diff --git a/pretrained/contentvec/README.md b/pretrained/contentvec/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6ea10938244c7282355be035c9489efa5bf08bdd --- /dev/null +++ b/pretrained/contentvec/README.md @@ -0,0 +1,5 @@ +# Download + +- [Link](https://github.com/auspicious3000/contentvec) +- Model: `ContentVec_legacy` +- Classes: 500 diff --git a/pretrained/contentvec/checkpoint_best_legacy_500.pt b/pretrained/contentvec/checkpoint_best_legacy_500.pt new file mode 100755 index 0000000000000000000000000000000000000000..9a2f13fb9c7047dff746e2d5d88c0d0a5aecf643 --- /dev/null +++ b/pretrained/contentvec/checkpoint_best_legacy_500.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:60d936ec5a566776fc392e69ad8b630d14eb588111233fe313436e200a7b187b +size 1330114945 diff --git a/pretrained/whisper/medium.pt b/pretrained/whisper/medium.pt new file mode 100644 index 0000000000000000000000000000000000000000..8aca41c710014a3d39774cd7592fa086177c672f --- /dev/null +++ b/pretrained/whisper/medium.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1 +size 1528008539 diff --git a/processors/__init__.py b/processors/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/processors/acoustic_extractor.py b/processors/acoustic_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..3c56486336b28a8ae4560d927cf8eb633b7c4513 --- /dev/null +++ b/processors/acoustic_extractor.py @@ -0,0 +1,864 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np + +import json +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler +from utils.io import save_feature, save_txt +from utils.util import has_existed +from utils.tokenizer import extract_encodec_token +from utils.stft import TacotronSTFT +from utils.dsp import compress, audio_to_label +from utils.data_utils import remove_outlier +from preprocessors.metadata import replace_augment_name +from scipy.interpolate import interp1d + +ZERO = 1e-12 + + +def extract_utt_acoustic_features_parallel(metadata, dataset_output, cfg, n_workers=1): + """Extract acoustic features from utterances using muliprocess + + Args: + metadata (dict): dictionary that stores data in train.json and test.json files + dataset_output (str): directory to store acoustic features + cfg (dict): dictionary that stores configurations + n_workers (int, optional): num of processes to extract features in parallel. Defaults to 1. + + Returns: + list: acoustic features + """ + for utt in tqdm(metadata): + if cfg.task_type == "tts": + extract_utt_acoustic_features_tts(dataset_output, cfg, utt) + if cfg.task_type == "svc": + extract_utt_acoustic_features_svc(dataset_output, cfg, utt) + if cfg.task_type == "vocoder": + extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt) + if cfg.task_type == "tta": + extract_utt_acoustic_features_tta(dataset_output, cfg, utt) + + +def avg_phone_feature(feature, duration, interpolation=False): + feature = feature[: sum(duration)] + if interpolation: + nonzero_ids = np.where(feature != 0)[0] + interp_fn = interp1d( + nonzero_ids, + feature[nonzero_ids], + fill_value=(feature[nonzero_ids[0]], feature[nonzero_ids[-1]]), + bounds_error=False, + ) + feature = interp_fn(np.arange(0, len(feature))) + + # Phoneme-level average + pos = 0 + for i, d in enumerate(duration): + if d > 0: + feature[i] = np.mean(feature[pos : pos + d]) + else: + feature[i] = 0 + pos += d + feature = feature[: len(duration)] + return feature + + +def extract_utt_acoustic_features_serial(metadata, dataset_output, cfg): + """Extract acoustic features from utterances (in single process) + + Args: + metadata (dict): dictionary that stores data in train.json and test.json files + dataset_output (str): directory to store acoustic features + cfg (dict): dictionary that stores configurations + + """ + for utt in tqdm(metadata): + if cfg.task_type == "tts": + extract_utt_acoustic_features_tts(dataset_output, cfg, utt) + if cfg.task_type == "svc": + extract_utt_acoustic_features_svc(dataset_output, cfg, utt) + if cfg.task_type == "vocoder": + extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt) + if cfg.task_type == "tta": + extract_utt_acoustic_features_tta(dataset_output, cfg, utt) + + +def __extract_utt_acoustic_features(dataset_output, cfg, utt): + """Extract acoustic features from utterances (in single process) + + Args: + dataset_output (str): directory to store acoustic features + cfg (dict): dictionary that stores configurations + utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index}, + path to utternace, duration, utternace index + + """ + from utils import audio, f0, world, duration + + uid = utt["Uid"] + wav_path = utt["Path"] + if os.path.exists(os.path.join(dataset_output, cfg.preprocess.raw_data)): + wav_path = os.path.join( + dataset_output, cfg.preprocess.raw_data, utt["Singer"], uid + ".wav" + ) + + with torch.no_grad(): + # Load audio data into tensor with sample rate of the config file + wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate) + wav = wav_torch.cpu().numpy() + + # extract features + if cfg.preprocess.extract_duration: + durations, phones, start, end = duration.get_duration( + utt, wav, cfg.preprocess + ) + save_feature(dataset_output, cfg.preprocess.duration_dir, uid, durations) + save_txt(dataset_output, cfg.preprocess.lab_dir, uid, phones) + wav = wav[start:end].astype(np.float32) + wav_torch = torch.from_numpy(wav).to(wav_torch.device) + + if cfg.preprocess.extract_linear_spec: + from utils.mel import extract_linear_features + + linear = extract_linear_features(wav_torch.unsqueeze(0), cfg.preprocess) + save_feature( + dataset_output, cfg.preprocess.linear_dir, uid, linear.cpu().numpy() + ) + + if cfg.preprocess.extract_mel: + from utils.mel import extract_mel_features + + if cfg.preprocess.mel_extract_mode == "taco": + _stft = TacotronSTFT( + sampling_rate=cfg.preprocess.sample_rate, + win_length=cfg.preprocess.win_size, + hop_length=cfg.preprocess.hop_size, + filter_length=cfg.preprocess.n_fft, + n_mel_channels=cfg.preprocess.n_mel, + mel_fmin=cfg.preprocess.fmin, + mel_fmax=cfg.preprocess.fmax, + ) + mel = extract_mel_features( + wav_torch.unsqueeze(0), cfg.preprocess, taco=True, _stft=_stft + ) + if cfg.preprocess.extract_duration: + mel = mel[:, : sum(durations)] + else: + mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess) + save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy()) + + if cfg.preprocess.extract_energy: + if ( + cfg.preprocess.energy_extract_mode == "from_mel" + and cfg.preprocess.extract_mel + ): + energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy() + elif cfg.preprocess.energy_extract_mode == "from_waveform": + energy = audio.energy(wav, cfg.preprocess) + elif cfg.preprocess.energy_extract_mode == "from_tacotron_stft": + _stft = TacotronSTFT( + sampling_rate=cfg.preprocess.sample_rate, + win_length=cfg.preprocess.win_size, + hop_length=cfg.preprocess.hop_size, + filter_length=cfg.preprocess.n_fft, + n_mel_channels=cfg.preprocess.n_mel, + mel_fmin=cfg.preprocess.fmin, + mel_fmax=cfg.preprocess.fmax, + ) + _, energy = audio.get_energy_from_tacotron(wav, _stft) + else: + assert cfg.preprocess.energy_extract_mode in [ + "from_mel", + "from_waveform", + "from_tacotron_stft", + ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]" + if cfg.preprocess.extract_duration: + energy = energy[: sum(durations)] + phone_energy = avg_phone_feature(energy, durations) + save_feature( + dataset_output, cfg.preprocess.phone_energy_dir, uid, phone_energy + ) + + save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy) + + if cfg.preprocess.extract_pitch: + pitch = f0.get_f0(wav, cfg.preprocess) + if cfg.preprocess.extract_duration: + pitch = pitch[: sum(durations)] + phone_pitch = avg_phone_feature(pitch, durations, interpolation=True) + save_feature( + dataset_output, cfg.preprocess.phone_pitch_dir, uid, phone_pitch + ) + save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch) + + if cfg.preprocess.extract_uv: + assert isinstance(pitch, np.ndarray) + uv = pitch != 0 + save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv) + + if cfg.preprocess.extract_audio: + save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav) + + if cfg.preprocess.extract_label: + if cfg.preprocess.is_mu_law: + # compress audio + wav = compress(wav, cfg.preprocess.bits) + label = audio_to_label(wav, cfg.preprocess.bits) + save_feature(dataset_output, cfg.preprocess.label_dir, uid, label) + + if cfg.preprocess.extract_acoustic_token: + if cfg.preprocess.acoustic_token_extractor == "Encodec": + codes = extract_encodec_token(wav_path) + save_feature(dataset_output, cfg.preprocess.acoustic_token_dir, uid, codes) + + +def extract_utt_acoustic_features_tts(dataset_output, cfg, utt): + __extract_utt_acoustic_features(dataset_output, cfg, utt) + + +def extract_utt_acoustic_features_svc(dataset_output, cfg, utt): + __extract_utt_acoustic_features(dataset_output, cfg, utt) + + +def extract_utt_acoustic_features_tta(dataset_output, cfg, utt): + __extract_utt_acoustic_features(dataset_output, cfg, utt) + + +def extract_utt_acoustic_features_vocoder(dataset_output, cfg, utt): + """Extract acoustic features from utterances (in single process) + + Args: + dataset_output (str): directory to store acoustic features + cfg (dict): dictionary that stores configurations + utt (dict): utterance info including dataset, singer, uid:{singer}_{song}_{index}, + path to utternace, duration, utternace index + + """ + from utils import audio, f0, world, duration + + uid = utt["Uid"] + wav_path = utt["Path"] + + with torch.no_grad(): + # Load audio data into tensor with sample rate of the config file + wav_torch, _ = audio.load_audio_torch(wav_path, cfg.preprocess.sample_rate) + wav = wav_torch.cpu().numpy() + + # extract features + if cfg.preprocess.extract_mel: + from utils.mel import extract_mel_features + + mel = extract_mel_features(wav_torch.unsqueeze(0), cfg.preprocess) + save_feature(dataset_output, cfg.preprocess.mel_dir, uid, mel.cpu().numpy()) + + if cfg.preprocess.extract_energy: + if ( + cfg.preprocess.energy_extract_mode == "from_mel" + and cfg.preprocess.extract_mel + ): + energy = (mel.exp() ** 2).sum(0).sqrt().cpu().numpy() + elif cfg.preprocess.energy_extract_mode == "from_waveform": + energy = audio.energy(wav, cfg.preprocess) + else: + assert cfg.preprocess.energy_extract_mode in [ + "from_mel", + "from_waveform", + ], f"{cfg.preprocess.energy_extract_mode} not in supported energy_extract_mode [from_mel, from_waveform, from_tacotron_stft]" + + save_feature(dataset_output, cfg.preprocess.energy_dir, uid, energy) + + if cfg.preprocess.extract_pitch: + pitch = f0.get_f0(wav, cfg.preprocess) + save_feature(dataset_output, cfg.preprocess.pitch_dir, uid, pitch) + + if cfg.preprocess.extract_uv: + assert isinstance(pitch, np.ndarray) + uv = pitch != 0 + save_feature(dataset_output, cfg.preprocess.uv_dir, uid, uv) + + if cfg.preprocess.extract_audio: + save_feature(dataset_output, cfg.preprocess.audio_dir, uid, wav) + + if cfg.preprocess.extract_label: + if cfg.preprocess.is_mu_law: + # compress audio + wav = compress(wav, cfg.preprocess.bits) + label = audio_to_label(wav, cfg.preprocess.bits) + save_feature(dataset_output, cfg.preprocess.label_dir, uid, label) + + +def cal_normalized_mel(mel, dataset_name, cfg): + mel_min, mel_max = load_mel_extrema(cfg, dataset_name) + mel_norm = normalize_mel_channel(mel, mel_min, mel_max) + return mel_norm + + +def cal_mel_min_max(dataset, output_path, cfg, metadata=None): + dataset_output = os.path.join(output_path, dataset) + + if metadata is None: + metadata = [] + for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]: + dataset_file = os.path.join(dataset_output, "{}.json".format(dataset_type)) + with open(dataset_file, "r") as f: + metadata.extend(json.load(f)) + + tmp_mel_min = [] + tmp_mel_max = [] + for item in metadata: + mel_path = os.path.join( + dataset_output, cfg.preprocess.mel_dir, item["Uid"] + ".npy" + ) + if not os.path.exists(mel_path): + continue + mel = np.load(mel_path) + if mel.shape[0] != cfg.preprocess.n_mel: + mel = mel.T + # mel: (n_mels, T) + assert mel.shape[0] == cfg.preprocess.n_mel + + tmp_mel_min.append(np.min(mel, axis=-1)) + tmp_mel_max.append(np.max(mel, axis=-1)) + + mel_min = np.min(tmp_mel_min, axis=0) + mel_max = np.max(tmp_mel_max, axis=0) + + ## save mel min max data + mel_min_max_dir = os.path.join(dataset_output, cfg.preprocess.mel_min_max_stats_dir) + os.makedirs(mel_min_max_dir, exist_ok=True) + + mel_min_path = os.path.join(mel_min_max_dir, "mel_min.npy") + mel_max_path = os.path.join(mel_min_max_dir, "mel_max.npy") + np.save(mel_min_path, mel_min) + np.save(mel_max_path, mel_max) + + +def denorm_for_pred_mels(cfg, dataset_name, split, pred): + """ + Args: + pred: a list whose every element is (frame_len, n_mels) + Return: + similar like pred + """ + mel_min, mel_max = load_mel_extrema(cfg.preprocess, dataset_name) + recovered_mels = [ + denormalize_mel_channel(mel.T, mel_min, mel_max).T for mel in pred + ] + + return recovered_mels + + +def load_mel_extrema(cfg, dataset_name): + data_dir = os.path.join(cfg.processed_dir, dataset_name, cfg.mel_min_max_stats_dir) + + min_file = os.path.join(data_dir, "mel_min.npy") + max_file = os.path.join(data_dir, "mel_max.npy") + + mel_min = np.load(min_file) + mel_max = np.load(max_file) + + return mel_min, mel_max + + +def denormalize_mel_channel(mel, mel_min, mel_max): + mel_min = np.expand_dims(mel_min, -1) + mel_max = np.expand_dims(mel_max, -1) + return (mel + 1) / 2 * (mel_max - mel_min + ZERO) + mel_min + + +def normalize_mel_channel(mel, mel_min, mel_max): + mel_min = np.expand_dims(mel_min, -1) + mel_max = np.expand_dims(mel_max, -1) + return (mel - mel_min) / (mel_max - mel_min + ZERO) * 2 - 1 + + +def normalize(dataset, feat_dir, cfg): + dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset) + print(f"normalize {feat_dir}") + + max_value = np.finfo(np.float64).min + min_value = np.finfo(np.float64).max + + scaler = StandardScaler() + feat_files = os.listdir(os.path.join(dataset_output, feat_dir)) + + for feat_file in tqdm(feat_files): + feat_file = os.path.join(dataset_output, feat_dir, feat_file) + if not feat_file.endswith(".npy"): + continue + feat = np.load(feat_file) + max_value = max(max_value, max(feat)) + min_value = min(min_value, min(feat)) + scaler.partial_fit(feat.reshape((-1, 1))) + mean = scaler.mean_[0] + std = scaler.scale_[0] + stat = np.array([min_value, max_value, mean, std]) + stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy") + np.save(stat_npy, stat) + return mean, std, min_value, max_value + + +def load_normalized(feat_dir, dataset_name, cfg): + dataset_output = os.path.join(cfg.preprocess.processed_dir, dataset_name) + stat_npy = os.path.join(dataset_output, f"{feat_dir}_stat.npy") + min_value, max_value, mean, std = np.load(stat_npy) + return mean, std, min_value, max_value + + +def cal_pitch_statistics_svc(dataset, output_path, cfg, metadata=None): + # path of dataset + dataset_dir = os.path.join(output_path, dataset) + save_dir = os.path.join(dataset_dir, cfg.preprocess.pitch_dir) + os.makedirs(save_dir, exist_ok=True) + if has_existed(os.path.join(save_dir, "statistics.json")): + return + + if metadata is None: + # load singers and ids + singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r")) + + # combine train and test metadata + metadata = [] + for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]: + dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type)) + with open(dataset_file, "r") as f: + metadata.extend(json.load(f)) + else: + singers = list(set([item["Singer"] for item in metadata])) + singers = { + "{}_{}".format(dataset, name): idx for idx, name in enumerate(singers) + } + + # use different scalers for each singer + pitch_scalers = [[] for _ in range(len(singers))] + total_pitch_scalers = [[] for _ in range(len(singers))] + + for utt_info in tqdm(metadata, desc="Loading F0..."): + # utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}' + singer = utt_info["Singer"] + pitch_path = os.path.join( + dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy" + ) + # total_pitch contains all pitch including unvoiced frames + if not os.path.exists(pitch_path): + continue + total_pitch = np.load(pitch_path) + assert len(total_pitch) > 0 + # pitch contains only voiced frames + pitch = total_pitch[total_pitch != 0] + spkid = singers[f"{replace_augment_name(dataset)}_{singer}"] + + # update pitch scalers + pitch_scalers[spkid].extend(pitch.tolist()) + # update total pitch scalers + total_pitch_scalers[spkid].extend(total_pitch.tolist()) + + # save pitch statistics for each singer in dict + sta_dict = {} + for singer in tqdm(singers, desc="Singers statistics"): + spkid = singers[singer] + # voiced pitch statistics + mean, std, min, max, median = ( + np.mean(pitch_scalers[spkid]), + np.std(pitch_scalers[spkid]), + np.min(pitch_scalers[spkid]), + np.max(pitch_scalers[spkid]), + np.median(pitch_scalers[spkid]), + ) + + # total pitch statistics + mean_t, std_t, min_t, max_t, median_t = ( + np.mean(total_pitch_scalers[spkid]), + np.std(total_pitch_scalers[spkid]), + np.min(total_pitch_scalers[spkid]), + np.max(total_pitch_scalers[spkid]), + np.median(total_pitch_scalers[spkid]), + ) + sta_dict[singer] = { + "voiced_positions": { + "mean": mean, + "std": std, + "median": median, + "min": min, + "max": max, + }, + "total_positions": { + "mean": mean_t, + "std": std_t, + "median": median_t, + "min": min_t, + "max": max_t, + }, + } + + # save statistics + with open(os.path.join(save_dir, "statistics.json"), "w") as f: + json.dump(sta_dict, f, indent=4, ensure_ascii=False) + + +def cal_pitch_statistics(dataset, output_path, cfg): + # path of dataset + dataset_dir = os.path.join(output_path, dataset) + if cfg.preprocess.use_phone_pitch: + pitch_dir = cfg.preprocess.phone_pitch_dir + else: + pitch_dir = cfg.preprocess.pitch_dir + save_dir = os.path.join(dataset_dir, pitch_dir) + + os.makedirs(save_dir, exist_ok=True) + if has_existed(os.path.join(save_dir, "statistics.json")): + return + # load singers and ids + singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r")) + + # combine train and test metadata + metadata = [] + for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]: + dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type)) + with open(dataset_file, "r") as f: + metadata.extend(json.load(f)) + + # use different scalers for each singer + pitch_scalers = [[] for _ in range(len(singers))] + total_pitch_scalers = [[] for _ in range(len(singers))] + + for utt_info in metadata: + utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}' + singer = utt_info["Singer"] + pitch_path = os.path.join(dataset_dir, pitch_dir, utt_info["Uid"] + ".npy") + # total_pitch contains all pitch including unvoiced frames + if not os.path.exists(pitch_path): + continue + total_pitch = np.load(pitch_path) + assert len(total_pitch) > 0 + # pitch contains only voiced frames + # pitch = total_pitch[total_pitch != 0] + if cfg.preprocess.pitch_remove_outlier: + pitch = remove_outlier(total_pitch) + spkid = singers[f"{replace_augment_name(dataset)}_{singer}"] + + # update pitch scalers + pitch_scalers[spkid].extend(pitch.tolist()) + # update total pitch scalers + total_pitch_scalers[spkid].extend(total_pitch.tolist()) + + # save pitch statistics for each singer in dict + sta_dict = {} + for singer in singers: + spkid = singers[singer] + # voiced pitch statistics + mean, std, min, max, median = ( + np.mean(pitch_scalers[spkid]), + np.std(pitch_scalers[spkid]), + np.min(pitch_scalers[spkid]), + np.max(pitch_scalers[spkid]), + np.median(pitch_scalers[spkid]), + ) + + # total pitch statistics + mean_t, std_t, min_t, max_t, median_t = ( + np.mean(total_pitch_scalers[spkid]), + np.std(total_pitch_scalers[spkid]), + np.min(total_pitch_scalers[spkid]), + np.max(total_pitch_scalers[spkid]), + np.median(total_pitch_scalers[spkid]), + ) + sta_dict[singer] = { + "voiced_positions": { + "mean": mean, + "std": std, + "median": median, + "min": min, + "max": max, + }, + "total_positions": { + "mean": mean_t, + "std": std_t, + "median": median_t, + "min": min_t, + "max": max_t, + }, + } + + # save statistics + with open(os.path.join(save_dir, "statistics.json"), "w") as f: + json.dump(sta_dict, f, indent=4, ensure_ascii=False) + + +def cal_energy_statistics(dataset, output_path, cfg): + # path of dataset + dataset_dir = os.path.join(output_path, dataset) + if cfg.preprocess.use_phone_energy: + energy_dir = cfg.preprocess.phone_energy_dir + else: + energy_dir = cfg.preprocess.energy_dir + save_dir = os.path.join(dataset_dir, energy_dir) + os.makedirs(save_dir, exist_ok=True) + print(os.path.join(save_dir, "statistics.json")) + if has_existed(os.path.join(save_dir, "statistics.json")): + return + # load singers and ids + singers = json.load(open(os.path.join(dataset_dir, "singers.json"), "r")) + + # combine train and test metadata + metadata = [] + for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]: + dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type)) + with open(dataset_file, "r") as f: + metadata.extend(json.load(f)) + + # use different scalers for each singer + energy_scalers = [[] for _ in range(len(singers))] + total_energy_scalers = [[] for _ in range(len(singers))] + + for utt_info in metadata: + utt = f'{utt_info["Dataset"]}_{utt_info["Uid"]}' + singer = utt_info["Singer"] + energy_path = os.path.join(dataset_dir, energy_dir, utt_info["Uid"] + ".npy") + # total_energy contains all energy including unvoiced frames + if not os.path.exists(energy_path): + continue + total_energy = np.load(energy_path) + assert len(total_energy) > 0 + # energy contains only voiced frames + # energy = total_energy[total_energy != 0] + if cfg.preprocess.energy_remove_outlier: + energy = remove_outlier(total_energy) + spkid = singers[f"{replace_augment_name(dataset)}_{singer}"] + + # update energy scalers + energy_scalers[spkid].extend(energy.tolist()) + # update total energyscalers + total_energy_scalers[spkid].extend(total_energy.tolist()) + + # save energy statistics for each singer in dict + sta_dict = {} + for singer in singers: + spkid = singers[singer] + # voiced energy statistics + mean, std, min, max, median = ( + np.mean(energy_scalers[spkid]), + np.std(energy_scalers[spkid]), + np.min(energy_scalers[spkid]), + np.max(energy_scalers[spkid]), + np.median(energy_scalers[spkid]), + ) + + # total energy statistics + mean_t, std_t, min_t, max_t, median_t = ( + np.mean(total_energy_scalers[spkid]), + np.std(total_energy_scalers[spkid]), + np.min(total_energy_scalers[spkid]), + np.max(total_energy_scalers[spkid]), + np.median(total_energy_scalers[spkid]), + ) + sta_dict[singer] = { + "voiced_positions": { + "mean": mean, + "std": std, + "median": median, + "min": min, + "max": max, + }, + "total_positions": { + "mean": mean_t, + "std": std_t, + "median": median_t, + "min": min_t, + "max": max_t, + }, + } + + # save statistics + with open(os.path.join(save_dir, "statistics.json"), "w") as f: + json.dump(sta_dict, f, indent=4, ensure_ascii=False) + + +def copy_acoustic_features(metadata, dataset_dir, src_dataset_dir, cfg): + """Copy acoustic features from src_dataset_dir to dataset_dir + + Args: + metadata (dict): dictionary that stores data in train.json and test.json files + dataset_dir (str): directory to store acoustic features + src_dataset_dir (str): directory to store acoustic features + cfg (dict): dictionary that stores configurations + + """ + + if cfg.preprocess.extract_mel: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.mel_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.mel_dir), exist_ok=True + ) + print( + "Copying mel features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_mel_path = os.path.join( + src_dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy" + ) + dst_mel_path = os.path.join( + dataset_dir, cfg.preprocess.mel_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_mel_path): + os.symlink(src_mel_path, dst_mel_path) + if cfg.preprocess.extract_energy: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.energy_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.energy_dir), exist_ok=True + ) + print( + "Copying energy features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_energy_path = os.path.join( + src_dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy" + ) + dst_energy_path = os.path.join( + dataset_dir, cfg.preprocess.energy_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_energy_path): + os.symlink(src_energy_path, dst_energy_path) + if cfg.preprocess.extract_pitch: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.pitch_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.pitch_dir), exist_ok=True + ) + print( + "Copying pitch features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_pitch_path = os.path.join( + src_dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy" + ) + dst_pitch_path = os.path.join( + dataset_dir, cfg.preprocess.pitch_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_pitch_path): + os.symlink(src_pitch_path, dst_pitch_path) + if cfg.preprocess.extract_uv: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.uv_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.uv_dir), exist_ok=True + ) + print( + "Copying uv features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_uv_path = os.path.join( + src_dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy" + ) + dst_uv_path = os.path.join( + dataset_dir, cfg.preprocess.uv_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_uv_path): + os.symlink(src_uv_path, dst_uv_path) + if cfg.preprocess.extract_audio: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.audio_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.audio_dir), exist_ok=True + ) + print( + "Copying audio features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_audio_path = os.path.join( + src_dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".npy" + ) + dst_audio_path = os.path.join( + dataset_dir, cfg.preprocess.audio_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_audio_path): + os.symlink(src_audio_path, dst_audio_path) + if cfg.preprocess.extract_label: + if not has_existed(os.path.join(dataset_dir, cfg.preprocess.label_dir)): + os.makedirs( + os.path.join(dataset_dir, cfg.preprocess.label_dir), exist_ok=True + ) + print( + "Copying label features from {} to {}...".format( + src_dataset_dir, dataset_dir + ) + ) + for utt_info in tqdm(metadata): + src_label_path = os.path.join( + src_dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy" + ) + dst_label_path = os.path.join( + dataset_dir, cfg.preprocess.label_dir, utt_info["Uid"] + ".npy" + ) + # create soft-links + if not os.path.exists(dst_label_path): + os.symlink(src_label_path, dst_label_path) + + +def align_duration_mel(dataset, output_path, cfg): + print("align the duration and mel") + + dataset_dir = os.path.join(output_path, dataset) + metadata = [] + for dataset_type in ["train", "test"] if "eval" not in dataset else ["test"]: + dataset_file = os.path.join(dataset_dir, "{}.json".format(dataset_type)) + with open(dataset_file, "r") as f: + metadata.extend(json.load(f)) + + utt2dur = {} + for index in tqdm(range(len(metadata))): + utt_info = metadata[index] + dataset = utt_info["Dataset"] + uid = utt_info["Uid"] + utt = "{}_{}".format(dataset, uid) + + mel_path = os.path.join(dataset_dir, cfg.preprocess.mel_dir, uid + ".npy") + mel = np.load(mel_path).transpose(1, 0) + duration_path = os.path.join( + dataset_dir, cfg.preprocess.duration_dir, uid + ".npy" + ) + duration = np.load(duration_path) + if sum(duration) != mel.shape[0]: + duration_sum = sum(duration) + mel_len = mel.shape[0] + mismatch = abs(duration_sum - mel_len) + assert mismatch <= 5, "duration and mel length mismatch!" + cloned = np.array(duration, copy=True) + if duration_sum > mel_len: + for j in range(1, len(duration) - 1): + if mismatch == 0: + break + dur_val = cloned[-j] + if dur_val >= mismatch: + cloned[-j] -= mismatch + mismatch -= dur_val + break + else: + cloned[-j] = 0 + mismatch -= dur_val + + elif duration_sum < mel_len: + cloned[-1] += mismatch + duration = cloned + utt2dur[utt] = duration + np.save(duration_path, duration) + + return utt2dur diff --git a/processors/content_extractor.py b/processors/content_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..c034b6bf6aac19eb3ab912661110c8066aa3119b --- /dev/null +++ b/processors/content_extractor.py @@ -0,0 +1,540 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import torch +import numpy as np +import yaml +import copy +from tqdm import tqdm +from torchaudio.compliance import kaldi +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import DataLoader +from fairseq import checkpoint_utils +from transformers import AutoModel, Wav2Vec2FeatureExtractor + +from utils.io_optim import ( + TorchaudioDataset, + LibrosaDataset, + FFmpegDataset, + collate_batch, +) +from modules import whisper_extractor as whisper +from modules.wenet_extractor.utils.init_model import init_model +from modules.wenet_extractor.utils.checkpoint import load_checkpoint + +""" + Extractor for content features + 1. whisper + 2. contentvec + 3. wenet + 4. mert + + Pipeline: + in preprocess.py: + call extract_utt_content_features() to extract content features for each utterance + extract_utt_content_features() envelopes the following steps: + 1. load the model (whisper, contentvec, wenet) + 2. extract the content features + 3. save the content features into files + in svc_dataset.py: + call offline_align() to align the content features to the given target length + +""" + +""" + Extractor Usage: + 1. initialize an instance of extractor + extractor = WhisperExtractor(cfg) + 2. load the specified model + extractor.load_model() + 3. extract the content features + extractor.extract_content(utt) for single utterance + extractor.extract_content_batch(utts) for batch utterances + 4. save the content features + extractor.save_feature(utt, content_feature) for single utterance +""" + + +class BaseExtractor: + def __init__(self, cfg): + self.cfg = cfg + self.extractor_type = None + self.model = None + + def offline_align(self, content, target_len): + """ + args: + content: (source_len, dim) + target_len: target length + return: + mapped_feature: (target_len, dim) + """ + target_hop = self.cfg.preprocess.hop_size + + assert self.extractor_type in ["whisper", "contentvec", "wenet"] + if self.extractor_type == "whisper": + source_hop = ( + self.cfg.preprocess.whisper_frameshift + * self.cfg.preprocess.whisper_downsample_rate + * self.cfg.preprocess.sample_rate + ) + elif self.extractor_type == "contentvec": + source_hop = ( + self.cfg.preprocess.contentvec_frameshift + * self.cfg.preprocess.sample_rate + ) + elif self.extractor_type == "wenet": + source_hop = ( + self.cfg.preprocess.wenet_frameshift + * self.cfg.preprocess.wenet_downsample_rate + * self.cfg.preprocess.sample_rate + ) + source_hop = int(source_hop) + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + + # (source_len, 256) + _, width = content.shape + # slice the content from padded feature + source_len = min(target_len * target_hop // source_hop + 1, len(content)) + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(content, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 8: + # err_log_dir is indeterminate + err_log_dir = os.path.join( + self.cfg.preprocess.processed_dir, "align_max_err.log" + ) + try: + with open(err_log_dir, "r") as f: + err_num = int(f.read()) + except: + with open(err_log_dir, "w") as f: + f.write("0") + err_num = 0 + if err > err_num: + with open(err_log_dir, "w") as f: + f.write(str(err)) + + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + mapped_feature = down_sampling_feats[:target_len] + + return mapped_feature + + def save_feature(self, utt, content_feature): + """Save a single utternace to path {cfg.preprocess.processed_dir} + + Args: + utt (dict): one item in metadata, containing information for one utterance + content_feature (tensor): content feature of one utterance + """ + uid = utt["Uid"] + assert self.extractor_type != None + out_dir = os.path.join( + self.cfg.preprocess.processed_dir, utt["Dataset"], self.extractor_type + ) + os.makedirs(out_dir, exist_ok=True) + save_path = os.path.join(out_dir, uid + ".npy") + # only keep effective parts + duration = utt["Duration"] + if self.extractor_type == "whisper": + frameshift = ( + self.cfg.preprocess.whisper_frameshift + * self.cfg.preprocess.whisper_downsample_rate + ) # 20ms + elif self.extractor_type == "contentvec": + frameshift = self.cfg.preprocess.contentvec_frameshift # 20ms + elif self.extractor_type == "wenet": + frameshift = ( + self.cfg.preprocess.wenet_frameshift + * self.cfg.preprocess.wenet_downsample_rate + ) # 40ms + elif self.extractor_type == "mert": + frameshift = self.cfg.preprocess.mert_frameshift + else: + raise NotImplementedError + # calculate the number of valid frames + num_frames = int(np.ceil((duration - frameshift) / frameshift)) + 1 + # (num_frames, dim) -> (valid_frames, dim) + assert ( + len(content_feature.shape) == 2 + ), "content feature shape error, it should be (num_frames, dim)" + content_feature = content_feature[:num_frames, :] + np.save(save_path, content_feature.cpu().detach().numpy()) + + +class WhisperExtractor(BaseExtractor): + def __init__(self, config): + super(WhisperExtractor, self).__init__(config) + self.extractor_type = "whisper" + + def load_model(self): + # load whisper checkpoint + print("Loading Whisper Model...") + + checkpoint_file = ( + self.cfg.preprocess.whisper_model_path + if "whisper_model_path" in self.cfg.preprocess + else None + ) + model = whisper.load_model( + self.cfg.preprocess.whisper_model, checkpoint_file=checkpoint_file + ) + if torch.cuda.is_available(): + print("Using GPU...\n") + model = model.cuda() + else: + print("Using CPU...\n") + + self.model = model.eval() + + def extract_content_features(self, wavs, lens): + """extract content features from a batch of dataloader + Args: + wavs: tensor (batch_size, T) + lens: list + """ + # wavs: (batch, max_len) + wavs = whisper.pad_or_trim(wavs) + # batch_mel: (batch, 80, 3000) + batch_mel = whisper.log_mel_spectrogram(wavs).to(self.model.device) + with torch.no_grad(): + # (batch, 1500, 1024) + features = self.model.embed_audio(batch_mel) + return features + + +class ContentvecExtractor(BaseExtractor): + def __init__(self, cfg): + super(ContentvecExtractor, self).__init__(cfg) + self.extractor_type = "contentvec" + + def load_model(self): + assert self.model == None + # Load model + ckpt_path = self.cfg.preprocess.contentvec_file + print("Load Contentvec Model...") + + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [ckpt_path], + suffix="", + ) + model = models[0] + model.eval() + + if torch.cuda.is_available(): + # print("Using GPU...\n") + model = model.cuda() + + self.model = model + + def extract_content_features(self, wavs, lens): + """extract content features from a batch of dataloader + Args: + wavs: tensor (batch, T) + lens: list + """ + device = next(self.model.parameters()).device + wavs = wavs.to(device) # (batch, max_len) + padding_mask = torch.eq(wavs, torch.zeros_like(wavs)).to(device) + with torch.no_grad(): + logits = self.model.extract_features( + source=wavs, padding_mask=padding_mask, output_layer=12 + ) + # feats: (batch, T, 256) + feats = self.model.final_proj(logits[0]) + return feats + + +class WenetExtractor(BaseExtractor): + def __init__(self, config): + super(WenetExtractor, self).__init__(config) + self.extractor_type = "wenet" + + def load_model(self): + wenet_cfg = self.cfg.preprocess.wenet_config + wenet_model_path = self.cfg.preprocess.wenet_model_path + # load Wenet config + with open(wenet_cfg, "r") as w: + wenet_configs = yaml.load(w, Loader=yaml.FullLoader) + self.extract_conf = copy.deepcopy(wenet_configs["dataset_conf"]) + print("Loading Wenet Model...") + self.model = init_model(wenet_configs) + load_checkpoint(self.model, wenet_model_path) + + if torch.cuda.is_available(): + print("Using GPU...\n") + self.model = self.model.cuda() + else: + print("Using CPU...\n") + + self.model = self.model.eval() + + def extract_content_features(self, wavs, lens): + """extract content features from a batch of dataloader + Args: + wavs: tensor + lens: list + """ + feats_list = [] + lengths_list = [] + + device = next(self.model.parameters()).device + # Extract fbank/mfcc features by kaldi + assert self.extract_conf is not None, "load model first!" + feats_type = self.extract_conf.get("feats_type", "fbank") + assert feats_type in ["fbank", "mfcc"] + + for idx, wav in enumerate(wavs): + # wav: (T) + wav = wav[: lens[idx]].to(device) + + # pad one frame to compensate for the frame cut off after feature extraction + pad_tensor = torch.zeros(160, device=wav.device) + wav = torch.cat((wav, pad_tensor), dim=-1) + wav *= 1 << 15 + + wav = wav.unsqueeze(0) # (T) -> (1, T) + if feats_type == "fbank": + fbank_conf = self.extract_conf.get("fbank_conf", {}) + feat = kaldi.fbank( + wav, + sample_frequency=16000, + num_mel_bins=fbank_conf["num_mel_bins"], + frame_length=fbank_conf["frame_length"], + frame_shift=fbank_conf["frame_shift"], + dither=fbank_conf["dither"], + ) + elif feats_type == "mfcc": + mfcc_conf = self.extract_conf.get("mfcc", {}) + feat = kaldi.mfcc( + wav, + sample_frequency=16000, + num_mel_bins=mfcc_conf["num_mel_bins"], + frame_length=mfcc_conf["frame_length"], + frame_shift=mfcc_conf["frame_shift"], + dither=mfcc_conf["dither"], + num_ceps=mfcc_conf.get("num_ceps", 40), + high_freq=mfcc_conf.get("high_freq", 0.0), + low_freq=mfcc_conf.get("low_freq", 20.0), + ) + feats_list.append(feat) + lengths_list.append(feat.shape[0]) + + feats_lengths = torch.tensor(lengths_list, dtype=torch.int32).to(device) + feats_tensor = pad_sequence(feats_list, batch_first=True).to( + device + ) # (batch, len, 80) + + features = self.model.encoder_extractor( + feats_tensor, + feats_lengths, + decoding_chunk_size=-1, + num_decoding_left_chunks=-1, + simulate_streaming=False, + ) + return features + + +class MertExtractor(BaseExtractor): + def __init__(self, cfg): + super(MertExtractor, self).__init__(cfg) + self.extractor_type = "mert" + self.preprocessor = None + + def load_model(self): + assert self.model == None + assert self.preprocessor == None + + print("Loading MERT Model: ...", self.cfg.preprocess.mert_model) + + local_mert_path = "/mnt/workspace/fangzihao/acce/Amphion/pretrained/MERT" + + model_name = self.cfg.preprocess.mert_model + model = AutoModel.from_pretrained(local_mert_path, trust_remote_code=True) + + if torch.cuda.is_available(): + model = model.cuda() + preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( + local_mert_path, trust_remote_code=True + ) + + self.model = model + self.preprocessor = preprocessor + + def extract_content_features(self, wavs, lens): + """extract content features from a batch of dataloader + Args: + wavs: tensor (batch, T) + lens: list + """ + with torch.no_grad(): + sample_rate = self.preprocessor.sampling_rate + device = next(self.model.parameters()).device + assert ( + sample_rate == self.cfg.preprocess.mert_sample_rate + ), "mert sample rate mismatch, expected {}, got {}".format( + self.cfg.preprocess.mert_sample_rate, sample_rate + ) + mert_features = [] + # wav: (len) + for wav in wavs: + # {input_values: tensor, attention_mask: tensor} + inputs = self.preprocessor( + wavs, sampling_rate=sample_rate, return_tensors="pt" + ).to(device) + + outputs = self.model(**inputs, output_hidden_states=True) + # (25 layers, time steps, 1024 feature_dim) + all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() + # (1, frame_len, 1024) -> (frame_len, 1024) + feature = outputs.hidden_states[ + self.cfg.preprocess.mert_feature_layer + ].squeeze(0) + mert_features.append(feature) + + return mert_features + + +def extract_utt_content_features_dataloader(cfg, metadata, num_workers): + dataset_name = metadata[0]["Dataset"] + + if cfg.preprocess.extract_whisper_feature: + feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "whisper") + os.makedirs(feat_dir, exist_ok=True) + feat_files_num = len(os.listdir(feat_dir)) + + if feat_files_num != len(metadata): + whisper_waveforms = FFmpegDataset( + cfg, dataset_name, cfg.preprocess.whisper_sample_rate, metadata=metadata + ) + data_loader = DataLoader( + whisper_waveforms, + num_workers=num_workers, + shuffle=False, + pin_memory=cfg.preprocess.pin_memory, + batch_size=cfg.preprocess.content_feature_batch_size, + collate_fn=collate_batch, + drop_last=False, + ) + extractor = WhisperExtractor(cfg) + extractor.load_model() + for batch_idx, items in enumerate(tqdm(data_loader)): + _metadata, wavs, lens = items + + batch_content_features = extractor.extract_content_features( + wavs, + lens, + ) + for index, utt in enumerate(_metadata): + extractor.save_feature(utt, batch_content_features[index]) + + if cfg.preprocess.extract_contentvec_feature: + feat_dir = os.path.join( + cfg.preprocess.processed_dir, dataset_name, "contentvec" + ) + os.makedirs(feat_dir, exist_ok=True) + feat_files_num = len(os.listdir(feat_dir)) + + if feat_files_num != len(metadata): + contentvec_waveforms = LibrosaDataset( + cfg, + dataset_name, + cfg.preprocess.contentvec_sample_rate, + metadata=metadata, + ) + data_loader = DataLoader( + contentvec_waveforms, + num_workers=num_workers, + shuffle=False, + pin_memory=cfg.preprocess.pin_memory, + batch_size=cfg.preprocess.content_feature_batch_size, + collate_fn=collate_batch, + drop_last=False, + ) + extractor = ContentvecExtractor(cfg) + extractor.load_model() + for batch_idx, items in enumerate(tqdm(data_loader)): + _metadata, wavs, lens = items + + batch_content_features = extractor.extract_content_features(wavs, lens) + for index, utt in enumerate(_metadata): + extractor.save_feature(utt, batch_content_features[index]) + + if cfg.preprocess.extract_wenet_feature: + feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "wenet") + os.makedirs(feat_dir, exist_ok=True) + feat_files_num = len(os.listdir(feat_dir)) + + if feat_files_num != len(metadata): + wenet_waveforms = TorchaudioDataset( + cfg, dataset_name, cfg.preprocess.wenet_sample_rate, metadata=metadata + ) + data_loader = DataLoader( + wenet_waveforms, + num_workers=num_workers, + shuffle=False, + pin_memory=cfg.preprocess.pin_memory, + batch_size=cfg.preprocess.content_feature_batch_size, + collate_fn=collate_batch, + drop_last=False, + ) + extractor = WenetExtractor(cfg) + extractor.load_model() + for batch_idx, items in enumerate(tqdm(data_loader)): + _metadata, wavs, lens = items + + batch_content_features = extractor.extract_content_features( + wavs, + lens, + ) + for index, utt in enumerate(_metadata): + extractor.save_feature(utt, batch_content_features[index]) + + if cfg.preprocess.extract_mert_feature: + feat_dir = os.path.join(cfg.preprocess.processed_dir, dataset_name, "mert") + os.makedirs(feat_dir, exist_ok=True) + feat_files_num = len(os.listdir(feat_dir)) + + if feat_files_num != len(metadata): + mert_waveforms = TorchaudioDataset( + cfg, dataset_name, cfg.preprocess.mert_sample_rate, metadata=metadata + ) + data_loader = DataLoader( + mert_waveforms, + num_workers=num_workers, + shuffle=False, + pin_memory=cfg.preprocess.pin_memory, + batch_size=cfg.preprocess.content_feature_batch_size, + collate_fn=collate_batch, + drop_last=False, + ) + extractor = MertExtractor(cfg) + extractor.load_model() + for batch_idx, items in enumerate(tqdm(data_loader)): + _metadata, wavs, lens = items + + batch_content_features = extractor.extract_content_features( + wavs, + lens, + ) + for index, utt in enumerate(_metadata): + extractor.save_feature(utt, batch_content_features[index]) diff --git a/processors/data_augment.py b/processors/data_augment.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc183361d4bcfd454693ee0b7ffdd9758c09312 --- /dev/null +++ b/processors/data_augment.py @@ -0,0 +1,378 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +import random +import os +import json + +import numpy as np +import parselmouth +import torch +import torchaudio +from tqdm import tqdm + +from audiomentations import TimeStretch + +from pedalboard import ( + Pedalboard, + HighShelfFilter, + LowShelfFilter, + PeakFilter, + PitchShift, +) + +from utils.util import has_existed + +PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT = 0.0 +PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT = 1.0 +PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT = 1.0 +PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT = 1.0 +PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT = 1.0 + + +def wav_to_Sound(wav, sr: int) -> parselmouth.Sound: + """Convert a waveform to a parselmouth.Sound object + + Args: + wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples) + sr (int, optional): sampling rate. + + Returns: + parselmouth.Sound: a parselmouth.Sound object + """ + assert wav.shape == (1, len(wav[0])), "wav must be of shape (1, n_samples)" + sound = None + if isinstance(wav, np.ndarray): + sound = parselmouth.Sound(wav[0], sampling_frequency=sr) + elif isinstance(wav, torch.Tensor): + sound = parselmouth.Sound(wav[0].numpy(), sampling_frequency=sr) + assert sound is not None, "wav must be either np.ndarray or torch.Tensor" + return sound + + +def get_pitch_median(wav, sr: int): + """Get the median pitch of a waveform + + Args: + wav (np.ndarray/torch.Tensor): waveform of shape (n_channels, n_samples) + sr (int, optional): sampling rate. + + Returns: + parselmouth.Pitch, float: a parselmouth.Pitch object and the median pitch + """ + if not isinstance(wav, parselmouth.Sound): + sound = wav_to_Sound(wav, sr) + else: + sound = wav + pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT + + # To Pitch: Time step(s)(standard value: 0.0), Pitch floor (Hz)(standard value: 75), Pitch ceiling (Hz)(standard value: 600.0) + pitch = parselmouth.praat.call(sound, "To Pitch", 0.8 / 75, 75, 600) + # Get quantile: From time (s), To time (s), Quantile(0.5 is then the 50% quantile, i.e., the median), Units (Hertz or Bark) + pitch_median = parselmouth.praat.call(pitch, "Get quantile", 0.0, 0.0, 0.5, "Hertz") + + return pitch, pitch_median + + +def change_gender( + sound, + pitch=None, + formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT, + new_pitch_median: float = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT, + pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT, + duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT, +) -> parselmouth.Sound: + """Invoke change gender function in praat + + Args: + sound (parselmouth.Sound): a parselmouth.Sound object + pitch (parselmouth.Pitch, optional): a parselmouth.Pitch object. Defaults to None. + formant_shift_ratio (float, optional): formant shift ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch. Less than 1.0 means lower pitch. + new_pitch_median (float, optional): new pitch median. + pitch_range_ratio (float, optional): pitch range ratio. A value of 1.0 means no change. Greater than 1.0 means higher pitch range. Less than 1.0 means lower pitch range. + duration_factor (float, optional): duration factor. A value of 1.0 means no change. Greater than 1.0 means longer duration. Less than 1.0 means shorter duration. + + Returns: + parselmouth.Sound: a parselmouth.Sound object + """ + if pitch is None: + new_sound = parselmouth.praat.call( + sound, + "Change gender", + 75, + 600, + formant_shift_ratio, + new_pitch_median, + pitch_range_ratio, + duration_factor, + ) + else: + new_sound = parselmouth.praat.call( + (sound, pitch), + "Change gender", + formant_shift_ratio, + new_pitch_median, + pitch_range_ratio, + duration_factor, + ) + return new_sound + + +def apply_formant_and_pitch_shift( + sound: parselmouth.Sound, + formant_shift_ratio: float = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT, + pitch_shift_ratio: float = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT, + pitch_range_ratio: float = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT, + duration_factor: float = PRAAT_CHANGEGENDER_DURATIONFACTOR_DEFAULT, +) -> parselmouth.Sound: + """use Praat "Changer gender" command to manipulate pitch and formant + "Change gender": Praat -> Sound Object -> Convert -> Change gender + refer to Help of Praat for more details + # https://github.com/YannickJadoul/Parselmouth/issues/25#issuecomment-608632887 might help + """ + pitch = None + new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT + if pitch_shift_ratio != 1.0: + pitch, pitch_median = get_pitch_median(sound, sound.sampling_frequency) + new_pitch_median = pitch_median * pitch_shift_ratio + + # refer to https://github.com/praat/praat/issues/1926#issuecomment-974909408 + pitch_minimum = parselmouth.praat.call( + pitch, "Get minimum", 0.0, 0.0, "Hertz", "Parabolic" + ) + new_median = pitch_median * pitch_shift_ratio + scaled_minimum = pitch_minimum * pitch_shift_ratio + result_minimum = new_median + (scaled_minimum - new_median) * pitch_range_ratio + if result_minimum < 0: + new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT + pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT + + if math.isnan(new_pitch_median): + new_pitch_median = PRAAT_CHANGEGENDER_PITCHMEDIAN_DEFAULT + pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT + + new_sound = change_gender( + sound, + pitch, + formant_shift_ratio, + new_pitch_median, + pitch_range_ratio, + duration_factor, + ) + return new_sound + + +# Function used in EQ +def pedalboard_equalizer(wav: np.ndarray, sr: int) -> np.ndarray: + """Use pedalboard to do equalizer""" + board = Pedalboard() + + cutoff_low_freq = 60 + cutoff_high_freq = 10000 + + q_min = 2 + q_max = 5 + + random_all_freq = True + num_filters = 10 + if random_all_freq: + key_freqs = [random.uniform(1, 12000) for _ in range(num_filters)] + else: + key_freqs = [ + power_ratio(float(z) / (num_filters - 1), cutoff_low_freq, cutoff_high_freq) + for z in range(num_filters) + ] + q_values = [ + power_ratio(random.uniform(0, 1), q_min, q_max) for _ in range(num_filters) + ] + gains = [random.uniform(-12, 12) for _ in range(num_filters)] + # low-shelving filter + board.append( + LowShelfFilter( + cutoff_frequency_hz=key_freqs[0], gain_db=gains[0], q=q_values[0] + ) + ) + # peaking filters + for i in range(1, 9): + board.append( + PeakFilter( + cutoff_frequency_hz=key_freqs[i], gain_db=gains[i], q=q_values[i] + ) + ) + # high-shelving filter + board.append( + HighShelfFilter( + cutoff_frequency_hz=key_freqs[9], gain_db=gains[9], q=q_values[9] + ) + ) + + # Apply the pedalboard to the audio + processed_audio = board(wav, sr) + return processed_audio + + +def power_ratio(r: float, a: float, b: float): + return a * math.pow((b / a), r) + + +def audiomentations_time_stretch(wav: np.ndarray, sr: int) -> np.ndarray: + """Use audiomentations to do time stretch""" + transform = TimeStretch( + min_rate=0.8, max_rate=1.25, leave_length_unchanged=False, p=1.0 + ) + augmented_wav = transform(wav, sample_rate=sr) + return augmented_wav + + +def formant_and_pitch_shift( + sound: parselmouth.Sound, fs: bool, ps: bool +) -> parselmouth.Sound: + """ """ + formant_shift_ratio = PRAAT_CHANGEGENDER_FORMANTSHIFTRATIO_DEFAULT + pitch_shift_ratio = PRAAT_CHANGEGENDER_PITCHSHIFTRATIO_DEFAULT + pitch_range_ratio = PRAAT_CHANGEGENDER_PITCHRANGERATIO_DEFAULT + + assert fs != ps, "fs, ps are mutually exclusive" + + if fs: + formant_shift_ratio = random.uniform(1.0, 1.4) + use_reciprocal = random.uniform(-1, 1) > 0 + if use_reciprocal: + formant_shift_ratio = 1.0 / formant_shift_ratio + # only use praat to change formant + new_sound = apply_formant_and_pitch_shift( + sound, + formant_shift_ratio=formant_shift_ratio, + ) + return new_sound + + if ps: + board = Pedalboard() + board.append(PitchShift(random.uniform(-12, 12))) + wav_numpy = sound.values + wav_numpy = board(wav_numpy, sound.sampling_frequency) + # use pedalboard to change pitch + new_sound = parselmouth.Sound( + wav_numpy, sampling_frequency=sound.sampling_frequency + ) + return new_sound + + +def wav_manipulation( + wav: torch.Tensor, + sr: int, + aug_type: str = "None", + formant_shift: bool = False, + pitch_shift: bool = False, + time_stretch: bool = False, + equalizer: bool = False, +) -> torch.Tensor: + assert aug_type == "None" or aug_type in [ + "formant_shift", + "pitch_shift", + "time_stretch", + "equalizer", + ], "aug_type must be one of formant_shift, pitch_shift, time_stretch, equalizer" + + assert aug_type == "None" or ( + formant_shift == False + and pitch_shift == False + and time_stretch == False + and equalizer == False + ), "if aug_type is specified, other argument must be False" + + if aug_type != "None": + if aug_type == "formant_shift": + formant_shift = True + if aug_type == "pitch_shift": + pitch_shift = True + if aug_type == "equalizer": + equalizer = True + if aug_type == "time_stretch": + time_stretch = True + + wav_numpy = wav.numpy() + + if equalizer: + wav_numpy = pedalboard_equalizer(wav_numpy, sr) + + if time_stretch: + wav_numpy = audiomentations_time_stretch(wav_numpy, sr) + + sound = wav_to_Sound(wav_numpy, sr) + + if formant_shift or pitch_shift: + sound = formant_and_pitch_shift(sound, formant_shift, pitch_shift) + + wav = torch.from_numpy(sound.values).float() + # shape (1, n_samples) + return wav + + +def augment_dataset(cfg, dataset) -> list: + """Augment dataset with formant_shift, pitch_shift, time_stretch, equalizer + + Args: + cfg (dict): configuration + dataset (str): dataset name + + Returns: + list: augmented dataset names + """ + # load metadata + dataset_path = os.path.join(cfg.preprocess.processed_dir, dataset) + split = ["train", "test"] if "eval" not in dataset else ["test"] + augment_datasets = [] + aug_types = [ + "formant_shift" if cfg.preprocess.use_formant_shift else None, + "pitch_shift" if cfg.preprocess.use_pitch_shift else None, + "time_stretch" if cfg.preprocess.use_time_stretch else None, + "equalizer" if cfg.preprocess.use_equalizer else None, + ] + aug_types = filter(None, aug_types) + for aug_type in aug_types: + print("Augmenting {} with {}...".format(dataset, aug_type)) + new_dataset = dataset + "_" + aug_type + augment_datasets.append(new_dataset) + new_dataset_path = os.path.join(cfg.preprocess.processed_dir, new_dataset) + + for dataset_type in split: + metadata_path = os.path.join(dataset_path, "{}.json".format(dataset_type)) + augmented_metadata = [] + new_metadata_path = os.path.join( + new_dataset_path, "{}.json".format(dataset_type) + ) + os.makedirs(new_dataset_path, exist_ok=True) + new_dataset_wav_dir = os.path.join(new_dataset_path, "wav") + os.makedirs(new_dataset_wav_dir, exist_ok=True) + + if has_existed(new_metadata_path): + continue + + with open(metadata_path, "r") as f: + metadata = json.load(f) + + for utt in tqdm(metadata): + original_wav_path = utt["Path"] + original_wav, sr = torchaudio.load(original_wav_path) + new_wav = wav_manipulation(original_wav, sr, aug_type=aug_type) + new_wav_path = os.path.join(new_dataset_wav_dir, utt["Uid"] + ".wav") + torchaudio.save(new_wav_path, new_wav, sr) + new_utt = { + "Dataset": utt["Dataset"] + "_" + aug_type, + "index": utt["index"], + "Singer": utt["Singer"], + "Uid": utt["Uid"], + "Path": new_wav_path, + "Duration": utt["Duration"], + } + augmented_metadata.append(new_utt) + new_metadata_path = os.path.join( + new_dataset_path, "{}.json".format(dataset_type) + ) + with open(new_metadata_path, "w") as f: + json.dump(augmented_metadata, f, indent=4, ensure_ascii=False) + return augment_datasets diff --git a/processors/phone_extractor.py b/processors/phone_extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c53a79decf6c8c8e5ee68c5d8e05d878564f6a --- /dev/null +++ b/processors/phone_extractor.py @@ -0,0 +1,142 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +from tqdm import tqdm +from text.g2p_module import G2PModule, LexiconModule +from text.symbol_table import SymbolTable + +''' + phoneExtractor: extract phone from text +''' +class phoneExtractor: + def __init__(self, cfg, dataset_name=None, phone_symbol_file=None): + ''' + Args: + cfg: config + dataset_name: name of dataset + ''' + self.cfg = cfg + + # phone symbols dict + self.phone_symbols = set() + + # phone symbols dict file + if phone_symbol_file is not None: + self.phone_symbols_file = phone_symbol_file + elif dataset_name is not None: + self.dataset_name = dataset_name + self.phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, + dataset_name, + cfg.preprocess.symbols_dict) + + + # initialize g2p module + if cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]: + self.g2p_module = G2PModule(backend=cfg.preprocess.phone_extractor) + elif cfg.preprocess.phone_extractor == 'lexicon': + assert cfg.preprocess.lexicon_path != "" + self.g2p_module = LexiconModule(cfg.preprocess.lexicon_path) + else: + print('No suppert to', cfg.preprocess.phone_extractor) + raise + + + def extract_phone(self, text): + ''' + Extract phone from text + Args: + + text: text of utterance + + Returns: + phone_symbols: set of phone symbols + phone_seq: list of phone sequence of each utterance + ''' + + if self.cfg.preprocess.phone_extractor in ["espeak", "pypinyin", "pypinyin_initials_finals"]: + text = text.replace("”", '"').replace("“", '"') + phone = self.g2p_module.g2p_conversion(text=text) + self.phone_symbols.update(phone) + phone_seq = [phn for phn in phone] + + elif self.cfg.preprocess.phone_extractor == 'lexicon': + phone_seq = self.g2p_module.g2p_conversion(text) + phone = phone_seq + if not isinstance(phone_seq, list): + phone_seq = phone_seq.split() + + return phone_seq + + def save_dataset_phone_symbols_to_table(self): + # load and merge saved phone symbols + if os.path.exists(self.phone_symbols_file): + phone_symbol_dict_saved = SymbolTable.from_file(self.phone_symbols_file)._sym2id.keys() + self.phone_symbols.update(set(phone_symbol_dict_saved)) + + # save phone symbols + phone_symbol_dict = SymbolTable() + for s in sorted(list(self.phone_symbols)): + phone_symbol_dict.add(s) + phone_symbol_dict.to_file(self.phone_symbols_file) + + +def extract_utt_phone_sequence(cfg, metadata): + ''' + Extract phone sequence from text + Args: + cfg: config + metadata: list of dict, each dict contains "Uid", "Text" + + ''' + + dataset_name = cfg.dataset[0] + + # output path + out_path = os.path.join(cfg.preprocess.processed_dir, dataset_name, cfg.preprocess.phone_dir) + os.makedirs(out_path, exist_ok=True) + + phone_extractor = phoneExtractor(cfg, dataset_name) + + for utt in tqdm(metadata): + uid = utt["Uid"] + text = utt["Text"] + + phone_seq = phone_extractor.extract_phone(text) + + phone_path = os.path.join(out_path, uid+'.phone') + with open(phone_path, 'w') as fin: + fin.write(' '.join(phone_seq)) + + if cfg.preprocess.phone_extractor != 'lexicon': + phone_extractor.save_dataset_phone_symbols_to_table() + + + +def save_all_dataset_phone_symbols_to_table(self, cfg, dataset): + # phone symbols dict + phone_symbols = set() + + for dataset_name in dataset: + phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, + dataset_name, + cfg.preprocess.symbols_dict) + + # load and merge saved phone symbols + assert os.path.exists(phone_symbols_file) + phone_symbol_dict_saved = SymbolTable.from_file(phone_symbols_file)._sym2id.keys() + phone_symbols.update(set(phone_symbol_dict_saved)) + + # save all phone symbols to each dataset + phone_symbol_dict = SymbolTable() + for s in sorted(list(phone_symbols)): + phone_symbol_dict.add(s) + for dataset_name in dataset: + phone_symbols_file = os.path.join(cfg.preprocess.processed_dir, + dataset_name, + cfg.preprocess.symbols_dict) + phone_symbol_dict.to_file(phone_symbols_file) + + \ No newline at end of file diff --git a/utils/HyperParams/__init__.py b/utils/HyperParams/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..706e31b1c499f11548d08d38e1c3091aeb2dadaa --- /dev/null +++ b/utils/HyperParams/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .hps import HyperParams diff --git a/utils/HyperParams/hps.py b/utils/HyperParams/hps.py new file mode 100644 index 0000000000000000000000000000000000000000..cc6f474c4e28d092ea8ba2cdcf7233322c5ff281 --- /dev/null +++ b/utils/HyperParams/hps.py @@ -0,0 +1,43 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class HyperParams: + """The class to store hyperparameters. The key is case-insensitive. + + Args: + *args: a list of dict or HyperParams. + **kwargs: a list of key-value pairs. + """ + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = HyperParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/audio.py b/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..374d50915cafa106a3035e77b99adc96e8484f0b --- /dev/null +++ b/utils/audio.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import numpy as np +from numpy import linalg as LA +import librosa +import soundfile as sf +import librosa.filters + + +def load_audio_torch(wave_file, fs): + """Load audio data into torch tensor + + Args: + wave_file (str): path to wave file + fs (int): sample rate + + Returns: + audio (tensor): audio data in tensor + fs (int): sample rate + """ + + audio, sample_rate = librosa.load(wave_file, sr=fs, mono=True) + # audio: (T,) + assert len(audio) > 2 + + # Check the audio type (for soundfile loading backbone) - float, 8bit or 16bit + if np.issubdtype(audio.dtype, np.integer): + max_mag = -np.iinfo(audio.dtype).min + else: + max_mag = max(np.amax(audio), -np.amin(audio)) + max_mag = ( + (2**31) + 1 + if max_mag > (2**15) + else ((2**15) + 1 if max_mag > 1.01 else 1.0) + ) + + # Normalize the audio + audio = torch.FloatTensor(audio.astype(np.float32)) / max_mag + + if (torch.isnan(audio) | torch.isinf(audio)).any(): + return [], sample_rate or fs or 48000 + + # Resample the audio to our target samplerate + if fs is not None and fs != sample_rate: + audio = torch.from_numpy( + librosa.core.resample(audio.numpy(), orig_sr=sample_rate, target_sr=fs) + ) + sample_rate = fs + + return audio, fs + + +def _stft(y, cfg): + return librosa.stft( + y=y, n_fft=cfg.n_fft, hop_length=cfg.hop_size, win_length=cfg.win_size + ) + + +def energy(wav, cfg): + D = _stft(wav, cfg) + magnitudes = np.abs(D).T # [F, T] + return LA.norm(magnitudes, axis=1) + + +def get_energy_from_tacotron(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + mel, energy = _stft.mel_spectrogram(audio) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return mel, energy diff --git a/utils/audio_slicer.py b/utils/audio_slicer.py new file mode 100644 index 0000000000000000000000000000000000000000..28474596b42c8f8215b878a80967112960d0c9e0 --- /dev/null +++ b/utils/audio_slicer.py @@ -0,0 +1,476 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import json +import numpy as np +from tqdm import tqdm +import torch +import torchaudio + +from utils.io import save_audio +from utils.audio import load_audio_torch + + +# This function is obtained from librosa. +def get_rms( + y, + *, + frame_length=2048, + hop_length=512, + pad_mode="constant", +): + padding = (int(frame_length // 2), int(frame_length // 2)) + y = np.pad(y, padding, mode=pad_mode) + + axis = -1 + # put our new within-frame axis at the end for now + out_strides = y.strides + tuple([y.strides[axis]]) + # Reduce the shape on the framing axis + x_shape_trimmed = list(y.shape) + x_shape_trimmed[axis] -= frame_length - 1 + out_shape = tuple(x_shape_trimmed) + tuple([frame_length]) + xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides) + if axis < 0: + target_axis = axis - 1 + else: + target_axis = axis + 1 + xw = np.moveaxis(xw, -1, target_axis) + # Downsample along the target axis + slices = [slice(None)] * xw.ndim + slices[axis] = slice(0, None, hop_length) + x = xw[tuple(slices)] + + # Calculate power + power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True) + + return np.sqrt(power) + + +class Slicer: + """ + Copy from: https://github.com/openvpi/audio-slicer/blob/main/slicer2.py + """ + + def __init__( + self, + sr: int, + threshold: float = -40.0, + min_length: int = 5000, + min_interval: int = 300, + hop_size: int = 10, + max_sil_kept: int = 5000, + ): + if not min_length >= min_interval >= hop_size: + raise ValueError( + "The following condition must be satisfied: min_length >= min_interval >= hop_size" + ) + if not max_sil_kept >= hop_size: + raise ValueError( + "The following condition must be satisfied: max_sil_kept >= hop_size" + ) + min_interval = sr * min_interval / 1000 + self.threshold = 10 ** (threshold / 20.0) + self.hop_size = round(sr * hop_size / 1000) + self.win_size = min(round(min_interval), 4 * self.hop_size) + self.min_length = round(sr * min_length / 1000 / self.hop_size) + self.min_interval = round(min_interval / self.hop_size) + self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size) + + def _apply_slice(self, waveform, begin, end): + begin = begin * self.hop_size + if len(waveform.shape) > 1: + end = min(waveform.shape[1], end * self.hop_size) + return waveform[:, begin:end], begin, end + else: + end = min(waveform.shape[0], end * self.hop_size) + return waveform[begin:end], begin, end + + # @timeit + def slice(self, waveform, return_chunks_positions=False): + if len(waveform.shape) > 1: + # (#channle, wave_len) -> (wave_len) + samples = waveform.mean(axis=0) + else: + samples = waveform + if samples.shape[0] <= self.min_length: + return [waveform] + rms_list = get_rms( + y=samples, frame_length=self.win_size, hop_length=self.hop_size + ).squeeze(0) + sil_tags = [] + silence_start = None + clip_start = 0 + for i, rms in enumerate(rms_list): + # Keep looping while frame is silent. + if rms < self.threshold: + # Record start of silent frames. + if silence_start is None: + silence_start = i + continue + # Keep looping while frame is not silent and silence start has not been recorded. + if silence_start is None: + continue + # Clear recorded silence start if interval is not enough or clip is too short + is_leading_silence = silence_start == 0 and i > self.max_sil_kept + need_slice_middle = ( + i - silence_start >= self.min_interval + and i - clip_start >= self.min_length + ) + if not is_leading_silence and not need_slice_middle: + silence_start = None + continue + # Need slicing. Record the range of silent frames to be removed. + if i - silence_start <= self.max_sil_kept: + pos = rms_list[silence_start : i + 1].argmin() + silence_start + if silence_start == 0: + sil_tags.append((0, pos)) + else: + sil_tags.append((pos, pos)) + clip_start = pos + elif i - silence_start <= self.max_sil_kept * 2: + pos = rms_list[ + i - self.max_sil_kept : silence_start + self.max_sil_kept + 1 + ].argmin() + pos += i - self.max_sil_kept + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + clip_start = pos_r + else: + sil_tags.append((min(pos_l, pos), max(pos_r, pos))) + clip_start = max(pos_r, pos) + else: + pos_l = ( + rms_list[ + silence_start : silence_start + self.max_sil_kept + 1 + ].argmin() + + silence_start + ) + pos_r = ( + rms_list[i - self.max_sil_kept : i + 1].argmin() + + i + - self.max_sil_kept + ) + if silence_start == 0: + sil_tags.append((0, pos_r)) + else: + sil_tags.append((pos_l, pos_r)) + clip_start = pos_r + silence_start = None + # Deal with trailing silence. + total_frames = rms_list.shape[0] + if ( + silence_start is not None + and total_frames - silence_start >= self.min_interval + ): + silence_end = min(total_frames, silence_start + self.max_sil_kept) + pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start + sil_tags.append((pos, total_frames + 1)) + # Apply and return slices. + if len(sil_tags) == 0: + return [waveform] + else: + chunks = [] + chunks_pos_of_waveform = [] + + if sil_tags[0][0] > 0: + chunk, begin, end = self._apply_slice(waveform, 0, sil_tags[0][0]) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + for i in range(len(sil_tags) - 1): + chunk, begin, end = self._apply_slice( + waveform, sil_tags[i][1], sil_tags[i + 1][0] + ) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + if sil_tags[-1][1] < total_frames: + chunk, begin, end = self._apply_slice( + waveform, sil_tags[-1][1], total_frames + ) + chunks.append(chunk) + chunks_pos_of_waveform.append((begin, end)) + + return ( + chunks + if not return_chunks_positions + else ( + chunks, + chunks_pos_of_waveform, + ) + ) + + +def split_utterances_from_audio( + wav_file, + output_dir, + max_duration_of_utterance=10.0, + min_interval=300, + db_threshold=-40, +): + """ + Split a long audio into utterances accoring to the silence (VAD). + + max_duration_of_utterance (second): + The maximum duration of every utterance (seconds) + min_interval (millisecond): + The smaller min_interval is, the more sliced audio clips this script is likely to generate. + """ + print("File:", wav_file.split("/")[-1]) + waveform, fs = torchaudio.load(wav_file) + + slicer = Slicer(sr=fs, min_interval=min_interval, threshold=db_threshold) + chunks, positions = slicer.slice(waveform, return_chunks_positions=True) + + durations = [(end - begin) / fs for begin, end in positions] + print( + "Slicer's min silence part is {}ms, min and max duration of sliced utterances is {}s and {}s".format( + min_interval, min(durations), max(durations) + ) + ) + + res_chunks, res_positions = [], [] + for i, chunk in enumerate(chunks): + if len(chunk.shape) == 1: + chunk = chunk[None, :] + + begin, end = positions[i] + assert end - begin == chunk.shape[-1] + + max_wav_len = max_duration_of_utterance * fs + if chunk.shape[-1] <= max_wav_len: + res_chunks.append(chunk) + res_positions.append(positions[i]) + else: + # TODO: to reserve overlapping and conduct fade-in, fade-out + + # Get segments number + number = 2 + while chunk.shape[-1] // number >= max_wav_len: + number += 1 + seg_len = chunk.shape[-1] // number + + # Split + for num in range(number): + s = seg_len * num + t = min(s + seg_len, chunk.shape[-1]) + + seg_begin = begin + s + seg_end = begin + t + + res_chunks.append(chunk[:, s:t]) + res_positions.append((seg_begin, seg_end)) + + # Save utterances + os.makedirs(output_dir, exist_ok=True) + res = {"fs": int(fs)} + for i, chunk in enumerate(res_chunks): + filename = "{:04d}.wav".format(i) + res[filename] = [int(p) for p in res_positions[i]] + save_audio(os.path.join(output_dir, filename), chunk, fs) + + # Save positions + with open(os.path.join(output_dir, "positions.json"), "w") as f: + json.dump(res, f, indent=4, ensure_ascii=False) + return res + + +def is_silence( + wavform, + fs, + threshold=-40.0, + min_interval=300, + hop_size=10, + min_length=5000, +): + """ + Detect whether the given wavform is a silence + + wavform: (T, ) + """ + threshold = 10 ** (threshold / 20.0) + + hop_size = round(fs * hop_size / 1000) + win_size = min(round(min_interval), 4 * hop_size) + min_length = round(fs * min_length / 1000 / hop_size) + + if wavform.shape[0] <= min_length: + return True + + # (#Frame,) + rms_array = get_rms(y=wavform, frame_length=win_size, hop_length=hop_size).squeeze( + 0 + ) + return (rms_array < threshold).all() + + +def split_audio( + wav_file, target_sr, output_dir, max_duration_of_segment=10.0, overlap_duration=1.0 +): + """ + Split a long audio into segments. + + target_sr: + The target sampling rate to save the segments. + max_duration_of_utterance (second): + The maximum duration of every utterance (second) + overlap_duraion: + Each segment has "overlap duration" (second) overlap with its previous and next segment + """ + # (#channel, T) -> (T,) + waveform, fs = torchaudio.load(wav_file) + waveform = torchaudio.functional.resample( + waveform, orig_freq=fs, new_freq=target_sr + ) + waveform = torch.mean(waveform, dim=0) + + # waveform, _ = load_audio_torch(wav_file, target_sr) + assert len(waveform.shape) == 1 + + assert overlap_duration < max_duration_of_segment + length = int(max_duration_of_segment * target_sr) + stride = int((max_duration_of_segment - overlap_duration) * target_sr) + chunks = [] + for i in range(0, len(waveform), stride): + # (length,) + chunks.append(waveform[i : i + length]) + if i + length >= len(waveform): + break + + # Save segments + os.makedirs(output_dir, exist_ok=True) + results = [] + for i, chunk in enumerate(chunks): + uid = "{:04d}".format(i) + filename = os.path.join(output_dir, "{}.wav".format(uid)) + results.append( + {"Uid": uid, "Path": filename, "Duration": len(chunk) / target_sr} + ) + save_audio( + filename, + chunk, + target_sr, + turn_up=not is_silence(chunk, target_sr), + add_silence=False, + ) + + return results + + +def merge_segments_torchaudio(wav_files, fs, output_path, overlap_duration=1.0): + """Merge the given wav_files (may have overlaps) into a long audio + + fs: + The sampling rate of the wav files. + output_path: + The output path to save the merged audio. + overlap_duration (float, optional): + Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. + """ + + waveforms = [] + for file in wav_files: + # (T,) + waveform, _ = load_audio_torch(file, fs) + waveforms.append(waveform) + + if len(waveforms) == 1: + save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) + return + + overlap_len = int(overlap_duration * fs) + fade_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) + fade_in = torchaudio.transforms.Fade(fade_in_len=overlap_len) + fade_in_and_out = torchaudio.transforms.Fade(fade_out_len=overlap_len) + + segments_lens = [len(wav) for wav in waveforms] + merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) + merged_waveform = torch.zeros(merged_waveform_len) + + start = 0 + for index, wav in enumerate( + tqdm(waveforms, desc="Merge for {}".format(output_path)) + ): + wav_len = len(wav) + + if index == 0: + wav = fade_out(wav) + elif index == len(waveforms) - 1: + wav = fade_in(wav) + else: + wav = fade_in_and_out(wav) + + merged_waveform[start : start + wav_len] = wav + start += wav_len - overlap_len + + save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) + + +def merge_segments_encodec(wav_files, fs, output_path, overlap_duration=1.0): + """Merge the given wav_files (may have overlaps) into a long audio + + fs: + The sampling rate of the wav files. + output_path: + The output path to save the merged audio. + overlap_duration (float, optional): + Each segment has "overlap duration" (second) overlap with its previous and next segment. Defaults to 1.0. + """ + + waveforms = [] + for file in wav_files: + # (T,) + waveform, _ = load_audio_torch(file, fs) + waveforms.append(waveform) + + if len(waveforms) == 1: + save_audio(output_path, waveforms[0], fs, add_silence=False, turn_up=False) + return + + device = waveforms[0].device + dtype = waveforms[0].dtype + shape = waveforms[0].shape[:-1] + + overlap_len = int(overlap_duration * fs) + segments_lens = [len(wav) for wav in waveforms] + merged_waveform_len = sum(segments_lens) - overlap_len * (len(waveforms) - 1) + + sum_weight = torch.zeros(merged_waveform_len, device=device, dtype=dtype) + out = torch.zeros(*shape, merged_waveform_len, device=device, dtype=dtype) + offset = 0 + + for frame in waveforms: + frame_length = frame.size(-1) + t = torch.linspace(0, 1, frame_length + 2, device=device, dtype=torch.float32)[ + 1:-1 + ] + weight = 0.5 - (t - 0.5).abs() + weighted_frame = frame * weight + + cur = out[..., offset : offset + frame_length] + cur += weighted_frame[..., : cur.size(-1)] + out[..., offset : offset + frame_length] = cur + + cur = sum_weight[offset : offset + frame_length] + cur += weight[..., : cur.size(-1)] + sum_weight[offset : offset + frame_length] = cur + + offset += frame_length - overlap_len + + assert sum_weight.min() > 0 + merged_waveform = out / sum_weight + save_audio(output_path, merged_waveform, fs, add_silence=False, turn_up=True) diff --git a/utils/data_utils.py b/utils/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7976d050f01990c8a98a37d48be67bc68695a6c3 --- /dev/null +++ b/utils/data_utils.py @@ -0,0 +1,575 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import json +import os + +import numpy as np +from scipy.interpolate import interp1d +from tqdm import tqdm +from sklearn.preprocessing import StandardScaler + + +def load_content_feature_path(meta_data, processed_dir, feat_dir): + utt2feat_path = {} + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + feat_path = os.path.join( + processed_dir, utt_info["Dataset"], feat_dir, f'{utt_info["Uid"]}.npy' + ) + utt2feat_path[utt] = feat_path + + return utt2feat_path + + +def load_source_content_feature_path(meta_data, feat_dir): + utt2feat_path = {} + for utt in meta_data: + feat_path = os.path.join(feat_dir, f"{utt}.npy") + utt2feat_path[utt] = feat_path + + return utt2feat_path + + +def get_spk_map(spk2id_path, utt2spk_path): + utt2spk = {} + with open(spk2id_path, "r") as spk2id_file: + spk2id = json.load(spk2id_file) + with open(utt2spk_path, encoding="utf-8") as f: + for line in f.readlines(): + utt, spk = line.strip().split("\t") + utt2spk[utt] = spk + return spk2id, utt2spk + + +def get_target_f0_median(f0_dir): + total_f0 = [] + for utt in os.listdir(f0_dir): + if not utt.endswith(".npy"): + continue + f0_feat_path = os.path.join(f0_dir, utt) + f0 = np.load(f0_feat_path) + total_f0 += f0.tolist() + + total_f0 = np.array(total_f0) + voiced_position = np.where(total_f0 != 0) + return np.median(total_f0[voiced_position]) + + +def get_conversion_f0_factor(source_f0, target_median, source_median=None): + """Align the median between source f0 and target f0 + + Note: Here we use multiplication, whose factor is target_median/source_median + + Reference: Frequency and pitch interval + http://blog.ccyg.studio/article/be12c2ee-d47c-4098-9782-ca76da3035e4/ + """ + if source_median is None: + voiced_position = np.where(source_f0 != 0) + source_median = np.median(source_f0[voiced_position]) + factor = target_median / source_median + return source_median, factor + + +def transpose_key(frame_pitch, trans_key): + # Transpose by user's argument + print("Transpose key = {} ...\n".format(trans_key)) + + transed_pitch = frame_pitch * 2 ** (trans_key / 12) + return transed_pitch + + +def pitch_shift_to_target(frame_pitch, target_pitch_median, source_pitch_median=None): + # Loading F0 Base (median) and shift + source_pitch_median, factor = get_conversion_f0_factor( + frame_pitch, target_pitch_median, source_pitch_median + ) + print( + "Auto transposing: source f0 median = {:.1f}, target f0 median = {:.1f}, factor = {:.2f}".format( + source_pitch_median, target_pitch_median, factor + ) + ) + transed_pitch = frame_pitch * factor + return transed_pitch + + +def load_frame_pitch( + meta_data, + processed_dir, + pitch_dir, + use_log_scale=False, + return_norm=False, + interoperate=False, + utt2spk=None, +): + utt2pitch = {} + utt2uv = {} + if utt2spk is None: + pitch_scaler = StandardScaler() + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch_path = os.path.join( + processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' + ) + pitch = np.load(pitch_path) + assert len(pitch) > 0 + uv = pitch != 0 + utt2uv[utt] = uv + if use_log_scale: + nonzero_idxes = np.where(pitch != 0)[0] + pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) + utt2pitch[utt] = pitch + pitch_scaler.partial_fit(pitch.reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + utt2pitch[utt] = normalized_pitch + pitch_statistic = {"mean": mean, "std": std} + else: + spk2utt = {} + pitch_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + pitch_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + pitch_path = os.path.join( + processed_dir, dataset, pitch_dir, f"{uid}.npy" + ) + pitch = np.load(pitch_path) + assert len(pitch) > 0 + uv = pitch != 0 + utt2uv[utt] = uv + if use_log_scale: + nonzero_idxes = np.where(pitch != 0)[0] + pitch[nonzero_idxes] = np.log(pitch[nonzero_idxes]) + utt2pitch[utt] = pitch + pitch_scaler.partial_fit(pitch.reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + utt2pitch[utt] = normalized_pitch + pitch_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2pitch, utt2uv, pitch_statistic + + +# discard +def load_phone_pitch( + meta_data, + processed_dir, + pitch_dir, + utt2dur, + use_log_scale=False, + return_norm=False, + interoperate=True, + utt2spk=None, +): + print("Load Phone Pitch") + utt2pitch = {} + utt2uv = {} + if utt2spk is None: + pitch_scaler = StandardScaler() + for utt_info in tqdm(meta_data): + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch_path = os.path.join( + processed_dir, utt_info["Dataset"], pitch_dir, f'{utt_info["Uid"]}.npy' + ) + frame_pitch = np.load(pitch_path) + assert len(frame_pitch) > 0 + uv = frame_pitch != 0 + utt2uv[utt] = uv + phone_pitch = phone_average_pitch(frame_pitch, utt2dur[utt], interoperate) + if use_log_scale: + nonzero_idxes = np.where(phone_pitch != 0)[0] + phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) + utt2pitch[utt] = phone_pitch + pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + max_value = np.finfo(np.float64).min + min_value = np.finfo(np.float64).max + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + max_value = max(max_value, max(normalized_pitch)) + min_value = min(min_value, min(normalized_pitch)) + utt2pitch[utt] = normalized_pitch + phone_normalized_pitch_path = os.path.join( + processed_dir, + utt_info["Dataset"], + "phone_level_" + pitch_dir, + f'{utt_info["Uid"]}.npy', + ) + pitch_statistic = { + "mean": mean, + "std": std, + "min_value": min_value, + "max_value": max_value, + } + else: + spk2utt = {} + pitch_statistic = [] + for utt_info in tqdm(meta_data): + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + pitch_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + pitch_path = os.path.join( + processed_dir, dataset, pitch_dir, f"{uid}.npy" + ) + frame_pitch = np.load(pitch_path) + assert len(frame_pitch) > 0 + uv = frame_pitch != 0 + utt2uv[utt] = uv + phone_pitch = phone_average_pitch( + frame_pitch, utt2dur[utt], interoperate + ) + if use_log_scale: + nonzero_idxes = np.where(phone_pitch != 0)[0] + phone_pitch[nonzero_idxes] = np.log(phone_pitch[nonzero_idxes]) + utt2pitch[utt] = phone_pitch + pitch_scaler.partial_fit(remove_outlier(phone_pitch).reshape(-1, 1)) + + mean, std = pitch_scaler.mean_[0], pitch_scaler.scale_[0] + max_value = np.finfo(np.float64).min + min_value = np.finfo(np.float64).max + + if return_norm: + for utt in spk2utt[spk]: + pitch = utt2pitch[utt] + normalized_pitch = (pitch - mean) / std + max_value = max(max_value, max(normalized_pitch)) + min_value = min(min_value, min(normalized_pitch)) + utt2pitch[utt] = normalized_pitch + pitch_statistic.append( + { + "spk": spk, + "mean": mean, + "std": std, + "min_value": min_value, + "max_value": max_value, + } + ) + + return utt2pitch, utt2uv, pitch_statistic + + +def phone_average_pitch(pitch, dur, interoperate=False): + pos = 0 + + if interoperate: + nonzero_ids = np.where(pitch != 0)[0] + interp_fn = interp1d( + nonzero_ids, + pitch[nonzero_ids], + fill_value=(pitch[nonzero_ids[0]], pitch[nonzero_ids[-1]]), + bounds_error=False, + ) + pitch = interp_fn(np.arange(0, len(pitch))) + phone_pitch = np.zeros(len(dur)) + + for i, d in enumerate(dur): + d = int(d) + if d > 0 and pos < len(pitch): + phone_pitch[i] = np.mean(pitch[pos : pos + d]) + else: + phone_pitch[i] = 0 + pos += d + return phone_pitch + + +def load_energy( + meta_data, + processed_dir, + energy_dir, + use_log_scale=False, + return_norm=False, + utt2spk=None, +): + utt2energy = {} + if utt2spk is None: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy_path = os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' + ) + if not os.path.exists(energy_path): + continue + energy = np.load(energy_path) + assert len(energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(energy != 0)[0] + energy[nonzero_idxes] = np.log(energy[nonzero_idxes]) + utt2energy[utt] = energy + + if return_norm: + with open( + os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, "statistics.json" + ) + ) as f: + stats = json.load(f) + mean, std = ( + stats[utt_info["Dataset"] + "_" + utt_info["Singer"]][ + "voiced_positions" + ]["mean"], + stats["LJSpeech_LJSpeech"]["voiced_positions"]["std"], + ) + for utt in utt2energy.keys(): + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + + energy_statistic = {"mean": mean, "std": std} + else: + spk2utt = {} + energy_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + energy_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + energy_path = os.path.join( + processed_dir, dataset, energy_dir, f"{uid}.npy" + ) + if not os.path.exists(energy_path): + continue + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2energy, energy_statistic + + +def load_frame_energy( + meta_data, + processed_dir, + energy_dir, + use_log_scale=False, + return_norm=False, + interoperate=False, + utt2spk=None, +): + utt2energy = {} + if utt2spk is None: + energy_scaler = StandardScaler() + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy_path = os.path.join( + processed_dir, utt_info["Dataset"], energy_dir, f'{utt_info["Uid"]}.npy' + ) + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic = {"mean": mean, "std": std} + + else: + spk2utt = {} + energy_statistic = [] + for utt_info in meta_data: + utt = utt_info["Dataset"] + "_" + utt_info["Uid"] + if not utt2spk[utt] in spk2utt: + spk2utt[utt2spk[utt]] = [] + spk2utt[utt2spk[utt]].append(utt) + + for spk in spk2utt: + energy_scaler = StandardScaler() + for utt in spk2utt[spk]: + dataset = utt.split("_")[0] + uid = "_".join(utt.split("_")[1:]) + energy_path = os.path.join( + processed_dir, dataset, energy_dir, f"{uid}.npy" + ) + frame_energy = np.load(energy_path) + assert len(frame_energy) > 0 + + if use_log_scale: + nonzero_idxes = np.where(frame_energy != 0)[0] + frame_energy[nonzero_idxes] = np.log(frame_energy[nonzero_idxes]) + utt2energy[utt] = frame_energy + energy_scaler.partial_fit(frame_energy.reshape(-1, 1)) + + mean, std = energy_scaler.mean_[0], energy_scaler.scale_[0] + if return_norm: + for utt in spk2utt[spk]: + energy = utt2energy[utt] + normalized_energy = (energy - mean) / std + utt2energy[utt] = normalized_energy + energy_statistic.append({"spk": spk, "mean": mean, "std": std}) + + return utt2energy, energy_statistic + + +def align_length(feature, target_len, pad_value=0.0): + feature_len = feature.shape[-1] + dim = len(feature.shape) + # align 1-D data + if dim == 2: + if target_len > feature_len: + feature = np.pad( + feature, + ((0, 0), (0, target_len - feature_len)), + constant_values=pad_value, + ) + else: + feature = feature[:, :target_len] + # align 2-D data + elif dim == 1: + if target_len > feature_len: + feature = np.pad( + feature, (0, target_len - feature_len), constant_values=pad_value + ) + else: + feature = feature[:target_len] + else: + raise NotImplementedError + return feature + + +def align_whisper_feauture_length( + feature, target_len, fast_mapping=True, source_hop=320, target_hop=256 +): + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + # print( + # "Mapping source's {} frames => target's {} frames".format( + # target_hop, source_hop + # ) + # ) + + max_source_len = 1500 + target_len = min(target_len, max_source_len * source_hop // target_hop) + + width = feature.shape[-1] + + if fast_mapping: + source_len = target_len * target_hop // source_hop + 1 + feature = feature[:source_len] + + else: + source_len = max_source_len + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(feature, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + assert len(down_sampling_feats) >= target_len + + # (target_len, dim) + feat = down_sampling_feats[:target_len] + + return feat + + +def align_content_feature_length(feature, target_len, source_hop=320, target_hop=256): + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + # print( + # "Mapping source's {} frames => target's {} frames".format( + # target_hop, source_hop + # ) + # ) + + # (source_len, 256) + source_len, width = feature.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(feature, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 4: ## why 4 not 3? + print("target_len:", target_len) + print("raw feature:", feature.shape) + print("up_sampling:", up_sampling_feats.shape) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feat = down_sampling_feats[:target_len] + + return feat + + +def remove_outlier(values): + values = np.array(values) + p25 = np.percentile(values, 25) + p75 = np.percentile(values, 75) + lower = p25 - 1.5 * (p75 - p25) + upper = p75 + 1.5 * (p75 - p25) + normal_indices = np.logical_and(values > lower, values < upper) + return values[normal_indices] diff --git a/utils/distribution.py b/utils/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..de3000e99194f7e848712d8b4cb77c988f098fd2 --- /dev/null +++ b/utils/distribution.py @@ -0,0 +1,270 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +import torch.nn.functional as F + +from torch.distributions import Normal + + +def log_sum_exp(x): + """numerically stable log_sum_exp implementation that prevents overflow""" + # TF ordering + axis = len(x.size()) - 1 + m, _ = torch.max(x, dim=axis) + m2, _ = torch.max(x, dim=axis, keepdim=True) + return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis)) + + +def discretized_mix_logistic_loss( + y_hat, y, num_classes=256, log_scale_min=-7.0, reduce=True +): + """Discretized mixture of logistic distributions loss + + Note that it is assumed that input is scaled to [-1, 1]. + + Args: + y_hat (Tensor): Predicted output (B x C x T) + y (Tensor): Target (B x T x 1). + num_classes (int): Number of classes + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each + minibatch. + + Returns + Tensor: loss + """ + assert y_hat.dim() == 3 + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + inv_stdv = torch.exp(-log_scales) + plus_in = inv_stdv * (centered_y + 1.0 / (num_classes - 1)) + cdf_plus = torch.sigmoid(plus_in) + min_in = inv_stdv * (centered_y - 1.0 / (num_classes - 1)) + cdf_min = torch.sigmoid(min_in) + + # log probability for edge case of 0 (before scaling) + # equivalent: torch.log(torch.sigmoid(plus_in)) + log_cdf_plus = plus_in - F.softplus(plus_in) + + # log probability for edge case of 255 (before scaling) + # equivalent: (1 - torch.sigmoid(min_in)).log() + log_one_minus_cdf_min = -F.softplus(min_in) + + # probability for all other cases + cdf_delta = cdf_plus - cdf_min + + mid_in = inv_stdv * centered_y + # log probability in the center of the bin, to be used in extreme cases + # (not actually used in our code) + log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in) + + # tf equivalent + """ + log_probs = tf.where(x < -0.999, log_cdf_plus, + tf.where(x > 0.999, log_one_minus_cdf_min, + tf.where(cdf_delta > 1e-5, + tf.log(tf.maximum(cdf_delta, 1e-12)), + log_pdf_mid - np.log(127.5)))) + """ + # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value + # for num_classes=65536 case? 1e-7? not sure.. + inner_inner_cond = (cdf_delta > 1e-5).float() + + inner_inner_out = inner_inner_cond * torch.log( + torch.clamp(cdf_delta, min=1e-12) + ) + (1.0 - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2)) + inner_cond = (y > 0.999).float() + inner_out = ( + inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out + ) + cond = (y < -0.999).float() + log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out + + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + return -torch.sum(log_sum_exp(log_probs)) + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def to_one_hot(tensor, n, fill_with=1.0): + # we perform one hot encore with respect to the last axis + one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_() + if tensor.is_cuda: + one_hot = one_hot.cuda() + one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with) + return one_hot + + +def sample_from_discretized_mix_logistic(y, log_scale_min=-7.0, clamp_log_scale=False): + """ + Sample from discretized mixture of logistic distributions + + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + + Returns: + Tensor: sample in range of [-1, 1]. + """ + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + logit_probs = y[:, :, :nr_mix] + + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + # select logistic parameters + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1) + if clamp_log_scale: + log_scales = torch.clamp(log_scales, min=log_scale_min) + # sample from logistic & clip to interval + # we don't actually round to the nearest 8bit value when sampling + u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5) + x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1.0 - u)) + + x = torch.clamp(torch.clamp(x, min=-1.0), max=1.0) + + return x + + +# we can easily define discretized version of the gaussian loss, however, +# use continuous version as same as the https://clarinet-demo.github.io/ +def mix_gaussian_loss(y_hat, y, log_scale_min=-7.0, reduce=True): + """Mixture of continuous gaussian distributions loss + + Note that it is assumed that input is scaled to [-1, 1]. + + Args: + y_hat (Tensor): Predicted output (B x C x T) + y (Tensor): Target (B x T x 1). + log_scale_min (float): Log scale minimum value + reduce (bool): If True, the losses are averaged or summed for each + minibatch. + Returns + Tensor: loss + """ + assert y_hat.dim() == 3 + C = y_hat.size(1) + if C == 2: + nr_mix = 1 + else: + assert y_hat.size(1) % 3 == 0 + nr_mix = y_hat.size(1) // 3 + + # (B x T x C) + y_hat = y_hat.transpose(1, 2) + + # unpack parameters. + if C == 2: + # special case for C == 2, just for compatibility + logit_probs = None + means = y_hat[:, :, 0:1] + log_scales = torch.clamp(y_hat[:, :, 1:2], min=log_scale_min) + else: + # (B, T, num_mixtures) x 3 + logit_probs = y_hat[:, :, :nr_mix] + means = y_hat[:, :, nr_mix : 2 * nr_mix] + log_scales = torch.clamp( + y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=log_scale_min + ) + + # B x T x 1 -> B x T x num_mixtures + y = y.expand_as(means) + + centered_y = y - means + dist = Normal(loc=0.0, scale=torch.exp(log_scales)) + # do we need to add a trick to avoid log(0)? + log_probs = dist.log_prob(centered_y) + + if nr_mix > 1: + log_probs = log_probs + F.log_softmax(logit_probs, -1) + + if reduce: + if nr_mix == 1: + return -torch.sum(log_probs) + else: + return -torch.sum(log_sum_exp(log_probs)) + else: + if nr_mix == 1: + return -log_probs + else: + return -log_sum_exp(log_probs).unsqueeze(-1) + + +def sample_from_mix_gaussian(y, log_scale_min=-7.0): + """ + Sample from (discretized) mixture of gaussian distributions + Args: + y (Tensor): B x C x T + log_scale_min (float): Log scale minimum value + Returns: + Tensor: sample in range of [-1, 1]. + """ + C = y.size(1) + if C == 2: + nr_mix = 1 + else: + assert y.size(1) % 3 == 0 + nr_mix = y.size(1) // 3 + + # B x T x C + y = y.transpose(1, 2) + + if C == 2: + logit_probs = None + else: + logit_probs = y[:, :, :nr_mix] + + if nr_mix > 1: + # sample mixture indicator from softmax + temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5) + temp = logit_probs.data - torch.log(-torch.log(temp)) + _, argmax = temp.max(dim=-1) + + # (B, T) -> (B, T, nr_mix) + one_hot = to_one_hot(argmax, nr_mix) + + # Select means and log scales + means = torch.sum(y[:, :, nr_mix : 2 * nr_mix] * one_hot, dim=-1) + log_scales = torch.sum(y[:, :, 2 * nr_mix : 3 * nr_mix] * one_hot, dim=-1) + else: + if C == 2: + means, log_scales = y[:, :, 0], y[:, :, 1] + elif C == 3: + means, log_scales = y[:, :, 1], y[:, :, 2] + else: + assert False, "shouldn't happen" + + scales = torch.exp(log_scales) + dist = Normal(loc=means, scale=scales) + x = dist.sample() + + x = torch.clamp(x, min=-1.0, max=1.0) + return x diff --git a/utils/dsp.py b/utils/dsp.py new file mode 100644 index 0000000000000000000000000000000000000000..18f9466f6b12e5539ce221f86030f5114ccdb503 --- /dev/null +++ b/utils/dsp.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +# ZERO = 1e-12 + + +def gaussian_normalize_mel_channel(mel, mu, sigma): + """ + Shift to Standorm Normal Distribution + + Args: + mel: (n_mels, frame_len) + mu: (n_mels,), mean value + sigma: (n_mels,), sd value + Return: + Tensor like mel + """ + mu = np.expand_dims(mu, -1) + sigma = np.expand_dims(sigma, -1) + return (mel - mu) / sigma + + +def de_gaussian_normalize_mel_channel(mel, mu, sigma): + """ + + Args: + mel: (n_mels, frame_len) + mu: (n_mels,), mean value + sigma: (n_mels,), sd value + Return: + Tensor like mel + """ + mu = np.expand_dims(mu, -1) + sigma = np.expand_dims(sigma, -1) + return sigma * mel + mu + + +def decompress(audio_compressed, bits): + mu = 2**bits - 1 + audio = np.sign(audio_compressed) / mu * ((1 + mu) ** np.abs(audio_compressed) - 1) + return audio + + +def compress(audio, bits): + mu = 2**bits - 1 + audio_compressed = np.sign(audio) * np.log(1 + mu * np.abs(audio)) / np.log(mu + 1) + return audio_compressed + + +def label_to_audio(quant, bits): + classes = 2**bits + audio = 2 * quant / (classes - 1.0) - 1.0 + return audio + + +def audio_to_label(audio, bits): + """Normalized audio data tensor to digit array + + Args: + audio (tensor): audio data + bits (int): data bits + + Returns: + array: digit array of audio data + """ + classes = 2**bits + # initialize an increasing array with values from -1 to 1 + bins = np.linspace(-1, 1, classes) + # change value in audio tensor to digits + quant = np.digitize(audio, bins) - 1 + return quant + + +def label_to_onehot(x, bits): + """Converts a class vector (integers) to binary class matrix. + Args: + x: class vector to be converted into a matrix + (integers from 0 to num_classes). + num_classes: total number of classes. + Returns: + A binary matrix representation of the input. The classes axis + is placed last. + """ + classes = 2**bits + + result = torch.zeros((x.shape[0], classes), dtype=torch.float32) + for i in range(x.shape[0]): + result[i, x[i]] = 1 + + output_shape = x.shape + (classes,) + output = torch.reshape(result, output_shape) + return output diff --git a/utils/duration.py b/utils/duration.py new file mode 100644 index 0000000000000000000000000000000000000000..c9544b40b88c68b4e1df33ab1c81a6196a43111e --- /dev/null +++ b/utils/duration.py @@ -0,0 +1,86 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import os +import tgt + + +def get_alignment(tier, cfg): + sample_rate = cfg["sample_rate"] + hop_size = cfg["hop_size"] + + sil_phones = ["sil", "sp", "spn"] + + phones = [] + durations = [] + start_time = 0 + end_time = 0 + end_idx = 0 + + for t in tier._objects: + s, e, p = t.start_time, t.end_time, t.text + + # Trim leading silences + if phones == []: + if p in sil_phones: + continue + else: + start_time = s + + if p not in sil_phones: + # For ordinary phones + phones.append(p) + end_time = e + end_idx = len(phones) + else: + # For silent phones + phones.append(p) + + durations.append( + int( + np.round(e * sample_rate / hop_size) + - np.round(s * sample_rate / hop_size) + ) + ) + + # Trim tailing silences + phones = phones[:end_idx] + durations = durations[:end_idx] + + return phones, durations, start_time, end_time + + +def get_duration(utt, wav, cfg): + speaker = utt["Singer"] + basename = utt["Uid"] + dataset = utt["Dataset"] + sample_rate = cfg["sample_rate"] + + # print(cfg.processed_dir, dataset, speaker, basename) + wav_path = os.path.join( + cfg.processed_dir, dataset, "raw_data", speaker, "{}.wav".format(basename) + ) + text_path = os.path.join( + cfg.processed_dir, dataset, "raw_data", speaker, "{}.lab".format(basename) + ) + tg_path = os.path.join( + cfg.processed_dir, dataset, "TextGrid", speaker, "{}.TextGrid".format(basename) + ) + + # Read raw text + with open(text_path, "r") as f: + raw_text = f.readline().strip("\n") + + # Get alignments + textgrid = tgt.io.read_textgrid(tg_path) + phone, duration, start, end = get_alignment( + textgrid.get_tier_by_name("phones"), cfg + ) + text = "{" + " ".join(phone) + "}" + if start >= end: + return None + + return duration, text, int(sample_rate * start), int(sample_rate * end) diff --git a/utils/f0.py b/utils/f0.py new file mode 100644 index 0000000000000000000000000000000000000000..fcd95396a6b08d8ffa1a310c5bee6f0d8b556796 --- /dev/null +++ b/utils/f0.py @@ -0,0 +1,299 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import librosa +import numpy as np +import torch +import parselmouth +import torchcrepe +import pyworld as pw + + +def get_bin_index(f0, m, M, n_bins, use_log_scale): + """ + WARNING: to abandon! + + Args: + raw_f0: tensor whose shpae is (N, frame_len) + Returns: + index: tensor whose shape is same to f0 + """ + raw_f0 = f0.clone() + raw_m, raw_M = m, M + + if use_log_scale: + f0[torch.where(f0 == 0)] = 1 + f0 = torch.log(f0) + m, M = float(np.log(m)), float(np.log(M)) + + # Set normal index in [1, n_bins - 1] + width = (M + 1e-7 - m) / (n_bins - 1) + index = (f0 - m) // width + 1 + # Set unvoiced frames as 0, Therefore, the vocabulary is [0, n_bins- 1], whose size is n_bins + index[torch.where(f0 == 0)] = 0 + + # TODO: Boundary check (special: to judge whether 0 for unvoiced) + if torch.any(raw_f0 > raw_M): + print("F0 Warning: too high f0: {}".format(raw_f0[torch.where(raw_f0 > raw_M)])) + index[torch.where(raw_f0 > raw_M)] = n_bins - 1 + if torch.any(raw_f0 < raw_m): + print("F0 Warning: too low f0: {}".format(raw_f0[torch.where(f0 < m)])) + index[torch.where(f0 < m)] = 0 + + return torch.as_tensor(index, dtype=torch.long, device=f0.device) + + +def f0_to_coarse(f0, pitch_bin, pitch_min, pitch_max): + ## TODO: Figure out the detail of this function + + f0_mel_min = 1127 * np.log(1 + pitch_min / 700) + f0_mel_max = 1127 * np.log(1 + pitch_max / 700) + + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (pitch_bin - 2) / ( + f0_mel_max - f0_mel_min + ) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > pitch_bin - 1] = pitch_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int32) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, ( + f0_coarse.max(), + f0_coarse.min(), + ) + return f0_coarse + + +def interpolate(f0): + """Interpolate the unvoiced part. Thus the f0 can be passed to a subtractive synthesizer. + Args: + f0: A numpy array of shape (seq_len,) + Returns: + f0: Interpolated f0 of shape (seq_len,) + uv: Unvoiced part of shape (seq_len,) + """ + uv = f0 == 0 + if len(f0[~uv]) > 0: + # interpolate the unvoiced f0 + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + uv = uv.astype("float") + uv = np.min(np.array([uv[:-2], uv[1:-1], uv[2:]]), axis=0) + uv = np.pad(uv, (1, 1)) + return f0, uv + + +def get_log_f0(f0): + f0[np.where(f0 == 0)] = 1 + log_f0 = np.log(f0) + return log_f0 + + +# ========== Methods ========== + + +def get_f0_features_using_pyin(audio, cfg): + """Using pyin to extract the f0 feature. + Args: + audio + fs + win_length + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + f0, voiced_flag, voiced_probs = librosa.pyin( + y=audio, + fmin=cfg.f0_min, + fmax=cfg.f0_max, + sr=cfg.sample_rate, + win_length=cfg.win_size, + hop_length=cfg.hop_size, + ) + # Set nan to 0 + f0[voiced_flag == False] = 0 + return f0 + + +def get_f0_features_using_parselmouth(audio, cfg, speed=1): + """Using parselmouth to extract the f0 feature. + Args: + audio + mel_len + hop_length + fs + f0_min + f0_max + speed(default=1) + Returns: + f0: numpy array of shape (frame_len,) + pitch_coarse: numpy array of shape (frame_len,) + """ + hop_size = int(np.round(cfg.hop_size * speed)) + + # Calculate the time step for pitch extraction + time_step = hop_size / cfg.sample_rate * 1000 + + f0 = ( + parselmouth.Sound(audio, cfg.sample_rate) + .to_pitch_ac( + time_step=time_step / 1000, + voicing_threshold=0.6, + pitch_floor=cfg.f0_min, + pitch_ceiling=cfg.f0_max, + ) + .selected_array["frequency"] + ) + + # Pad the pitch to the mel_len + # pad_size = (int(len(audio) // hop_size) - len(f0) + 1) // 2 + # f0 = np.pad(f0, [[pad_size, mel_len - len(f0) - pad_size]], mode="constant") + + # Get the coarse part + pitch_coarse = f0_to_coarse(f0, cfg.pitch_bin, cfg.f0_min, cfg.f0_max) + return f0, pitch_coarse + + +def get_f0_features_using_dio(audio, cfg): + """Using dio to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + # Get the raw f0 + _f0, t = pw.dio( + audio.astype("double"), + cfg.sample_rate, + f0_floor=cfg.f0_min, + f0_ceil=cfg.f0_max, + channels_in_octave=2, + frame_period=(1000 * cfg.hop_size / cfg.sample_rate), + ) + # Get the f0 + f0 = pw.stonemask(audio.astype("double"), _f0, t, cfg.sample_rate) + return f0 + + +def get_f0_features_using_harvest(audio, mel_len, fs, hop_length, f0_min, f0_max): + """Using harvest to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + f0_min + f0_max + Returns: + f0: numpy array of shape (frame_len,) + """ + f0, _ = pw.harvest( + audio.astype("double"), + fs, + f0_floor=f0_min, + f0_ceil=f0_max, + frame_period=(1000 * hop_length / fs), + ) + f0 = f0.astype("float")[:mel_len] + return f0 + + +def get_f0_features_using_crepe( + audio, mel_len, fs, hop_length, hop_length_new, f0_min, f0_max, threshold=0.3 +): + """Using torchcrepe to extract the f0 feature. + Args: + audio + mel_len + fs + hop_length + hop_length_new + f0_min + f0_max + threshold(default=0.3) + Returns: + f0: numpy array of shape (frame_len,) + """ + # Currently, crepe only supports 16khz audio + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + audio_16k = librosa.resample(audio, orig_sr=fs, target_sr=16000) + audio_16k_torch = torch.FloatTensor(audio_16k).unsqueeze(0).to(device) + + # Get the raw pitch + f0, pd = torchcrepe.predict( + audio_16k_torch, + 16000, + hop_length_new, + f0_min, + f0_max, + pad=True, + model="full", + batch_size=1024, + device=device, + return_periodicity=True, + ) + + # Filter, de-silence, set up threshold for unvoiced part + pd = torchcrepe.filter.median(pd, 3) + pd = torchcrepe.threshold.Silence(-60.0)(pd, audio_16k_torch, 16000, hop_length_new) + f0 = torchcrepe.threshold.At(threshold)(f0, pd) + f0 = torchcrepe.filter.mean(f0, 3) + + # Convert unvoiced part to 0hz + f0 = torch.where(torch.isnan(f0), torch.full_like(f0, 0), f0) + + # Interpolate f0 + nzindex = torch.nonzero(f0[0]).squeeze() + f0 = torch.index_select(f0[0], dim=0, index=nzindex).cpu().numpy() + time_org = 0.005 * nzindex.cpu().numpy() + time_frame = np.arange(mel_len) * hop_length / fs + f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1]) + return f0 + + +def get_f0(audio, cfg): + if cfg.pitch_extractor == "dio": + f0 = get_f0_features_using_dio(audio, cfg) + elif cfg.pitch_extractor == "pyin": + f0 = get_f0_features_using_pyin(audio, cfg) + elif cfg.pitch_extractor == "parselmouth": + f0, _ = get_f0_features_using_parselmouth(audio, cfg) + # elif cfg.data.f0_extractor == 'cwt': # todo + + return f0 + + +def get_cents(f0_hz): + """ + F_{cent} = 1200 * log2 (F/440) + + Reference: + APSIPA'17, Perceptual Evaluation of Singing Quality + """ + voiced_f0 = f0_hz[f0_hz != 0] + return 1200 * np.log2(voiced_f0 / 440) + + +def get_pitch_derivatives(f0_hz): + """ + f0_hz: (,T) + """ + f0_cent = get_cents(f0_hz) + return f0_cent[1:] - f0_cent[:-1] + + +def get_pitch_sub_median(f0_hz): + """ + f0_hz: (,T) + """ + f0_cent = get_cents(f0_hz) + return f0_cent - np.median(f0_cent) diff --git a/utils/hparam.py b/utils/hparam.py new file mode 100644 index 0000000000000000000000000000000000000000..c5dd35c6a3158b0aaf8d936dba139a030a48bc62 --- /dev/null +++ b/utils/hparam.py @@ -0,0 +1,659 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/training/python/training/hparam.py pylint: disable=line-too-long +"""Hyperparameter values.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import json +import numbers +import re +import six + +# Define the regular expression for parsing a single clause of the input +# (delimited by commas). A legal clause looks like: +# []? = +# where is either a single token or [] enclosed list of tokens. +# For example: "var[1] = a" or "x = [1,2,3]" +PARAM_RE = re.compile( + r""" + (?P[a-zA-Z][\w\.]*) # variable name: "var" or "x" + (\[\s*(?P\d+)\s*\])? # (optional) index: "1" or None + \s*=\s* + ((?P[^,\[]*) # single value: "a" or None + | + \[(?P[^\]]*)\]) # list of values: None or "1,2,3" + ($|,\s*)""", + re.VERBOSE, +) + + +def _parse_fail(name, var_type, value, values): + """Helper function for raising a value error for bad assignment.""" + raise ValueError( + "Could not parse hparam '%s' of type '%s' with value '%s' in %s" + % (name, var_type.__name__, value, values) + ) + + +def _reuse_fail(name, values): + """Helper function for raising a value error for reuse of name.""" + raise ValueError("Multiple assignments to variable '%s' in %s" % (name, values)) + + +def _process_scalar_value(name, parse_fn, var_type, m_dict, values, results_dictionary): + """Update results_dictionary with a scalar value. + + Used to update the results_dictionary to be returned by parse_values when + encountering a clause with a scalar RHS (e.g. "s=5" or "arr[0]=5".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("s" or "arr"). + parse_fn: Function for parsing the actual value. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + m_dict['index']: List index value (or None) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has already been used. + """ + try: + parsed_value = parse_fn(m_dict["val"]) + except ValueError: + _parse_fail(name, var_type, m_dict["val"], values) + + # If no index is provided + if not m_dict["index"]: + if name in results_dictionary: + _reuse_fail(name, values) + results_dictionary[name] = parsed_value + else: + if name in results_dictionary: + # The name has already been used as a scalar, then it + # will be in this dictionary and map to a non-dictionary. + if not isinstance(results_dictionary.get(name), dict): + _reuse_fail(name, values) + else: + results_dictionary[name] = {} + + index = int(m_dict["index"]) + # Make sure the index position hasn't already been assigned a value. + if index in results_dictionary[name]: + _reuse_fail("{}[{}]".format(name, index), values) + results_dictionary[name][index] = parsed_value + + +def _process_list_value(name, parse_fn, var_type, m_dict, values, results_dictionary): + """Update results_dictionary from a list of values. + + Used to update results_dictionary to be returned by parse_values when + encountering a clause with a list RHS (e.g. "arr=[1,2,3]".) + + Mutates results_dictionary. + + Args: + name: Name of variable in assignment ("arr"). + parse_fn: Function for parsing individual values. + var_type: Type of named variable. + m_dict: Dictionary constructed from regex parsing. + m_dict['val']: RHS value (scalar) + values: Full expression being parsed + results_dictionary: The dictionary being updated for return by the parsing + function. + + Raises: + ValueError: If the name has an index or the values cannot be parsed. + """ + if m_dict["index"] is not None: + raise ValueError("Assignment of a list to a list index.") + elements = filter(None, re.split("[ ,]", m_dict["vals"])) + # Make sure the name hasn't already been assigned a value + if name in results_dictionary: + raise _reuse_fail(name, values) + try: + results_dictionary[name] = [parse_fn(e) for e in elements] + except ValueError: + _parse_fail(name, var_type, m_dict["vals"], values) + + +def _cast_to_type_if_compatible(name, param_type, value): + """Cast hparam to the provided type, if compatible. + + Args: + name: Name of the hparam to be cast. + param_type: The type of the hparam. + value: The value to be cast, if compatible. + + Returns: + The result of casting `value` to `param_type`. + + Raises: + ValueError: If the type of `value` is not compatible with param_type. + * If `param_type` is a string type, but `value` is not. + * If `param_type` is a boolean, but `value` is not, or vice versa. + * If `param_type` is an integer type, but `value` is not. + * If `param_type` is a float type, but `value` is not a numeric type. + """ + fail_msg = "Could not cast hparam '%s' of type '%s' from value %r" % ( + name, + param_type, + value, + ) + + # Some callers use None, for which we can't do any casting/checking. :( + if issubclass(param_type, type(None)): + return value + + # Avoid converting a non-string type to a string. + if issubclass(param_type, (six.string_types, six.binary_type)) and not isinstance( + value, (six.string_types, six.binary_type) + ): + raise ValueError(fail_msg) + + # Avoid converting a number or string type to a boolean or vice versa. + if issubclass(param_type, bool) != isinstance(value, bool): + raise ValueError(fail_msg) + + # Avoid converting float to an integer (the reverse is fine). + if issubclass(param_type, numbers.Integral) and not isinstance( + value, numbers.Integral + ): + raise ValueError(fail_msg) + + # Avoid converting a non-numeric type to a numeric type. + if issubclass(param_type, numbers.Number) and not isinstance(value, numbers.Number): + raise ValueError(fail_msg) + + return param_type(value) + + +def parse_values(values, type_map, ignore_unknown=False): + """Parses hyperparameter values from a string into a python map. + + `values` is a string containing comma-separated `name=value` pairs. + For each pair, the value of the hyperparameter named `name` is set to + `value`. + + If a hyperparameter name appears multiple times in `values`, a ValueError + is raised (e.g. 'a=1,a=2', 'a[1]=1,a[1]=2'). + + If a hyperparameter name in both an index assignment and scalar assignment, + a ValueError is raised. (e.g. 'a=[1,2,3],a[0] = 1'). + + The hyperparameter name may contain '.' symbols, which will result in an + attribute name that is only accessible through the getattr and setattr + functions. (And must be first explicit added through add_hparam.) + + WARNING: Use of '.' in your variable names is allowed, but is not well + supported and not recommended. + + The `value` in `name=value` must follows the syntax according to the + type of the parameter: + + * Scalar integer: A Python-parsable integer point value. E.g.: 1, + 100, -12. + * Scalar float: A Python-parsable floating point value. E.g.: 1.0, + -.54e89. + * Boolean: Either true or false. + * Scalar string: A non-empty sequence of characters, excluding comma, + spaces, and square brackets. E.g.: foo, bar_1. + * List: A comma separated list of scalar values of the parameter type + enclosed in square brackets. E.g.: [1,2,3], [1.0,1e-12], [high,low]. + + When index assignment is used, the corresponding type_map key should be the + list name. E.g. for "arr[1]=0" the type_map must have the key "arr" (not + "arr[1]"). + + Args: + values: String. Comma separated list of `name=value` pairs where + 'value' must follow the syntax described above. + type_map: A dictionary mapping hyperparameter names to types. Note every + parameter name in values must be a key in type_map. The values must + conform to the types indicated, where a value V is said to conform to a + type T if either V has type T, or V is a list of elements of type T. + Hence, for a multidimensional parameter 'x' taking float values, + 'x=[0.1,0.2]' will parse successfully if type_map['x'] = float. + ignore_unknown: Bool. Whether values that are missing a type in type_map + should be ignored. If set to True, a ValueError will not be raised for + unknown hyperparameter type. + + Returns: + A python map mapping each name to either: + * A scalar value. + * A list of scalar values. + * A dictionary mapping index numbers to scalar values. + (e.g. "x=5,L=[1,2],arr[1]=3" results in {'x':5,'L':[1,2],'arr':{1:3}}") + + Raises: + ValueError: If there is a problem with input. + * If `values` cannot be parsed. + * If a list is assigned to a list index (e.g. 'a[1] = [1,2,3]'). + * If the same rvalue is assigned two different values (e.g. 'a=1,a=2', + 'a[1]=1,a[1]=2', or 'a=1,a=[1]') + """ + results_dictionary = {} + pos = 0 + while pos < len(values): + m = PARAM_RE.match(values, pos) + if not m: + raise ValueError("Malformed hyperparameter value: %s" % values[pos:]) + # Check that there is a comma between parameters and move past it. + pos = m.end() + # Parse the values. + m_dict = m.groupdict() + name = m_dict["name"] + if name not in type_map: + if ignore_unknown: + continue + raise ValueError("Unknown hyperparameter type for %s" % name) + type_ = type_map[name] + + # Set up correct parsing function (depending on whether type_ is a bool) + if type_ == bool: + + def parse_bool(value): + if value in ["true", "True"]: + return True + elif value in ["false", "False"]: + return False + else: + try: + return bool(int(value)) + except ValueError: + _parse_fail(name, type_, value, values) + + parse = parse_bool + else: + parse = type_ + + # If a singe value is provided + if m_dict["val"] is not None: + _process_scalar_value( + name, parse, type_, m_dict, values, results_dictionary + ) + + # If the assigned value is a list: + elif m_dict["vals"] is not None: + _process_list_value(name, parse, type_, m_dict, values, results_dictionary) + + else: # Not assigned a list or value + _parse_fail(name, type_, "", values) + + return results_dictionary + + +class HParams(object): + """Class to hold a set of hyperparameters as name-value pairs. + + A `HParams` object holds hyperparameters used to build and train a model, + such as the number of hidden units in a neural net layer or the learning rate + to use when training. + + You first create a `HParams` object by specifying the names and values of the + hyperparameters. + + To make them easily accessible the parameter names are added as direct + attributes of the class. A typical usage is as follows: + + ```python + # Create a HParams object specifying names and values of the model + # hyperparameters: + hparams = HParams(learning_rate=0.1, num_hidden_units=100) + + # The hyperparameter are available as attributes of the HParams object: + hparams.learning_rate ==> 0.1 + hparams.num_hidden_units ==> 100 + ``` + + Hyperparameters have type, which is inferred from the type of their value + passed at construction type. The currently supported types are: integer, + float, boolean, string, and list of integer, float, boolean, or string. + + You can override hyperparameter values by calling the + [`parse()`](#HParams.parse) method, passing a string of comma separated + `name=value` pairs. This is intended to make it possible to override + any hyperparameter values from a single command-line flag to which + the user passes 'hyper-param=value' pairs. It avoids having to define + one flag for each hyperparameter. + + The syntax expected for each value depends on the type of the parameter. + See `parse()` for a description of the syntax. + + Example: + + ```python + # Define a command line flag to pass name=value pairs. + # For example using argparse: + import argparse + parser = argparse.ArgumentParser(description='Train my model.') + parser.add_argument('--hparams', type=str, + help='Comma separated list of "name=value" pairs.') + args = parser.parse_args() + ... + def my_program(): + # Create a HParams object specifying the names and values of the + # model hyperparameters: + hparams = tf.HParams(learning_rate=0.1, num_hidden_units=100, + activations=['relu', 'tanh']) + + # Override hyperparameters values by parsing the command line + hparams.parse(args.hparams) + + # If the user passed `--hparams=learning_rate=0.3` on the command line + # then 'hparams' has the following attributes: + hparams.learning_rate ==> 0.3 + hparams.num_hidden_units ==> 100 + hparams.activations ==> ['relu', 'tanh'] + + # If the hyperparameters are in json format use parse_json: + hparams.parse_json('{"learning_rate": 0.3, "activations": "relu"}') + ``` + """ + + _HAS_DYNAMIC_ATTRIBUTES = True # Required for pytype checks. + + def __init__(self, model_structure=None, **kwargs): + """Create an instance of `HParams` from keyword arguments. + + The keyword arguments specify name-values pairs for the hyperparameters. + The parameter types are inferred from the type of the values passed. + + The parameter names are added as attributes of `HParams` object, so they + can be accessed directly with the dot notation `hparams._name_`. + + Example: + + ```python + # Define 3 hyperparameters: 'learning_rate' is a float parameter, + # 'num_hidden_units' an integer parameter, and 'activation' a string + # parameter. + hparams = tf.HParams( + learning_rate=0.1, num_hidden_units=100, activation='relu') + + hparams.activation ==> 'relu' + ``` + + Note that a few names are reserved and cannot be used as hyperparameter + names. If you use one of the reserved name the constructor raises a + `ValueError`. + + Args: + model_structure: An instance of ModelStructure, defining the feature + crosses to be used in the Trial. + **kwargs: Key-value pairs where the key is the hyperparameter name and + the value is the value for the parameter. + + Raises: + ValueError: If both `hparam_def` and initialization values are provided, + or if one of the arguments is invalid. + + """ + # Register the hyperparameters and their type in _hparam_types. + # This simplifies the implementation of parse(). + # _hparam_types maps the parameter name to a tuple (type, bool). + # The type value is the type of the parameter for scalar hyperparameters, + # or the type of the list elements for multidimensional hyperparameters. + # The bool value is True if the value is a list, False otherwise. + self._hparam_types = {} + self._model_structure = model_structure + for name, value in six.iteritems(kwargs): + self.add_hparam(name, value) + + def add_hparam(self, name, value): + """Adds {name, value} pair to hyperparameters. + + Args: + name: Name of the hyperparameter. + value: Value of the hyperparameter. Can be one of the following types: + int, float, string, int list, float list, or string list. + + Raises: + ValueError: if one of the arguments is invalid. + """ + # Keys in kwargs are unique, but 'name' could the name of a pre-existing + # attribute of this object. In that case we refuse to use it as a + # hyperparameter name. + if getattr(self, name, None) is not None: + raise ValueError("Hyperparameter name is reserved: %s" % name) + if isinstance(value, (list, tuple)): + if not value: + raise ValueError( + "Multi-valued hyperparameters cannot be empty: %s" % name + ) + self._hparam_types[name] = (type(value[0]), True) + else: + self._hparam_types[name] = (type(value), False) + setattr(self, name, value) + + def set_hparam(self, name, value): + """Set the value of an existing hyperparameter. + + This function verifies that the type of the value matches the type of the + existing hyperparameter. + + Args: + name: Name of the hyperparameter. + value: New value of the hyperparameter. + + Raises: + KeyError: If the hyperparameter doesn't exist. + ValueError: If there is a type mismatch. + """ + param_type, is_list = self._hparam_types[name] + if isinstance(value, list): + if not is_list: + raise ValueError( + "Must not pass a list for single-valued parameter: %s" % name + ) + setattr( + self, + name, + [_cast_to_type_if_compatible(name, param_type, v) for v in value], + ) + else: + if is_list: + raise ValueError( + "Must pass a list for multi-valued parameter: %s." % name + ) + setattr(self, name, _cast_to_type_if_compatible(name, param_type, value)) + + def del_hparam(self, name): + """Removes the hyperparameter with key 'name'. + + Does nothing if it isn't present. + + Args: + name: Name of the hyperparameter. + """ + if hasattr(self, name): + delattr(self, name) + del self._hparam_types[name] + + def parse(self, values): + """Override existing hyperparameter values, parsing new values from a string. + + See parse_values for more detail on the allowed format for values. + + Args: + values: String. Comma separated list of `name=value` pairs where 'value' + must follow the syntax described above. + + Returns: + The `HParams` instance. + + Raises: + ValueError: If `values` cannot be parsed or a hyperparameter in `values` + doesn't exist. + """ + type_map = {} + for name, t in self._hparam_types.items(): + param_type, _ = t + type_map[name] = param_type + + values_map = parse_values(values, type_map) + return self.override_from_dict(values_map) + + def override_from_dict(self, values_dict): + """Override existing hyperparameter values, parsing new values from a dictionary. + + Args: + values_dict: Dictionary of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + KeyError: If a hyperparameter in `values_dict` doesn't exist. + ValueError: If `values_dict` cannot be parsed. + """ + for name, value in values_dict.items(): + self.set_hparam(name, value) + return self + + def set_model_structure(self, model_structure): + self._model_structure = model_structure + + def get_model_structure(self): + return self._model_structure + + def to_json(self, indent=None, separators=None, sort_keys=False): + """Serializes the hyperparameters into JSON. + + Args: + indent: If a non-negative integer, JSON array elements and object members + will be pretty-printed with that indent level. An indent level of 0, or + negative, will only insert newlines. `None` (the default) selects the + most compact representation. + separators: Optional `(item_separator, key_separator)` tuple. Default is + `(', ', ': ')`. + sort_keys: If `True`, the output dictionaries will be sorted by key. + + Returns: + A JSON string. + """ + + def remove_callables(x): + """Omit callable elements from input with arbitrary nesting.""" + if isinstance(x, dict): + return { + k: remove_callables(v) + for k, v in six.iteritems(x) + if not callable(v) + } + elif isinstance(x, list): + return [remove_callables(i) for i in x if not callable(i)] + return x + + return json.dumps( + remove_callables(self.values()), + indent=indent, + separators=separators, + sort_keys=sort_keys, + ) + + def parse_json(self, values_json): + """Override existing hyperparameter values, parsing new values from a json object. + + Args: + values_json: String containing a json object of name:value pairs. + + Returns: + The `HParams` instance. + + Raises: + KeyError: If a hyperparameter in `values_json` doesn't exist. + ValueError: If `values_json` cannot be parsed. + """ + values_map = json.loads(values_json) + return self.override_from_dict(values_map) + + def values(self): + """Return the hyperparameter values as a Python dictionary. + + Returns: + A dictionary with hyperparameter names as keys. The values are the + hyperparameter values. + """ + return {n: getattr(self, n) for n in self._hparam_types.keys()} + + def get(self, key, default=None): + """Returns the value of `key` if it exists, else `default`.""" + if key in self._hparam_types: + # Ensure that default is compatible with the parameter type. + if default is not None: + param_type, is_param_list = self._hparam_types[key] + type_str = "list<%s>" % param_type if is_param_list else str(param_type) + fail_msg = ( + "Hparam '%s' of type '%s' is incompatible with " + "default=%s" % (key, type_str, default) + ) + + is_default_list = isinstance(default, list) + if is_param_list != is_default_list: + raise ValueError(fail_msg) + + try: + if is_default_list: + for value in default: + _cast_to_type_if_compatible(key, param_type, value) + else: + _cast_to_type_if_compatible(key, param_type, default) + except ValueError as e: + raise ValueError("%s. %s" % (fail_msg, e)) + + return getattr(self, key) + + return default + + def __contains__(self, key): + return key in self._hparam_types + + def __str__(self): + return str(sorted(self.values().items())) + + def __repr__(self): + return "%s(%s)" % (type(self).__name__, self.__str__()) + + @staticmethod + def _get_kind_name(param_type, is_list): + """Returns the field name given parameter type and is_list. + + Args: + param_type: Data type of the hparam. + is_list: Whether this is a list. + + Returns: + A string representation of the field name. + + Raises: + ValueError: If parameter type is not recognized. + """ + if issubclass(param_type, bool): + # This check must happen before issubclass(param_type, six.integer_types), + # since Python considers bool to be a subclass of int. + typename = "bool" + elif issubclass(param_type, six.integer_types): + # Setting 'int' and 'long' types to be 'int64' to ensure the type is + # compatible with both Python2 and Python3. + typename = "int64" + elif issubclass(param_type, (six.string_types, six.binary_type)): + # Setting 'string' and 'bytes' types to be 'bytes' to ensure the type is + # compatible with both Python2 and Python3. + typename = "bytes" + elif issubclass(param_type, float): + typename = "float" + else: + raise ValueError("Unsupported parameter type: %s" % str(param_type)) + + suffix = "list" if is_list else "value" + return "_".join([typename, suffix]) diff --git a/utils/hubert.py b/utils/hubert.py new file mode 100644 index 0000000000000000000000000000000000000000..84b509fb9fde8485cfb504a675e5d3b7d27622ff --- /dev/null +++ b/utils/hubert.py @@ -0,0 +1,155 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/svc-develop-team/so-vits-svc/blob/4.0/preprocess_hubert_f0.py + +import os +import librosa +import torch +import numpy as np +from fairseq import checkpoint_utils +from tqdm import tqdm +import torch + + +def load_hubert_model(hps): + # Load model + ckpt_path = hps.hubert_file + print("Load Hubert Model...") + + models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task( + [ckpt_path], + suffix="", + ) + model = models[0] + model.eval() + + if torch.cuda.is_available(): + model = model.cuda() + + return model + + +def get_hubert_content(hmodel, wav_16k_tensor): + feats = wav_16k_tensor + if feats.dim() == 2: # double channels + feats = feats.mean(-1) + assert feats.dim() == 1, feats.dim() + feats = feats.view(1, -1) + padding_mask = torch.BoolTensor(feats.shape).fill_(False) + inputs = { + "source": feats.to(wav_16k_tensor.device), + "padding_mask": padding_mask.to(wav_16k_tensor.device), + "output_layer": 9, # layer 9 + } + with torch.no_grad(): + logits = hmodel.extract_features(**inputs) + feats = hmodel.final_proj(logits[0]).squeeze(0) + + return feats + + +def content_vector_encoder(model, audio_path, default_sampling_rate=16000): + """ + # content vector default sr: 16000 + """ + + wav16k, sr = librosa.load(audio_path, sr=default_sampling_rate) + device = next(model.parameters()).device + wav16k = torch.from_numpy(wav16k).to(device) + + # (1, 256, frame_len) + content_feature = get_hubert_content(model, wav_16k_tensor=wav16k) + + return content_feature.cpu().detach().numpy() + + +def repeat_expand_2d(content, target_len): + """ + content : [hubert_dim(256), src_len] + target: [hubert_dim(256), target_len] + """ + src_len = content.shape[-1] + target = torch.zeros([content.shape[0], target_len], dtype=torch.float).to( + content.device + ) + temp = torch.arange(src_len + 1) * target_len / src_len + current_pos = 0 + for i in range(target_len): + if i < temp[current_pos + 1]: + target[:, i] = content[:, current_pos] + else: + current_pos += 1 + target[:, i] = content[:, current_pos] + + return target + + +def get_mapped_features(raw_content_features, mapping_features): + """ + Content Vector: frameshift = 20ms, hop_size = 480 in 24k + + Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms) + """ + source_hop = 480 + target_hop = 256 + + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + print( + "Mapping source's {} frames => target's {} frames".format( + target_hop, source_hop + ) + ) + + results = [] + for index, mapping_feat in enumerate(tqdm(mapping_features)): + # mappping_feat: (mels_frame_len, n_mels) + target_len = len(mapping_feat) + + # (source_len, 256) + raw_feats = raw_content_features[index][0].cpu().numpy().T + source_len, width = raw_feats.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 3: + print("index:", index) + print("mels:", mapping_feat.shape) + print("raw content vector:", raw_feats.shape) + print("up_sampling:", up_sampling_feats.shape) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feats = down_sampling_feats[:target_len] + results.append(feats) + + return results + + +def extract_hubert_features_of_dataset(datasets, model, out_dir): + for utt in tqdm(datasets): + uid = utt["Uid"] + audio_path = utt["Path"] + + content_vector_feature = content_vector_encoder(model, audio_path) # (T, 256) + + save_path = os.path.join(out_dir, uid + ".npy") + np.save(save_path, content_vector_feature) diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..e6e31d2fc6ebf3ca888f58dbce96bb3d6c7c2905 --- /dev/null +++ b/utils/io.py @@ -0,0 +1,153 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import torch +import torchaudio + + +def save_feature(process_dir, feature_dir, item, feature, overrides=True): + """Save features to path + + Args: + process_dir (str): directory to store features + feature_dir (_type_): directory to store one type of features (mel, energy, ...) + item (str): uid + feature (tensor): feature tensor + overrides (bool, optional): whether to override existing files. Defaults to True. + """ + process_dir = os.path.join(process_dir, feature_dir) + os.makedirs(process_dir, exist_ok=True) + out_path = os.path.join(process_dir, item + ".npy") + + if os.path.exists(out_path): + if overrides: + np.save(out_path, feature) + else: + np.save(out_path, feature) + + +def save_txt(process_dir, feature_dir, item, feature, overrides=True): + process_dir = os.path.join(process_dir, feature_dir) + os.makedirs(process_dir, exist_ok=True) + out_path = os.path.join(process_dir, item + ".txt") + + if os.path.exists(out_path): + if overrides: + f = open(out_path, "w") + f.writelines(feature) + f.close() + else: + f = open(out_path, "w") + f.writelines(feature) + f.close() + + +def save_audio(path, waveform, fs, add_silence=False, turn_up=False, volume_peak=0.9): + if turn_up: + # continue to turn up to volume_peak + ratio = volume_peak / max(waveform.max(), abs(waveform.min())) + waveform = waveform * ratio + + if add_silence: + silence_len = fs // 20 + silence = np.zeros((silence_len,), dtype=waveform.dtype) + result = np.concatenate([silence, waveform, silence]) + waveform = result + + waveform = torch.as_tensor(waveform, dtype=torch.float32, device="cpu") + if len(waveform.size()) == 1: + waveform = waveform[None, :] + elif waveform.size(0) != 1: + # Stereo to mono + waveform = torch.mean(waveform, dim=0, keepdim=True) + torchaudio.save(path, waveform, fs, encoding="PCM_S", bits_per_sample=16) + + +async def async_load_audio(path, sample_rate: int = 24000): + r""" + Args: + path: The source loading path. + sample_rate: The target sample rate, will automatically resample if necessary. + + Returns: + waveform: The waveform object. Should be [1 x sequence_len]. + """ + + async def use_torchaudio_load(path): + return torchaudio.load(path) + + waveform, sr = await use_torchaudio_load(path) + waveform = torch.mean(waveform, dim=0, keepdim=True) + + if sr != sample_rate: + waveform = torchaudio.functional.resample(waveform, sr, sample_rate) + + if torch.any(torch.isnan(waveform) or torch.isinf(waveform)): + raise ValueError("NaN or Inf found in waveform.") + return waveform + + +async def async_save_audio( + path, + waveform, + sample_rate: int = 24000, + add_silence: bool = False, + volume_peak: float = 0.9, +): + r""" + Args: + path: The target saving path. + waveform: The waveform object. Should be [n_channel x sequence_len]. + sample_rate: Sample rate. + add_silence: If ``true``, concat 0.05s silence to beginning and end. + volume_peak: Turn up volume for larger number, vice versa. + """ + + async def use_torchaudio_save(path, waveform, sample_rate): + torchaudio.save( + path, waveform, sample_rate, encoding="PCM_S", bits_per_sample=16 + ) + + waveform = torch.as_tensor(waveform, device="cpu", dtype=torch.float32) + shape = waveform.size()[:-1] + + ratio = abs(volume_peak) / max(waveform.max(), abs(waveform.min())) + waveform = waveform * ratio + + if add_silence: + silence_len = sample_rate // 20 + silence = torch.zeros((*shape, silence_len), dtype=waveform.type()) + waveform = torch.concatenate((silence, waveform, silence), dim=-1) + + if waveform.dim() == 1: + waveform = waveform[None] + + await use_torchaudio_save(path, waveform, sample_rate) + + +def load_mel_extrema(cfg, dataset_name, split): + dataset_dir = os.path.join( + cfg.OUTPUT_PATH, + "preprocess/{}_version".format(cfg.data.process_version), + dataset_name, + ) + + min_file = os.path.join( + dataset_dir, + "mel_min_max", + split.split("_")[-1], + "mel_min.npy", + ) + max_file = os.path.join( + dataset_dir, + "mel_min_max", + split.split("_")[-1], + "mel_max.npy", + ) + mel_min = np.load(min_file) + mel_max = np.load(max_file) + return mel_min, mel_max diff --git a/utils/io_optim.py b/utils/io_optim.py new file mode 100644 index 0000000000000000000000000000000000000000..942619d625d4c8e2d00ae1255421ceaf6ab39986 --- /dev/null +++ b/utils/io_optim.py @@ -0,0 +1,123 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torchaudio +import json +import os +import numpy as np +import librosa +from torch.nn.utils.rnn import pad_sequence +from modules import whisper_extractor as whisper + + +class TorchaudioDataset(torch.utils.data.Dataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + """ + Args: + cfg: config + dataset: dataset name + + """ + assert isinstance(dataset, str) + + self.sr = sr + self.cfg = cfg + + if metadata is None: + self.train_metadata_path = os.path.join( + cfg.preprocess.processed_dir, dataset, cfg.preprocess.train_file + ) + self.valid_metadata_path = os.path.join( + cfg.preprocess.processed_dir, dataset, cfg.preprocess.valid_file + ) + self.metadata = self.get_metadata() + else: + self.metadata = metadata + + if accelerator is not None: + self.device = accelerator.device + elif torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + + def get_metadata(self): + metadata = [] + with open(self.train_metadata_path, "r", encoding="utf-8") as t: + metadata.extend(json.load(t)) + with open(self.valid_metadata_path, "r", encoding="utf-8") as v: + metadata.extend(json.load(v)) + return metadata + + def __len__(self): + return len(self.metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + wav, sr = torchaudio.load(wav_path) + + # resample + if sr != self.sr: + wav = torchaudio.functional.resample(wav, sr, self.sr) + # downmixing + if wav.shape[0] > 1: + wav = torch.mean(wav, dim=0, keepdim=True) + assert wav.shape[0] == 1 + wav = wav.squeeze(0) + # record the length of wav without padding + length = wav.shape[0] + # wav: (T) + return utt_info, wav, length + + +class LibrosaDataset(TorchaudioDataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + super().__init__(cfg, dataset, sr, accelerator, metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + wav, _ = librosa.load(wav_path, sr=self.sr) + # wav: (T) + wav = torch.from_numpy(wav) + + # record the length of wav without padding + length = wav.shape[0] + return utt_info, wav, length + + +class FFmpegDataset(TorchaudioDataset): + def __init__(self, cfg, dataset, sr, accelerator=None, metadata=None): + super().__init__(cfg, dataset, sr, accelerator, metadata) + + def __getitem__(self, index): + utt_info = self.metadata[index] + wav_path = utt_info["Path"] + + # wav: (T,) + wav = whisper.load_audio(wav_path) # sr = 16000 + # convert to torch tensor + wav = torch.from_numpy(wav) + # record the length of wav without padding + length = wav.shape[0] + + return utt_info, wav, length + + +def collate_batch(batch_list): + """ + Args: + batch_list: list of (metadata, wav, length) + """ + metadata = [item[0] for item in batch_list] + # wavs: (B, T) + wavs = pad_sequence([item[1] for item in batch_list], batch_first=True) + lens = [item[2] for item in batch_list] + + return metadata, wavs, lens diff --git a/utils/mel.py b/utils/mel.py new file mode 100644 index 0000000000000000000000000000000000000000..d32d38226dfe6c0162527b3136a392b9168c7f06 --- /dev/null +++ b/utils/mel.py @@ -0,0 +1,283 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from librosa.filters import mel as librosa_mel_fn + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def extract_linear_features(y, cfg, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global hann_window + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.squeeze(spec, 0) + return spec + + +def mel_spectrogram_torch(y, cfg, center=False): + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec + + +mel_basis = {} +hann_window = {} + + +def extract_mel_features( + y, + cfg, + center=False + # n_fft, n_mel, sampling_rate, hop_size, win_size, fmin, fmax, center=False +): + """Extract mel features + + Args: + y (tensor): audio data in tensor + cfg (dict): configuration in cfg.preprocess + center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. + + Returns: + tensor: a tensor containing the mel feature calculated based on STFT result + """ + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec.squeeze(0) + + +def extract_mel_features_tts( + y, + cfg, + center=False, + taco=False, + _stft=None, +): + """Extract mel features + + Args: + y (tensor): audio data in tensor + cfg (dict): configuration in cfg.preprocess + center (bool, optional): In STFT, whether t-th frame is centered at time t*hop_length. Defaults to False. + taco: use tacotron mel + + Returns: + tensor: a tensor containing the mel feature calculated based on STFT result + """ + if not taco: + if torch.min(y) < -1.0: + print("min value is ", torch.min(y)) + if torch.max(y) > 1.0: + print("max value is ", torch.max(y)) + + global mel_basis, hann_window + if cfg.fmax not in mel_basis: + mel = librosa_mel_fn( + sr=cfg.sample_rate, + n_fft=cfg.n_fft, + n_mels=cfg.n_mel, + fmin=cfg.fmin, + fmax=cfg.fmax, + ) + mel_basis[str(cfg.fmax) + "_" + str(y.device)] = ( + torch.from_numpy(mel).float().to(y.device) + ) + hann_window[str(y.device)] = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + # complex tensor as default, then use view_as_real for future pytorch compatibility + spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window[str(y.device)], + center=center, + pad_mode="reflect", + normalized=False, + onesided=True, + return_complex=True, + ) + spec = torch.view_as_real(spec) + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + spec = spec.squeeze(0) + else: + audio = torch.clip(y, -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + spec, energy = _stft.mel_spectrogram(audio) + spec = torch.squeeze(spec, 0) + + spec = torch.matmul(mel_basis[str(cfg.fmax) + "_" + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + + return spec.squeeze(0) + + +def amplitude_phase_spectrum(y, cfg): + hann_window = torch.hann_window(cfg.win_size).to(y.device) + + y = torch.nn.functional.pad( + y.unsqueeze(1), + (int((cfg.n_fft - cfg.hop_size) / 2), int((cfg.n_fft - cfg.hop_size) / 2)), + mode="reflect", + ) + y = y.squeeze(1) + + stft_spec = torch.stft( + y, + cfg.n_fft, + hop_length=cfg.hop_size, + win_length=cfg.win_size, + window=hann_window, + center=False, + return_complex=True, + ) + + stft_spec = torch.view_as_real(stft_spec) + if stft_spec.size()[0] == 1: + stft_spec = stft_spec.squeeze(0) + + if len(list(stft_spec.size())) == 4: + rea = stft_spec[:, :, :, 0] # [batch_size, n_fft//2+1, frames] + imag = stft_spec[:, :, :, 1] # [batch_size, n_fft//2+1, frames] + else: + rea = stft_spec[:, :, 0] # [n_fft//2+1, frames] + imag = stft_spec[:, :, 1] # [n_fft//2+1, frames] + + log_amplitude = torch.log( + torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5 + ) # [n_fft//2+1, frames] + phase = torch.atan2(imag, rea) # [n_fft//2+1, frames] + + return log_amplitude, phase, rea, imag diff --git a/utils/mert.py b/utils/mert.py new file mode 100644 index 0000000000000000000000000000000000000000..4181429feb36f5013bafafff9505b0b9571485b4 --- /dev/null +++ b/utils/mert.py @@ -0,0 +1,139 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://huggingface.co/m-a-p/MERT-v1-330M + +import torch +from tqdm import tqdm +import numpy as np + +from transformers import Wav2Vec2FeatureExtractor +from transformers import AutoModel +import torchaudio +import torchaudio.transforms as T +from sklearn.preprocessing import StandardScaler + + +def mert_encoder(model, processor, audio_path, hps): + """ + # mert default sr: 24000 + """ + with torch.no_grad(): + resample_rate = processor.sampling_rate + device = next(model.parameters()).device + + input_audio, sampling_rate = torchaudio.load(audio_path) + input_audio = input_audio.squeeze() + + if sampling_rate != resample_rate: + resampler = T.Resample(sampling_rate, resample_rate) + input_audio = resampler(input_audio) + + inputs = processor( + input_audio, sampling_rate=resample_rate, return_tensors="pt" + ).to( + device + ) # {input_values: tensor, attention_mask: tensor} + + outputs = model(**inputs, output_hidden_states=True) # list: len is 25 + + # [25 layer, Time steps, 1024 feature_dim] + # all_layer_hidden_states = torch.stack(outputs.hidden_states).squeeze() + # mert_features.append(all_layer_hidden_states) + + feature = outputs.hidden_states[ + hps.mert_feature_layer + ].squeeze() # [1, frame len, 1024] -> [frame len, 1024] + + return feature.cpu().detach().numpy() + + +def mert_features_normalization(raw_mert_features): + normalized_mert_features = list() + + mert_features = np.array(raw_mert_features) + scaler = StandardScaler().fit(mert_features) + for raw_mert_feature in raw_mert_feature: + normalized_mert_feature = scaler.transform(raw_mert_feature) + normalized_mert_features.append(normalized_mert_feature) + return normalized_mert_features + + +def get_mapped_mert_features(raw_mert_features, mapping_features, fast_mapping=True): + source_hop = 320 + target_hop = 256 + + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + print( + "Mapping source's {} frames => target's {} frames".format( + target_hop, source_hop + ) + ) + + mert_features = [] + for index, mapping_feat in enumerate(tqdm(mapping_features)): + # mapping_feat: (mels_frame_len, n_mels) + target_len = mapping_feat.shape[0] + + # (frame_len, 1024) + raw_feats = raw_mert_features[index].cpu().numpy() + source_len, width = raw_feats.shape + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + + err = abs(target_len - len(down_sampling_feats)) + if err > 3: + print("index:", index) + print("mels:", mapping_feat.shape) + print("raw mert vector:", raw_feats.shape) + print("up_sampling:", up_sampling_feats.shape) + print("const:", const) + print("down_sampling_feats:", down_sampling_feats.shape) + exit() + if len(down_sampling_feats) < target_len: + # (1, dim) -> (err, dim) + end = down_sampling_feats[-1][None, :].repeat(err, axis=0) + down_sampling_feats = np.concatenate([down_sampling_feats, end], axis=0) + + # (target_len, dim) + feats = down_sampling_feats[:target_len] + mert_features.append(feats) + + return mert_features + + +def load_mert_model(hps): + print("Loading MERT Model: ", hps.mert_model) + + # Load model + model_name = hps.mert_model + model = AutoModel.from_pretrained(model_name, trust_remote_code=True) + + if torch.cuda.is_available(): + model = model.cuda() + + # model = model.eval() + + preprocessor = Wav2Vec2FeatureExtractor.from_pretrained( + model_name, trust_remote_code=True + ) + return model, preprocessor + + +# loading the corresponding preprocessor config +# def load_preprocessor (model_name="m-a-p/MERT-v1-330M"): +# print('load_preprocessor...') +# preprocessor = Wav2Vec2FeatureExtractor.from_pretrained(model_name,trust_remote_code=True) +# return preprocessor diff --git a/utils/model_summary.py b/utils/model_summary.py new file mode 100644 index 0000000000000000000000000000000000000000..ec72b0d17869dc0886a9eb665efedb70bb307bbf --- /dev/null +++ b/utils/model_summary.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import humanfriendly +import numpy as np +import torch + + +def get_human_readable_count(number: int) -> str: + """Return human_readable_count + + Originated from: + https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/core/memory.py + + Abbreviates an integer number with K, M, B, T for thousands, millions, + billions and trillions, respectively. + Examples: + >>> get_human_readable_count(123) + '123 ' + >>> get_human_readable_count(1234) # (one thousand) + '1 K' + >>> get_human_readable_count(2e6) # (two million) + '2 M' + >>> get_human_readable_count(3e9) # (three billion) + '3 B' + >>> get_human_readable_count(4e12) # (four trillion) + '4 T' + >>> get_human_readable_count(5e15) # (more than trillion) + '5,000 T' + Args: + number: a positive integer number + Return: + A string formatted according to the pattern described above. + """ + assert number >= 0 + labels = [" ", "K", "M", "B", "T"] + num_digits = int(np.floor(np.log10(number)) + 1 if number > 0 else 1) + num_groups = int(np.ceil(num_digits / 3)) + num_groups = min(num_groups, len(labels)) + shift = -3 * (num_groups - 1) + number = number * (10**shift) + index = num_groups - 1 + return f"{number:.2f} {labels[index]}" + + +def to_bytes(dtype) -> int: + return int(str(dtype)[-2:]) // 8 + + +def model_summary(model: torch.nn.Module) -> str: + message = "Model structure:\n" + message += str(model) + tot_params = sum(p.numel() for p in model.parameters()) + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + percent_trainable = "{:.1f}".format(num_params * 100.0 / tot_params) + tot_params = get_human_readable_count(tot_params) + num_params = get_human_readable_count(num_params) + message += "\n\nModel summary:\n" + message += f" Class Name: {model.__class__.__name__}\n" + message += f" Total Number of model parameters: {tot_params}\n" + message += ( + f" Number of trainable parameters: {num_params} ({percent_trainable}%)\n" + ) + num_bytes = humanfriendly.format_size( + sum( + p.numel() * to_bytes(p.dtype) for p in model.parameters() if p.requires_grad + ) + ) + message += f" Size: {num_bytes}\n" + dtype = next(iter(model.parameters())).dtype + message += f" Type: {dtype}" + return message diff --git a/utils/prompt_preparer.py b/utils/prompt_preparer.py new file mode 100644 index 0000000000000000000000000000000000000000..dba833ee72cd661bbc178ffc4735e05a4e4949c9 --- /dev/null +++ b/utils/prompt_preparer.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +class PromptPreparer: + def prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes): + if self.prefix_mode == 0: + y_emb, prefix_len = self._handle_prefix_mode_0(y, codes, nar_stage) + elif self.prefix_mode == 1: + y_emb, prefix_len = self._handle_prefix_mode_1(y, y_lens, codes, nar_stage) + elif self.prefix_mode in [2, 4]: + y_emb, prefix_len = self._handle_prefix_mode_2_4(y, y_lens, codes, nar_stage, y_prompts_codes) + else: + raise ValueError("Invalid prefix mode") + + return y_emb, prefix_len + + def _handle_prefix_mode_0(self, y, codes, nar_stage): + prefix_len = 0 + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, nar_stage): + y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j]) + return y_emb, 0 + + def _handle_prefix_mode_1(self, y, y_lens, codes, nar_stage): + int_low = (0.25 * y_lens.min()).type(torch.int64).item() + prefix_len = torch.randint(int_low, int_low * 2, size=()).item() + prefix_len = min(prefix_len, 225) + + y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len]) + y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:]) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j]( + codes[:, :prefix_len, j] + ) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j]( + codes[:, prefix_len:, j] + ) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + return y_emb, prefix_len + + def _handle_prefix_mode_2_4(self, y, y_lens, codes, nar_stage, y_prompts_codes): + if self.prefix_mode == 2: + prefix_len = min(225, int(0.25 * y_lens.min().item())) + + y_prompts_codes = [] + for b in range(codes.shape[0]): + start = self.rng.randint(0, y_lens[b].item() - prefix_len) + y_prompts_codes.append( + torch.clone(codes[b, start : start + prefix_len]) + ) + codes[ + b, start : start + prefix_len, nar_stage + ] = self.audio_token_num + y_prompts_codes = torch.stack(y_prompts_codes, dim=0) + else: + prefix_len = y_prompts_codes.shape[1] + + y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0]) + y_emb = self.nar_audio_embeddings[0](y) + for j in range(1, self.num_quantizers): + y_prompts += self.nar_audio_embeddings[j]( + y_prompts_codes[..., j] + ) + if j < nar_stage: + y_emb += self.nar_audio_embeddings[j](codes[..., j]) + y_emb = torch.concat([y_prompts, y_emb], axis=1) + + return y_emb, prefix_len diff --git a/utils/ssim.py b/utils/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..a0b95007b2c225f3cae869f0653a733d1d92043a --- /dev/null +++ b/utils/ssim.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from https://github.com/Po-Hsun-Su/pytorch-ssim + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor( + [ + exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2)) + for x in range(window_size) + ] + ) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable( + _2D_window.expand(channel, 1, window_size, window_size).contiguous() + ) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = ( + F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + ) + sigma2_sq = ( + F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + ) + sigma12 = ( + F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) + - mu1_mu2 + ) + + C1 = 0.01**2 + C2 = 0.03**2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( + (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) + ) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, fake, real, bias=6.0): + fake = fake[:, None, :, :] + bias # [B, 1, T, n_mels] + real = real[:, None, :, :] + bias # [B, 1, T, n_mels] + self.window = self.window.to(dtype=fake.dtype, device=fake.device) + loss = 1 - _ssim( + fake, real, self.window, self.window_size, self.channel, self.size_average + ) + return loss diff --git a/utils/stft.py b/utils/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..bcec4c84ace0cc40d65361316222b090428cc391 --- /dev/null +++ b/utils/stft.py @@ -0,0 +1,278 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data.cuda(), + torch.autograd.Variable(self.forward_basis, requires_grad=False).cuda(), + stride=self.hop_length, + padding=0, + ).cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes): + output = dynamic_range_compression(magnitudes) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1 + assert torch.max(y.data) <= 1 + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output) + energy = torch.norm(magnitudes, dim=1) + + return mel_output, energy diff --git a/utils/symbol_table.py b/utils/symbol_table.py new file mode 100644 index 0000000000000000000000000000000000000000..730ffe7a8018f80f662e260542a859a6f6f74a47 --- /dev/null +++ b/utils/symbol_table.py @@ -0,0 +1,313 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from +# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/utils/symbol_table.py + +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import Generic +from typing import List +from typing import Optional +from typing import TypeVar +from typing import Union + +Symbol = TypeVar('Symbol') + + +@dataclass(repr=False) +class SymbolTable(Generic[Symbol]): + '''SymbolTable that maps symbol IDs, found on the FSA arcs to + actual objects. These objects can be arbitrary Python objects + that can serve as keys in a dictionary (i.e. they need to be + hashable and immutable). + + The SymbolTable can only be read to/written from disk if the + symbols are strings. + ''' + _id2sym: Dict[int, Symbol] = field(default_factory=dict) + '''Map an integer to a symbol. + ''' + + _sym2id: Dict[Symbol, int] = field(default_factory=dict) + '''Map a symbol to an integer. + ''' + + _next_available_id: int = 1 + '''A helper internal field that helps adding new symbols + to the table efficiently. + ''' + + eps: Symbol = '' + '''Null symbol, always mapped to index 0. + ''' + + def __post_init__(self): + assert all(self._sym2id[sym] == idx for idx, sym in self._id2sym.items()) + assert all(self._id2sym[idx] == sym for sym, idx in self._sym2id.items()) + assert 0 not in self._id2sym or self._id2sym[0] == self.eps + + self._next_available_id = max(self._id2sym, default=0) + 1 + self._id2sym.setdefault(0, self.eps) + self._sym2id.setdefault(self.eps, 0) + + + @staticmethod + def from_str(s: str) -> 'SymbolTable': + '''Build a symbol table from a string. + + The string consists of lines. Every line has two fields separated + by space(s), tab(s) or both. The first field is the symbol and the + second the integer id of the symbol. + + Args: + s: + The input string with the format described above. + Returns: + An instance of :class:`SymbolTable`. + ''' + id2sym: Dict[int, str] = dict() + sym2id: Dict[str, int] = dict() + + for line in s.split('\n'): + fields = line.split() + if len(fields) == 0: + continue # skip empty lines + assert len(fields) == 2, \ + f'Expect a line with 2 fields. Given: {len(fields)}' + sym, idx = fields[0], int(fields[1]) + assert sym not in sym2id, f'Duplicated symbol {sym}' + assert idx not in id2sym, f'Duplicated id {idx}' + id2sym[idx] = sym + sym2id[sym] = idx + + eps = id2sym.get(0, '') + + return SymbolTable(_id2sym=id2sym, _sym2id=sym2id, eps=eps) + + @staticmethod + def from_file(filename: str) -> 'SymbolTable': + '''Build a symbol table from file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + + Returns: + An instance of :class:`SymbolTable`. + + ''' + with open(filename, 'r', encoding='utf-8') as f: + return SymbolTable.from_str(f.read().strip()) + + def to_str(self) -> str: + ''' + Returns: + Return a string representation of this object. You can pass + it to the method ``from_str`` to recreate an identical object. + ''' + s = '' + for idx, symbol in sorted(self._id2sym.items()): + s += f'{symbol} {idx}\n' + return s + + def to_file(self, filename: str): + '''Serialize the SymbolTable to a file. + + Every line in the symbol table file has two fields separated by + space(s), tab(s) or both. The following is an example file: + + .. code-block:: + + 0 + a 1 + b 2 + c 3 + + Args: + filename: + Name of the symbol table file. Its format is documented above. + ''' + with open(filename, 'w') as f: + for idx, symbol in sorted(self._id2sym.items()): + print(symbol, idx, file=f) + + def add(self, symbol: Symbol, index: Optional[int] = None) -> int: + '''Add a new symbol to the SymbolTable. + + Args: + symbol: + The symbol to be added. + index: + Optional int id to which the symbol should be assigned. + If it is not available, a ValueError will be raised. + + Returns: + The int id to which the symbol has been assigned. + ''' + # Already in the table? Return its ID. + if symbol in self._sym2id: + return self._sym2id[symbol] + # Specific ID not provided - use next available. + if index is None: + index = self._next_available_id + # Specific ID provided but not available. + if index in self._id2sym: + raise ValueError(f"Cannot assign id '{index}' to '{symbol}' - " + f"already occupied by {self._id2sym[index]}") + self._sym2id[symbol] = index + self._id2sym[index] = symbol + + # Update next available ID if needed + if self._next_available_id <= index: + self._next_available_id = index + 1 + + return index + + def get(self, k: Union[int, Symbol]) -> Union[Symbol, int]: + '''Get a symbol for an id or get an id for a symbol + + Args: + k: + If it is an id, it tries to find the symbol corresponding + to the id; if it is a symbol, it tries to find the id + corresponding to the symbol. + + Returns: + An id or a symbol depending on the given `k`. + ''' + if isinstance(k, int): + return self._id2sym[k] + else: + return self._sym2id[k] + + def merge(self, other: 'SymbolTable') -> 'SymbolTable': + '''Create a union of two SymbolTables. + Raises an AssertionError if the same IDs are occupied by + different symbols. + + Args: + other: + A symbol table to merge with ``self``. + + Returns: + A new symbol table. + ''' + self._check_compatible(other) + return SymbolTable( + _id2sym={**self._id2sym, **other._id2sym}, + _sym2id={**self._sym2id, **other._sym2id}, + eps=self.eps + ) + + def _check_compatible(self, other: 'SymbolTable') -> None: + # Epsilon compatibility + assert self.eps == other.eps, f'Mismatched epsilon symbol: ' \ + f'{self.eps} != {other.eps}' + # IDs compatibility + common_ids = set(self._id2sym).intersection(other._id2sym) + for idx in common_ids: + assert self[idx] == other[idx], f'ID conflict for id: {idx}, ' \ + f'self[idx] = "{self[idx]}", ' \ + f'other[idx] = "{other[idx]}"' + # Symbols compatibility + common_symbols = set(self._sym2id).intersection(other._sym2id) + for sym in common_symbols: + assert self[sym] == other[sym], f'ID conflict for id: {sym}, ' \ + f'self[sym] = "{self[sym]}", ' \ + f'other[sym] = "{other[sym]}"' + + def __getitem__(self, item: Union[int, Symbol]) -> Union[Symbol, int]: + return self.get(item) + + def __contains__(self, item: Union[int, Symbol]) -> bool: + if isinstance(item, int): + return item in self._id2sym + else: + return item in self._sym2id + + def __len__(self) -> int: + return len(self._id2sym) + + def __eq__(self, other: 'SymbolTable') -> bool: + if len(self) != len(other): + return False + + for s in self.symbols: + if self[s] != other[s]: + return False + + return True + + @property + def ids(self) -> List[int]: + '''Returns a list of integer IDs corresponding to the symbols. + ''' + ans = list(self._id2sym.keys()) + ans.sort() + return ans + + @property + def symbols(self) -> List[Symbol]: + '''Returns a list of symbols (e.g., strings) corresponding to + the integer IDs. + ''' + ans = list(self._sym2id.keys()) + ans.sort() + return ans + + +class TextToken: + def __init__( + self, + text_tokens: List[str], + add_eos: bool = True, + add_bos: bool = True, + pad_symbol: str = "", + bos_symbol: str = "", + eos_symbol: str = "", + ): + self.pad_symbol = pad_symbol + self.add_eos = add_eos + self.add_bos = add_bos + self.bos_symbol = bos_symbol + self.eos_symbol = eos_symbol + + unique_tokens = [pad_symbol] + if add_bos: + unique_tokens.append(bos_symbol) + if add_eos: + unique_tokens.append(eos_symbol) + unique_tokens.extend(sorted(text_tokens)) + + self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} + self.idx2token = unique_tokens + + + def get_token_id_seq(self, text): + tokens_seq = [p for p in text] + seq = ( + ([self.bos_symbol] if self.add_bos else []) + + tokens_seq + + ([self.eos_symbol] if self.add_eos else []) + ) + + token_ids = [self.token2idx[token] for token in seq] + token_lens = len(tokens_seq) + self.add_eos + self.add_bos + + return token_ids, token_lens + + \ No newline at end of file diff --git a/utils/tokenizer.py b/utils/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..8971432bdcc2f3f2920775bb3397c90f2a91f8b8 --- /dev/null +++ b/utils/tokenizer.py @@ -0,0 +1,151 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# This code is modified from +# https://github.com/lifeiteng/vall-e/blob/9c69096d603ce13174fb5cb025f185e2e9b36ac7/valle/data/tokenizer.py + +import re +from typing import Any, Dict, List, Optional, Pattern, Union + +import torch +import torchaudio +from encodec import EncodecModel +from encodec.utils import convert_audio + + + +class AudioTokenizer: + """EnCodec audio tokenizer for encoding and decoding audio. + + Attributes: + device: The device on which the codec model is loaded. + codec: The pretrained EnCodec model. + sample_rate: Sample rate of the model. + channels: Number of audio channels in the model. + """ + + def __init__(self, device: Any = None) -> None: + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + remove_encodec_weight_norm(model) + + if not device: + device = torch.device("cpu") + if torch.cuda.is_available(): + device = torch.device("cuda:0") + + self._device = device + + self.codec = model.to(device) + self.sample_rate = model.sample_rate + self.channels = model.channels + + @property + def device(self): + return self._device + + def encode(self, wav: torch.Tensor) -> torch.Tensor: + """Encode the audio waveform. + + Args: + wav: A tensor representing the audio waveform. + + Returns: + A tensor representing the encoded audio. + """ + return self.codec.encode(wav.to(self.device)) + + def decode(self, frames: torch.Tensor) -> torch.Tensor: + """Decode the encoded audio frames. + + Args: + frames: A tensor representing the encoded audio frames. + + Returns: + A tensor representing the decoded audio waveform. + """ + return self.codec.decode(frames) + + + +def tokenize_audio(tokenizer: AudioTokenizer, audio_path: str): + """ + Tokenize the audio waveform using the given AudioTokenizer. + + Args: + tokenizer: An instance of AudioTokenizer. + audio_path: Path to the audio file. + + Returns: + A tensor of encoded frames from the audio. + + Raises: + FileNotFoundError: If the audio file is not found. + RuntimeError: If there's an error processing the audio data. + """ + # try: + # Load and preprocess the audio waveform + wav, sr = torchaudio.load(audio_path) + wav = convert_audio(wav, sr, tokenizer.sample_rate, tokenizer.channels) + wav = wav.unsqueeze(0) + + # Extract discrete codes from EnCodec + with torch.no_grad(): + encoded_frames = tokenizer.encode(wav) + return encoded_frames + + # except FileNotFoundError: + # raise FileNotFoundError(f"Audio file not found at {audio_path}") + # except Exception as e: + # raise RuntimeError(f"Error processing audio data: {e}") + + + +def remove_encodec_weight_norm(model): + from encodec.modules import SConv1d + from encodec.modules.seanet import SConvTranspose1d, SEANetResnetBlock + from torch.nn.utils import remove_weight_norm + + encoder = model.encoder.model + for key in encoder._modules: + if isinstance(encoder._modules[key], SEANetResnetBlock): + remove_weight_norm(encoder._modules[key].shortcut.conv.conv) + block_modules = encoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(encoder._modules[key], SConv1d): + remove_weight_norm(encoder._modules[key].conv.conv) + + decoder = model.decoder.model + for key in decoder._modules: + if isinstance(decoder._modules[key], SEANetResnetBlock): + remove_weight_norm(decoder._modules[key].shortcut.conv.conv) + block_modules = decoder._modules[key].block._modules + for skey in block_modules: + if isinstance(block_modules[skey], SConv1d): + remove_weight_norm(block_modules[skey].conv.conv) + elif isinstance(decoder._modules[key], SConvTranspose1d): + remove_weight_norm(decoder._modules[key].convtr.convtr) + elif isinstance(decoder._modules[key], SConv1d): + remove_weight_norm(decoder._modules[key].conv.conv) + + +def extract_encodec_token(wav_path): + model = EncodecModel.encodec_model_24khz() + model.set_target_bandwidth(6.0) + + wav, sr = torchaudio.load(wav_path) + wav = convert_audio(wav, sr, model.sample_rate, model.channels) + wav = wav.unsqueeze(0) + if torch.cuda.is_available(): + model = model.cuda() + wav = wav.cuda() + with torch.no_grad(): + encoded_frames = model.encode(wav) + codes_ = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [B, n_q, T] + codes = codes_.cpu().numpy()[0,:,:].T # [T, 8] + + return codes \ No newline at end of file diff --git a/utils/topk_sampling.py b/utils/topk_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..03f405c15506609bbd571818123cf7f0db7f1f4b --- /dev/null +++ b/utils/topk_sampling.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +import torch.nn.functional as F + + +# This function is modified from https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py +def top_k_top_p_filtering( + logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 +): + """ + Filter a distribution of logits using top-k and/or nucleus (top-p) filtering. + + Args: + logits (torch.Tensor): Logits distribution with shape (batch size, vocabulary size). + top_k (int, optional): Keep only top k tokens with highest probability (top-k filtering). + Set to 0 to disable. Defaults to 0. + top_p (float, optional): Keep the top tokens with a cumulative probability >= top_p (nucleus filtering). + Must be between 0 and 1, inclusive. Defaults to 1.0. + filter_value (float, optional): The value to assign to filtered logits. Defaults to -float('Inf'). + min_tokens_to_keep (int, optional): Ensure that at least this number of tokens are kept per batch example. + Defaults to 1. + + Returns: + torch.Tensor: The filtered logits. + """ + """ + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + Make sure we keep at least min_tokens_to_keep per batch example in the output + From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 + """ + if top_k > 0: + # Apply top-k filtering + top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) + indices_to_remove = logits < torch.topk(logits, top_k).values[..., -1, None] + logits[indices_to_remove] = filter_value + + if top_p < 1.0: + # Apply top-p filtering + sorted_logits, sorted_indices = torch.sort(logits, descending=True) + cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) + + # Create a mask to remove tokens with cumulative probability above the top_p threshold + sorted_indices_to_remove = cumulative_probs > top_p + if min_tokens_to_keep > 1: + sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + # Scatter sorted tensors back to original indexing + indices_to_remove = sorted_indices.scatter(1, sorted_indices, sorted_indices_to_remove) + logits[indices_to_remove] = filter_value + + return logits + + +def topk_sampling(logits, top_k=50, top_p=1.0, temperature=1.0): + """ + Perform top-k and top-p sampling on logits. + + Args: + logits (torch.Tensor): The logits to sample from. + top_k (int, optional): The number of highest probability tokens to keep for top-k filtering. + Must be a positive integer. Defaults to 50. + top_p (float, optional): The cumulative probability threshold for nucleus sampling. + Must be between 0 and 1. Defaults to 1.0. + temperature (float, optional): The scaling factor to adjust the logits distribution. + Must be strictly positive. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ + + # Adjust logits using temperature + if temperature != 1.0: + logits = logits / temperature + + # Top-p/top-k filtering + logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) + + # Sample from the filtered distribution + token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1) + return token diff --git a/utils/trainer_utils.py b/utils/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5d9ad794864aa3ee0c49a86e9be293f69442886 --- /dev/null +++ b/utils/trainer_utils.py @@ -0,0 +1,16 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch + + +def check_nan(logger, loss, y_pred, y_gt): + if torch.any(torch.isnan(loss)): + logger.info("out has nan: ", torch.any(torch.isnan(y_pred))) + logger.info("y_gt has nan: ", torch.any(torch.isnan(y_gt))) + logger.info("out: ", y_pred) + logger.info("y_gt: ", y_gt) + logger.info("loss = {:.4f}\n".format(loss.item())) + exit() diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000000000000000000000000000000000000..096d9129cd0dab4798f205ab7fdc690303ef5d37 --- /dev/null +++ b/utils/util.py @@ -0,0 +1,688 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import collections +import glob +import os +import random +import time +import argparse +from collections import OrderedDict + +import json5 +import numpy as np +import glob +from torch.nn import functional as F + + +try: + from ruamel.yaml import YAML as yaml +except: + from ruamel_yaml import YAML as yaml + +import torch + +from utils.hparam import HParams +import logging +from logging import handlers + + +def str2bool(v): + """Used in argparse.ArgumentParser.add_argument to indicate + that a type is a bool type and user can enter + + - yes, true, t, y, 1, to represent True + - no, false, f, n, 0, to represent False + + See https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse # noqa + """ + if isinstance(v, bool): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): + return True + elif v.lower() in ("no", "false", "f", "n", "0"): + return False + else: + raise argparse.ArgumentTypeError("Boolean value expected.") + + +def find_checkpoint_of_mapper(mapper_ckpt_dir): + mapper_ckpts = glob.glob(os.path.join(mapper_ckpt_dir, "ckpts/*.pt")) + + # Select the max steps + mapper_ckpts.sort() + mapper_weights_file = mapper_ckpts[-1] + return mapper_weights_file + + +def pad_f0_to_tensors(f0s, batched=None): + # Initialize + tensors = [] + + if batched == None: + # Get the max frame for padding + size = -1 + for f0 in f0s: + size = max(size, f0.shape[-1]) + + tensor = torch.zeros(len(f0s), size) + + for i, f0 in enumerate(f0s): + tensor[i, : f0.shape[-1]] = f0[:] + + tensors.append(tensor) + else: + start = 0 + while start + batched - 1 < len(f0s): + end = start + batched - 1 + + # Get the max frame for padding + size = -1 + for i in range(start, end + 1): + size = max(size, f0s[i].shape[-1]) + + tensor = torch.zeros(batched, size) + + for i in range(start, end + 1): + tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] + + tensors.append(tensor) + + start = start + batched + + if start != len(f0s): + end = len(f0s) + + # Get the max frame for padding + size = -1 + for i in range(start, end): + size = max(size, f0s[i].shape[-1]) + + tensor = torch.zeros(len(f0s) - start, size) + + for i in range(start, end): + tensor[i - start, : f0s[i].shape[-1]] = f0s[i][:] + + tensors.append(tensor) + + return tensors + + +def pad_mels_to_tensors(mels, batched=None): + """ + Args: + mels: A list of mel-specs + Returns: + tensors: A list of tensors containing the batched mel-specs + mel_frames: A list of tensors containing the frames of the original mel-specs + """ + # Initialize + tensors = [] + mel_frames = [] + + # Split mel-specs into batches to avoid cuda memory exceed + if batched == None: + # Get the max frame for padding + size = -1 + for mel in mels: + size = max(size, mel.shape[-1]) + + tensor = torch.zeros(len(mels), mels[0].shape[0], size) + mel_frame = torch.zeros(len(mels), dtype=torch.int32) + + for i, mel in enumerate(mels): + tensor[i, :, : mel.shape[-1]] = mel[:] + mel_frame[i] = mel.shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + else: + start = 0 + while start + batched - 1 < len(mels): + end = start + batched - 1 + + # Get the max frame for padding + size = -1 + for i in range(start, end + 1): + size = max(size, mels[i].shape[-1]) + + tensor = torch.zeros(batched, mels[0].shape[0], size) + mel_frame = torch.zeros(batched, dtype=torch.int32) + + for i in range(start, end + 1): + tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] + mel_frame[i - start] = mels[i].shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + + start = start + batched + + if start != len(mels): + end = len(mels) + + # Get the max frame for padding + size = -1 + for i in range(start, end): + size = max(size, mels[i].shape[-1]) + + tensor = torch.zeros(len(mels) - start, mels[0].shape[0], size) + mel_frame = torch.zeros(len(mels) - start, dtype=torch.int32) + + for i in range(start, end): + tensor[i - start, :, : mels[i].shape[-1]] = mels[i][:] + mel_frame[i - start] = mels[i].shape[-1] + + tensors.append(tensor) + mel_frames.append(mel_frame) + + return tensors, mel_frames + + +def load_model_config(args): + """Load model configurations (in args.json under checkpoint directory) + + Args: + args (ArgumentParser): arguments to run bins/preprocess.py + + Returns: + dict: dictionary that stores model configurations + """ + if args.checkpoint_dir is None: + assert args.checkpoint_file is not None + checkpoint_dir = os.path.split(args.checkpoint_file)[0] + else: + checkpoint_dir = args.checkpoint_dir + config_path = os.path.join(checkpoint_dir, "args.json") + print("config_path: ", config_path) + + config = load_config(config_path) + return config + + +def remove_and_create(dir): + if os.path.exists(dir): + os.system("rm -r {}".format(dir)) + os.makedirs(dir, exist_ok=True) + + +def has_existed(path, warning=False): + if not warning: + return os.path.exists(path) + + if os.path.exists(path): + answer = input( + "The path {} has existed. \nInput 'y' (or hit Enter) to skip it, and input 'n' to re-write it [y/n]\n".format( + path + ) + ) + if not answer == "n": + return True + + return False + + +def remove_older_ckpt(saved_model_name, checkpoint_dir, max_to_keep=5): + if os.path.exists(os.path.join(checkpoint_dir, "checkpoint")): + with open(os.path.join(checkpoint_dir, "checkpoint"), "r") as f: + ckpts = [x.strip() for x in f.readlines()] + else: + ckpts = [] + ckpts.append(saved_model_name) + for item in ckpts[:-max_to_keep]: + if os.path.exists(os.path.join(checkpoint_dir, item)): + os.remove(os.path.join(checkpoint_dir, item)) + with open(os.path.join(checkpoint_dir, "checkpoint"), "w") as f: + for item in ckpts[-max_to_keep:]: + f.write("{}\n".format(item)) + + +def set_all_random_seed(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def save_checkpoint( + args, + generator, + g_optimizer, + step, + discriminator=None, + d_optimizer=None, + max_to_keep=5, +): + saved_model_name = "model.ckpt-{}.pt".format(step) + checkpoint_path = os.path.join(args.checkpoint_dir, saved_model_name) + + if discriminator and d_optimizer: + torch.save( + { + "generator": generator.state_dict(), + "discriminator": discriminator.state_dict(), + "g_optimizer": g_optimizer.state_dict(), + "d_optimizer": d_optimizer.state_dict(), + "global_step": step, + }, + checkpoint_path, + ) + else: + torch.save( + { + "generator": generator.state_dict(), + "g_optimizer": g_optimizer.state_dict(), + "global_step": step, + }, + checkpoint_path, + ) + + print("Saved checkpoint: {}".format(checkpoint_path)) + + if os.path.exists(os.path.join(args.checkpoint_dir, "checkpoint")): + with open(os.path.join(args.checkpoint_dir, "checkpoint"), "r") as f: + ckpts = [x.strip() for x in f.readlines()] + else: + ckpts = [] + ckpts.append(saved_model_name) + for item in ckpts[:-max_to_keep]: + if os.path.exists(os.path.join(args.checkpoint_dir, item)): + os.remove(os.path.join(args.checkpoint_dir, item)) + with open(os.path.join(args.checkpoint_dir, "checkpoint"), "w") as f: + for item in ckpts[-max_to_keep:]: + f.write("{}\n".format(item)) + + +def attempt_to_restore( + generator, g_optimizer, checkpoint_dir, discriminator=None, d_optimizer=None +): + checkpoint_list = os.path.join(checkpoint_dir, "checkpoint") + if os.path.exists(checkpoint_list): + checkpoint_filename = open(checkpoint_list).readlines()[-1].strip() + checkpoint_path = os.path.join(checkpoint_dir, "{}".format(checkpoint_filename)) + print("Restore from {}".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path, map_location="cpu") + if generator: + if not list(generator.state_dict().keys())[0].startswith("module."): + raw_dict = checkpoint["generator"] + clean_dict = OrderedDict() + for k, v in raw_dict.items(): + if k.startswith("module."): + clean_dict[k[7:]] = v + else: + clean_dict[k] = v + generator.load_state_dict(clean_dict) + else: + generator.load_state_dict(checkpoint["generator"]) + if g_optimizer: + g_optimizer.load_state_dict(checkpoint["g_optimizer"]) + global_step = 100000 + if discriminator and "discriminator" in checkpoint.keys(): + discriminator.load_state_dict(checkpoint["discriminator"]) + global_step = checkpoint["global_step"] + print("restore discriminator") + if d_optimizer and "d_optimizer" in checkpoint.keys(): + d_optimizer.load_state_dict(checkpoint["d_optimizer"]) + print("restore d_optimizer...") + else: + global_step = 0 + return global_step + + +class ExponentialMovingAverage(object): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def register(self, name, val): + self.shadow[name] = val.clone() + + def update(self, name, x): + assert name in self.shadow + update_delta = self.shadow[name] - x + self.shadow[name] -= (1.0 - self.decay) * update_delta + + +def apply_moving_average(model, ema): + for name, param in model.named_parameters(): + if name in ema.shadow: + ema.update(name, param.data) + + +def register_model_to_ema(model, ema): + for name, param in model.named_parameters(): + if param.requires_grad: + ema.register(name, param.data) + + +class YParams(HParams): + def __init__(self, yaml_file): + if not os.path.exists(yaml_file): + raise IOError("yaml file: {} is not existed".format(yaml_file)) + super().__init__() + self.d = collections.OrderedDict() + with open(yaml_file) as fp: + for _, v in yaml().load(fp).items(): + for k1, v1 in v.items(): + try: + if self.get(k1): + self.set_hparam(k1, v1) + else: + self.add_hparam(k1, v1) + self.d[k1] = v1 + except Exception: + import traceback + + print(traceback.format_exc()) + + # @property + def get_elements(self): + return self.d.items() + + +def override_config(base_config, new_config): + """Update new configurations in the original dict with the new dict + + Args: + base_config (dict): original dict to be overridden + new_config (dict): dict with new configurations + + Returns: + dict: updated configuration dict + """ + for k, v in new_config.items(): + if type(v) == dict: + if k not in base_config.keys(): + base_config[k] = {} + base_config[k] = override_config(base_config[k], v) + else: + base_config[k] = v + return base_config + + +def get_lowercase_keys_config(cfg): + """Change all keys in cfg to lower case + + Args: + cfg (dict): dictionary that stores configurations + + Returns: + dict: dictionary that stores configurations + """ + updated_cfg = dict() + for k, v in cfg.items(): + if type(v) == dict: + v = get_lowercase_keys_config(v) + updated_cfg[k.lower()] = v + return updated_cfg + + +def _load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): whether changing keys to lower case. Defaults to False. + + Returns: + dict: dictionary that stores configurations + """ + with open(config_fn, "r") as f: + data = f.read() + config_ = json5.loads(data) + if "base_config" in config_: + # load configurations from new path + p_config_path = os.path.join(os.getenv("WORK_DIR"), config_["base_config"]) + p_config_ = _load_config(p_config_path) + config_ = override_config(p_config_, config_) + if lowercase: + # change keys in config_ to lower case + config_ = get_lowercase_keys_config(config_) + return config_ + + +def load_config(config_fn, lowercase=False): + """Load configurations into a dictionary + + Args: + config_fn (str): path to configuration file + lowercase (bool, optional): _description_. Defaults to False. + + Returns: + JsonHParams: an object that stores configurations + """ + config_ = _load_config(config_fn, lowercase=lowercase) + # create an JsonHParams object with configuration dict + cfg = JsonHParams(**config_) + return cfg + + +def save_config(save_path, cfg): + """Save configurations into a json file + + Args: + save_path (str): path to save configurations + cfg (dict): dictionary that stores configurations + """ + with open(save_path, "w") as f: + json5.dump( + cfg, f, ensure_ascii=False, indent=4, quote_keys=True, sort_keys=True + ) + + +class JsonHParams: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + if type(v) == dict: + v = JsonHParams(**v) + self[k] = v + + def keys(self): + return self.__dict__.keys() + + def items(self): + return self.__dict__.items() + + def values(self): + return self.__dict__.values() + + def __len__(self): + return len(self.__dict__) + + def __getitem__(self, key): + return getattr(self, key) + + def __setitem__(self, key, value): + return setattr(self, key, value) + + def __contains__(self, key): + return key in self.__dict__ + + def __repr__(self): + return self.__dict__.__repr__() + + +class ValueWindow: + def __init__(self, window_size=100): + self._window_size = window_size + self._values = [] + + def append(self, x): + self._values = self._values[-(self._window_size - 1) :] + [x] + + @property + def sum(self): + return sum(self._values) + + @property + def count(self): + return len(self._values) + + @property + def average(self): + return self.sum / max(1, self.count) + + def reset(self): + self._values = [] + + +class Logger(object): + def __init__( + self, + filename, + level="info", + when="D", + backCount=10, + fmt="%(asctime)s : %(message)s", + ): + self.level_relations = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "crit": logging.CRITICAL, + } + if level == "debug": + fmt = "%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s" + self.logger = logging.getLogger(filename) + format_str = logging.Formatter(fmt) + self.logger.setLevel(self.level_relations.get(level)) + sh = logging.StreamHandler() + sh.setFormatter(format_str) + th = handlers.TimedRotatingFileHandler( + filename=filename, when=when, backupCount=backCount, encoding="utf-8" + ) + th.setFormatter(format_str) + self.logger.addHandler(sh) + self.logger.addHandler(th) + self.logger.info( + "==========================New Starting Here==============================" + ) + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def slice_segments(x, ids_str, segment_size=4): + ret = torch.zeros_like(x[:, :, :segment_size]) + for i in range(x.size(0)): + idx_str = ids_str[i] + idx_end = idx_str + segment_size + ret[i] = x[i, :, idx_str:idx_end] + return ret + + +def rand_slice_segments(x, x_lengths=None, segment_size=4): + b, d, t = x.size() + if x_lengths is None: + x_lengths = t + ids_str_max = x_lengths - segment_size + 1 + ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long) + ret = slice_segments(x, ids_str, segment_size) + return ret, ids_str + + +def subsequent_mask(length): + mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0) + return mask + + +@torch.jit.script +def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels): + n_channels_int = n_channels[0] + in_act = input_a + input_b + t_act = torch.tanh(in_act[:, :n_channels_int, :]) + s_act = torch.sigmoid(in_act[:, n_channels_int:, :]) + acts = t_act * s_act + return acts + + +def convert_pad_shape(pad_shape): + l = pad_shape[::-1] + pad_shape = [item for sublist in l for item in sublist] + return pad_shape + + +def sequence_mask(length, max_length=None): + if max_length is None: + max_length = length.max() + x = torch.arange(max_length, dtype=length.dtype, device=length.device) + return x.unsqueeze(0) < length.unsqueeze(1) + + +def generate_path(duration, mask): + """ + duration: [b, 1, t_x] + mask: [b, 1, t_y, t_x] + """ + device = duration.device + + b, _, t_y, t_x = mask.shape + cum_duration = torch.cumsum(duration, -1) + + cum_duration_flat = cum_duration.view(b * t_x) + path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype) + path = path.view(b, t_x, t_y) + path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1] + path = path.unsqueeze(1).transpose(2, 3) * mask + return path + + +def clip_grad_value_(parameters, clip_value, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + if clip_value is not None: + clip_value = float(clip_value) + + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + if clip_value is not None: + p.grad.data.clamp_(min=-clip_value, max=clip_value) + total_norm = total_norm ** (1.0 / norm_type) + return total_norm + + +def get_current_time(): + pass + + +def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: + """ + Args: + lengths: + A 1-D tensor containing sentence lengths. + max_len: + The length of masks. + Returns: + Return a 2-D bool tensor, where masked positions + are filled with `True` and non-masked positions are + filled with `False`. + + >>> lengths = torch.tensor([1, 3, 2, 5]) + >>> make_pad_mask(lengths) + tensor([[False, True, True, True, True], + [False, False, False, True, True], + [False, False, True, True, True], + [False, False, False, False, False]]) + """ + assert lengths.ndim == 1, lengths.ndim + max_len = max(max_len, lengths.max()) + n = lengths.size(0) + seq_range = torch.arange(0, max_len, device=lengths.device) + expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) + + return expaned_lengths >= lengths.unsqueeze(-1) + diff --git a/utils/whisper.py b/utils/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..16462c7e84f2ce71d4fd5e57832d705fd07b95ca --- /dev/null +++ b/utils/whisper.py @@ -0,0 +1,165 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os +import pickle +from tqdm import tqdm +import numpy as np + +from modules import whisper_extractor as whisper + + +def whisper_encoder_batch(model, audio_paths): + batch = len(audio_paths) + batch_mel = torch.zeros((batch, 80, 3000), dtype=torch.float32, device=model.device) + + for i, audio_path in enumerate(audio_paths): + # (48000,) + audio = whisper.load_audio(str(audio_path)) + audio = whisper.pad_or_trim(audio) + + # (80, 3000) + mel = whisper.log_mel_spectrogram(audio).to(model.device) + batch_mel[i] = mel + + with torch.no_grad(): + # (batch, 1500, 1024) + features = model.embed_audio(batch_mel) + + return features.cpu().detach().numpy() + + +def whisper_encoder(model, audio_path): + audio = whisper.load_audio(str(audio_path)) + audio = whisper.pad_or_trim(audio) + + # (80, 3000) + mel = whisper.log_mel_spectrogram(audio).to(model.device).unsqueeze(0) + + with torch.no_grad(): + # (1, 1500, 1024) -> # (1500, 1024) + features = model.embed_audio(mel).squeeze(0) + + return features.cpu().detach().numpy() + + +def get_mapped_whisper_features( + raw_whisper_features, mapping_features, fast_mapping=True +): + """ + Whisper: frameshift = 20ms (30s audio -> 1500 frames), hop_size = 480 in 24k + # Ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/model.py#L136 + + Now it's only used for mapping to bigvgan's mels (sr = 24k, hop_size = 256, frameshift ~= 10.7 ms) + """ + source_hop = 480 + target_hop = 256 + + factor = np.gcd(source_hop, target_hop) + source_hop //= factor + target_hop //= factor + print( + "Mapping source's {} frames => target's {} frames".format( + target_hop, source_hop + ) + ) + + max_source_len = 1500 + whisper_features = [] + for index, mapping_feat in enumerate(tqdm(mapping_features)): + # mapping_feat: (mels_frame_len, n_mels) + target_len = mapping_feat.shape[0] + # The max target_len is 2812 + target_len = min(target_len, max_source_len * source_hop // target_hop) + + # (1500, dim) + raw_feats = raw_whisper_features[index] + width = raw_feats.shape[-1] + + if fast_mapping: + source_len = target_len * target_hop // source_hop + 1 + raw_feats = raw_feats[:source_len] + else: + source_len = max_source_len + + # const ~= target_len * target_hop + const = source_len * source_hop // target_hop * target_hop + + # (source_len * source_hop, dim) + up_sampling_feats = np.repeat(raw_feats, source_hop, axis=0) + # (const, dim) -> (const/target_hop, target_hop, dim) -> (const/target_hop, dim) + down_sampling_feats = np.average( + up_sampling_feats[:const].reshape(-1, target_hop, width), axis=1 + ) + assert len(down_sampling_feats) >= target_len + + # (target_len, dim) + feats = down_sampling_feats[:target_len] + whisper_features.append(feats) + + return whisper_features + + +def load_whisper_model(hps): + print("Loading Whisper Model: ", hps.whisper_model) + model = whisper.load_model(hps.whisper_model) + if torch.cuda.is_available(): + model = model.cuda() + + model = model.eval() + return model + + +def load_target_acoustic_features( + output_path, dataset, acoustic_features_name, acoustic_features_fs, dataset_type +): + mapping_dir = os.path.join( + output_path, + dataset, + "{}/{}".format(acoustic_features_name, acoustic_features_fs), + ) + with open(os.path.join(mapping_dir, "{}.pkl".format(dataset_type)), "rb") as f: + mapping_features = pickle.load(f) + + # Mels: (n_mels, frame_len) -> (frame_len, n_mels) + if acoustic_features_name == "mels": + print("Transposing mel features...") + mapping_features = [feat.T for feat in mapping_features] + + print( + "Mapping to the acoustic features {}, #sz = {}, feats[0] is {}".format( + acoustic_features_name, len(mapping_features), mapping_features[0].shape + ) + ) + return mapping_features + + +def extract_whisper_features_of_dataset( + datasets, + model, + batch_size, + out_dir, +): + audio_paths = [utt["Path"] for utt in datasets] + if len(audio_paths) < batch_size: + batch_size = len(audio_paths) + + start, end = 0, 0 + while end < len(audio_paths): + # Raw features: (batch_size, 1500, dim) + start = end + end = start + batch_size + tmp_raw_whisper_features = whisper_encoder_batch(model, audio_paths[start:end]) + + # Mapping to acoustic features' lengths + for index, utt in enumerate(tqdm(datasets[start:end])): + uid = utt["Uid"] + raw_whisper_feature = tmp_raw_whisper_features[index] + + save_path = os.path.join(out_dir, uid + ".npy") + np.save(save_path, raw_whisper_feature) + + print("{}/{} Done...".format(end, len(audio_paths))) diff --git a/utils/world.py b/utils/world.py new file mode 100644 index 0000000000000000000000000000000000000000..ce5f61bd9b571607fd83da6b22283757e67201da --- /dev/null +++ b/utils/world.py @@ -0,0 +1,92 @@ +# Copyright (c) 2023 Amphion. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# 1. Extract WORLD features including F0, AP, SP +# 2. Transform between SP and MCEP +import torchaudio +import pyworld as pw +import numpy as np +import torch +import diffsptk +import os +from tqdm import tqdm +import pickle +import torchaudio + + +def get_mcep_params(fs): + """Hyperparameters of transformation between SP and MCEP + + Reference: + https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world_v2/copy_synthesis.sh + + """ + if fs in [44100, 48000]: + fft_size = 2048 + alpha = 0.77 + if fs in [16000]: + fft_size = 1024 + alpha = 0.58 + return fft_size, alpha + + +def extract_world_features(waveform, frameshift=10): + # waveform: (1, seq) + # x: (seq,) + x = np.array(waveform, dtype=np.double) + + _f0, t = pw.dio(x, fs, frame_period=frameshift) # raw pitch extractor + f0 = pw.stonemask(x, _f0, t, fs) # pitch refinement + sp = pw.cheaptrick(x, f0, t, fs) # extract smoothed spectrogram + ap = pw.d4c(x, f0, t, fs) # extract aperiodicity + + return f0, sp, ap, fs + + +def sp2mcep(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.ScalarOperation("SquareRoot")(x) + tmp = diffsptk.ScalarOperation("Multiplication", 32768.0)(tmp) + mgc = diffsptk.MelCepstralAnalysis( + cep_order=mcsize - 1, fft_length=fft_size, alpha=alpha, n_iter=1 + )(tmp) + return mgc.numpy() + + +def mcep2sp(x, mcsize, fs): + fft_size, alpha = get_mcep_params(fs) + x = torch.as_tensor(x, dtype=torch.float) + + tmp = diffsptk.MelGeneralizedCepstrumToSpectrum( + alpha=alpha, + cep_order=mcsize - 1, + fft_length=fft_size, + )(x) + tmp = diffsptk.ScalarOperation("Division", 32768.0)(tmp) + sp = diffsptk.ScalarOperation("Power", 2)(tmp) + return sp.double().numpy() + + +def f0_statistics(f0_features, path): + print("\nF0 statistics...") + + total_f0 = [] + for f0 in tqdm(f0_features): + total_f0 += [f for f in f0 if f != 0] + + mean = sum(total_f0) / len(total_f0) + print("Min = {}, Max = {}, Mean = {}".format(min(total_f0), max(total_f0), mean)) + + with open(path, "wb") as f: + pickle.dump([mean, total_f0], f) + + +def world_synthesis(f0, sp, ap, fs, frameshift): + y = pw.synthesize( + f0, sp, ap, fs, frame_period=frameshift + ) # synthesize an utterance using the parameters + return y