diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..c7d9f3332a950355d5a77d85000f05e6f45435ea --- /dev/null +++ b/.gitattributes @@ -0,0 +1,34 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8839d8dc8e2b8fa468ebe15f8ae5093d69bcaade --- /dev/null +++ b/README.md @@ -0,0 +1,14 @@ +--- +title: 111 +emoji: 🎶 +colorFrom: yellow +colorTo: green +sdk: gradio +sdk_version: 3.8.1 +app_file: inference/m4singer/gradio/infer.py +pinned: false +duplicated_from: zlc99/M4Singer +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference + diff --git a/checkpoints/m4singer_diff_e2e/config.yaml b/checkpoints/m4singer_diff_e2e/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..846ab237ada9be0be6721b1d7bae78f8e9b66635 --- /dev/null +++ b/checkpoints/m4singer_diff_e2e/config.yaml @@ -0,0 +1,348 @@ +K_step: 1000 +accumulate_grad_batches: 1 +audio_num_mel_bins: 80 +audio_sample_rate: 24000 +base_config: +- usr/configs/m4singer/base.yaml +binarization_args: + shuffle: false + with_align: true + with_f0: true + with_f0cwt: true + with_spk_embed: true + with_txt: true + with_wav: false +binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer +binary_data_dir: data/binary/m4singer +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +content_cond_steps: [] +cwt_add_f0_loss: false +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_std_scale: 0.8 +datasets: +- m4singer +debug: false +dec_ffn_kernel_size: 9 +dec_layers: 4 +decay_steps: 100000 +decoder_type: fft +dict_dir: '' +diff_decoder_type: wavenet +diff_loss_type: l1 +dilation_cycle_length: 4 +dropout: 0.1 +ds_workers: 4 +dur_enc_hidden_stride_kernel: +- 0,2,3 +- 0,2,3 +- 0,1,3 +dur_loss: mse +dur_predictor_kernel: 3 +dur_predictor_layers: 5 +enc_ffn_kernel_size: 9 +enc_layers: 4 +encoder_K: 8 +encoder_type: fft +endless_ds: true +ffn_act: gelu +ffn_padding: SAME +fft_size: 512 +fmax: 12000 +fmin: 30 +fs2_ckpt: checkpoints/m4singer_fs2_e2e +gaussian_start: true +gen_dir_name: '' +gen_tgt_spk_id: -1 +hidden_size: 256 +hop_size: 128 +infer: false +keep_bins: 80 +lambda_commit: 0.25 +lambda_energy: 0.0 +lambda_f0: 0.0 +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_uv: 0.0 +lambda_word_dur: 1.0 +load_ckpt: '' +log_interval: 100 +loud_norm: false +lr: 0.001 +max_beta: 0.02 +max_epochs: 1000 +max_eval_sentences: 1 +max_eval_tokens: 60000 +max_frames: 5000 +max_input_tokens: 1550 +max_sentences: 28 +max_tokens: 36000 +max_updates: 900000 +mel_loss: ssim:0.5|l1:0.5 +mel_vmax: 1.5 +mel_vmin: -6.0 +min_level_db: -120 +norm_type: gn +num_ckpt_keep: 3 +num_heads: 2 +num_sanity_val_steps: 1 +num_spk: 20 +num_test_samples: 0 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pe_ckpt: checkpoints/m4singer_pe +pe_enable: true +pitch_ar: false +pitch_enc_hidden_stride_kernel: +- 0,2,5 +- 0,2,5 +- 0,2,5 +pitch_extractor: parselmouth +pitch_loss: l1 +pitch_norm: log +pitch_type: frame +pndm_speedup: 10 +pre_align_args: + allow_no_txt: false + denoise: false + forced_align: mfa + txt_processor: zh_g2pM + use_sox: true + use_tone: false +pre_align_cls: data_gen.singing.pre_align.SingingPreAlign +predictor_dropout: 0.5 +predictor_grad: 0.1 +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 5 +prenet_dropout: 0.5 +prenet_hidden_size: 256 +pretrain_fs_ckpt: '' +processed_data_dir: xxx +profile_infer: false +raw_data_dir: data/raw/m4singer +ref_norm_layer: bn +rel_pos: true +reset_phone_dict: true +residual_channels: 256 +residual_layers: 20 +save_best: false +save_ckpt: true +save_codes: +- configs +- modules +- tasks +- utils +- usr +save_f0: true +save_gt: true +schedule_type: linear +seed: 1234 +sort_by_len: true +spec_max: +- -0.3894500136375427 +- -0.3796464204788208 +- -0.2914905250072479 +- -0.15550297498703003 +- -0.08502643555402756 +- 0.10698417574167252 +- -0.0739326998591423 +- -0.0541548952460289 +- 0.15501998364925385 +- 0.06483431905508041 +- 0.03054228238761425 +- -0.013737732544541359 +- -0.004876468330621719 +- 0.04368264228105545 +- 0.13329921662807465 +- 0.16471388936042786 +- 0.04605761915445328 +- -0.05680707097053528 +- 0.0542571023106575 +- -0.0076539707370102406 +- -0.00953489076346159 +- -0.04434828832745552 +- 0.001293870504014194 +- -0.12238839268684387 +- 0.06418416649103165 +- 0.02843189612030983 +- 0.08505241572856903 +- 0.07062800228595734 +- 0.00120724702719599 +- -0.07675088942050934 +- 0.03785804659128189 +- 0.04890783503651619 +- -0.06888376921415329 +- -0.0839693546295166 +- -0.17545585334300995 +- -0.2911079525947571 +- -0.4238220453262329 +- -0.262084037065506 +- -0.3002263605594635 +- -0.3845032751560211 +- -0.3906497061252594 +- -0.6550108790397644 +- -0.7810799479484558 +- -0.7503029704093933 +- -0.7995198965072632 +- -0.8092347383499146 +- -0.6196113228797913 +- -0.6684317588806152 +- -0.7735874056816101 +- -0.8324533104896545 +- -0.9601566791534424 +- -0.955253541469574 +- -0.748817503452301 +- -0.9106167554855347 +- -0.9707801342010498 +- -1.053107500076294 +- -1.0448424816131592 +- -1.1082794666290283 +- -1.1296544075012207 +- -1.071642279624939 +- -1.1003081798553467 +- -1.166810154914856 +- -1.1408926248550415 +- -1.1330615282058716 +- -1.1167492866516113 +- -1.0716774463653564 +- -1.035891056060791 +- -1.0092483758926392 +- -0.9675999879837036 +- -0.938962996006012 +- -1.0120564699172974 +- -0.9777995347976685 +- -1.029313564300537 +- -0.9459163546562195 +- -0.8519706130027771 +- -0.7751091122627258 +- -0.7933766841888428 +- -0.9019735455513 +- -0.9983296990394592 +- -1.505873441696167 +spec_min: +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +spk_cond_steps: [] +stop_token_weight: 5.0 +task_cls: usr.diffsinger_task.DiffSingerMIDITask +test_ids: [] +test_input_dir: '' +test_num: 0 +test_prefixes: +- "Alto-2#\u5C81\u6708\u795E\u5077" +- "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C" +- "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E" +- "Tenor-1#\u7AE5\u8BDD" +- "Tenor-2#\u6D88\u6101" +- "Tenor-2#\u4E00\u8364\u4E00\u7D20" +- "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4" +- "Soprano-1#\u95EE\u6625" +test_set_name: test +timesteps: 1000 +train_set_name: train +use_denoise: false +use_energy_embed: false +use_gt_dur: false +use_gt_f0: false +use_midi: true +use_nsf: true +use_pitch_embed: false +use_pos_embed: true +use_spk_embed: false +use_spk_id: true +use_split_spk_id: false +use_uv: true +use_var_enc: false +val_check_interval: 2000 +valid_num: 0 +valid_set_name: valid +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/m4singer_hifigan +warmup_updates: 2000 +wav2spec_eps: 1e-6 +weight_decay: 0 +win_size: 512 +work_dir: checkpoints/m4singer_diff_e2e diff --git a/checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt b/checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..8a224229c8756253c7a24937b29354e34dac0a26 --- /dev/null +++ b/checkpoints/m4singer_diff_e2e/model_ckpt_steps_900000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dbea4e8b9712d2cca54cc07915859472a17f2f3b97a86f33a6c9974192bb5b47 +size 392239086 diff --git a/checkpoints/m4singer_fs2_e2e/config.yaml b/checkpoints/m4singer_fs2_e2e/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d894bc33c7e4d8eebe4b71b0400c88f7342e55cc --- /dev/null +++ b/checkpoints/m4singer_fs2_e2e/config.yaml @@ -0,0 +1,347 @@ +K_step: 51 +accumulate_grad_batches: 1 +audio_num_mel_bins: 80 +audio_sample_rate: 24000 +base_config: +- configs/singing/fs2.yaml +- usr/configs/m4singer/base.yaml +binarization_args: + shuffle: false + with_align: true + with_f0: true + with_f0cwt: true + with_spk_embed: true + with_txt: true + with_wav: false +binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer +binary_data_dir: data/binary/m4singer +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +content_cond_steps: [] +cwt_add_f0_loss: false +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_std_scale: 0.8 +datasets: +- m4singer +debug: false +dec_ffn_kernel_size: 9 +dec_layers: 4 +decay_steps: 50000 +decoder_type: fft +dict_dir: '' +diff_decoder_type: wavenet +diff_loss_type: l1 +dilation_cycle_length: 1 +dropout: 0.1 +ds_workers: 4 +dur_enc_hidden_stride_kernel: +- 0,2,3 +- 0,2,3 +- 0,1,3 +dur_loss: mse +dur_predictor_kernel: 3 +dur_predictor_layers: 5 +enc_ffn_kernel_size: 9 +enc_layers: 4 +encoder_K: 8 +encoder_type: fft +endless_ds: true +ffn_act: gelu +ffn_padding: SAME +fft_size: 512 +fmax: 12000 +fmin: 30 +fs2_ckpt: '' +gen_dir_name: '' +gen_tgt_spk_id: -1 +hidden_size: 256 +hop_size: 128 +infer: false +keep_bins: 80 +lambda_commit: 0.25 +lambda_energy: 0.0 +lambda_f0: 1.0 +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_uv: 1.0 +lambda_word_dur: 1.0 +load_ckpt: '' +log_interval: 100 +loud_norm: false +lr: 1 +max_beta: 0.06 +max_epochs: 1000 +max_eval_sentences: 1 +max_eval_tokens: 60000 +max_frames: 5000 +max_input_tokens: 1550 +max_sentences: 12 +max_tokens: 40000 +max_updates: 320000 +mel_loss: ssim:0.5|l1:0.5 +mel_vmax: 1.5 +mel_vmin: -6.0 +min_level_db: -120 +norm_type: gn +num_ckpt_keep: 3 +num_heads: 2 +num_sanity_val_steps: 1 +num_spk: 20 +num_test_samples: 0 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pe_ckpt: checkpoints/m4singer_pe +pe_enable: true +pitch_ar: false +pitch_enc_hidden_stride_kernel: +- 0,2,5 +- 0,2,5 +- 0,2,5 +pitch_extractor: parselmouth +pitch_loss: l1 +pitch_norm: log +pitch_type: frame +pre_align_args: + allow_no_txt: false + denoise: false + forced_align: mfa + txt_processor: zh_g2pM + use_sox: true + use_tone: false +pre_align_cls: data_gen.singing.pre_align.SingingPreAlign +predictor_dropout: 0.5 +predictor_grad: 0.1 +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 5 +prenet_dropout: 0.5 +prenet_hidden_size: 256 +pretrain_fs_ckpt: '' +processed_data_dir: xxx +profile_infer: false +raw_data_dir: data/raw/m4singer +ref_norm_layer: bn +rel_pos: true +reset_phone_dict: true +residual_channels: 256 +residual_layers: 20 +save_best: false +save_ckpt: true +save_codes: +- configs +- modules +- tasks +- utils +- usr +save_f0: true +save_gt: true +schedule_type: linear +seed: 1234 +sort_by_len: true +spec_max: +- -0.3894500136375427 +- -0.3796464204788208 +- -0.2914905250072479 +- -0.15550297498703003 +- -0.08502643555402756 +- 0.10698417574167252 +- -0.0739326998591423 +- -0.0541548952460289 +- 0.15501998364925385 +- 0.06483431905508041 +- 0.03054228238761425 +- -0.013737732544541359 +- -0.004876468330621719 +- 0.04368264228105545 +- 0.13329921662807465 +- 0.16471388936042786 +- 0.04605761915445328 +- -0.05680707097053528 +- 0.0542571023106575 +- -0.0076539707370102406 +- -0.00953489076346159 +- -0.04434828832745552 +- 0.001293870504014194 +- -0.12238839268684387 +- 0.06418416649103165 +- 0.02843189612030983 +- 0.08505241572856903 +- 0.07062800228595734 +- 0.00120724702719599 +- -0.07675088942050934 +- 0.03785804659128189 +- 0.04890783503651619 +- -0.06888376921415329 +- -0.0839693546295166 +- -0.17545585334300995 +- -0.2911079525947571 +- -0.4238220453262329 +- -0.262084037065506 +- -0.3002263605594635 +- -0.3845032751560211 +- -0.3906497061252594 +- -0.6550108790397644 +- -0.7810799479484558 +- -0.7503029704093933 +- -0.7995198965072632 +- -0.8092347383499146 +- -0.6196113228797913 +- -0.6684317588806152 +- -0.7735874056816101 +- -0.8324533104896545 +- -0.9601566791534424 +- -0.955253541469574 +- -0.748817503452301 +- -0.9106167554855347 +- -0.9707801342010498 +- -1.053107500076294 +- -1.0448424816131592 +- -1.1082794666290283 +- -1.1296544075012207 +- -1.071642279624939 +- -1.1003081798553467 +- -1.166810154914856 +- -1.1408926248550415 +- -1.1330615282058716 +- -1.1167492866516113 +- -1.0716774463653564 +- -1.035891056060791 +- -1.0092483758926392 +- -0.9675999879837036 +- -0.938962996006012 +- -1.0120564699172974 +- -0.9777995347976685 +- -1.029313564300537 +- -0.9459163546562195 +- -0.8519706130027771 +- -0.7751091122627258 +- -0.7933766841888428 +- -0.9019735455513 +- -0.9983296990394592 +- -1.505873441696167 +spec_min: +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +- -6.0 +spk_cond_steps: [] +stop_token_weight: 5.0 +task_cls: usr.diffsinger_task.AuxDecoderMIDITask +test_ids: [] +test_input_dir: '' +test_num: 0 +test_prefixes: +- "Alto-2#\u5C81\u6708\u795E\u5077" +- "Alto-2#\u5947\u5999\u80FD\u529B\u6B4C" +- "Tenor-1#\u4E00\u5343\u5E74\u4EE5\u540E" +- "Tenor-1#\u7AE5\u8BDD" +- "Tenor-2#\u6D88\u6101" +- "Tenor-2#\u4E00\u8364\u4E00\u7D20" +- "Soprano-1#\u5FF5\u5974\u5A07\u8D64\u58C1\u6000\u53E4" +- "Soprano-1#\u95EE\u6625" +test_set_name: test +timesteps: 100 +train_set_name: train +use_denoise: false +use_energy_embed: false +use_gt_dur: false +use_gt_f0: false +use_midi: true +use_nsf: true +use_pitch_embed: false +use_pos_embed: true +use_spk_embed: false +use_spk_id: true +use_split_spk_id: false +use_uv: true +use_var_enc: false +val_check_interval: 2000 +valid_num: 0 +valid_set_name: valid +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/m4singer_hifigan +warmup_updates: 2000 +wav2spec_eps: 1e-6 +weight_decay: 0 +win_size: 512 +work_dir: checkpoints/m4singer_fs2_e2e diff --git a/checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt b/checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..b10e16d301639ede0d3bee931a73a0e2bcb06a5f --- /dev/null +++ b/checkpoints/m4singer_fs2_e2e/model_ckpt_steps_320000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:993d7063a1773bd29d2810591f98152218a4cf8440e2b10c4761516a28f9d566 +size 290456153 diff --git a/checkpoints/m4singer_hifigan/config.yaml b/checkpoints/m4singer_hifigan/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2186ec6c74f0fab4de1424868f3548aeabe43212 --- /dev/null +++ b/checkpoints/m4singer_hifigan/config.yaml @@ -0,0 +1,246 @@ +max_eval_tokens: 60000 +max_eval_sentences: 1 +save_ckpt: true +log_interval: 100 +accumulate_grad_batches: 1 +adam_b1: 0.8 +adam_b2: 0.99 +amp: false +audio_num_mel_bins: 80 +audio_sample_rate: 24000 +aux_context_window: 0 +#base_config: +#- egs/egs_bases/singing/pwg.yaml +#- egs/egs_bases/tts/vocoder/hifigan.yaml +binarization_args: + reset_phone_dict: true + reset_word_dict: true + shuffle: false + trim_eos_bos: false + trim_sil: false + with_align: false + with_f0: true + with_f0cwt: false + with_linear: false + with_spk_embed: false + with_spk_id: true + with_txt: false + with_wav: true + with_word: false +binarizer_cls: data_gen.tts.singing.binarize.SingingBinarizer +binary_data_dir: data/binary/m4singer_vocoder +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +clip_grad_value: 0 +datasets: [] +debug: false +dec_ffn_kernel_size: 9 +dec_layers: 4 +dict_dir: '' +disc_start_steps: 40000 +discriminator_grad_norm: 1 +discriminator_optimizer_params: + eps: 1.0e-06 + lr: 0.0002 + weight_decay: 0.0 +discriminator_params: + bias: true + conv_channels: 64 + in_channels: 1 + kernel_size: 3 + layers: 10 + nonlinear_activation: LeakyReLU + nonlinear_activation_params: + negative_slope: 0.2 + out_channels: 1 + use_weight_norm: true +discriminator_scheduler_params: + gamma: 0.999 + step_size: 600 +dropout: 0.1 +ds_workers: 1 +enc_ffn_kernel_size: 9 +enc_layers: 4 +endless_ds: true +ffn_act: gelu +ffn_padding: SAME +fft_size: 512 +fmax: 12000 +fmin: 30 +frames_multiple: 1 +gen_dir_name: '' +generator_grad_norm: 10 +generator_optimizer_params: + eps: 1.0e-06 + lr: 0.0002 + weight_decay: 0.0 +generator_params: + aux_context_window: 0 + aux_channels: 80 + dropout: 0.0 + gate_channels: 128 + in_channels: 1 + kernel_size: 3 + layers: 30 + out_channels: 1 + residual_channels: 64 + skip_channels: 64 + stacks: 3 + upsample_net: ConvInUpsampleNetwork + upsample_params: + upsample_scales: + - 2 + - 4 + - 4 + - 4 + use_nsf: false + use_pitch_embed: true + use_weight_norm: true +generator_scheduler_params: + gamma: 0.999 + step_size: 600 +griffin_lim_iters: 60 +hidden_size: 256 +hop_size: 128 +infer: false +lambda_adv: 1.0 +lambda_cdisc: 4.0 +lambda_energy: 0.0 +lambda_f0: 0.0 +lambda_mel: 5.0 +lambda_mel_adv: 1.0 +lambda_ph_dur: 0.0 +lambda_sent_dur: 0.0 +lambda_uv: 0.0 +lambda_word_dur: 0.0 +load_ckpt: 'checkpoints/m4singer_hifigan' +loud_norm: false +lr: 2.0 +max_epochs: 1000 +max_frames: 2400 +max_input_tokens: 1550 +max_samples: 8192 +max_sentences: 20 +max_tokens: 24000 +max_updates: 3000000 +max_valid_sentences: 1 +max_valid_tokens: 60000 +mel_loss: ssim:0.5|l1:0.5 +mel_vmax: 1.5 +mel_vmin: -6 +min_frames: 0 +min_level_db: -120 +num_ckpt_keep: 3 +num_heads: 2 +num_mels: 80 +num_sanity_val_steps: 5 +num_spk: 100 +num_test_samples: 0 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pitch_extractor: parselmouth +pitch_type: frame +pre_align_args: + allow_no_txt: false + denoise: false + sox_resample: true + sox_to_wav: false + trim_sil: false + txt_processor: zh + use_tone: false +pre_align_cls: data_gen.tts.singing.pre_align.SingingPreAlign +predictor_grad: 0.0 +print_nan_grads: false +processed_data_dir: '' +profile_infer: false +raw_data_dir: '' +ref_level_db: 20 +rename_tmux: true +rerun_gen: true +resblock: '1' +resblock_dilation_sizes: +- - 1 + - 3 + - 5 +- - 1 + - 3 + - 5 +- - 1 + - 3 + - 5 +resblock_kernel_sizes: +- 3 +- 7 +- 11 +resume_from_checkpoint: 0 +save_best: true +save_codes: [] +save_f0: true +save_gt: true +scheduler: rsqrt +seed: 1234 +sort_by_len: true +stft_loss_params: + fft_sizes: + - 1024 + - 2048 + - 512 + hop_sizes: + - 120 + - 240 + - 50 + win_lengths: + - 600 + - 1200 + - 240 + window: hann_window +task_cls: tasks.vocoder.hifigan.HifiGanTask +tb_log_interval: 100 +test_ids: [] +test_input_dir: '' +test_num: 50 +test_prefixes: [] +test_set_name: test +train_set_name: train +train_sets: '' +upsample_initial_channel: 512 +upsample_kernel_sizes: +- 16 +- 16 +- 4 +- 4 +upsample_rates: +- 8 +- 4 +- 2 +- 2 +use_cdisc: false +use_cond_disc: false +use_fm_loss: false +use_gt_dur: true +use_gt_f0: true +use_mel_loss: true +use_ms_stft: false +use_pitch_embed: true +use_ref_enc: true +use_spec_disc: false +use_spk_embed: false +use_spk_id: false +use_split_spk_id: false +val_check_interval: 2000 +valid_infer_interval: 10000 +valid_monitor_key: val_loss +valid_monitor_mode: min +valid_set_name: valid +vocoder: pwg +vocoder_ckpt: '' +vocoder_denoise_c: 0.0 +warmup_updates: 8000 +weight_decay: 0 +win_length: null +win_size: 512 +window: hann +word_size: 3000 +work_dir: checkpoints/m4singer_hifigan diff --git a/checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt b/checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..3b82718692236fd16d91241951ff36f47e3e6067 --- /dev/null +++ b/checkpoints/m4singer_hifigan/model_ckpt_steps_1970000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c3e859bd2b1e125fe661aedfd6fa3e97e10e06f3ec3d03b7735a041984402f89 +size 1016324099 diff --git a/checkpoints/m4singer_pe/config.yaml b/checkpoints/m4singer_pe/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..573bbc903f185f9cb208844c445d8fb5ebd73451 --- /dev/null +++ b/checkpoints/m4singer_pe/config.yaml @@ -0,0 +1,172 @@ +accumulate_grad_batches: 1 +audio_num_mel_bins: 80 +audio_sample_rate: 24000 +base_config: +- configs/tts/lj/fs2.yaml +binarization_args: + shuffle: false + with_align: true + with_f0: true + with_f0cwt: true + with_spk_embed: true + with_txt: true + with_wav: false +binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer +binary_data_dir: data/binary/m4singer +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +cwt_add_f0_loss: false +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_std_scale: 0.8 +debug: false +dec_ffn_kernel_size: 9 +dec_layers: 4 +decoder_type: fft +dict_dir: '' +dropout: 0.1 +ds_workers: 4 +dur_enc_hidden_stride_kernel: +- 0,2,3 +- 0,2,3 +- 0,1,3 +dur_loss: mse +dur_predictor_kernel: 3 +dur_predictor_layers: 2 +enc_ffn_kernel_size: 9 +enc_layers: 4 +encoder_K: 8 +encoder_type: fft +endless_ds: true +ffn_act: gelu +ffn_padding: SAME +fft_size: 512 +fmax: 12000 +fmin: 30 +gen_dir_name: '' +hidden_size: 256 +hop_size: 128 +infer: false +lambda_commit: 0.25 +lambda_energy: 0.1 +lambda_f0: 1.0 +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_uv: 1.0 +lambda_word_dur: 1.0 +load_ckpt: '' +log_interval: 100 +loud_norm: false +lr: 0.1 +max_epochs: 1000 +max_eval_sentences: 1 +max_eval_tokens: 60000 +max_frames: 5000 +max_input_tokens: 1550 +max_sentences: 100000 +max_tokens: 20000 +max_updates: 280000 +mel_loss: l1 +mel_vmax: 1.5 +mel_vmin: -6 +min_level_db: -120 +norm_type: gn +num_ckpt_keep: 3 +num_heads: 2 +num_sanity_val_steps: 5 +num_spk: 1 +num_test_samples: 20 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pitch_ar: false +pitch_enc_hidden_stride_kernel: +- 0,2,5 +- 0,2,5 +- 0,2,5 +pitch_extractor_conv_layers: 2 +pitch_loss: l1 +pitch_norm: log +pitch_type: frame +pre_align_args: + allow_no_txt: false + denoise: false + forced_align: mfa + txt_processor: en + use_sox: false + use_tone: true +pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign +predictor_dropout: 0.5 +predictor_grad: 0.1 +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 2 +prenet_dropout: 0.5 +prenet_hidden_size: 256 +pretrain_fs_ckpt: '' +processed_data_dir: data/processed/ljspeech +profile_infer: false +raw_data_dir: data/raw/LJSpeech-1.1 +ref_norm_layer: bn +reset_phone_dict: true +save_best: false +save_ckpt: true +save_codes: +- configs +- modules +- tasks +- utils +- usr +save_f0: false +save_gt: false +seed: 1234 +sort_by_len: true +stop_token_weight: 5.0 +task_cls: tasks.tts.pe.PitchExtractionTask +test_ids: +- 68 +- 70 +- 74 +- 87 +- 110 +- 172 +- 190 +- 215 +- 231 +- 294 +- 316 +- 324 +- 402 +- 422 +- 485 +- 500 +- 505 +- 508 +- 509 +- 519 +test_input_dir: '' +test_num: 523 +test_set_name: test +train_set_name: train +use_denoise: false +use_energy_embed: false +use_gt_dur: false +use_gt_f0: false +use_pitch_embed: true +use_pos_embed: true +use_spk_embed: false +use_spk_id: false +use_split_spk_id: false +use_uv: true +use_var_enc: false +val_check_interval: 2000 +valid_num: 348 +valid_set_name: valid +vocoder: pwg +vocoder_ckpt: '' +warmup_updates: 2000 +weight_decay: 0 +win_size: 512 +work_dir: checkpoints/m4singer_pe diff --git a/checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt b/checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..cda4c78b2d640f0ad536f7571abdee2caa232fb3 --- /dev/null +++ b/checkpoints/m4singer_pe/model_ckpt_steps_280000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:10cbf382bf82ecf335fbf68ba226f93c9c715b0476f6604351cbad9783f529fe +size 39146292 diff --git a/configs/config_base.yaml b/configs/config_base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8519ba64caf62da365571b0b3edcace16e54c46c --- /dev/null +++ b/configs/config_base.yaml @@ -0,0 +1,42 @@ +# task +binary_data_dir: '' +work_dir: '' # experiment directory. +infer: false # infer +seed: 1234 +debug: false +save_codes: + - configs + - modules + - tasks + - utils + - usr + +############# +# dataset +############# +ds_workers: 1 +test_num: 100 +valid_num: 100 +endless_ds: false +sort_by_len: true + +######### +# train and eval +######### +load_ckpt: '' +save_ckpt: true +save_best: false +num_ckpt_keep: 3 +clip_grad_norm: 0 +accumulate_grad_batches: 1 +log_interval: 100 +num_sanity_val_steps: 5 # steps of validation at the beginning +check_val_every_n_epoch: 10 +val_check_interval: 2000 +max_epochs: 1000 +max_updates: 160000 +max_tokens: 31250 +max_sentences: 100000 +max_eval_tokens: -1 +max_eval_sentences: -1 +test_input_dir: '' diff --git a/configs/singing/base.yaml b/configs/singing/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d211a1a544d86b1ccc4698a68e0c0dd94a4eec0c --- /dev/null +++ b/configs/singing/base.yaml @@ -0,0 +1,42 @@ +base_config: + - configs/tts/base.yaml + - configs/tts/base_zh.yaml + + +datasets: [] +test_prefixes: [] +test_num: 0 +valid_num: 0 + +pre_align_cls: data_gen.singing.pre_align.SingingPreAlign +binarizer_cls: data_gen.singing.binarize.SingingBinarizer +pre_align_args: + use_tone: false # for ZH + forced_align: mfa + use_sox: true +hop_size: 128 # Hop size. +fft_size: 512 # FFT size. +win_size: 512 # FFT size. +max_frames: 8000 +fmin: 50 # Minimum freq in mel basis calculation. +fmax: 11025 # Maximum frequency in mel basis calculation. +pitch_type: frame + +hidden_size: 256 +mel_loss: "ssim:0.5|l1:0.5" +lambda_f0: 0.0 +lambda_uv: 0.0 +lambda_energy: 0.0 +lambda_ph_dur: 0.0 +lambda_sent_dur: 0.0 +lambda_word_dur: 0.0 +predictor_grad: 0.0 +use_spk_embed: true +use_spk_id: false + +max_tokens: 20000 +max_updates: 400000 +num_spk: 100 +save_f0: true +use_gt_dur: true +use_gt_f0: true diff --git a/configs/singing/fs2.yaml b/configs/singing/fs2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..83fe12394dd07f9baa948c96ea1dcfcfbcaa4bc3 --- /dev/null +++ b/configs/singing/fs2.yaml @@ -0,0 +1,3 @@ +base_config: + - configs/tts/fs2.yaml + - configs/singing/base.yaml diff --git a/configs/tts/base.yaml b/configs/tts/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f6d249ff47479a9693e6dd97ae0cb00aa436d98f --- /dev/null +++ b/configs/tts/base.yaml @@ -0,0 +1,95 @@ +# task +base_config: configs/config_base.yaml +task_cls: '' +############# +# dataset +############# +raw_data_dir: '' +processed_data_dir: '' +binary_data_dir: '' +dict_dir: '' +pre_align_cls: '' +binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer +pre_align_args: + use_tone: true # for ZH + forced_align: mfa + use_sox: false + txt_processor: en + allow_no_txt: false + denoise: false +binarization_args: + shuffle: false + with_txt: true + with_wav: false + with_align: true + with_spk_embed: true + with_f0: true + with_f0cwt: true + +loud_norm: false +endless_ds: true +reset_phone_dict: true + +test_num: 100 +valid_num: 100 +max_frames: 1550 +max_input_tokens: 1550 +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) +win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) +fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) +fmax: 7600 # To be increased/reduced depending on data. +fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter +min_level_db: -100 +num_spk: 1 +mel_vmin: -6 +mel_vmax: 1.5 +ds_workers: 4 + +######### +# model +######### +dropout: 0.1 +enc_layers: 4 +dec_layers: 4 +hidden_size: 384 +num_heads: 2 +prenet_dropout: 0.5 +prenet_hidden_size: 256 +stop_token_weight: 5.0 +enc_ffn_kernel_size: 9 +dec_ffn_kernel_size: 9 +ffn_act: gelu +ffn_padding: 'SAME' + + +########### +# optimization +########### +lr: 2.0 +warmup_updates: 8000 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +weight_decay: 0 +clip_grad_norm: 1 + + +########### +# train and eval +########### +max_tokens: 30000 +max_sentences: 100000 +max_eval_sentences: 1 +max_eval_tokens: 60000 +train_set_name: 'train' +valid_set_name: 'valid' +test_set_name: 'test' +vocoder: pwg +vocoder_ckpt: '' +profile_infer: false +out_wav_norm: false +save_gt: false +save_f0: false +gen_dir_name: '' +use_denoise: false diff --git a/configs/tts/base_zh.yaml b/configs/tts/base_zh.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a43a9ad399233b5520c9dc7476ff611d3fa6e9bb --- /dev/null +++ b/configs/tts/base_zh.yaml @@ -0,0 +1,3 @@ +pre_align_args: + txt_processor: zh_g2pM +binarizer_cls: data_gen.tts.binarizer_zh.ZhBinarizer \ No newline at end of file diff --git a/configs/tts/fs2.yaml b/configs/tts/fs2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..dd6aa37175542f503d85b2e9f9cbe86238b32700 --- /dev/null +++ b/configs/tts/fs2.yaml @@ -0,0 +1,80 @@ +base_config: configs/tts/base.yaml +task_cls: tasks.tts.fs2.FastSpeech2Task + +# model +hidden_size: 256 +dropout: 0.1 +encoder_type: fft # fft|tacotron|tacotron2|conformer +encoder_K: 8 # for tacotron encoder +decoder_type: fft # fft|rnn|conv|conformer +use_pos_embed: true + +# duration +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 2 +dur_predictor_kernel: 3 +dur_predictor_layers: 2 +predictor_dropout: 0.5 + +# pitch and energy +use_pitch_embed: true +pitch_type: ph # frame|ph|cwt +use_uv: true +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_add_f0_loss: false +cwt_std_scale: 0.8 + +pitch_ar: false +#pitch_embed_type: 0q +pitch_loss: 'l1' # l1|l2|ssim +pitch_norm: log +use_energy_embed: false + +# reference encoder and speaker embedding +use_spk_id: false +use_split_spk_id: false +use_spk_embed: false +use_var_enc: false +lambda_commit: 0.25 +ref_norm_layer: bn +pitch_enc_hidden_stride_kernel: + - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size + - 0,2,5 + - 0,2,5 +dur_enc_hidden_stride_kernel: + - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size + - 0,2,3 + - 0,1,3 + + +# mel +mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5 + +# loss lambda +lambda_f0: 1.0 +lambda_uv: 1.0 +lambda_energy: 0.1 +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_word_dur: 1.0 +predictor_grad: 0.1 + +# train and eval +pretrain_fs_ckpt: '' +warmup_updates: 2000 +max_tokens: 32000 +max_sentences: 100000 +max_eval_sentences: 1 +max_updates: 120000 +num_valid_plots: 5 +num_test_samples: 0 +test_ids: [] +use_gt_dur: false +use_gt_f0: false + +# exp +dur_loss: mse # huber|mol +norm_type: gn \ No newline at end of file diff --git a/configs/tts/hifigan.yaml b/configs/tts/hifigan.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5ecd7390bcbd3c34a8303057c40a95bc841dd9db --- /dev/null +++ b/configs/tts/hifigan.yaml @@ -0,0 +1,21 @@ +base_config: configs/tts/pwg.yaml +task_cls: tasks.vocoder.hifigan.HifiGanTask +resblock: "1" +adam_b1: 0.8 +adam_b2: 0.99 +upsample_rates: [ 8,8,2,2 ] +upsample_kernel_sizes: [ 16,16,4,4 ] +upsample_initial_channel: 128 +resblock_kernel_sizes: [ 3,7,11 ] +resblock_dilation_sizes: [ [ 1,3,5 ], [ 1,3,5 ], [ 1,3,5 ] ] + +lambda_mel: 45.0 + +max_samples: 8192 +max_sentences: 16 + +generator_params: + lr: 0.0002 # Generator's learning rate. + aux_context_window: 0 # Context window size for auxiliary feature. +discriminator_optimizer_params: + lr: 0.0002 # Discriminator's learning rate. \ No newline at end of file diff --git a/configs/tts/lj/base_mel2wav.yaml b/configs/tts/lj/base_mel2wav.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f99da439f539cee27f29d60bfbe2edcc06858bac --- /dev/null +++ b/configs/tts/lj/base_mel2wav.yaml @@ -0,0 +1,3 @@ +raw_data_dir: 'data/raw/LJSpeech-1.1' +processed_data_dir: 'data/processed/ljspeech' +binary_data_dir: 'data/binary/ljspeech_wav' diff --git a/configs/tts/lj/base_text2mel.yaml b/configs/tts/lj/base_text2mel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0e0f3ce585226bb0f44533281033098d174ca7fb --- /dev/null +++ b/configs/tts/lj/base_text2mel.yaml @@ -0,0 +1,13 @@ +raw_data_dir: 'data/raw/LJSpeech-1.1' +processed_data_dir: 'data/processed/ljspeech' +binary_data_dir: 'data/binary/ljspeech' +pre_align_cls: data_gen.tts.lj.pre_align.LJPreAlign + +pitch_type: cwt +mel_loss: l1 +num_test_samples: 20 +test_ids: [ 68, 70, 74, 87, 110, 172, 190, 215, 231, 294, + 316, 324, 402, 422, 485, 500, 505, 508, 509, 519 ] +use_energy_embed: false +test_num: 523 +valid_num: 348 \ No newline at end of file diff --git a/configs/tts/lj/fs2.yaml b/configs/tts/lj/fs2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc894b16c094b3d94771a0ad73566c28381c61e4 --- /dev/null +++ b/configs/tts/lj/fs2.yaml @@ -0,0 +1,3 @@ +base_config: + - configs/tts/fs2.yaml + - configs/tts/lj/base_text2mel.yaml \ No newline at end of file diff --git a/configs/tts/lj/hifigan.yaml b/configs/tts/lj/hifigan.yaml new file mode 100644 index 0000000000000000000000000000000000000000..68354a46f3c38b420283b12bc7d62761bfd38d5e --- /dev/null +++ b/configs/tts/lj/hifigan.yaml @@ -0,0 +1,3 @@ +base_config: + - configs/tts/hifigan.yaml + - configs/tts/lj/base_mel2wav.yaml \ No newline at end of file diff --git a/configs/tts/lj/pwg.yaml b/configs/tts/lj/pwg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..72d539eaabd8035b8e38a0ded5d579efa8d42520 --- /dev/null +++ b/configs/tts/lj/pwg.yaml @@ -0,0 +1,3 @@ +base_config: + - configs/tts/pwg.yaml + - configs/tts/lj/base_mel2wav.yaml \ No newline at end of file diff --git a/configs/tts/pwg.yaml b/configs/tts/pwg.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bb443d584b64223ac467417c12f62f6f403fa52f --- /dev/null +++ b/configs/tts/pwg.yaml @@ -0,0 +1,110 @@ +base_config: configs/tts/base.yaml +task_cls: tasks.vocoder.pwg.PwgTask + +binarization_args: + with_wav: true + with_spk_embed: false + with_align: false +test_input_dir: '' + +########### +# train and eval +########### +max_samples: 25600 +max_sentences: 5 +max_eval_sentences: 1 +max_updates: 1000000 +val_check_interval: 2000 + + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +sampling_rate: 22050 # Sampling rate. +fft_size: 1024 # FFT size. +hop_size: 256 # Hop size. +win_length: null # Window length. +# If set to null, it will be the same as fft_size. +window: "hann" # Window function. +num_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. +fmax: 7600 # Maximum frequency in mel basis calculation. +format: "hdf5" # Feature file format. "npy" or "hdf5" is supported. + +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Kernel size of dilated convolution. + layers: 30 # Number of residual block layers. + stacks: 3 # Number of stacks i.e., dilation cycles. + residual_channels: 64 # Number of channels in residual conv. + gate_channels: 128 # Number of channels in gated conv. + skip_channels: 64 # Number of channels in skip conv. + aux_channels: 80 # Number of channels for auxiliary feature conv. + # Must be the same as num_mels. + aux_context_window: 2 # Context window size for auxiliary feature. + # If set to 2, previous 2 and future 2 frames will be considered. + dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture. + upsample_params: # Upsampling network parameters. + upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size. + use_pitch_embed: false + +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Number of output channels. + layers: 10 # Number of conv layers. + conv_channels: 64 # Number of chnn layers. + bias: true # Whether to use bias parameter in conv. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv. + nonlinear_activation_params: # Nonlinear function parameters + negative_slope: 0.2 # Alpha in LeakyReLU. + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann_window" # Window function for STFT-based loss +use_mel_loss: false + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 4.0 # Loss balancing coefficient. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + lr: 0.0001 # Generator's learning rate. + eps: 1.0e-6 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + step_size: 200000 # Generator's scheduler step size. + gamma: 0.5 # Generator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +generator_grad_norm: 10 # Generator's gradient norm. +discriminator_optimizer_params: + lr: 0.00005 # Discriminator's learning rate. + eps: 1.0e-6 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. +disc_start_steps: 40000 # Number of steps to start to train discriminator. diff --git a/data_gen/singing/binarize.py b/data_gen/singing/binarize.py new file mode 100644 index 0000000000000000000000000000000000000000..533a145046d087602091f75bc97d516850d626f8 --- /dev/null +++ b/data_gen/singing/binarize.py @@ -0,0 +1,393 @@ +import os +import random +from copy import deepcopy +import pandas as pd +import logging +from tqdm import tqdm +import json +import glob +import re +from resemblyzer import VoiceEncoder +import traceback +import numpy as np +import pretty_midi +import librosa +from scipy.interpolate import interp1d +import torch +from textgrid import TextGrid + +from utils.hparams import hparams +from data_gen.tts.data_gen_utils import build_phone_encoder, get_pitch +from utils.pitch_utils import f0_to_coarse +from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError +from data_gen.tts.binarizer_zh import ZhBinarizer +from data_gen.tts.txt_processors.zh_g2pM import ALL_YUNMU +from vocoders.base_vocoder import VOCODERS + + +class SingingBinarizer(BaseBinarizer): + def __init__(self, processed_data_dir=None): + if processed_data_dir is None: + processed_data_dir = hparams['processed_data_dir'] + self.processed_data_dirs = processed_data_dir.split(",") + self.binarization_args = hparams['binarization_args'] + self.pre_align_args = hparams['pre_align_args'] + self.item2txt = {} + self.item2ph = {} + self.item2wavfn = {} + self.item2f0fn = {} + self.item2tgfn = {} + self.item2spk = {} + + def split_train_test_set(self, item_names): + item_names = deepcopy(item_names) + test_item_names = [x for x in item_names if any([ts in x for ts in hparams['test_prefixes']])] + train_item_names = [x for x in item_names if x not in set(test_item_names)] + logging.info("train {}".format(len(train_item_names))) + logging.info("test {}".format(len(test_item_names))) + return train_item_names, test_item_names + + def load_meta_data(self): + for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): + wav_suffix = '_wf0.wav' + txt_suffix = '.txt' + ph_suffix = '_ph.txt' + tg_suffix = '.TextGrid' + all_wav_pieces = glob.glob(f'{processed_data_dir}/*/*{wav_suffix}') + + for piece_path in all_wav_pieces: + item_name = raw_item_name = piece_path[len(processed_data_dir)+1:].replace('/', '-')[:-len(wav_suffix)] + if len(self.processed_data_dirs) > 1: + item_name = f'ds{ds_id}_{item_name}' + self.item2txt[item_name] = open(f'{piece_path.replace(wav_suffix, txt_suffix)}').readline() + self.item2ph[item_name] = open(f'{piece_path.replace(wav_suffix, ph_suffix)}').readline() + self.item2wavfn[item_name] = piece_path + + self.item2spk[item_name] = re.split('-|#', piece_path.split('/')[-2])[0] + if len(self.processed_data_dirs) > 1: + self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" + self.item2tgfn[item_name] = piece_path.replace(wav_suffix, tg_suffix) + print('spkers: ', set(self.item2spk.values())) + self.item_names = sorted(list(self.item2txt.keys())) + if self.binarization_args['shuffle']: + random.seed(1234) + random.shuffle(self.item_names) + self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) + + @property + def train_item_names(self): + return self._train_item_names + + @property + def valid_item_names(self): + return self._test_item_names + + @property + def test_item_names(self): + return self._test_item_names + + def process(self): + self.load_meta_data() + os.makedirs(hparams['binary_data_dir'], exist_ok=True) + self.spk_map = self.build_spk_map() + print("| spk_map: ", self.spk_map) + spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json" + json.dump(self.spk_map, open(spk_map_fn, 'w')) + + self.phone_encoder = self._phone_encoder() + self.process_data('valid') + self.process_data('test') + self.process_data('train') + + def _phone_encoder(self): + ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" + ph_set = [] + if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn): + for ph_sent in self.item2ph.values(): + ph_set += ph_sent.split(' ') + ph_set = sorted(set(ph_set)) + json.dump(ph_set, open(ph_set_fn, 'w')) + print("| Build phone set: ", ph_set) + else: + ph_set = json.load(open(ph_set_fn, 'r')) + print("| Load phone set: ", ph_set) + return build_phone_encoder(hparams['binary_data_dir']) + + # @staticmethod + # def get_pitch(wav_fn, spec, res): + # wav_suffix = '_wf0.wav' + # f0_suffix = '_f0.npy' + # f0fn = wav_fn.replace(wav_suffix, f0_suffix) + # pitch_info = np.load(f0fn) + # f0 = [x[1] for x in pitch_info] + # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] + # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] + # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] + # # f0_x_coor = np.arange(0, 1, 1 / len(f0)) + # # f0_x_coor[-1] = 1 + # # f0 = interp1d(f0_x_coor, f0, 'nearest')(spec_x_coor)[:len(spec)] + # if sum(f0) == 0: + # raise BinarizationError("Empty f0") + # assert len(f0) == len(spec), (len(f0), len(spec)) + # pitch_coarse = f0_to_coarse(f0) + # + # # vis f0 + # # import matplotlib.pyplot as plt + # # from textgrid import TextGrid + # # tg_fn = wav_fn.replace(wav_suffix, '.TextGrid') + # # fig = plt.figure(figsize=(12, 6)) + # # plt.pcolor(spec.T, vmin=-5, vmax=0) + # # ax = plt.gca() + # # ax2 = ax.twinx() + # # ax2.plot(f0, color='red') + # # ax2.set_ylim(0, 800) + # # itvs = TextGrid.fromFile(tg_fn)[0] + # # for itv in itvs: + # # x = itv.maxTime * hparams['audio_sample_rate'] / hparams['hop_size'] + # # plt.vlines(x=x, ymin=0, ymax=80, color='black') + # # plt.text(x=x, y=20, s=itv.mark, color='black') + # # plt.savefig('tmp/20211229_singing_plots_test.png') + # + # res['f0'] = f0 + # res['pitch'] = pitch_coarse + + @classmethod + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): + if hparams['vocoder'] in VOCODERS: + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) + else: + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) + res = { + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id + } + try: + if binarization_args['with_f0']: + # cls.get_pitch(wav_fn, mel, res) + cls.get_pitch(wav, mel, res) + if binarization_args['with_txt']: + try: + # print(ph) + phone_encoded = res['phone'] = encoder.encode(ph) + except: + traceback.print_exc() + raise BinarizationError(f"Empty phoneme") + if binarization_args['with_align']: + cls.get_align(tg_fn, ph, mel, phone_encoded, res) + except BinarizationError as e: + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") + return None + return res + + +class MidiSingingBinarizer(SingingBinarizer): + item2midi = {} + item2midi_dur = {} + item2is_slur = {} + item2ph_durs = {} + item2wdb = {} + + def load_meta_data(self): + for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): + meta_midi = json.load(open(os.path.join(processed_data_dir, 'meta.json'))) # [list of dict] + + for song_item in meta_midi: + item_name = raw_item_name = song_item['item_name'] + if len(self.processed_data_dirs) > 1: + item_name = f'ds{ds_id}_{item_name}' + self.item2wavfn[item_name] = song_item['wav_fn'] + self.item2txt[item_name] = song_item['txt'] + + self.item2ph[item_name] = ' '.join(song_item['phs']) + self.item2wdb[item_name] = [1 if x in ALL_YUNMU + ['AP', 'SP', ''] else 0 for x in song_item['phs']] + self.item2ph_durs[item_name] = song_item['ph_dur'] + + self.item2midi[item_name] = song_item['notes'] + self.item2midi_dur[item_name] = song_item['notes_dur'] + self.item2is_slur[item_name] = song_item['is_slur'] + self.item2spk[item_name] = 'pop-cs' + if len(self.processed_data_dirs) > 1: + self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" + + print('spkers: ', set(self.item2spk.values())) + self.item_names = sorted(list(self.item2txt.keys())) + if self.binarization_args['shuffle']: + random.seed(1234) + random.shuffle(self.item_names) + self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) + + @staticmethod + def get_pitch(wav_fn, wav, spec, ph, res): + wav_suffix = '.wav' + # midi_suffix = '.mid' + wav_dir = 'wavs' + f0_dir = 'f0' + + item_name = '/'.join(os.path.splitext(wav_fn)[0].split('/')[-2:]).replace('_wf0', '') + res['pitch_midi'] = np.asarray(MidiSingingBinarizer.item2midi[item_name]) + res['midi_dur'] = np.asarray(MidiSingingBinarizer.item2midi_dur[item_name]) + res['is_slur'] = np.asarray(MidiSingingBinarizer.item2is_slur[item_name]) + res['word_boundary'] = np.asarray(MidiSingingBinarizer.item2wdb[item_name]) + assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, ( + res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape) + + # gt f0. + gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams) + if sum(gt_f0) == 0: + raise BinarizationError("Empty **gt** f0") + res['f0'] = gt_f0 + res['pitch'] = gt_pitch_coarse + + @staticmethod + def get_align(ph_durs, mel, phone_encoded, res, hop_size=hparams['hop_size'], audio_sample_rate=hparams['audio_sample_rate']): + mel2ph = np.zeros([mel.shape[0]], int) + startTime = 0 + + for i_ph in range(len(ph_durs)): + start_frame = int(startTime * audio_sample_rate / hop_size + 0.5) + end_frame = int((startTime + ph_durs[i_ph]) * audio_sample_rate / hop_size + 0.5) + mel2ph[start_frame:end_frame] = i_ph + 1 + startTime = startTime + ph_durs[i_ph] + + # print('ph durs: ', ph_durs) + # print('mel2ph: ', mel2ph, len(mel2ph)) + res['mel2ph'] = mel2ph + # res['dur'] = None + + @classmethod + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): + if hparams['vocoder'] in VOCODERS: + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) + else: + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) + res = { + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id + } + try: + if binarization_args['with_f0']: + cls.get_pitch(wav_fn, wav, mel, ph, res) + if binarization_args['with_txt']: + try: + phone_encoded = res['phone'] = encoder.encode(ph) + except: + traceback.print_exc() + raise BinarizationError(f"Empty phoneme") + if binarization_args['with_align']: + cls.get_align(MidiSingingBinarizer.item2ph_durs[item_name], mel, phone_encoded, res) + except BinarizationError as e: + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") + return None + return res + + +class ZhSingingBinarizer(ZhBinarizer, SingingBinarizer): + pass + +class M4SingerBinarizer(MidiSingingBinarizer): + item2midi = {} + item2midi_dur = {} + item2is_slur = {} + item2ph_durs = {} + item2wdb = {} + + def split_train_test_set(self, item_names): + item_names = deepcopy(item_names) + test_item_names = [x for x in item_names if any([x.startswith(ts) for ts in hparams['test_prefixes']])] + train_item_names = [x for x in item_names if x not in set(test_item_names)] + logging.info("train {}".format(len(train_item_names))) + logging.info("test {}".format(len(test_item_names))) + return train_item_names, test_item_names + + def load_meta_data(self): + raw_data_dir = hparams['raw_data_dir'] + song_items = json.load(open(os.path.join(raw_data_dir, 'meta.json'))) # [list of dict] + for song_item in song_items: + item_name = raw_item_name = song_item['item_name'] + singer, song_name, sent_id = item_name.split("#") + self.item2wavfn[item_name] = f'{raw_data_dir}/{singer}#{song_name}/{sent_id}.wav' + self.item2txt[item_name] = song_item['txt'] + + self.item2ph[item_name] = ' '.join(song_item['phs']) + self.item2ph_durs[item_name] = song_item['ph_dur'] + + self.item2midi[item_name] = song_item['notes'] + self.item2midi_dur[item_name] = song_item['notes_dur'] + self.item2is_slur[item_name] = song_item['is_slur'] + self.item2wdb[item_name] = [1 if (0 < i < len(song_item['phs']) - 1 and p in ALL_YUNMU + ['', ''])\ + or i == len(song_item['phs']) - 1 else 0 for i, p in enumerate(song_item['phs'])] + self.item2spk[item_name] = singer + + print('spkers: ', set(self.item2spk.values())) + self.item_names = sorted(list(self.item2txt.keys())) + if self.binarization_args['shuffle']: + random.seed(1234) + random.shuffle(self.item_names) + self._train_item_names, self._test_item_names = self.split_train_test_set(self.item_names) + + @staticmethod + def get_pitch(item_name, wav, spec, ph, res): + wav_suffix = '.wav' + # midi_suffix = '.mid' + wav_dir = 'wavs' + f0_dir = 'text_f0_align' + + #item_name = os.path.splitext(os.path.basename(wav_fn))[0] + res['pitch_midi'] = np.asarray(M4SingerBinarizer.item2midi[item_name]) + res['midi_dur'] = np.asarray(M4SingerBinarizer.item2midi_dur[item_name]) + res['is_slur'] = np.asarray(M4SingerBinarizer.item2is_slur[item_name]) + res['word_boundary'] = np.asarray(M4SingerBinarizer.item2wdb[item_name]) + assert res['pitch_midi'].shape == res['midi_dur'].shape == res['is_slur'].shape, (res['pitch_midi'].shape, res['midi_dur'].shape, res['is_slur'].shape) + + # gt f0. + # f0 = None + # f0_suffix = '_f0.npy' + # f0fn = wav_fn.replace(wav_suffix, f0_suffix).replace(wav_dir, f0_dir) + # pitch_info = np.load(f0fn) + # f0 = [x[1] for x in pitch_info] + # spec_x_coor = np.arange(0, 1, 1 / len(spec))[:len(spec)] + # + # f0_x_coor = np.arange(0, 1, 1 / len(f0))[:len(f0)] + # f0 = interp1d(f0_x_coor, f0, 'nearest', fill_value='extrapolate')(spec_x_coor)[:len(spec)] + # if sum(f0) == 0: + # raise BinarizationError("Empty **gt** f0") + # + # pitch_coarse = f0_to_coarse(f0) + # res['f0'] = f0 + # res['pitch'] = pitch_coarse + + # gt f0. + gt_f0, gt_pitch_coarse = get_pitch(wav, spec, hparams) + if sum(gt_f0) == 0: + raise BinarizationError("Empty **gt** f0") + res['f0'] = gt_f0 + res['pitch'] = gt_pitch_coarse + + @classmethod + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): + if hparams['vocoder'] in VOCODERS: + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) + else: + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) + res = { + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id + } + try: + if binarization_args['with_f0']: + cls.get_pitch(item_name, wav, mel, ph, res) + if binarization_args['with_txt']: + try: + phone_encoded = res['phone'] = encoder.encode(ph) + except: + traceback.print_exc() + raise BinarizationError(f"Empty phoneme") + if binarization_args['with_align']: + cls.get_align(M4SingerBinarizer.item2ph_durs[item_name], mel, phone_encoded, res) + except BinarizationError as e: + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") + return None + return res + +if __name__ == "__main__": + SingingBinarizer().process() diff --git a/data_gen/tts/base_binarizer.py b/data_gen/tts/base_binarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..da3c8e1e6e3854e74032e41a2ab23dddb9fb5c16 --- /dev/null +++ b/data_gen/tts/base_binarizer.py @@ -0,0 +1,224 @@ +import os +os.environ["OMP_NUM_THREADS"] = "1" + +from utils.multiprocess_utils import chunked_multiprocess_run +import random +import traceback +import json +from resemblyzer import VoiceEncoder +from tqdm import tqdm +from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder +from utils.hparams import set_hparams, hparams +import numpy as np +from utils.indexed_datasets import IndexedDatasetBuilder +from vocoders.base_vocoder import VOCODERS +import pandas as pd + + +class BinarizationError(Exception): + pass + + +class BaseBinarizer: + def __init__(self, processed_data_dir=None): + if processed_data_dir is None: + processed_data_dir = hparams['processed_data_dir'] + self.processed_data_dirs = processed_data_dir.split(",") + self.binarization_args = hparams['binarization_args'] + self.pre_align_args = hparams['pre_align_args'] + self.forced_align = self.pre_align_args['forced_align'] + tg_dir = None + if self.forced_align == 'mfa': + tg_dir = 'mfa_outputs' + if self.forced_align == 'kaldi': + tg_dir = 'kaldi_outputs' + self.item2txt = {} + self.item2ph = {} + self.item2wavfn = {} + self.item2tgfn = {} + self.item2spk = {} + for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): + self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str) + for r_idx, r in self.meta_df.iterrows(): + item_name = raw_item_name = r['item_name'] + if len(self.processed_data_dirs) > 1: + item_name = f'ds{ds_id}_{item_name}' + self.item2txt[item_name] = r['txt'] + self.item2ph[item_name] = r['ph'] + self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1]) + self.item2spk[item_name] = r.get('spk', 'SPK1') + if len(self.processed_data_dirs) > 1: + self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" + if tg_dir is not None: + self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid" + self.item_names = sorted(list(self.item2txt.keys())) + if self.binarization_args['shuffle']: + random.seed(1234) + random.shuffle(self.item_names) + + @property + def train_item_names(self): + return self.item_names[hparams['test_num']+hparams['valid_num']:] + + @property + def valid_item_names(self): + return self.item_names[0: hparams['test_num']+hparams['valid_num']] # + + @property + def test_item_names(self): + return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids' + + def build_spk_map(self): + spk_map = set() + for item_name in self.item_names: + spk_name = self.item2spk[item_name] + spk_map.add(spk_name) + spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))} + assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map) + return spk_map + + def item_name2spk_id(self, item_name): + return self.spk_map[self.item2spk[item_name]] + + def _phone_encoder(self): + ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" + ph_set = [] + if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn): + for processed_data_dir in self.processed_data_dirs: + ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()] + ph_set = sorted(set(ph_set)) + json.dump(ph_set, open(ph_set_fn, 'w')) + else: + ph_set = json.load(open(ph_set_fn, 'r')) + print("| phone set: ", ph_set) + return build_phone_encoder(hparams['binary_data_dir']) + + def meta_data(self, prefix): + if prefix == 'valid': + item_names = self.valid_item_names + elif prefix == 'test': + item_names = self.test_item_names + else: + item_names = self.train_item_names + for item_name in item_names: + ph = self.item2ph[item_name] + txt = self.item2txt[item_name] + tg_fn = self.item2tgfn.get(item_name) + wav_fn = self.item2wavfn[item_name] + spk_id = self.item_name2spk_id(item_name) + yield item_name, ph, txt, tg_fn, wav_fn, spk_id + + def process(self): + os.makedirs(hparams['binary_data_dir'], exist_ok=True) + self.spk_map = self.build_spk_map() + print("| spk_map: ", self.spk_map) + spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json" + json.dump(self.spk_map, open(spk_map_fn, 'w')) + + self.phone_encoder = self._phone_encoder() + self.process_data('valid') + self.process_data('test') + self.process_data('train') + + def process_data(self, prefix): + data_dir = hparams['binary_data_dir'] + args = [] + builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}') + lengths = [] + f0s = [] + total_sec = 0 + if self.binarization_args['with_spk_embed']: + voice_encoder = VoiceEncoder().cuda() + + meta_data = list(self.meta_data(prefix)) + for m in meta_data: + args.append(list(m) + [self.phone_encoder, self.binarization_args]) + num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3)) + for f_id, (_, item) in enumerate( + zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))): + if item is None: + continue + item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \ + if self.binarization_args['with_spk_embed'] else None + if not self.binarization_args['with_wav'] and 'wav' in item: + #print("del wav") + del item['wav'] + builder.add_item(item) + lengths.append(item['len']) + total_sec += item['sec'] + if item.get('f0') is not None: + f0s.append(item['f0']) + builder.finalize() + np.save(f'{data_dir}/{prefix}_lengths.npy', lengths) + if len(f0s) > 0: + f0s = np.concatenate(f0s, 0) + f0s = f0s[f0s != 0] + np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()]) + print(f"| {prefix} total duration: {total_sec:.3f}s") + + @classmethod + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): + if hparams['vocoder'] in VOCODERS: + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) + else: + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) + res = { + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id + } + try: + if binarization_args['with_f0']: + cls.get_pitch(wav, mel, res) + if binarization_args['with_f0cwt']: + cls.get_f0cwt(res['f0'], res) + if binarization_args['with_txt']: + try: + phone_encoded = res['phone'] = encoder.encode(ph) + except: + traceback.print_exc() + raise BinarizationError(f"Empty phoneme") + if binarization_args['with_align']: + cls.get_align(tg_fn, ph, mel, phone_encoded, res) + except BinarizationError as e: + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") + return None + return res + + @staticmethod + def get_align(tg_fn, ph, mel, phone_encoded, res): + if tg_fn is not None and os.path.exists(tg_fn): + mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams) + else: + raise BinarizationError(f"Align not found") + if mel2ph.max() - 1 >= len(phone_encoded): + raise BinarizationError( + f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}") + res['mel2ph'] = mel2ph + res['dur'] = dur + + @staticmethod + def get_pitch(wav, mel, res): + f0, pitch_coarse = get_pitch(wav, mel, hparams) + if sum(f0) == 0: + raise BinarizationError("Empty f0") + res['f0'] = f0 + res['pitch'] = pitch_coarse + + @staticmethod + def get_f0cwt(f0, res): + from utils.cwt import get_cont_lf0, get_lf0_cwt + uv, cont_lf0_lpf = get_cont_lf0(f0) + logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf) + cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) + if np.any(np.isnan(Wavelet_lf0)): + raise BinarizationError("NaN CWT") + res['cwt_spec'] = Wavelet_lf0 + res['cwt_scales'] = scales + res['f0_mean'] = logf0s_mean_org + res['f0_std'] = logf0s_std_org + + +if __name__ == "__main__": + set_hparams() + BaseBinarizer().process() diff --git a/data_gen/tts/bin/binarize.py b/data_gen/tts/bin/binarize.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd3c1f69fa59ed52fdd32eb80e746dedbae7535 --- /dev/null +++ b/data_gen/tts/bin/binarize.py @@ -0,0 +1,20 @@ +import os + +os.environ["OMP_NUM_THREADS"] = "1" + +import importlib +from utils.hparams import set_hparams, hparams + + +def binarize(): + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + print("| Binarizer: ", binarizer_cls) + binarizer_cls().process() + + +if __name__ == '__main__': + set_hparams() + binarize() diff --git a/data_gen/tts/binarizer_zh.py b/data_gen/tts/binarizer_zh.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd424a1a669ecf1a74cb69d690a690d0d39fe55 --- /dev/null +++ b/data_gen/tts/binarizer_zh.py @@ -0,0 +1,59 @@ +import os + +os.environ["OMP_NUM_THREADS"] = "1" + +from data_gen.tts.txt_processors.zh_g2pM import ALL_SHENMU +from data_gen.tts.base_binarizer import BaseBinarizer, BinarizationError +from data_gen.tts.data_gen_utils import get_mel2ph +from utils.hparams import set_hparams, hparams +import numpy as np + + +class ZhBinarizer(BaseBinarizer): + @staticmethod + def get_align(tg_fn, ph, mel, phone_encoded, res): + if tg_fn is not None and os.path.exists(tg_fn): + _, dur = get_mel2ph(tg_fn, ph, mel, hparams) + else: + raise BinarizationError(f"Align not found") + ph_list = ph.split(" ") + assert len(dur) == len(ph_list) + mel2ph = [] + # 分隔符的时长分配给韵母 + dur_cumsum = np.pad(np.cumsum(dur), [1, 0], mode='constant', constant_values=0) + for i in range(len(dur)): + p = ph_list[i] + if p[0] != '<' and not p[0].isalpha(): + uv_ = res['f0'][dur_cumsum[i]:dur_cumsum[i + 1]] == 0 + j = 0 + while j < len(uv_) and not uv_[j]: + j += 1 + dur[i - 1] += j + dur[i] -= j + if dur[i] < 100: + dur[i - 1] += dur[i] + dur[i] = 0 + # 声母和韵母等长 + for i in range(len(dur)): + p = ph_list[i] + if p in ALL_SHENMU: + p_next = ph_list[i + 1] + if not (dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU): + print(f"assert dur[i] > 0 and p_next[0].isalpha() and p_next not in ALL_SHENMU, " + f"dur[i]: {dur[i]}, p: {p}, p_next: {p_next}.") + continue + total = dur[i + 1] + dur[i] + dur[i] = total // 2 + dur[i + 1] = total - dur[i] + for i in range(len(dur)): + mel2ph += [i + 1] * dur[i] + mel2ph = np.array(mel2ph) + if mel2ph.max() - 1 >= len(phone_encoded): + raise BinarizationError(f"| Align does not match: {(mel2ph.max() - 1, len(phone_encoded))}") + res['mel2ph'] = mel2ph + res['dur'] = dur + + +if __name__ == "__main__": + set_hparams() + ZhBinarizer().process() diff --git a/data_gen/tts/data_gen_utils.py b/data_gen/tts/data_gen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d416b78c1e7aa6b03951c1db12bd4fd26d0a708 --- /dev/null +++ b/data_gen/tts/data_gen_utils.py @@ -0,0 +1,347 @@ +import warnings + +warnings.filterwarnings("ignore") + +import parselmouth +import os +import torch +from skimage.transform import resize +from utils.text_encoder import TokenTextEncoder +from utils.pitch_utils import f0_to_coarse +import struct +import webrtcvad +from scipy.ndimage.morphology import binary_dilation +import librosa +import numpy as np +from utils import audio +import pyloudnorm as pyln +import re +import json +from collections import OrderedDict + +PUNCS = '!,.?;:' + +int16_max = (2 ** 15) - 1 + + +def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + :param wav: the raw waveform as a numpy array of floats + :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have. + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + + ## Voice Activation Detection + # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. + # This sets the granularity of the VAD. Should not need to be changed. + sampling_rate = 16000 + wav_raw, sr = librosa.core.load(path, sr=sr) + + if norm: + meter = pyln.Meter(sr) # create BS.1770 meter + loudness = meter.integrated_loudness(wav_raw) + wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0) + if np.abs(wav_raw).max() > 1.0: + wav_raw = wav_raw / np.abs(wav_raw).max() + + wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best') + + vad_window_length = 30 # In milliseconds + # Number of frames to average together when performing the moving average smoothing. + # The larger this value, the larger the VAD variations must be to not get smoothed out. + vad_moving_average_width = 8 + + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + audio_mask = resize(audio_mask, (len(wav_raw),)) > 0 + if return_raw_wav: + return wav_raw, audio_mask, sr + return wav_raw[audio_mask], audio_mask, sr + + +def process_utterance(wav_path, + fft_size=1024, + hop_size=256, + win_length=1024, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-6, + sample_rate=22050, + loud_norm=False, + min_level_db=-100, + return_linear=False, + trim_long_sil=False, vocoder='pwg'): + if isinstance(wav_path, str): + if trim_long_sil: + wav, _, _ = trim_long_silences(wav_path, sample_rate) + else: + wav, _ = librosa.core.load(wav_path, sr=sample_rate) + else: + wav = wav_path + + if loud_norm: + meter = pyln.Meter(sample_rate) # create BS.1770 meter + loudness = meter.integrated_loudness(wav) + wav = pyln.normalize.loudness(wav, loudness, -22.0) + if np.abs(wav).max() > 1: + wav = wav / np.abs(wav).max() + + # get amplitude spectrogram + x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="constant") + spc = np.abs(x_stft) # (n_bins, T) + + # get mel basis + fmin = 0 if fmin == -1 else fmin + fmax = sample_rate / 2 if fmax == -1 else fmax + mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax) + mel = mel_basis @ spc + + if vocoder == 'pwg': + mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) + else: + assert False, f'"{vocoder}" is not in ["pwg"].' + + l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1) + wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) + wav = wav[:mel.shape[1] * hop_size] + + if not return_linear: + return wav, mel + else: + spc = audio.amp_to_db(spc) + spc = audio.normalize(spc, {'min_level_db': min_level_db}) + return wav, mel, spc + + +def get_pitch(wav_data, mel, hparams): + """ + + :param wav_data: [T] + :param mel: [T, 80] + :param hparams: + :return: + """ + time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000 + f0_min = 80 + f0_max = 750 + + if hparams['hop_size'] == 128: + pad_size = 4 + elif hparams['hop_size'] == 256: + pad_size = 2 + else: + assert False + + f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac( + time_step=time_step / 1000, voicing_threshold=0.6, + pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] + lpad = pad_size * 2 + rpad = len(mel) - len(f0) - lpad + f0 = np.pad(f0, [[lpad, rpad]], mode='constant') + # mel and f0 are extracted by 2 different libraries. we should force them to have the same length. + # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value... + # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda) + delta_l = len(mel) - len(f0) + assert np.abs(delta_l) <= 8 + if delta_l > 0: + f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) + f0 = f0[:len(mel)] + pitch_coarse = f0_to_coarse(f0) + return f0, pitch_coarse + + +def remove_empty_lines(text): + """remove empty lines""" + assert (len(text) > 0) + assert (isinstance(text, list)) + text = [t.strip() for t in text] + if "" in text: + text.remove("") + return text + + +class TextGrid(object): + def __init__(self, text): + text = remove_empty_lines(text) + self.text = text + self.line_count = 0 + self._get_type() + self._get_time_intval() + self._get_size() + self.tier_list = [] + self._get_item_list() + + def _extract_pattern(self, pattern, inc): + """ + Parameters + ---------- + pattern : regex to extract pattern + inc : increment of line count after extraction + Returns + ------- + group : extracted info + """ + try: + group = re.match(pattern, self.text[self.line_count]).group(1) + self.line_count += inc + except AttributeError: + raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count])) + return group + + def _get_type(self): + self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2) + + def _get_time_intval(self): + self.xmin = self._extract_pattern(r"xmin = (.*)", 1) + self.xmax = self._extract_pattern(r"xmax = (.*)", 2) + + def _get_size(self): + self.size = int(self._extract_pattern(r"size = (.*)", 2)) + + def _get_item_list(self): + """Only supports IntervalTier currently""" + for itemIdx in range(1, self.size + 1): + tier = OrderedDict() + item_list = [] + tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1) + tier_class = self._extract_pattern(r"class = \"(.*)\"", 1) + if tier_class != "IntervalTier": + raise NotImplementedError("Only IntervalTier class is supported currently") + tier_name = self._extract_pattern(r"name = \"(.*)\"", 1) + tier_xmin = self._extract_pattern(r"xmin = (.*)", 1) + tier_xmax = self._extract_pattern(r"xmax = (.*)", 1) + tier_size = self._extract_pattern(r"intervals: size = (.*)", 1) + for i in range(int(tier_size)): + item = OrderedDict() + item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1) + item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1) + item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1) + item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1) + item_list.append(item) + tier["idx"] = tier_idx + tier["class"] = tier_class + tier["name"] = tier_name + tier["xmin"] = tier_xmin + tier["xmax"] = tier_xmax + tier["size"] = tier_size + tier["items"] = item_list + self.tier_list.append(tier) + + def toJson(self): + _json = OrderedDict() + _json["file_type"] = self.file_type + _json["xmin"] = self.xmin + _json["xmax"] = self.xmax + _json["size"] = self.size + _json["tiers"] = self.tier_list + return json.dumps(_json, ensure_ascii=False, indent=2) + + +def get_mel2ph(tg_fn, ph, mel, hparams): + ph_list = ph.split(" ") + with open(tg_fn, "r") as f: + tg = f.readlines() + tg = remove_empty_lines(tg) + tg = TextGrid(tg) + tg = json.loads(tg.toJson()) + split = np.ones(len(ph_list) + 1, np.float) * -1 + tg_idx = 0 + ph_idx = 0 + tg_align = [x for x in tg['tiers'][-1]['items']] + tg_align_ = [] + for x in tg_align: + x['xmin'] = float(x['xmin']) + x['xmax'] = float(x['xmax']) + if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']: + x['text'] = '' + if len(tg_align_) > 0 and tg_align_[-1]['text'] == '': + tg_align_[-1]['xmax'] = x['xmax'] + continue + tg_align_.append(x) + tg_align = tg_align_ + tg_len = len([x for x in tg_align if x['text'] != '']) + ph_len = len([x for x in ph_list if not is_sil_phoneme(x)]) + assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn) + while tg_idx < len(tg_align) or ph_idx < len(ph_list): + if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]): + split[ph_idx] = 1e8 + ph_idx += 1 + continue + x = tg_align[tg_idx] + if x['text'] == '' and ph_idx == len(ph_list): + tg_idx += 1 + continue + assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn) + ph = ph_list[ph_idx] + if x['text'] == '' and not is_sil_phoneme(ph): + assert False, (ph_list, tg_align) + if x['text'] != '' and is_sil_phoneme(ph): + ph_idx += 1 + else: + assert (x['text'] == '' and is_sil_phoneme(ph)) \ + or x['text'].lower() == ph.lower() \ + or x['text'].lower() == 'sil', (x['text'], ph) + split[ph_idx] = x['xmin'] + if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]): + split[ph_idx - 1] = split[ph_idx] + ph_idx += 1 + tg_idx += 1 + assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align]) + assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn) + mel2ph = np.zeros([mel.shape[0]], np.int) + split[0] = 0 + split[-1] = 1e8 + for i in range(len(split) - 1): + assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],) + split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split] + for ph_idx in range(len(ph_list)): + mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1 + mel2ph_torch = torch.from_numpy(mel2ph) + T_t = len(ph_list) + dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch)) + dur = dur[1:].numpy() + return mel2ph, dur + + +def build_phone_encoder(data_dir): + phone_list_file = os.path.join(data_dir, 'phone_set.json') + phone_list = json.load(open(phone_list_file)) + return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + + +def is_sil_phoneme(p): + return not p[0].isalpha() diff --git a/data_gen/tts/txt_processors/base_text_processor.py b/data_gen/tts/txt_processors/base_text_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..84a9772fee5bcd910fca4ca7396945b03ed9982e --- /dev/null +++ b/data_gen/tts/txt_processors/base_text_processor.py @@ -0,0 +1,8 @@ +class BaseTxtProcessor: + @staticmethod + def sp_phonemes(): + return ['|'] + + @classmethod + def process(cls, txt, pre_align_args): + raise NotImplementedError diff --git a/data_gen/tts/txt_processors/en.py b/data_gen/tts/txt_processors/en.py new file mode 100644 index 0000000000000000000000000000000000000000..f9d4eedff5c1b057d81fa8a50c031b6656fc3708 --- /dev/null +++ b/data_gen/tts/txt_processors/en.py @@ -0,0 +1,78 @@ +import re +from data_gen.tts.data_gen_utils import PUNCS +from g2p_en import G2p +import unicodedata +from g2p_en.expand import normalize_numbers +from nltk import pos_tag +from nltk.tokenize import TweetTokenizer + +from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor + + +class EnG2p(G2p): + word_tokenize = TweetTokenizer().tokenize + + def __call__(self, text): + # preprocessing + words = EnG2p.word_tokenize(text) + tokens = pos_tag(words) # tuples of (word, tag) + + # steps + prons = [] + for word, pos in tokens: + if re.search("[a-z]", word) is None: + pron = [word] + + elif word in self.homograph2features: # Check homograph + pron1, pron2, pos1 = self.homograph2features[word] + if pos.startswith(pos1): + pron = pron1 + else: + pron = pron2 + elif word in self.cmu: # lookup CMU dict + pron = self.cmu[word][0] + else: # predict for oov + pron = self.predict(word) + + prons.extend(pron) + prons.extend([" "]) + + return prons[:-1] + + +class TxtProcessor(BaseTxtProcessor): + g2p = EnG2p() + + @staticmethod + def preprocess_text(text): + text = normalize_numbers(text) + text = ''.join(char for char in unicodedata.normalize('NFD', text) + if unicodedata.category(char) != 'Mn') # Strip accents + text = text.lower() + text = re.sub("[\'\"()]+", "", text) + text = re.sub("[-]+", " ", text) + text = re.sub(f"[^ a-z{PUNCS}]", "", text) + text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! + text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! + text = text.replace("i.e.", "that is") + text = text.replace("i.e.", "that is") + text = text.replace("etc.", "etc") + text = re.sub(f"([{PUNCS}])", r" \1 ", text) + text = re.sub(rf"\s+", r" ", text) + return text + + @classmethod + def process(cls, txt, pre_align_args): + txt = cls.preprocess_text(txt).strip() + phs = cls.g2p(txt) + phs_ = [] + n_word_sep = 0 + for p in phs: + if p.strip() == '': + phs_ += ['|'] + n_word_sep += 1 + else: + phs_ += p.split(" ") + phs = phs_ + assert n_word_sep + 1 == len(txt.split(" ")), (phs, f"\"{txt}\"") + return phs, txt diff --git a/data_gen/tts/txt_processors/zh.py b/data_gen/tts/txt_processors/zh.py new file mode 100644 index 0000000000000000000000000000000000000000..3fb73ebb96792e6842a176df1c848ba8639df839 --- /dev/null +++ b/data_gen/tts/txt_processors/zh.py @@ -0,0 +1,41 @@ +import re +from pypinyin import pinyin, Style +from data_gen.tts.data_gen_utils import PUNCS +from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor +from utils.text_norm import NSWNormalizer + + +class TxtProcessor(BaseTxtProcessor): + table = {ord(f): ord(t) for f, t in zip( + u':,。!?【】()%#@&1234567890', + u':,.!?[]()%#@&1234567890')} + + @staticmethod + def preprocess_text(text): + text = text.translate(TxtProcessor.table) + text = NSWNormalizer(text).normalize(remove_punc=False) + text = re.sub("[\'\"()]+", "", text) + text = re.sub("[-]+", " ", text) + text = re.sub(f"[^ A-Za-z\u4e00-\u9fff{PUNCS}]", "", text) + text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! + text = re.sub(f"([{PUNCS}])", r" \1 ", text) + text = re.sub(rf"\s+", r"", text) + return text + + @classmethod + def process(cls, txt, pre_align_args): + txt = cls.preprocess_text(txt) + shengmu = pinyin(txt, style=Style.INITIALS) # https://blog.csdn.net/zhoulei124/article/details/89055403 + yunmu_finals = pinyin(txt, style=Style.FINALS) + yunmu_tone3 = pinyin(txt, style=Style.FINALS_TONE3) + yunmu = [[t[0] + '5'] if t[0] == f[0] else t for f, t in zip(yunmu_finals, yunmu_tone3)] \ + if pre_align_args['use_tone'] else yunmu_finals + + assert len(shengmu) == len(yunmu) + phs = ["|"] + for a, b, c in zip(shengmu, yunmu, yunmu_finals): + if a[0] == c[0]: + phs += [a[0], "|"] + else: + phs += [a[0], b[0], "|"] + return phs, txt diff --git a/data_gen/tts/txt_processors/zh_g2pM.py b/data_gen/tts/txt_processors/zh_g2pM.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd97f6db9fa0e5da85f6c92da48b4df795e6828 --- /dev/null +++ b/data_gen/tts/txt_processors/zh_g2pM.py @@ -0,0 +1,71 @@ +import re +import jieba +from pypinyin import pinyin, Style +from data_gen.tts.data_gen_utils import PUNCS +from data_gen.tts.txt_processors import zh +from g2pM import G2pM + +ALL_SHENMU = ['b', 'c', 'ch', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'q', 'r', 's', 'sh', 't', 'x', 'z', 'zh'] +ALL_YUNMU = ['a', 'ai', 'an', 'ang', 'ao', 'e', 'ei', 'en', 'eng', 'er', 'i', 'ia', 'ian', 'iang', 'iao', + 'ie', 'in', 'ing', 'iong', 'iou', 'o', 'ong', 'ou', 'u', 'ua', 'uai', 'uan', 'uang', 'uei', + 'uen', 'uo', 'v', 'van', 've', 'vn'] + + +class TxtProcessor(zh.TxtProcessor): + model = G2pM() + + @staticmethod + def sp_phonemes(): + return ['|', '#'] + + @classmethod + def process(cls, txt, pre_align_args): + txt = cls.preprocess_text(txt) + ph_list = cls.model(txt, tone=pre_align_args['use_tone'], char_split=True) + seg_list = '#'.join(jieba.cut(txt)) + assert len(ph_list) == len([s for s in seg_list if s != '#']), (ph_list, seg_list) + + # 加入词边界'#' + ph_list_ = [] + seg_idx = 0 + for p in ph_list: + p = p.replace("u:", "v") + if seg_list[seg_idx] == '#': + ph_list_.append('#') + seg_idx += 1 + else: + ph_list_.append("|") + seg_idx += 1 + if re.findall('[\u4e00-\u9fff]', p): + if pre_align_args['use_tone']: + p = pinyin(p, style=Style.TONE3, strict=True)[0][0] + if p[-1] not in ['1', '2', '3', '4', '5']: + p = p + '5' + else: + p = pinyin(p, style=Style.NORMAL, strict=True)[0][0] + + finished = False + if len([c.isalpha() for c in p]) > 1: + for shenmu in ALL_SHENMU: + if p.startswith(shenmu) and not p.lstrip(shenmu).isnumeric(): + ph_list_ += [shenmu, p.lstrip(shenmu)] + finished = True + break + if not finished: + ph_list_.append(p) + + ph_list = ph_list_ + + # 去除静音符号周围的词边界标记 [..., '#', ',', '#', ...] + sil_phonemes = list(PUNCS) + TxtProcessor.sp_phonemes() + ph_list_ = [] + for i in range(0, len(ph_list), 1): + if ph_list[i] != '#' or (ph_list[i - 1] not in sil_phonemes and ph_list[i + 1] not in sil_phonemes): + ph_list_.append(ph_list[i]) + ph_list = ph_list_ + return ph_list, txt + + +if __name__ == '__main__': + phs, txt = TxtProcessor.process('他来到了,网易杭研大厦', {'use_tone': True}) + print(phs) diff --git a/inference/m4singer/base_svs_infer.py b/inference/m4singer/base_svs_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..380fc594895ab8904f33cc091d396d2723216619 --- /dev/null +++ b/inference/m4singer/base_svs_infer.py @@ -0,0 +1,242 @@ +import os + +import torch +import numpy as np +from modules.hifigan.hifigan import HifiGanGenerator +from vocoders.hifigan import HifiGAN +from inference.m4singer.m4singer.map import m4singer_pinyin2ph_func + +from utils import load_ckpt +from utils.hparams import set_hparams, hparams +from utils.text_encoder import TokenTextEncoder +from pypinyin import pinyin, lazy_pinyin, Style +import librosa +import glob +import re + + +class BaseSVSInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.device = device + + phone_list = ["", "", "a", "ai", "an", "ang", "ao", "b", "c", "ch", "d", "e", "ei", "en", "eng", "er", "f", "g", "h", + "i", "ia", "ian", "iang", "iao", "ie", "in", "ing", "iong", "iou", "j", "k", "l", "m", "n", "o", "ong", "ou", + "p", "q", "r", "s", "sh", "t", "u", "ua", "uai", "uan", "uang", "uei", "uen", "uo", "v", "van", "ve", "vn", + "x", "z", "zh"] + self.ph_encoder = TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + self.pinyin2phs = m4singer_pinyin2ph_func() + self.spk_map = {"Alto-1": 0, "Alto-2": 1, "Alto-3": 2, "Alto-4": 3, "Alto-5": 4, "Alto-6": 5, "Alto-7": 6, "Bass-1": 7, + "Bass-2": 8, "Bass-3": 9, "Soprano-1": 10, "Soprano-2": 11, "Soprano-3": 12, "Tenor-1": 13, "Tenor-2": 14, + "Tenor-3": 15, "Tenor-4": 16, "Tenor-5": 17, "Tenor-6": 18, "Tenor-7": 19} + + self.model = self.build_model() + self.model.eval() + self.model.to(self.device) + self.vocoder = self.build_vocoder() + self.vocoder.eval() + self.vocoder.to(self.device) + + def build_model(self): + raise NotImplementedError + + def forward_model(self, inp): + raise NotImplementedError + + def build_vocoder(self): + base_dir = hparams['vocoder_ckpt'] + config_path = f'{base_dir}/config.yaml' + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load HifiGAN: ', ckpt) + ckpt_dict = torch.load(ckpt, map_location="cpu") + config = set_hparams(config_path, global_hparams=False) + state = ckpt_dict["state_dict"]["model_gen"] + vocoder = HifiGanGenerator(config) + vocoder.load_state_dict(state, strict=True) + vocoder.remove_weight_norm() + vocoder = vocoder.eval().to(self.device) + return vocoder + + def run_vocoder(self, c, **kwargs): + c = c.transpose(2, 1) # [B, 80, T] + f0 = kwargs.get('f0') # [B, T] + if f0 is not None and hparams.get('use_nsf'): + # f0 = torch.FloatTensor(f0).to(self.device) + y = self.vocoder(c, f0).view(-1) + else: + y = self.vocoder(c).view(-1) + # [T] + return y[None] + + def preprocess_word_level_input(self, inp): + # Pypinyin can't solve polyphonic words + text_raw = inp['text'] + + # lyric + pinyins = lazy_pinyin(text_raw, strict=False) + ph_per_word_lst = [self.pinyin2phs[pinyin.strip()] for pinyin in pinyins if pinyin.strip() in self.pinyin2phs] + + # Note + note_per_word_lst = [x.strip() for x in inp['notes'].split('|') if x.strip() != ''] + mididur_per_word_lst = [x.strip() for x in inp['notes_duration'].split('|') if x.strip() != ''] + + if len(note_per_word_lst) == len(ph_per_word_lst) == len(mididur_per_word_lst): + print('Pass word-notes check.') + else: + print('The number of words does\'t match the number of notes\' windows. ', + 'You should split the note(s) for each word by | mark.') + print(ph_per_word_lst, note_per_word_lst, mididur_per_word_lst) + print(len(ph_per_word_lst), len(note_per_word_lst), len(mididur_per_word_lst)) + return None + + note_lst = [] + ph_lst = [] + midi_dur_lst = [] + is_slur = [] + for idx, ph_per_word in enumerate(ph_per_word_lst): + # for phs in one word: + # single ph like ['ai'] or multiple phs like ['n', 'i'] + ph_in_this_word = ph_per_word.split() + + # for notes in one word: + # single note like ['D4'] or multiple notes like ['D4', 'E4'] which means a 'slur' here. + note_in_this_word = note_per_word_lst[idx].split() + midi_dur_in_this_word = mididur_per_word_lst[idx].split() + # process for the model input + # Step 1. + # Deal with note of 'not slur' case or the first note of 'slur' case + # j ie + # F#4/Gb4 F#4/Gb4 + # 0 0 + for ph in ph_in_this_word: + ph_lst.append(ph) + note_lst.append(note_in_this_word[0]) + midi_dur_lst.append(midi_dur_in_this_word[0]) + is_slur.append(0) + # step 2. + # Deal with the 2nd, 3rd... notes of 'slur' case + # j ie ie + # F#4/Gb4 F#4/Gb4 C#4/Db4 + # 0 0 1 + if len(note_in_this_word) > 1: # is_slur = True, we should repeat the YUNMU to match the 2nd, 3rd... notes. + for idx in range(1, len(note_in_this_word)): + ph_lst.append(ph_in_this_word[-1]) + note_lst.append(note_in_this_word[idx]) + midi_dur_lst.append(midi_dur_in_this_word[idx]) + is_slur.append(1) + ph_seq = ' '.join(ph_lst) + + if len(ph_lst) == len(note_lst) == len(midi_dur_lst): + print(len(ph_lst), len(note_lst), len(midi_dur_lst)) + print('Pass word-notes check.') + else: + print('The number of words does\'t match the number of notes\' windows. ', + 'You should split the note(s) for each word by | mark.') + return None + return ph_seq, note_lst, midi_dur_lst, is_slur + + def preprocess_phoneme_level_input(self, inp): + ph_seq = inp['ph_seq'] + note_lst = inp['note_seq'].split() + midi_dur_lst = inp['note_dur_seq'].split() + is_slur = [float(x) for x in inp['is_slur_seq'].split()] + print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst)) + if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst): + print('Pass word-notes check.') + else: + print('The number of words does\'t match the number of notes\' windows. ', + 'You should split the note(s) for each word by | mark.') + return None + return ph_seq, note_lst, midi_dur_lst, is_slur + + def preprocess_input(self, inp, input_type='word'): + """ + + :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} + :return: + """ + + item_name = inp.get('item_name', '') + spk_name = inp.get('spk_name', 'Alto-1') + + # single spk + spk_id = self.spk_map[spk_name] + + # get ph seq, note lst, midi dur lst, is slur lst. + if input_type == 'word': + ret = self.preprocess_word_level_input(inp) + elif input_type == 'phoneme': + ret = self.preprocess_phoneme_level_input(inp) + else: + print('Invalid input type.') + return None + + if ret: + ph_seq, note_lst, midi_dur_lst, is_slur = ret + else: + print('==========> Preprocess_word_level or phone_level input wrong.') + return None + + # convert note lst to midi id; convert note dur lst to midi duration + try: + midis = [librosa.note_to_midi(x.split("/")[0]) if x != 'rest' else 0 + for x in note_lst] + midi_dur_lst = [float(x) for x in midi_dur_lst] + except Exception as e: + print(e) + print('Invalid Input Type.') + return None + + ph_token = self.ph_encoder.encode(ph_seq) + item = {'item_name': item_name, 'text': inp['text'], 'ph': ph_seq, 'spk_id': spk_id, + 'ph_token': ph_token, 'pitch_midi': np.asarray(midis), 'midi_dur': np.asarray(midi_dur_lst), + 'is_slur': np.asarray(is_slur), } + item['ph_len'] = len(item['ph_token']) + return item + + def input_to_batch(self, item): + item_names = [item['item_name']] + text = [item['text']] + ph = [item['ph']] + txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) + txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) + spk_ids = torch.LongTensor([item['spk_id']])[:].to(self.device) + + pitch_midi = torch.LongTensor(item['pitch_midi'])[None, :hparams['max_frames']].to(self.device) + midi_dur = torch.FloatTensor(item['midi_dur'])[None, :hparams['max_frames']].to(self.device) + is_slur = torch.LongTensor(item['is_slur'])[None, :hparams['max_frames']].to(self.device) + + batch = { + 'item_name': item_names, + 'text': text, + 'ph': ph, + 'txt_tokens': txt_tokens, + 'txt_lengths': txt_lengths, + 'spk_ids': spk_ids, + 'pitch_midi': pitch_midi, + 'midi_dur': midi_dur, + 'is_slur': is_slur + } + return batch + + def postprocess_output(self, output): + return output + + def infer_once(self, inp): + inp = self.preprocess_input(inp, input_type=inp['input_type'] if inp.get('input_type') else 'word') + output = self.forward_model(inp) + output = self.postprocess_output(output) + return output + + @classmethod + def example_run(cls, inp): + from utils.audio import save_wav + set_hparams(print_hparams=False) + infer_ins = cls(hparams) + out = infer_ins.infer_once(inp) + os.makedirs('infer_out', exist_ok=True) + f_name = inp['spk_name'] + ' | ' + inp['text'] + save_wav(out, f'infer_out/{f_name}.wav', hparams['audio_sample_rate']) \ No newline at end of file diff --git a/inference/m4singer/ds_e2e.py b/inference/m4singer/ds_e2e.py new file mode 100644 index 0000000000000000000000000000000000000000..566b1e00389e1907a68d5a6ae4f156fb84d50b16 --- /dev/null +++ b/inference/m4singer/ds_e2e.py @@ -0,0 +1,67 @@ +import torch +# from inference.tts.fs import FastSpeechInfer +# from modules.tts.fs2_orig import FastSpeech2Orig +from inference.m4singer.base_svs_infer import BaseSVSInfer +from utils import load_ckpt +from utils.hparams import hparams +from usr.diff.shallow_diffusion_tts import GaussianDiffusion +from usr.diffsinger_task import DIFF_DECODERS +from modules.fastspeech.pe import PitchExtractor +import utils + + +class DiffSingerE2EInfer(BaseSVSInfer): + def build_model(self): + model = GaussianDiffusion( + phone_encoder=self.ph_encoder, + out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + K_step=hparams['K_step'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + model.eval() + load_ckpt(model, hparams['work_dir'], 'model') + + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + self.pe = PitchExtractor().to(self.device) + utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) + self.pe.eval() + return model + + def forward_model(self, inp): + sample = self.input_to_batch(inp) + txt_tokens = sample['txt_tokens'] # [B, T_t] + spk_id = sample.get('spk_ids') + with torch.no_grad(): + output = self.model(txt_tokens, spk_embed=spk_id, ref_mels=None, infer=True, + pitch_midi=sample['pitch_midi'], midi_dur=sample['midi_dur'], + is_slur=sample['is_slur']) + mel_out = output['mel_out'] # [B, T,80] + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + f0_pred = self.pe(mel_out)['f0_denorm_pred'] # pe predict from Pred mel + else: + f0_pred = output['f0_denorm'] + wav_out = self.run_vocoder(mel_out, f0=f0_pred) + wav_out = wav_out.cpu().numpy() + return wav_out[0] + +if __name__ == '__main__': + inp = { + 'spk_name': 'Tenor-1', + 'text': 'AP你要相信AP相信我们会像童话故事里AP', + 'notes': 'rest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest', + 'notes_duration': '0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14', + 'input_type': 'word', + } + + c = { + 'spk_name': 'Tenor-1', + 'text': '你要相信相信我们会像童话故事里', + 'ph_seq': ' n i iao iao x iang x in in x iang iang x in uo uo m en h uei x iang t ong ong h ua g u u sh i l i ', + 'note_seq': 'rest G#3 G#3 A#3 C4 D#4 D#4 D#4 D#4 F4 rest E4 E4 F4 F4 F4 D#4 A#3 A#3 A#3 A#3 A#3 C#4 C#4 B3 B3 C4 C#4 C#4 B3 B3 C4 A#3 A#3 G#3 G#3 rest', + 'note_dur_seq': '0.14 0.47 0.47 0.1905 0.1895 0.41 0.41 0.3005 0.3005 0.3895 0.21 0.2391 0.2391 0.1809 0.32 0.32 0.4105 0.2095 0.35 0.35 0.43 0.43 0.45 0.45 0.2309 0.2309 0.2291 0.48 0.48 0.225 0.225 0.195 0.29 0.29 0.71 0.71 0.14', + 'is_slur_seq': '0 0 0 0 1 0 0 0 0 1 0 0 0 1 0 0 0 1 0 0 0 0 0 0 0 0 1 0 0 0 0 1 0 0 0 0 0', + 'input_type': 'phoneme' + } + DiffSingerE2EInfer.example_run(inp) diff --git a/inference/m4singer/gradio/gradio_settings.yaml b/inference/m4singer/gradio/gradio_settings.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5266b23e2c7ad57d221d67174ba7f5efa47ff7c3 --- /dev/null +++ b/inference/m4singer/gradio/gradio_settings.yaml @@ -0,0 +1,48 @@ +title: 'M4Singer' +description: | + This page aims to display the singing voice synthesis function of M4Singer. SingerID can be switched freely to preview the timbre of each singer. Click examples below to quickly load scores and audio. + (本页面为M4Singer歌声合成功能展示。SingerID可以自由切换用以预览各歌手的音色。点击下方Examples可以快速加载乐谱和音频。) + + Please assign pitch and duration values to each Chinese character. The corresponding pitch and duration value of each character should be separated by a | separator. It is necessary to ensure that the note window separated by the separator is consistent with the number of Chinese characters. AP (aspirate) or SP (silence) is also viewed as a Chinese character. + (请给每个汉字分配音高和时值, 每个字对应的音高和时值需要用 | 分隔符隔开。需要保证分隔符分割出来的音符窗口与汉字个数一致。换气或静音符也算一个汉字。) + + The notes corresponding to AP and SP are fixed as rest. If there are multiple notes in a window (| .... |), it means that the Chinese character corresponding to the window is glissando, and each note needs to be assigned a duration. + (AP和SP对应的音符固定为rest。若一个窗口(| .... |)内有多个音符, 代表该窗口对应的汉字为滑音, 需要为每个音符都分配时长。) + +article: | + Note: This page is running on CPU, please refer to Github REPO for the local running solutions and for our dataset. + + -------- + If our work is useful for your research, please consider citing: + ```bibtex + @inproceedings{ + zhang2022msinger, + title={M4Singer: A Multi-Style, Multi-Singer and Musical Score Provided Mandarin Singing Corpus}, + author={Lichao Zhang and Ruiqi Li and Shoutong Wang and Liqun Deng and Jinglin Liu and Yi Ren and Jinzheng He and Rongjie Huang and Jieming Zhu and Xiao Chen and Zhou Zhao}, + booktitle={Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track}, + year={2022}, + } + ``` + + ![visitors](https://visitor-badge.laobi.icu/badge?page_id=zlc99/M4Singer) +example_inputs: + - |- + Tenor-1AP你要相信AP相信我们会像童话故事里APrest | G#3 | A#3 C4 | D#4 | D#4 F4 | rest | E4 F4 | F4 | D#4 A#3 | A#3 | A#3 | C#4 | B3 C4 | C#4 | B3 C4 | A#3 | G#3 | rest0.14 | 0.47 | 0.1905 0.1895 | 0.41 | 0.3005 0.3895 | 0.21 | 0.2391 0.1809 | 0.32 | 0.4105 0.2095 | 0.35 | 0.43 | 0.45 | 0.2309 0.2291 | 0.48 | 0.225 0.195 | 0.29 | 0.71 | 0.14 + - |- + Tenor-1AP因为在一千年以后AP世界早已没有我APrest | C#4 | D4 | E4 | F#4 | E4 | D4 G#3 | A3 | D4 E4 | rest | F#4 | E4 | D4 | C#4 | B3 F#3 | F#3 | C4 C#4 | rest0.18 | 0.32 | 0.38 | 0.81 | 0.38 | 0.39 | 0.3155 0.2045 | 0.28 | 0.4609 1.0291 | 0.27 | 0.42 | 0.15 | 0.53 | 0.22 | 0.3059 0.2841 | 0.4 | 0.2909 1.1091 | 0.3 + - |- + Tenor-2AP可是你在敲打AP我的窗棂APrest | G#3 | B3 | B3 C#4 | E4 | C#4 B3 | G#3 | rest | C3 | E3 | B3 G#3 | F#3 | rest0.2 | 0.38 | 0.48 | 0.41 0.72 | 0.39 | 0.5195 0.2905 | 0.5 | 0.33 | 0.4 | 0.31 | 0.565 0.265 | 1.15 | 0.24 + - |- + Tenor-2SP一杯敬朝阳一杯敬月光APrest | G#3 | G#3 | G#3 | G3 | G3 G#3 | G3 | C4 | C4 | A#3 | C4 | rest0.33 | 0.26 | 0.23 | 0.27 | 0.36 | 0.3159 0.4041 | 0.54 | 0.21 | 0.32 | 0.24 | 0.58 | 0.17 + - |- + Soprano-1SP乱石穿空AP惊涛拍岸APrest | C#5 | D#5 | F5 D#5 | C#5 | rest | C#5 | C#5 | C#5 G#4 | G#4 | rest0.325 | 0.75 | 0.54 | 0.48 0.55 | 1.38 | 0.31 | 0.55 | 0.48 | 0.4891 0.4709 | 1.15 | 0.22 + - |- + Soprano-1AP点点滴滴染绿了村寨APrest | C5 | A#4 | C5 | D#5 F5 D#5 | D#5 | C5 | C5 | C5 | A#4 | rest0.175 | 0.24 | 0.26 | 1.08 | 0.3541 0.4364 0.2195 | 0.47 | 0.27 | 0.12 | 0.51 | 0.72 | 0.15 + - |- + Alto-2AP拒绝声色的张扬AP不拒绝你APrest | C4 | C4 | C4 | B3 A3 | C4 | C4 D4 | D4 | rest | D4 | D4 | C4 | G4 E4 | rest0.49 | 0.31 | 0.18 | 0.48 | 0.3 0.4 | 0.25 | 0.3591 0.2409 | 0.46 | 0.34 | 0.4 | 0.45 | 0.45 | 2.4545 0.9855 | 0.215 + - |- + Alto-2AP半醒着AP笑着哭着都快活APrest | D4 | B3 | C4 D4 | rest | E4 | D4 | E4 | D4 | E4 | E4 F#4 | F4 F#4 | rest0.165 | 0.45 | 0.53 | 0.3859 0.2441 | 0.35 | 0.38 | 0.17 | 0.32 | 0.26 | 0.33 | 0.38 0.21 | 0.3309 0.9491 | 0.125 + + +inference_cls: inference.m4singer.ds_e2e.DiffSingerE2EInfer +exp_name: m4singer_diff_e2e \ No newline at end of file diff --git a/inference/m4singer/gradio/infer.py b/inference/m4singer/gradio/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..d18159303f6327bb3efe265c7799e277e8126acb --- /dev/null +++ b/inference/m4singer/gradio/infer.py @@ -0,0 +1,143 @@ +import importlib +import re + +import gradio as gr +import yaml +from gradio.components import Textbox, Dropdown + +from inference.m4singer.base_svs_infer import BaseSVSInfer +from utils.hparams import set_hparams +from utils.hparams import hparams as hp +import numpy as np +from inference.m4singer.gradio.share_btn import community_icon_html, loading_icon_html, share_js + +class GradioInfer: + def __init__(self, exp_name, inference_cls, title, description, article, example_inputs): + self.exp_name = exp_name + self.title = title + self.description = description + self.article = article + self.example_inputs = example_inputs + pkg = ".".join(inference_cls.split(".")[:-1]) + cls_name = inference_cls.split(".")[-1] + self.inference_cls = getattr(importlib.import_module(pkg), cls_name) + + def greet(self, singer, text, notes, notes_duration): + PUNCS = '。?;:' + sents = re.split(rf'([{PUNCS}])', text.replace('\n', ',')) + sents_notes = re.split(rf'([{PUNCS}])', notes.replace('\n', ',')) + sents_notes_dur = re.split(rf'([{PUNCS}])', notes_duration.replace('\n', ',')) + + if sents[-1] not in list(PUNCS): + sents = sents + [''] + sents_notes = sents_notes + [''] + sents_notes_dur = sents_notes_dur + [''] + + audio_outs = [] + s, n, n_dur = "", "", "" + for i in range(0, len(sents), 2): + if len(sents[i]) > 0: + s += sents[i] + sents[i + 1] + n += sents_notes[i] + sents_notes[i+1] + n_dur += sents_notes_dur[i] + sents_notes_dur[i+1] + if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0): + audio_out = self.infer_ins.infer_once({ + 'spk_name': singer, + 'text': s, + 'notes': n, + 'notes_duration': n_dur, + }) + audio_out = audio_out * 32767 + audio_out = audio_out.astype(np.int16) + audio_outs.append(audio_out) + audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16)) + s = "" + n = "" + audio_outs = np.concatenate(audio_outs) + return (hp['audio_sample_rate'], audio_outs), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) + + def run(self): + set_hparams(config=f'checkpoints/{self.exp_name}/config.yaml', exp_name=self.exp_name, print_hparams=False) + infer_cls = self.inference_cls + self.infer_ins: BaseSVSInfer = infer_cls(hp) + example_inputs = self.example_inputs + for i in range(len(example_inputs)): + singer, text, notes, notes_dur = example_inputs[i].split('') + example_inputs[i] = [singer, text, notes, notes_dur] + + singerList = \ + [ + 'Tenor-1', 'Tenor-2', 'Tenor-3', 'Tenor-4', 'Tenor-5', 'Tenor-6', 'Tenor-7', + 'Alto-1', 'Alto-2', 'Alto-3', 'Alto-4', 'Alto-5', 'Alto-6', 'Alto-7', + 'Soprano-1', 'Soprano-2', 'Soprano-3', + 'Bass-1', 'Bass-2', 'Bass-3', + ] + + css = """ + #share-btn-container { + display: flex; padding-left: 0.5rem !important; padding-right: 0.5rem !important; background-color: #000000; justify-content: center; align-items: center; border-radius: 9999px !important; width: 13rem; + } + #share-btn { + all: initial; color: #ffffff;font-weight: 600; cursor:pointer; font-family: 'IBM Plex Sans', sans-serif; margin-left: 0.5rem !important; padding-top: 0.25rem !important; padding-bottom: 0.25rem !important;right:0; + } + #share-btn * { + all: unset; + } + #share-btn-container div:nth-child(-n+2){ + width: auto !important; + min-height: 0px !important; + } + #share-btn-container .wrap { + display: none !important; + } + """ + with gr.Blocks(css=css) as demo: + gr.HTML("""
+
+

+ M4Singer +

+
+
+ """ + ) + gr.Markdown(self.description) + with gr.Row(): + with gr.Column(): + singer_l = Dropdown(choices=singerList, value=example_inputs[0][0], label="SingerID", elem_id="inp_singer") + inp_text = Textbox(lines=2, placeholder=None, value=example_inputs[0][1], label="input text", elem_id="inp_text") + inp_note = Textbox(lines=2, placeholder=None, value=example_inputs[0][2], label="input note", elem_id="inp_note") + inp_duration = Textbox(lines=2, placeholder=None, value=example_inputs[0][3], label="input duration", elem_id="inp_duration") + generate = gr.Button("Generate Singing Voice from Musical Score") + with gr.Column(lem_id="col-container"): + singing_output = gr.Audio(label="Result", type="numpy", elem_id="music-output") + + with gr.Group(elem_id="share-btn-container"): + community_icon = gr.HTML(community_icon_html, visible=False) + loading_icon = gr.HTML(loading_icon_html, visible=False) + share_button = gr.Button("Share to community", elem_id="share-btn", visible=False) + gr.Examples(examples=self.example_inputs, + inputs=[singer_l, inp_text, inp_note, inp_duration], + outputs=[singing_output, share_button, community_icon, loading_icon], + fn=self.greet, + cache_examples=True) + gr.Markdown(self.article) + generate.click(self.greet, + inputs=[singer_l, inp_text, inp_note, inp_duration], + outputs=[singing_output, share_button, community_icon, loading_icon],) + share_button.click(None, [], [], _js=share_js) + demo.queue().launch(share=False) + + +if __name__ == '__main__': + gradio_config = yaml.safe_load(open('inference/m4singer/gradio/gradio_settings.yaml')) + g = GradioInfer(**gradio_config) + g.run() + diff --git a/inference/m4singer/gradio/share_btn.py b/inference/m4singer/gradio/share_btn.py new file mode 100644 index 0000000000000000000000000000000000000000..9e054130eebdb38ba6d33565d49207238f3aa244 --- /dev/null +++ b/inference/m4singer/gradio/share_btn.py @@ -0,0 +1,86 @@ +community_icon_html = """""" + +loading_icon_html = """""" + +share_js = """async () => { + async function uploadFile(file){ + const UPLOAD_URL = 'https://huggingface.co/uploads'; + const response = await fetch(UPLOAD_URL, { + method: 'POST', + headers: { + 'Content-Type': file.type, + 'X-Requested-With': 'XMLHttpRequest', + }, + body: file, /// <- File inherits from Blob + }); + const url = await response.text(); + return url; + } + + async function getOutputMusicFile(audioEL){ + const res = await fetch(audioEL.src); + const blob = await res.blob(); + const audioId = Date.now() % 200; + const fileName = `SVS-${{audioId}}.wav`; + const musicBlob = new File([blob], fileName, { type: 'audio/wav' }); + return musicBlob; + } + + const gradioEl = document.querySelector('body > gradio-app'); + + //const gradioEl = document.querySelector("gradio-app").shadowRoot; + const inputSinger = gradioEl.querySelector('#inp_singer select').value; + const inputText = gradioEl.querySelector('#inp_text textarea').value; + const inputNote = gradioEl.querySelector('#inp_note textarea').value; + const inputDuration = gradioEl.querySelector('#inp_duration textarea').value; + const outputMusic = gradioEl.querySelector('#music-output audio'); + const outputMusic_src = gradioEl.querySelector('#music-output audio').src; + + const outputMusic_name = outputMusic_src.split('/').pop(); + let titleTxt = outputMusic_name; + if(titleTxt.length > 30){ + titleTxt = 'demo'; + } + const shareBtnEl = gradioEl.querySelector('#share-btn'); + const shareIconEl = gradioEl.querySelector('#share-btn-share-icon'); + const loadingIconEl = gradioEl.querySelector('#share-btn-loading-icon'); + if(!outputMusic){ + return; + }; + shareBtnEl.style.pointerEvents = 'none'; + shareIconEl.style.display = 'none'; + loadingIconEl.style.removeProperty('display'); + const musicFile = await getOutputMusicFile(outputMusic); + const dataOutputMusic = await uploadFile(musicFile); + const descriptionMd = `#### Input Musical Score: +${inputSinger} + +${inputText} + +${inputNote} + +${inputDuration} + +#### Singing Voice: + + +`; + const params = new URLSearchParams({ + title: titleTxt, + description: descriptionMd, + }); + const paramsStr = params.toString(); + window.open(`https://huggingface.co/spaces/zlc99/M4Singer/discussions/new?${paramsStr}`, '_blank'); + shareBtnEl.style.removeProperty('pointer-events'); + shareIconEl.style.removeProperty('display'); + loadingIconEl.style.display = 'none'; +}""" \ No newline at end of file diff --git a/inference/m4singer/m4singer/m4singer_pinyin2ph.txt b/inference/m4singer/m4singer/m4singer_pinyin2ph.txt new file mode 100644 index 0000000000000000000000000000000000000000..1b069a94e29e115f1b44bb516de59bb3a4d84486 --- /dev/null +++ b/inference/m4singer/m4singer/m4singer_pinyin2ph.txt @@ -0,0 +1,413 @@ +| a | a | +| ai | ai | +| an | an | +| ang | ang | +| ao | ao | +| ba | b a | +| bai | b ai | +| ban | b an | +| bang | b ang | +| bao | b ao | +| bei | b ei | +| ben | b en | +| beng | b eng | +| bi | b i | +| bian | b ian | +| biao | b iao | +| bie | b ie | +| bin | b in | +| bing | b ing | +| bo | b o | +| bu | b u | +| ca | c a | +| cai | c ai | +| can | c an | +| cang | c ang | +| cao | c ao | +| ce | c e | +| cei | c ei | +| cen | c en | +| ceng | c eng | +| cha | ch a | +| chai | ch ai | +| chan | ch an | +| chang | ch ang | +| chao | ch ao | +| che | ch e | +| chen | ch en | +| cheng | ch eng | +| chi | ch i | +| chong | ch ong | +| chou | ch ou | +| chu | ch u | +| chua | ch ua | +| chuai | ch uai | +| chuan | ch uan | +| chuang | ch uang | +| chui | ch uei | +| chun | ch uen | +| chuo | ch uo | +| ci | c i | +| cong | c ong | +| cou | c ou | +| cu | c u | +| cuan | c uan | +| cui | c uei | +| cun | c uen | +| cuo | c uo | +| da | d a | +| dai | d ai | +| dan | d an | +| dang | d ang | +| dao | d ao | +| de | d e | +| dei | d ei | +| den | d en | +| deng | d eng | +| di | d i | +| dia | d ia | +| dian | d ian | +| diao | d iao | +| die | d ie | +| ding | d ing | +| diu | d iou | +| dong | d ong | +| dou | d ou | +| du | d u | +| duan | d uan | +| dui | d uei | +| dun | d uen | +| duo | d uo | +| e | e | +| ei | ei | +| en | en | +| eng | eng | +| er | er | +| fa | f a | +| fan | f an | +| fang | f ang | +| fei | f ei | +| fen | f en | +| feng | f eng | +| fo | f o | +| fou | f ou | +| fu | f u | +| ga | g a | +| gai | g ai | +| gan | g an | +| gang | g ang | +| gao | g ao | +| ge | g e | +| gei | g ei | +| gen | g en | +| geng | g eng | +| gong | g ong | +| gou | g ou | +| gu | g u | +| gua | g ua | +| guai | g uai | +| guan | g uan | +| guang | g uang | +| gui | g uei | +| gun | g uen | +| guo | g uo | +| ha | h a | +| hai | h ai | +| han | h an | +| hang | h ang | +| hao | h ao | +| he | h e | +| hei | h ei | +| hen | h en | +| heng | h eng | +| hong | h ong | +| hou | h ou | +| hu | h u | +| hua | h ua | +| huai | h uai | +| huan | h uan | +| huang | h uang | +| hui | h uei | +| hun | h uen | +| huo | h uo | +| ji | j i | +| jia | j ia | +| jian | j ian | +| jiang | j iang | +| jiao | j iao | +| jie | j ie | +| jin | j in | +| jing | j ing | +| jiong | j iong | +| jiu | j iou | +| ju | j v | +| juan | j van | +| jue | j ve | +| jun | j vn | +| ka | k a | +| kai | k ai | +| kan | k an | +| kang | k ang | +| kao | k ao | +| ke | k e | +| kei | k ei | +| ken | k en | +| keng | k eng | +| kong | k ong | +| kou | k ou | +| ku | k u | +| kua | k ua | +| kuai | k uai | +| kuan | k uan | +| kuang | k uang | +| kui | k uei | +| kun | k uen | +| kuo | k uo | +| la | l a | +| lai | l ai | +| lan | l an | +| lang | l ang | +| lao | l ao | +| le | l e | +| lei | l ei | +| leng | l eng | +| li | l i | +| lia | l ia | +| lian | l ian | +| liang | l iang | +| liao | l iao | +| lie | l ie | +| lin | l in | +| ling | l ing | +| liu | l iou | +| lo | l o | +| long | l ong | +| lou | l ou | +| lu | l u | +| luan | l uan | +| lun | l uen | +| luo | l uo | +| lv | l v | +| lve | l ve | +| m | m | +| ma | m a | +| mai | m ai | +| man | m an | +| mang | m ang | +| mao | m ao | +| me | m e | +| mei | m ei | +| men | m en | +| meng | m eng | +| mi | m i | +| mian | m ian | +| miao | m iao | +| mie | m ie | +| min | m in | +| ming | m ing | +| miu | m iou | +| mo | m o | +| mou | m ou | +| mu | m u | +| n | n | +| na | n a | +| nai | n ai | +| nan | n an | +| nang | n ang | +| nao | n ao | +| ne | n e | +| nei | n ei | +| nen | n en | +| neng | n eng | +| ni | n i | +| nian | n ian | +| niang | n iang | +| niao | n iao | +| nie | n ie | +| nin | n in | +| ning | n ing | +| niu | n iou | +| nong | n ong | +| nou | n ou | +| nu | n u | +| nuan | n uan | +| nuo | n uo | +| nv | n v | +| nve | n ve | +| o | o | +| ou | ou | +| pa | p a | +| pai | p ai | +| pan | p an | +| pang | p ang | +| pao | p ao | +| pei | p ei | +| pen | p en | +| peng | p eng | +| pi | p i | +| pian | p ian | +| piao | p iao | +| pie | p ie | +| pin | p in | +| ping | p ing | +| po | p o | +| pou | p ou | +| pu | p u | +| qi | q i | +| qia | q ia | +| qian | q ian | +| qiang | q iang | +| qiao | q iao | +| qie | q ie | +| qin | q in | +| qing | q ing | +| qiong | q iong | +| qiu | q iou | +| qu | q v | +| quan | q van | +| que | q ve | +| qun | q vn | +| ran | r an | +| rang | r ang | +| rao | r ao | +| re | r e | +| ren | r en | +| reng | r eng | +| ri | r i | +| rong | r ong | +| rou | r ou | +| ru | r u | +| rua | r ua | +| ruan | r uan | +| rui | r uei | +| run | r uen | +| ruo | r uo | +| sa | s a | +| sai | s ai | +| san | s an | +| sang | s ang | +| sao | s ao | +| se | s e | +| sen | s en | +| seng | s eng | +| sha | sh a | +| shai | sh ai | +| shan | sh an | +| shang | sh ang | +| shao | sh ao | +| she | sh e | +| shei | sh ei | +| shen | sh en | +| sheng | sh eng | +| shi | sh i | +| shou | sh ou | +| shu | sh u | +| shua | sh ua | +| shuai | sh uai | +| shuan | sh uan | +| shuang | sh uang | +| shui | sh uei | +| shun | sh uen | +| shuo | sh uo | +| si | s i | +| song | s ong | +| sou | s ou | +| su | s u | +| suan | s uan | +| sui | s uei | +| sun | s uen | +| suo | s uo | +| ta | t a | +| tai | t ai | +| tan | t an | +| tang | t ang | +| tao | t ao | +| te | t e | +| tei | t ei | +| teng | t eng | +| ti | t i | +| tian | t ian | +| tiao | t iao | +| tie | t ie | +| ting | t ing | +| tong | t ong | +| tou | t ou | +| tu | t u | +| tuan | t uan | +| tui | t uei | +| tun | t uen | +| tuo | t uo | +| wa | ua | +| wai | uai | +| wan | uan | +| wang | uang | +| wei | uei | +| wen | uen | +| weng | ueng | +| wo | uo | +| wu | u | +| xi | x i | +| xia | x ia | +| xian | x ian | +| xiang | x iang | +| xiao | x iao | +| xie | x ie | +| xin | x in | +| xing | x ing | +| xiong | x iong | +| xiu | x iou | +| xu | x v | +| xuan | x van | +| xue | x ve | +| xun | x vn | +| ya | ia | +| yan | ian | +| yang | iang | +| yao | iao | +| ye | ie | +| yi | i | +| yin | in | +| ying | ing | +| yong | iong | +| you | iou | +| yu | v | +| yuan | van | +| yue | ve | +| yun | vn | +| za | z a | +| zai | z ai | +| zan | z an | +| zang | z ang | +| zao | z ao | +| ze | z e | +| zei | z ei | +| zen | z en | +| zeng | z eng | +| zha | zh a | +| zhai | zh ai | +| zhan | zh an | +| zhang | zh ang | +| zhao | zh ao | +| zhe | zh e | +| zhei | zh ei | +| zhen | zh en | +| zheng | zh eng | +| zhi | zh i | +| zhong | zh ong | +| zhou | zh ou | +| zhu | zh u | +| zhua | zh ua | +| zhuai | zh uai | +| zhuan | zh uan | +| zhuang | zh uang | +| zhui | zh uei | +| zhun | zh uen | +| zhuo | zh uo | +| zi | z i | +| zong | z ong | +| zou | z ou | +| zu | z u | +| zuan | z uan | +| zui | z uei | +| zun | z uen | +| zuo | z uo | \ No newline at end of file diff --git a/inference/m4singer/m4singer/map.py b/inference/m4singer/m4singer/map.py new file mode 100644 index 0000000000000000000000000000000000000000..3059bfc754464ef7018c4479fd044fda3d4693a8 --- /dev/null +++ b/inference/m4singer/m4singer/map.py @@ -0,0 +1,7 @@ +def m4singer_pinyin2ph_func(): + pinyin2phs = {'AP': '', 'SP': ''} + with open('inference/m4singer/m4singer/m4singer_pinyin2ph.txt') as rf: + for line in rf.readlines(): + elements = [x.strip() for x in line.split('|') if x.strip() != ''] + pinyin2phs[elements[0]] = elements[1] + return pinyin2phs \ No newline at end of file diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..8701c5ba5bb519ade02864da34115911d7eb9c7e --- /dev/null +++ b/modules/commons/common_layers.py @@ -0,0 +1,668 @@ +import math +import torch +from torch import nn +from torch.nn import Parameter +import torch.onnx.operators +import torch.nn.functional as F +import utils + + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + + +class Permute(nn.Module): + def __init__(self, *args): + super(Permute, self).__init__() + self.args = args + + def forward(self, x): + return x.permute(self.args) + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if not export and torch.cuda.is_available(): + try: + from apex.normalization import FusedLayerNorm + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = utils.make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class ConvTBC(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + + self.weight = torch.nn.Parameter(torch.Tensor( + self.kernel_size, in_channels, out_channels)) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + + def forward(self, input): + return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding) + + +class MultiheadAttention(nn.Module): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ + 'value to be of the same size' + + if self.qkv_same_dim: + self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) + else: + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + + if bias: + self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.enable_torch_version = False + if hasattr(F, "multi_head_attention_forward"): + self.enable_torch_version = True + else: + self.enable_torch_version = False + self.last_attn_probs = None + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + nn.init.xavier_uniform_(self.q_proj_weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + + if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None: + if self.qkv_same_dim: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + self.in_proj_weight, + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask) + else: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + torch.empty([0]), + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + + if incremental_state is not None: + print('Not implemented error.') + exit() + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + print('Not implemented error.') + exit() + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e9, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e9, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = utils.softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + def in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def in_proj_q(self, query): + if self.qkv_same_dim: + return self._in_proj(query, end=self.embed_dim) + else: + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) + + def in_proj_k(self, key): + if self.qkv_same_dim: + return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) + else: + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) + + def in_proj_v(self, value): + if self.qkv_same_dim: + return self._in_proj(value, start=2 * self.embed_dim) + else: + weight = self.v_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + return attn_weights + + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class CustomSwish(nn.Module): + def forward(self, input_tensor): + return Swish.apply(input_tensor) + + +class TransformerFFNLayer(nn.Module): + def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'): + super().__init__() + self.kernel_size = kernel_size + self.dropout = dropout + self.act = act + if padding == 'SAME': + self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2) + elif padding == 'LEFT': + self.ffn_1 = nn.Sequential( + nn.ConstantPad1d((kernel_size - 1, 0), 0.0), + nn.Conv1d(hidden_size, filter_size, kernel_size) + ) + self.ffn_2 = Linear(filter_size, hidden_size) + if self.act == 'swish': + self.swish_fn = CustomSwish() + + def forward(self, x, incremental_state=None): + # x: T x B x C + if incremental_state is not None: + assert incremental_state is None, 'Nar-generation does not allow this.' + exit(1) + + x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1) + x = x * self.kernel_size ** -0.5 + + if incremental_state is not None: + x = x[-1:] + if self.act == 'gelu': + x = F.gelu(x) + if self.act == 'relu': + x = F.relu(x) + if self.act == 'swish': + x = self.swish_fn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.ffn_2(x) + return x + + +class BatchNorm1dTBC(nn.Module): + def __init__(self, c): + super(BatchNorm1dTBC, self).__init__() + self.bn = nn.BatchNorm1d(c) + + def forward(self, x): + """ + + :param x: [T, B, C] + :return: [T, B, C] + """ + x = x.permute(1, 2, 0) # [B, C, T] + x = self.bn(x) # [B, C, T] + x = x.permute(2, 0, 1) # [T, B, C] + return x + + +class EncSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, + relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'): + super().__init__() + self.c = c + self.dropout = dropout + self.num_heads = num_heads + if num_heads > 0: + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm1 = BatchNorm1dTBC(c) + self.self_attn = MultiheadAttention( + self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False, + ) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm2 = BatchNorm1dTBC(c) + self.ffn = TransformerFFNLayer( + c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act) + + def forward(self, x, encoder_padding_mask=None, **kwargs): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + if self.num_heads > 0: + residual = x + x = self.layer_norm1(x) + x, _, = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + + residual = x + x = self.layer_norm2(x) + x = self.ffn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + return x + + +class DecSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, kernel_size=9, act='gelu'): + super().__init__() + self.c = c + self.dropout = dropout + self.layer_norm1 = LayerNorm(c) + self.self_attn = MultiheadAttention( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + self.layer_norm2 = LayerNorm(c) + self.encoder_attn = MultiheadAttention( + c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False, + ) + self.layer_norm3 = LayerNorm(c) + self.ffn = TransformerFFNLayer( + c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + attn_out=None, + reset_attn_weight=None, + **kwargs, + ): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + self.layer_norm3.training = layer_norm_training + residual = x + x = self.layer_norm1(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + residual = x + x = self.layer_norm2(x) + if encoder_out is not None: + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + enc_dec_attn_constraint_mask=None, #utils.get_incremental_state(self, incremental_state, 'enc_dec_attn_constraint_mask'), + reset_attn_weight=reset_attn_weight + ) + attn_logits = attn[1] + else: + assert attn_out is not None + x = self.encoder_attn.in_proj_v(attn_out.transpose(0, 1)) + attn_logits = None + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + residual = x + x = self.layer_norm3(x) + x = self.ffn(x, incremental_state=incremental_state) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + # if len(attn_logits.size()) > 3: + # indices = attn_logits.softmax(-1).max(-1).values.sum(-1).argmax(-1) + # attn_logits = attn_logits.gather(1, + # indices[:, None, None, None].repeat(1, 1, attn_logits.size(-2), attn_logits.size(-1))).squeeze(1) + return x, attn_logits diff --git a/modules/commons/espnet_positional_embedding.py b/modules/commons/espnet_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..74decb6ab300951490ae08a4b93041a0542b5bb7 --- /dev/null +++ b/modules/commons/espnet_positional_embedding.py @@ -0,0 +1,113 @@ +import math +import torch + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + See Sec. 3.2 https://arxiv.org/abs/1809.08895 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x): + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x) + self.dropout(pos_emb) \ No newline at end of file diff --git a/modules/commons/ssim.py b/modules/commons/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0241f267ef58b24979e022b05f2a9adf768826 --- /dev/null +++ b/modules/commons/ssim.py @@ -0,0 +1,391 @@ +# ''' +# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py +# ''' +# +# import torch +# import torch.jit +# import torch.nn.functional as F +# +# +# @torch.jit.script +# def create_window(window_size: int, sigma: float, channel: int): +# ''' +# Create 1-D gauss kernel +# :param window_size: the size of gauss kernel +# :param sigma: sigma of normal distribution +# :param channel: input channel +# :return: 1D kernel +# ''' +# coords = torch.arange(window_size, dtype=torch.float) +# coords -= window_size // 2 +# +# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) +# g /= g.sum() +# +# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) +# return g +# +# +# @torch.jit.script +# def _gaussian_filter(x, window_1d, use_padding: bool): +# ''' +# Blur input with 1-D kernel +# :param x: batch of tensors to be blured +# :param window_1d: 1-D gauss kernel +# :param use_padding: padding image before conv +# :return: blured tensors +# ''' +# C = x.shape[1] +# padding = 0 +# if use_padding: +# window_size = window_1d.shape[3] +# padding = window_size // 2 +# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) +# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) +# return out +# +# +# @torch.jit.script +# def ssim(X, Y, window, data_range: float, use_padding: bool = False): +# ''' +# Calculate ssim index for X and Y +# :param X: images [B, C, H, N_bins] +# :param Y: images [B, C, H, N_bins] +# :param window: 1-D gauss kernel +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param use_padding: padding image before conv +# :return: +# ''' +# +# K1 = 0.01 +# K2 = 0.03 +# compensation = 1.0 +# +# C1 = (K1 * data_range) ** 2 +# C2 = (K2 * data_range) ** 2 +# +# mu1 = _gaussian_filter(X, window, use_padding) +# mu2 = _gaussian_filter(Y, window, use_padding) +# sigma1_sq = _gaussian_filter(X * X, window, use_padding) +# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) +# sigma12 = _gaussian_filter(X * Y, window, use_padding) +# +# mu1_sq = mu1.pow(2) +# mu2_sq = mu2.pow(2) +# mu1_mu2 = mu1 * mu2 +# +# sigma1_sq = compensation * (sigma1_sq - mu1_sq) +# sigma2_sq = compensation * (sigma2_sq - mu2_sq) +# sigma12 = compensation * (sigma12 - mu1_mu2) +# +# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) +# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. +# cs_map = cs_map.clamp_min(0.) +# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map +# +# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW +# cs = cs_map.mean(dim=(1, 2, 3)) +# +# return ssim_val, cs +# +# +# @torch.jit.script +# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8): +# ''' +# interface of ms-ssim +# :param X: a batch of images, (N,C,H,W) +# :param Y: a batch of images, (N,C,H,W) +# :param window: 1-D gauss kernel +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param weights: weights for different levels +# :param use_padding: padding image before conv +# :param eps: use for avoid grad nan. +# :return: +# ''' +# levels = weights.shape[0] +# cs_vals = [] +# ssim_vals = [] +# for _ in range(levels): +# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) +# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. +# ssim_val = ssim_val.clamp_min(eps) +# cs = cs.clamp_min(eps) +# cs_vals.append(cs) +# +# ssim_vals.append(ssim_val) +# padding = (X.shape[2] % 2, X.shape[3] % 2) +# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding) +# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding) +# +# cs_vals = torch.stack(cs_vals, dim=0) +# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0) +# return ms_ssim_val +# +# +# class SSIM(torch.jit.ScriptModule): +# __constants__ = ['data_range', 'use_padding'] +# +# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): +# ''' +# :param window_size: the size of gauss kernel +# :param window_sigma: sigma of normal distribution +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param channel: input channels (default: 3) +# :param use_padding: padding image before conv +# ''' +# super().__init__() +# assert window_size % 2 == 1, 'Window size must be odd.' +# window = create_window(window_size, window_sigma, channel) +# self.register_buffer('window', window) +# self.data_range = data_range +# self.use_padding = use_padding +# +# @torch.jit.script_method +# def forward(self, X, Y): +# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) +# return r[0] +# +# +# class MS_SSIM(torch.jit.ScriptModule): +# __constants__ = ['data_range', 'use_padding', 'eps'] +# +# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None, +# levels=None, eps=1e-8): +# ''' +# class for ms-ssim +# :param window_size: the size of gauss kernel +# :param window_sigma: sigma of normal distribution +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param channel: input channels +# :param use_padding: padding image before conv +# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) +# :param levels: number of downsampling +# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. +# ''' +# super().__init__() +# assert window_size % 2 == 1, 'Window size must be odd.' +# self.data_range = data_range +# self.use_padding = use_padding +# self.eps = eps +# +# window = create_window(window_size, window_sigma, channel) +# self.register_buffer('window', window) +# +# if weights is None: +# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] +# weights = torch.tensor(weights, dtype=torch.float) +# +# if levels is not None: +# weights = weights[:levels] +# weights = weights / weights.sum() +# +# self.register_buffer('weights', weights) +# +# @torch.jit.script_method +# def forward(self, X, Y): +# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, +# use_padding=self.use_padding, eps=self.eps) +# +# +# if __name__ == '__main__': +# print('Simple Test') +# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') +# img1 = im / 255 +# img2 = img1 * 0.5 +# +# losser = SSIM(data_range=1.).cuda() +# loss = losser(img1, img2).mean() +# +# losser2 = MS_SSIM(data_range=1.).cuda() +# loss2 = losser2(img1, img2).mean() +# +# print(loss.item()) +# print(loss2.item()) +# +# if __name__ == '__main__': +# print('Training Test') +# import cv2 +# import torch.optim +# import numpy as np +# import imageio +# import time +# +# out_test_video = False +# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF +# video_use_gif = False +# +# im = cv2.imread('test_img1.jpg', 1) +# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255. +# +# if out_test_video: +# if video_use_gif: +# fps = 0.5 +# out_wh = (im.shape[1] // 2, im.shape[0] // 2) +# suffix = '.gif' +# else: +# fps = 5 +# out_wh = (im.shape[1], im.shape[0]) +# suffix = '.mkv' +# video_last_time = time.perf_counter() +# video = imageio.get_writer('ssim_test' + suffix, fps=fps) +# +# # 测试ssim +# print('Training SSIM') +# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. +# rand_im.requires_grad = True +# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) +# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda() +# ssim_score = 0 +# while ssim_score < 0.999: +# optim.zero_grad() +# loss = losser(rand_im, t_im) +# (-loss).sum().backward() +# ssim_score = loss.item() +# optim.step() +# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] +# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) +# +# if out_test_video: +# if time.perf_counter() - video_last_time > 1. / fps: +# video_last_time = time.perf_counter() +# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) +# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) +# if isinstance(out_frame, cv2.UMat): +# out_frame = out_frame.get() +# video.append_data(out_frame) +# +# cv2.imshow('ssim', r_im) +# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score) +# cv2.waitKey(1) +# +# if out_test_video: +# video.close() +# +# # 测试ms_ssim +# if out_test_video: +# if video_use_gif: +# fps = 0.5 +# out_wh = (im.shape[1] // 2, im.shape[0] // 2) +# suffix = '.gif' +# else: +# fps = 5 +# out_wh = (im.shape[1], im.shape[0]) +# suffix = '.mkv' +# video_last_time = time.perf_counter() +# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps) +# +# print('Training MS_SSIM') +# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. +# rand_im.requires_grad = True +# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) +# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda() +# ssim_score = 0 +# while ssim_score < 0.999: +# optim.zero_grad() +# loss = losser(rand_im, t_im) +# (-loss).sum().backward() +# ssim_score = loss.item() +# optim.step() +# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] +# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) +# +# if out_test_video: +# if time.perf_counter() - video_last_time > 1. / fps: +# video_last_time = time.perf_counter() +# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) +# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) +# if isinstance(out_frame, cv2.UMat): +# out_frame = out_frame.get() +# video.append_data(out_frame) +# +# cv2.imshow('ms_ssim', r_im) +# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score) +# cv2.waitKey(1) +# +# if out_test_video: +# video.close() + +""" +Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim +""" + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +window = None + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + global window + if window is None: + window = create_window(window_size, channel) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/modules/diffsinger_midi/fs2.py b/modules/diffsinger_midi/fs2.py new file mode 100644 index 0000000000000000000000000000000000000000..8ddf2aa42bfb6109cd41d149fa7a8059e7e186c1 --- /dev/null +++ b/modules/diffsinger_midi/fs2.py @@ -0,0 +1,118 @@ +from modules.commons.common_layers import * +from modules.commons.common_layers import Embedding +from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \ + EnergyPredictor, FastspeechEncoder +from utils.cwt import cwt2f0 +from utils.hparams import hparams +from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0 +from modules.fastspeech.fs2 import FastSpeech2 + + +class FastspeechMIDIEncoder(FastspeechEncoder): + def forward_embedding(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding): + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(txt_tokens) + x = x + midi_embedding + midi_dur_embedding + slur_embedding + if hparams['use_pos_embed']: + if hparams.get('rel_pos') is not None and hparams['rel_pos']: + x = self.embed_positions(x) + else: + positions = self.embed_positions(txt_tokens) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + def forward(self, txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding): + """ + + :param txt_tokens: [B, T] + :return: { + 'encoder_out': [T x B x C] + } + """ + encoder_padding_mask = txt_tokens.eq(self.padding_idx).data + x = self.forward_embedding(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, H] + x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask) + return x + + +FS_ENCODERS = { + 'fft': lambda hp, embed_tokens, d: FastspeechMIDIEncoder( + embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'], + num_heads=hp['num_heads']), +} + + +class FastSpeech2MIDI(FastSpeech2): + def __init__(self, dictionary, out_dims=None): + super().__init__(dictionary, out_dims) + del self.encoder + self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary) + self.midi_embed = Embedding(300, self.hidden_size, self.padding_idx) + self.midi_dur_layer = Linear(1, self.hidden_size) + self.is_slur_embed = Embedding(2, self.hidden_size) + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False, + spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs): + ret = {} + + midi_embedding = self.midi_embed(kwargs['pitch_midi']) + midi_dur_embedding, slur_embedding = 0, 0 + if kwargs.get('midi_dur') is not None: + midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H] + if kwargs.get('is_slur') is not None: + slur_embedding = self.is_slur_embed(kwargs['is_slur']) + encoder_out = self.encoder(txt_tokens, midi_embedding, midi_dur_embedding, slur_embedding) # [B, T, C] + src_nonpadding = (txt_tokens > 0).float()[:, :, None] + + # add ref style embed + # Not implemented + # variance encoder + var_embed = 0 + + # encoder_out_dur denotes encoder outputs for duration predictor + # in speech adaptation, duration predictor use old speaker embedding + if hparams['use_spk_embed']: + spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] + elif hparams['use_spk_id']: + spk_embed_id = spk_embed + if spk_embed_dur_id is None: + spk_embed_dur_id = spk_embed_id + if spk_embed_f0_id is None: + spk_embed_f0_id = spk_embed_id + spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :] + spk_embed_dur = spk_embed_f0 = spk_embed + if hparams['use_split_spk_id']: + spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :] + spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :] + else: + spk_embed_dur = spk_embed_f0 = spk_embed = 0 + + # add dur + dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding + + mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) + + decoder_inp = F.pad(encoder_out, [0, 0, 1, 0]) + + mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) + decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] + + tgt_nonpadding = (mel2ph > 0).float()[:, :, None] + + # add pitch and energy embed + pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding + if hparams['use_pitch_embed']: + pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding + decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph) + if hparams['use_energy_embed']: + decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret) + + ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding + + if skip_decoder: + return ret + ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) + + return ret diff --git a/modules/fastspeech/fs2.py b/modules/fastspeech/fs2.py new file mode 100644 index 0000000000000000000000000000000000000000..52b4ac4aaa7ae49f06736a038bde83ca2cfa8483 --- /dev/null +++ b/modules/fastspeech/fs2.py @@ -0,0 +1,255 @@ +from modules.commons.common_layers import * +from modules.commons.common_layers import Embedding +from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \ + EnergyPredictor, FastspeechEncoder +from utils.cwt import cwt2f0 +from utils.hparams import hparams +from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0 + +FS_ENCODERS = { + 'fft': lambda hp, embed_tokens, d: FastspeechEncoder( + embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'], + num_heads=hp['num_heads']), +} + +FS_DECODERS = { + 'fft': lambda hp: FastspeechDecoder( + hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']), +} + + +class FastSpeech2(nn.Module): + def __init__(self, dictionary, out_dims=None): + super().__init__() + self.dictionary = dictionary + self.padding_idx = dictionary.pad() + self.enc_layers = hparams['enc_layers'] + self.dec_layers = hparams['dec_layers'] + self.hidden_size = hparams['hidden_size'] + self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size) + self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary) + self.decoder = FS_DECODERS[hparams['decoder_type']](hparams) + self.out_dims = out_dims + if out_dims is None: + self.out_dims = hparams['audio_num_mel_bins'] + self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True) + + if hparams['use_spk_id']: + self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size) + if hparams['use_split_spk_id']: + self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size) + self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size) + elif hparams['use_spk_embed']: + self.spk_embed_proj = Linear(256, self.hidden_size, bias=True) + predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size + self.dur_predictor = DurationPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['dur_predictor_layers'], + dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'], + kernel_size=hparams['dur_predictor_kernel']) + self.length_regulator = LengthRegulator() + if hparams['use_pitch_embed']: + self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx) + if hparams['pitch_type'] == 'cwt': + h = hparams['cwt_hidden_size'] + cwt_out_dims = 10 + if hparams['use_uv']: + cwt_out_dims = cwt_out_dims + 1 + self.cwt_predictor = nn.Sequential( + nn.Linear(self.hidden_size, h), + PitchPredictor( + h, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])) + self.cwt_stats_layers = nn.Sequential( + nn.Linear(self.hidden_size, h), nn.ReLU(), + nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2) + ) + else: + self.pitch_predictor = PitchPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], + odim=2 if hparams['pitch_type'] == 'frame' else 1, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + if hparams['use_energy_embed']: + self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx) + self.energy_predictor = EnergyPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], odim=1, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + + def build_embedding(self, dictionary, embed_dim): + num_embeddings = len(dictionary) + emb = Embedding(num_embeddings, embed_dim, self.padding_idx) + return emb + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False, + spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs): + ret = {} + encoder_out = self.encoder(txt_tokens) # [B, T, C] + src_nonpadding = (txt_tokens > 0).float()[:, :, None] + + # add ref style embed + # Not implemented + # variance encoder + var_embed = 0 + + # encoder_out_dur denotes encoder outputs for duration predictor + # in speech adaptation, duration predictor use old speaker embedding + if hparams['use_spk_embed']: + spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] + elif hparams['use_spk_id']: + spk_embed_id = spk_embed + if spk_embed_dur_id is None: + spk_embed_dur_id = spk_embed_id + if spk_embed_f0_id is None: + spk_embed_f0_id = spk_embed_id + spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :] + spk_embed_dur = spk_embed_f0 = spk_embed + if hparams['use_split_spk_id']: + spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :] + spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :] + else: + spk_embed_dur = spk_embed_f0 = spk_embed = 0 + + # add dur + dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding + + mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) + + decoder_inp = F.pad(encoder_out, [0, 0, 1, 0]) + + mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) + decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] + + tgt_nonpadding = (mel2ph > 0).float()[:, :, None] + + # add pitch and energy embed + pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding + if hparams['use_pitch_embed']: + pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding + decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph) + if hparams['use_energy_embed']: + decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret) + + ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding + + if skip_decoder: + return ret + ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) + + return ret + + def add_dur(self, dur_input, mel2ph, txt_tokens, ret): + """ + + :param dur_input: [B, T_txt, H] + :param mel2ph: [B, T_mel] + :param txt_tokens: [B, T_txt] + :param ret: + :return: + """ + src_padding = txt_tokens == 0 + dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach()) + if mel2ph is None: + dur, xs = self.dur_predictor.inference(dur_input, src_padding) + ret['dur'] = xs + ret['dur_choice'] = dur + mel2ph = self.length_regulator(dur, src_padding).detach() + # from modules.fastspeech.fake_modules import FakeLengthRegulator + # fake_lr = FakeLengthRegulator() + # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach() + # print(mel2ph == fake_mel2ph) + else: + ret['dur'] = self.dur_predictor(dur_input, src_padding) + ret['mel2ph'] = mel2ph + return mel2ph + + def add_energy(self, decoder_inp, energy, ret): + decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) + ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0] + if energy is None: + energy = energy_pred + energy = torch.clamp(energy * 256 // 4, max=255).long() + energy_embed = self.energy_embed(energy) + return energy_embed + + def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None): + if hparams['pitch_type'] == 'ph': + pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach()) + pitch_padding = encoder_out.sum().abs() == 0 + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp) + if f0 is None: + f0 = pitch_pred[:, :, 0] + ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding) + pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt] + pitch = F.pad(pitch, [1, 0]) + pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel] + pitch_embed = self.pitch_embed(pitch) + return pitch_embed + decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) + + pitch_padding = mel2ph == 0 + + if hparams['pitch_type'] == 'cwt': + pitch_padding = None + ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp) + stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2] + mean = ret['f0_mean'] = stats_out[:, 0] + std = ret['f0_std'] = stats_out[:, 1] + cwt_spec = cwt_out[:, :, :10] + if f0 is None: + std = std * hparams['cwt_std_scale'] + f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + if hparams['use_uv']: + assert cwt_out.shape[-1] == 11 + uv = cwt_out[:, :, -1] > 0 + elif hparams['pitch_ar']: + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None) + if f0 is None: + f0 = pitch_pred[:, :, 0] + else: + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp) + if f0 is None: + f0 = pitch_pred[:, :, 0] + if hparams['use_uv'] and uv is None: + uv = pitch_pred[:, :, 1] > 0 + ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding) + if pitch_padding is not None: + f0[pitch_padding] = 0 + + pitch = f0_to_coarse(f0_denorm) # start from 0 + pitch_embed = self.pitch_embed(pitch) + return pitch_embed + + def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs): + x = decoder_inp # [B, T, H] + x = self.decoder(x) + x = self.mel_out(x) + return x * tgt_nonpadding + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales']) + f0 = torch.cat( + [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1) + f0_norm = norm_f0(f0, None, hparams) + return f0_norm + + def out2mel(self, out): + return out + + @staticmethod + def mel_norm(x): + return (x + 5.5) / (6.3 / 2) - 1 + + @staticmethod + def mel_denorm(x): + return (x + 1) * (6.3 / 2) - 5.5 diff --git a/modules/fastspeech/pe.py b/modules/fastspeech/pe.py new file mode 100644 index 0000000000000000000000000000000000000000..d85989bc41d85ce119bbcc16a958a253ef290930 --- /dev/null +++ b/modules/fastspeech/pe.py @@ -0,0 +1,149 @@ +from modules.commons.common_layers import * +from utils.hparams import hparams +from modules.fastspeech.tts_modules import PitchPredictor +from utils.pitch_utils import denorm_f0 + + +class Prenet(nn.Module): + def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None): + super(Prenet, self).__init__() + padding = kernel // 2 + self.layers = [] + self.strides = strides if strides is not None else [1] * n_layers + for l in range(n_layers): + self.layers.append(nn.Sequential( + nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]), + nn.ReLU(), + nn.BatchNorm1d(out_dim) + )) + in_dim = out_dim + self.layers = nn.ModuleList(self.layers) + self.out_proj = nn.Linear(out_dim, out_dim) + + def forward(self, x): + """ + + :param x: [B, T, 80] + :return: [L, B, T, H], [B, T, H] + """ + padding_mask = x.abs().sum(-1).eq(0).data # [B, T] + nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T] + x = x.transpose(1, 2) + hiddens = [] + for i, l in enumerate(self.layers): + nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]] + x = l(x) * nonpadding_mask_TB + hiddens.append(x) + hiddens = torch.stack(hiddens, 0) # [L, B, H, T] + hiddens = hiddens.transpose(2, 3) # [L, B, T, H] + x = self.out_proj(x.transpose(1, 2)) # [B, T, H] + x = x * nonpadding_mask_TB.transpose(1, 2) + return hiddens, x + + +class ConvBlock(nn.Module): + def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0): + super().__init__() + self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride) + self.norm = norm + if self.norm == 'bn': + self.norm = nn.BatchNorm1d(n_chans) + elif self.norm == 'in': + self.norm = nn.InstanceNorm1d(n_chans, affine=True) + elif self.norm == 'gn': + self.norm = nn.GroupNorm(n_chans // 16, n_chans) + elif self.norm == 'ln': + self.norm = LayerNorm(n_chans // 16, n_chans) + elif self.norm == 'wn': + self.conv = torch.nn.utils.weight_norm(self.conv.conv) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + + def forward(self, x): + """ + + :param x: [B, C, T] + :return: [B, C, T] + """ + x = self.conv(x) + if not isinstance(self.norm, str): + if self.norm == 'none': + pass + elif self.norm == 'ln': + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + else: + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + return x + + +class ConvStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', + dropout=0, strides=None, res=True): + super().__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.res = res + self.in_proj = Linear(idim, n_chans) + if strides is None: + strides = [1] * n_layers + else: + assert len(strides) == n_layers + for idx in range(n_layers): + self.conv.append(ConvBlock( + n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout)) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x, return_hiddens=False): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + hiddens = [] + for f in self.conv: + x_ = f(x) + x = x + x_ if self.res else x_ # (B, C, Tmax) + hiddens.append(x) + x = x.transpose(1, -1) + x = self.out_proj(x) # (B, Tmax, H) + if return_hiddens: + hiddens = torch.stack(hiddens, 1) # [B, L, C, T] + return x, hiddens + return x + + +class PitchExtractor(nn.Module): + def __init__(self, n_mel_bins=80, conv_layers=2): + super().__init__() + self.hidden_size = 256 + self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size + self.conv_layers = conv_layers + + self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1]) + if self.conv_layers > 0: + self.mel_encoder = ConvStacks( + idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers) + self.pitch_predictor = PitchPredictor( + self.hidden_size, n_chans=self.predictor_hidden, + n_layers=5, dropout_rate=0.5, odim=2, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + + def forward(self, mel_input=None): + ret = {} + mel_hidden = self.mel_prenet(mel_input)[1] + if self.conv_layers > 0: + mel_hidden = self.mel_encoder(mel_hidden) + + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden) + + pitch_padding = mel_input.abs().sum(-1) == 0 + use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv'] + + ret['f0_denorm_pred'] = denorm_f0( + pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, + hparams, pitch_padding=pitch_padding) + return ret \ No newline at end of file diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..195eff279de781dd2565cfb2da65533c58f6c332 --- /dev/null +++ b/modules/fastspeech/tts_modules.py @@ -0,0 +1,357 @@ +import logging +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from modules.commons.espnet_positional_embedding import RelPositionalEncoding +from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC +from utils.hparams import hparams + +DEFAULT_MAX_SOURCE_POSITIONS = 2000 +DEFAULT_MAX_TARGET_POSITIONS = 2000 + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = EncSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size + if kernel_size is not None else hparams['enc_ffn_kernel_size'], + padding=hparams['ffn_padding'], + norm=norm, act=hparams['ffn_act']) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + +###################### +# fastspeech modules +###################### +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + + +class DurationPredictor(torch.nn.Module): + """Duration predictor module. + This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder. + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + Note: + The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, + the outputs are calculated in log domain but in `inference`, those are calculated in linear domain. + """ + + def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'): + """Initilize duration predictor module. + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + offset (float, optional): Offset value to avoid nan in log domain. + """ + super(DurationPredictor, self).__init__() + self.offset = offset + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.padding = padding + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential( + torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2) + if padding == 'SAME' + else (kernel_size - 1, 0), 0), + torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate) + )] + if hparams['dur_loss'] in ['mse', 'huber']: + odims = 1 + elif hparams['dur_loss'] == 'mog': + odims = 15 + elif hparams['dur_loss'] == 'crf': + odims = 32 + from torchcrf import CRF + self.crf = CRF(odims, batch_first=True) + self.linear = torch.nn.Linear(n_chans, odims) + + def _forward(self, xs, x_masks=None, is_inference=False): + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + if x_masks is not None: + xs = xs * (1 - x_masks.float())[:, None, :] + + xs = self.linear(xs.transpose(1, -1)) # [B, T, C] + xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C) + if is_inference: + return self.out2dur(xs), xs + else: + if hparams['dur_loss'] in ['mse']: + xs = xs.squeeze(-1) # (B, Tmax) + return xs + + def out2dur(self, xs): + if hparams['dur_loss'] in ['mse']: + # NOTE: calculate in log domain + xs = xs.squeeze(-1) # (B, Tmax) + dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value + elif hparams['dur_loss'] == 'mog': + return NotImplementedError + elif hparams['dur_loss'] == 'crf': + dur = torch.LongTensor(self.crf.decode(xs)).cuda() + return dur + + def forward(self, xs, x_masks=None): + """Calculate forward propagation. + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + """ + return self._forward(xs, x_masks, False) + + def inference(self, xs, x_masks=None): + """Inference duration. + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). + Returns: + LongTensor: Batch of predicted durations in linear domain (B, Tmax). + """ + return self._forward(xs, x_masks, True) + + +class LengthRegulator(torch.nn.Module): + def __init__(self, pad_value=0.0): + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + + def forward(self, dur, dur_padding=None, alpha=1.0): + """ + Example (no batch dim version): + 1. dur = [2,2,3] + 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] + 3. token_mask = [[1,1,0,0,0,0,0], + [0,0,1,1,0,0,0], + [0,0,0,0,1,1,1]] + 4. token_idx * token_mask = [[1,1,0,0,0,0,0], + [0,0,2,2,0,0,0], + [0,0,0,0,3,3,3]] + 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] + + :param dur: Batch of durations of each frame (B, T_txt) + :param dur_padding: Batch of padding of each frame (B, T_txt) + :param alpha: duration rescale coefficient + :return: + mel2ph (B, T_speech) + """ + assert alpha > 0 + dur = torch.round(dur.float() * alpha).long() + if dur_padding is not None: + dur = dur * (1 - dur_padding.long()) + token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) + dur_cumsum = torch.cumsum(dur, 1) + dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) + + pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) + token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) + mel2ph = (token_idx * token_mask.long()).sum(1) + return mel2ph + + +class PitchPredictor(torch.nn.Module): + def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5, + dropout_rate=0.1, padding='SAME'): + """Initilize pitch predictor module. + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + """ + super(PitchPredictor, self).__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.padding = padding + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential( + torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2) + if padding == 'SAME' + else (kernel_size - 1, 0), 0), + torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate) + )] + self.linear = torch.nn.Linear(n_chans, odim) + self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096) + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) + + def forward(self, xs): + """ + + :param xs: [B, T, H] + :return: [B, T, H] + """ + positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0]) + xs = xs + positions + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + # NOTE: calculate in log domain + xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H) + return xs + + +class EnergyPredictor(PitchPredictor): + pass + + +def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): + B, _ = mel2ph.shape + dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph)) + dur = dur[:, 1:] + if max_dur is not None: + dur = dur.clamp(max=max_dur) + return dur + + +class FFTBlocks(nn.Module): + def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2, + use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True): + super().__init__() + self.num_layers = num_layers + embed_dim = self.hidden_size = hidden_size + self.dropout = dropout if dropout is not None else hparams['dropout'] + self.use_pos_embed = use_pos_embed + self.use_last_norm = use_last_norm + if use_pos_embed: + self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS + self.padding_idx = 0 + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1 + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + self.layers = nn.ModuleList([]) + self.layers.extend([ + TransformerEncoderLayer(self.hidden_size, self.dropout, + kernel_size=ffn_kernel_size, num_heads=num_heads) + for _ in range(self.num_layers) + ]) + if self.use_last_norm: + if norm == 'ln': + self.layer_norm = nn.LayerNorm(embed_dim) + elif norm == 'bn': + self.layer_norm = BatchNorm1dTBC(embed_dim) + else: + self.layer_norm = None + + def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False): + """ + :param x: [B, T, C] + :param padding_mask: [B, T] + :return: [B, T, C] or [L, B, T, C] + """ + padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1] + if self.use_pos_embed: + positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) * nonpadding_mask_TB + hiddens = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB + hiddens.append(x) + if self.use_last_norm: + x = self.layer_norm(x) * nonpadding_mask_TB + if return_hiddens: + x = torch.stack(hiddens, 0) # [L, T, B, C] + x = x.transpose(1, 2) # [L, B, T, C] + else: + x = x.transpose(0, 1) # [B, T, C] + return x + + +class FastspeechEncoder(FFTBlocks): + def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2): + hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size + kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size + num_layers = hparams['dec_layers'] if num_layers is None else num_layers + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads, + use_pos_embed=False) # use_pos_embed_alpha for compatibility + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(hidden_size) + self.padding_idx = 0 + if hparams.get('rel_pos') is not None and hparams['rel_pos']: + self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0) + else: + self.embed_positions = SinusoidalPositionalEmbedding( + hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + def forward(self, txt_tokens): + """ + + :param txt_tokens: [B, T] + :return: { + 'encoder_out': [T x B x C] + } + """ + encoder_padding_mask = txt_tokens.eq(self.padding_idx).data + x = self.forward_embedding(txt_tokens) # [B, T, H] + x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask) + return x + + def forward_embedding(self, txt_tokens): + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(txt_tokens) + if hparams['use_pos_embed']: + positions = self.embed_positions(txt_tokens) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + +class FastspeechDecoder(FFTBlocks): + def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None): + num_heads = hparams['num_heads'] if num_heads is None else num_heads + hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size + kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size + num_layers = hparams['dec_layers'] if num_layers is None else num_layers + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads) + diff --git a/modules/hifigan/hifigan.py b/modules/hifigan/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..32826d51408fa391ccffb1e91dc1434bfe833d55 --- /dev/null +++ b/modules/hifigan/hifigan.py @@ -0,0 +1,370 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork +from modules.parallel_wavegan.models.source import SourceModuleHnNSF +import numpy as np + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__(in_channels, out_channels, + kernel_size=1, padding=0, + dilation=1, bias=bias) + + +class HifiGanGenerator(torch.nn.Module): + def __init__(self, h, c_out=1): + super(HifiGanGenerator, self).__init__() + self.h = h + self.num_kernels = len(h['resblock_kernel_sizes']) + self.num_upsamples = len(h['upsample_rates']) + + if h['use_pitch_embed']: + self.harmonic_num = 8 + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates'])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h['audio_sample_rate'], + harmonic_num=self.harmonic_num) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3)) + resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])): + c_cur = h['upsample_initial_channel'] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2))) + if h['use_pitch_embed']: + if i + 1 < len(h['upsample_rates']): + stride_f0 = np.prod(h['upsample_rates'][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h['upsample_initial_channel'] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, f0=None): + if f0 is not None: + # harmonic-source signal, noise-source signal, uv flag + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + if f0 is not None: + x_source = self.noise_convs[i](har_source) + x_source = torch.nn.functional.relu(x_source) + tmp_shape = x_source.shape[1] + x_source = torch.nn.functional.layer_norm(x_source.transpose(1, -1), (tmp_shape, )).transpose(1, -1) + x = x + x_source + xs = None + for j in range(self.num_kernels): + xs_ = self.resblocks[i * self.num_kernels + j](x) + if xs is None: + xs = xs_ + else: + xs += xs_ + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1): + super(DiscriminatorP, self).__init__() + self.use_cond = use_cond + if use_cond: + from utils.hparams import hparams + t = hparams['hop_size'] + self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2) + c_in = 2 + + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x, mel): + fmap = [] + if self.use_cond: + x_mel = self.cond_net(mel) + x = torch.cat([x_mel, x], 1) + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_cond=False, c_in=1): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2, use_cond=use_cond, c_in=c_in), + DiscriminatorP(3, use_cond=use_cond, c_in=c_in), + DiscriminatorP(5, use_cond=use_cond, c_in=c_in), + DiscriminatorP(7, use_cond=use_cond, c_in=c_in), + DiscriminatorP(11, use_cond=use_cond, c_in=c_in), + ]) + + def forward(self, y, y_hat, mel=None): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y, mel) + y_d_g, fmap_g = d(y_hat, mel) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1): + super(DiscriminatorS, self).__init__() + self.use_cond = use_cond + if use_cond: + t = np.prod(upsample_rates) + self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2) + c_in = 2 + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(c_in, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x, mel): + if self.use_cond: + x_mel = self.cond_net(mel) + x = torch.cat([x_mel, x], 1) + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self, use_cond=False, c_in=1): + super(MultiScaleDiscriminator, self).__init__() + from utils.hparams import hparams + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True, use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 16], + c_in=c_in), + DiscriminatorS(use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 32], + c_in=c_in), + DiscriminatorS(use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 64], + c_in=c_in), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=1), + AvgPool1d(4, 2, padding=1) + ]) + + def forward(self, y, y_hat, mel=None): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y, mel) + y_d_g, fmap_g = d(y_hat, mel) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + r_losses = 0 + g_losses = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + r_losses += r_loss + g_losses += g_loss + r_losses = r_losses / len(disc_real_outputs) + g_losses = g_losses / len(disc_real_outputs) + return r_losses, g_losses + + +def cond_discriminator_loss(outputs): + loss = 0 + for dg in outputs: + g_loss = torch.mean(dg ** 2) + loss += g_loss + loss = loss / len(outputs) + return loss + + +def generator_loss(disc_outputs): + loss = 0 + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + loss += l + loss = loss / len(disc_outputs) + return loss + diff --git a/modules/hifigan/mel_utils.py b/modules/hifigan/mel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..04c1e3ea5de2cd24bbb14ab72206539a8d37d9c0 --- /dev/null +++ b/modules/hifigan/mel_utils.py @@ -0,0 +1,81 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, hparams, center=False, complex=False): + # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) + # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) + # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + # fmax: 10000 # To be increased/reduced depending on data. + # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter + # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, + n_fft = hparams['fft_size'] + num_mels = hparams['audio_num_mel_bins'] + sampling_rate = hparams['audio_sample_rate'] + hop_size = hparams['hop_size'] + win_size = hparams['win_size'] + fmin = hparams['fmin'] + fmax = hparams['fmax'] + y = y.clamp(min=-1., max=1.) + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + if not complex: + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + else: + B, C, T, _ = spec.shape + spec = spec.transpose(1, 2) # [B, T, n_fft, 2] + return spec + diff --git a/modules/parallel_wavegan/__init__.py b/modules/parallel_wavegan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/parallel_wavegan/layers/__init__.py b/modules/parallel_wavegan/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e477f51116a3157781b1aefefbaf32fe4d4bd1f0 --- /dev/null +++ b/modules/parallel_wavegan/layers/__init__.py @@ -0,0 +1,5 @@ +from .causal_conv import * # NOQA +from .pqmf import * # NOQA +from .residual_block import * # NOQA +from modules.parallel_wavegan.layers.residual_stack import * # NOQA +from .upsample import * # NOQA diff --git a/modules/parallel_wavegan/layers/causal_conv.py b/modules/parallel_wavegan/layers/causal_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..fca77daf65f234e6fbe355ed148fc8f0ee85038a --- /dev/null +++ b/modules/parallel_wavegan/layers/causal_conv.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Causal convolusion layer modules.""" + + +import torch + + +class CausalConv1d(torch.nn.Module): + """CausalConv1d module with customized initialization.""" + + def __init__(self, in_channels, out_channels, kernel_size, + dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): + """Initialize CausalConv1d module.""" + super(CausalConv1d, self).__init__() + self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) + self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, + dilation=dilation, bias=bias) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + return self.conv(self.pad(x))[:, :, :x.size(2)] + + +class CausalConvTranspose1d(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): + """Initialize CausalConvTranspose1d module.""" + super(CausalConvTranspose1d, self).__init__() + self.deconv = torch.nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride, bias=bias) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + + Returns: + Tensor: Output tensor (B, out_channels, T_out). + + """ + return self.deconv(x)[:, :, :-self.stride] diff --git a/modules/parallel_wavegan/layers/pqmf.py b/modules/parallel_wavegan/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..ac21074fd32a370a099fa2facb62cfd3253d7579 --- /dev/null +++ b/modules/parallel_wavegan/layers/pqmf.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Pseudo QMF modules.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0): + """Design prototype filter for PQMF. + + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + + Args: + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + Returns: + ndarray: Impluse response of prototype filter (taps + 1,). + + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid='ignore'): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ + / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) + h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(torch.nn.Module): + """PQMF module. + + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): + """Initilize PQMF module. + + Args: + subbands (int): The number of subbands. + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + """ + super(PQMF, self).__init__() + + # define filter coefficient + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - ((taps - 1) / 2)) + + (-1) ** k * np.pi / 4) + h_synthesis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - ((taps - 1) / 2)) - + (-1) ** k * np.pi / 4) + + # convert to tensor + analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) + synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) + + # register coefficients as beffer + self.register_buffer("analysis_filter", analysis_filter) + self.register_buffer("synthesis_filter", synthesis_filter) + + # filter for downsampling & upsampling + updown_filter = torch.zeros((subbands, subbands, subbands)).float() + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.subbands = subbands + + # keep padding info + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def analysis(self, x): + """Analysis with PQMF. + + Args: + x (Tensor): Input tensor (B, 1, T). + + Returns: + Tensor: Output tensor (B, subbands, T // subbands). + + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + + Args: + x (Tensor): Input tensor (B, subbands, T // subbands). + + Returns: + Tensor: Output tensor (B, 1, T). + + """ + x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/modules/parallel_wavegan/layers/residual_block.py b/modules/parallel_wavegan/layers/residual_block.py new file mode 100644 index 0000000000000000000000000000000000000000..7a267a86c1fa521c2824addf9dda304c43f1ff1f --- /dev/null +++ b/modules/parallel_wavegan/layers/residual_block.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +"""Residual block module in WaveNet. + +This code is modified from https://github.com/r9y9/wavenet_vocoder. + +""" + +import math + +import torch +import torch.nn.functional as F + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super(Conv1d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__(in_channels, out_channels, + kernel_size=1, padding=0, + dilation=1, bias=bias) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__(self, + kernel_size=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. + + """ + super(ResidualBlock, self).__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = Conv1d(residual_channels, gate_channels, kernel_size, + padding=padding, dilation=dilation, bias=bias) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) + + def forward(self, x, c): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) + + return x, s diff --git a/modules/parallel_wavegan/layers/residual_stack.py b/modules/parallel_wavegan/layers/residual_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..6e07c8803ad348dd923f6b7c0f7aff14aab9cf78 --- /dev/null +++ b/modules/parallel_wavegan/layers/residual_stack.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Residual stack module in MelGAN.""" + +import torch + +from . import CausalConv1d + + +class ResidualStack(torch.nn.Module): + """Residual stack module introduced in MelGAN.""" + + def __init__(self, + kernel_size=3, + channels=32, + dilation=1, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_causal_conv=False, + ): + """Initialize ResidualStack module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels of convolution layers. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(ResidualStack, self).__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), + torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + else: + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + CausalConv1d(channels, channels, kernel_size, dilation=dilation, + bias=bias, pad=pad, pad_params=pad_params), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + + # defile extra layer for skip connection + self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, chennels, T). + + """ + return self.stack(c) + self.skip_layer(c) diff --git a/modules/parallel_wavegan/layers/tf_layers.py b/modules/parallel_wavegan/layers/tf_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f46bd755c161cda2ac904fe37f3f3c6357a88d --- /dev/null +++ b/modules/parallel_wavegan/layers/tf_layers.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 MINH ANH (@dathudeptrai) +# MIT License (https://opensource.org/licenses/MIT) + +"""Tensorflow Layer modules complatible with pytorch.""" + +import tensorflow as tf + + +class TFReflectionPad1d(tf.keras.layers.Layer): + """Tensorflow ReflectionPad1d module.""" + + def __init__(self, padding_size): + """Initialize TFReflectionPad1d module. + + Args: + padding_size (int): Padding size. + + """ + super(TFReflectionPad1d, self).__init__() + self.padding_size = padding_size + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensor: Padded tensor (B, T + 2 * padding_size, 1, C). + + """ + return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT") + + +class TFConvTranspose1d(tf.keras.layers.Layer): + """Tensorflow ConvTranspose1d module.""" + + def __init__(self, channels, kernel_size, stride, padding): + """Initialize TFConvTranspose1d( module. + + Args: + channels (int): Number of channels. + kernel_size (int): kernel size. + strides (int): Stride width. + padding (str): Padding type ("same" or "valid"). + + """ + super(TFConvTranspose1d, self).__init__() + self.conv1d_transpose = tf.keras.layers.Conv2DTranspose( + filters=channels, + kernel_size=(kernel_size, 1), + strides=(stride, 1), + padding=padding, + ) + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensors: Output tensor (B, T', 1, C'). + + """ + x = self.conv1d_transpose(x) + return x + + +class TFResidualStack(tf.keras.layers.Layer): + """Tensorflow ResidualStack module.""" + + def __init__(self, + kernel_size, + channels, + dilation, + bias, + nonlinear_activation, + nonlinear_activation_params, + padding, + ): + """Initialize TFResidualStack module. + + Args: + kernel_size (int): Kernel size. + channles (int): Number of channels. + dilation (int): Dilation ine. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + padding (str): Padding type ("same" or "valid"). + + """ + super(TFResidualStack, self).__init__() + self.block = [ + getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), + TFReflectionPad1d(dilation), + tf.keras.layers.Conv2D( + filters=channels, + kernel_size=(kernel_size, 1), + dilation_rate=(dilation, 1), + use_bias=bias, + padding="valid", + ), + getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), + tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) + ] + self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensor: Output tensor (B, T, 1, C). + + """ + _x = tf.identity(x) + for i, layer in enumerate(self.block): + _x = layer(_x) + shortcut = self.shortcut(x) + return shortcut + _x diff --git a/modules/parallel_wavegan/layers/upsample.py b/modules/parallel_wavegan/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..18c6397c420a81fadc5320e3a48f3249534decd8 --- /dev/null +++ b/modules/parallel_wavegan/layers/upsample.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- + +"""Upsampling module. + +This code is modified from https://github.com/r9y9/wavenet_vocoder. + +""" + +import numpy as np +import torch +import torch.nn.functional as F + +from . import Conv1d + + +class Stretch2d(torch.nn.Module): + """Stretch2d module.""" + + def __init__(self, x_scale, y_scale, mode="nearest"): + """Initialize Stretch2d module. + + Args: + x_scale (int): X scaling factor (Time axis in spectrogram). + y_scale (int): Y scaling factor (Frequency axis in spectrogram). + mode (str): Interpolation mode. + + """ + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, C, F, T). + + Returns: + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + + """ + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class Conv2d(torch.nn.Conv2d): + """Conv2d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv2d module.""" + super(Conv2d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + self.weight.data.fill_(1. / np.prod(self.kernel_size)) + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class UpsampleNetwork(torch.nn.Module): + """Upsampling network module.""" + + def __init__(self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + """Initialize upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + interpolate_mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + + """ + super(UpsampleNetwork, self).__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_scales: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T). + + Returns: + Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). + + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + if self.use_causal_conv and isinstance(f, Conv2d): + c = f(c)[..., :c.size(-1)] + else: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvInUpsampleNetwork(torch.nn.Module): + """Convolution + upsampling network module.""" + + def __init__(self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False + ): + """Initialize convolution + upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + aux_channels (int): Number of channels of pre-convolutional layer. + aux_context_window (int): Context window size of the pre-convolutional layer. + use_causal_conv (bool): Whether to use causal structure. + + """ + super(ConvInUpsampleNetwork, self).__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_scales=upsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T'). + + Returns: + Tensor: Upsampled tensor (B, C, T), + where T = (T' - aux_context_window * 2) * prod(upsample_scales). + + Note: + The length of inputs considers the context window size. + + """ + c_ = self.conv_in(c) + c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/modules/parallel_wavegan/losses/__init__.py b/modules/parallel_wavegan/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b03080a907cb5cb4b316ceb74866ddbc406b33bf --- /dev/null +++ b/modules/parallel_wavegan/losses/__init__.py @@ -0,0 +1 @@ +from .stft_loss import * # NOQA diff --git a/modules/parallel_wavegan/losses/stft_loss.py b/modules/parallel_wavegan/losses/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..74d2aa21ad30ba094c406366e652067462f49cd2 --- /dev/null +++ b/modules/parallel_wavegan/losses/stft_loss.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + + """ + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + + +class SpectralConvergengeLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergengeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Spectral convergence loss value. + + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + """Initilize los STFT magnitude loss module.""" + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Log STFT magnitude loss value. + + """ + return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window"): + """Initialize Multi resolution STFT loss module. + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/modules/parallel_wavegan/models/__init__.py b/modules/parallel_wavegan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4803ba6b2a0afc8022e756ae5b3f4c7403c3c1bd --- /dev/null +++ b/modules/parallel_wavegan/models/__init__.py @@ -0,0 +1,2 @@ +from .melgan import * # NOQA +from .parallel_wavegan import * # NOQA diff --git a/modules/parallel_wavegan/models/melgan.py b/modules/parallel_wavegan/models/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..e021ae4817a8c1c97338e61b00b230c881836fd8 --- /dev/null +++ b/modules/parallel_wavegan/models/melgan.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""MelGAN Modules.""" + +import logging + +import numpy as np +import torch + +from modules.parallel_wavegan.layers import CausalConv1d +from modules.parallel_wavegan.layers import CausalConvTranspose1d +from modules.parallel_wavegan.layers import ResidualStack + + +class MelGANGenerator(torch.nn.Module): + """MelGAN generator module.""" + + def __init__(self, + in_channels=80, + out_channels=1, + kernel_size=7, + channels=512, + bias=True, + upsample_scales=[8, 8, 2, 2], + stack_kernel_size=3, + stacks=3, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_final_nonlinear_activation=True, + use_weight_norm=True, + use_causal_conv=False, + ): + """Initialize MelGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of initial and final conv layer. + channels (int): Initial number of channels for conv layer. + bias (bool): Whether to add bias parameter in convolution layers. + upsample_scales (list): List of upsampling scales. + stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. + stacks (int): Number of stacks in a single residual stack. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANGenerator, self).__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2 ** len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + # add initial layer + layers = [] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias), + ] + else: + layers += [ + CausalConv1d(in_channels, channels, kernel_size, + bias=bias, pad=pad, pad_params=pad_params), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] + if not use_causal_conv: + layers += [ + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias=bias, + ) + ] + else: + layers += [ + CausalConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, + ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2 ** (i + 1)), + dilation=stack_kernel_size ** j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, + ) + ] + + # add final layer + layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias), + ] + else: + layers += [ + CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, + bias=bias, pad=pad, pad_params=pad_params), + ] + if use_final_nonlinear_activation: + layers += [torch.nn.Tanh()] + + # define the model as a single function + self.melgan = torch.nn.Sequential(*layers) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, 1, T ** prod(upsample_scales)). + + """ + return self.melgan(c) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py + + """ + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + +class MelGANDiscriminator(torch.nn.Module): + """MelGAN discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + ): + """Initilize MelGAN discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + + """ + super(MelGANDiscriminator, self).__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), + torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, out_chs, kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, out_channels, kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias=bias, + ), + ] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(torch.nn.Module): + """MelGAN multi-scale discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + scales=3, + downsample_pooling="AvgPool1d", + # follow the official implementation setting + downsample_pooling_params={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "count_include_pad": False, + }, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_weight_norm=True, + ): + """Initilize MelGAN multi-scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + downsample_pooling (str): Pooling module name for downsampling of the inputs. + downsample_pooling_params (dict): Parameters for the above pooling module. + kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANMultiScaleDiscriminator, self).__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for _ in range(scales): + self.discriminators += [ + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + ) + ] + self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py + + """ + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) diff --git a/modules/parallel_wavegan/models/parallel_wavegan.py b/modules/parallel_wavegan/models/parallel_wavegan.py new file mode 100644 index 0000000000000000000000000000000000000000..c63b59f67aa48342179415c1d1beac68574a5498 --- /dev/null +++ b/modules/parallel_wavegan/models/parallel_wavegan.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Parallel WaveGAN Modules.""" + +import logging +import math + +import torch +from torch import nn + +from modules.parallel_wavegan.layers import Conv1d +from modules.parallel_wavegan.layers import Conv1d1x1 +from modules.parallel_wavegan.layers import ResidualBlock +from modules.parallel_wavegan.layers import upsample +from modules.parallel_wavegan import models + + +class ParallelWaveGANGenerator(torch.nn.Module): + """Parallel WaveGAN Generator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=30, + stacks=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + aux_context_window=2, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + upsample_conditional_features=True, + upsample_net="ConvInUpsampleNetwork", + upsample_params={"upsample_scales": [4, 4, 4, 4]}, + use_pitch_embed=False, + ): + """Initialize Parallel WaveGAN Generator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for auxiliary feature conv. + aux_context_window (int): Context window size for auxiliary feature. + dropout (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal structure. + upsample_conditional_features (bool): Whether to use upsampling network. + upsample_net (str): Upsampling network architecture. + upsample_params (dict): Upsampling network parameters. + + """ + super(ParallelWaveGANGenerator, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define conv + upsampling network + if upsample_conditional_features: + upsample_params.update({ + "use_causal_conv": use_causal_conv, + }) + if upsample_net == "MelGANGenerator": + assert aux_context_window == 0 + upsample_params.update({ + "use_weight_norm": False, # not to apply twice + "use_final_nonlinear_activation": False, + }) + self.upsample_net = getattr(models, upsample_net)(**upsample_params) + else: + if upsample_net == "ConvInUpsampleNetwork": + upsample_params.update({ + "aux_channels": aux_channels, + "aux_context_window": aux_context_window, + }) + self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + else: + self.upsample_net = None + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=use_causal_conv, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList([ + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ]) + + self.use_pitch_embed = use_pitch_embed + if use_pitch_embed: + self.pitch_embed = nn.Embedding(300, aux_channels, 0) + self.c_proj = nn.Linear(2 * aux_channels, aux_channels) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, c=None, pitch=None, **kwargs): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, C_in, T). + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + pitch (Tensor): Local conditioning pitch (B, T'). + + Returns: + Tensor: Output tensor (B, C_out, T) + + """ + # perform upsampling + if c is not None and self.upsample_net is not None: + if self.use_pitch_embed: + p = self.pitch_embed(pitch) + c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2) + c = self.upsample_net(c) + assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1)) + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, + dilation=lambda x: 2 ** x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + """Return receptive field size.""" + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + +class ParallelWaveGANDiscriminator(torch.nn.Module): + """Parallel WaveGAN Discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + use_weight_norm=True, + ): + """Initialize Parallel WaveGAN Discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Number of output channels. + layers (int): Number of conv layers. + conv_channels (int): Number of chnn layers. + dilation_factor (int): Dilation factor. For example, if dilation_factor = 2, + the dilation will be 2, 4, 8, ..., and so on. + nonlinear_activation (str): Nonlinear function after each conv. + nonlinear_activation_params (dict): Nonlinear function parameters + bias (bool): Whether to use bias parameter in conv. + use_weight_norm (bool) Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + + """ + super(ParallelWaveGANDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert dilation_factor > 0, "Dilation factor must be > 0." + self.conv_layers = torch.nn.ModuleList() + conv_in_channels = in_channels + for i in range(layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor ** i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + Conv1d(conv_in_channels, conv_channels, + kernel_size=kernel_size, padding=padding, + dilation=dilation, bias=bias), + getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = Conv1d( + conv_in_channels, out_channels, + kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + Tensor: Output tensor (B, 1, T) + + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveGANDiscriminator(torch.nn.Module): + """Parallel WaveGAN Discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=30, + stacks=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + """Initialize Parallel WaveGAN Discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + dropout (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal structure. + nonlinear_activation_params (dict): Nonlinear function parameters + + """ + super(ResidualParallelWaveGANDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + self.first_conv = torch.nn.Sequential( + Conv1d1x1(in_channels, residual_channels, bias=True), + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=use_causal_conv, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList([ + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + Conv1d1x1(skip_channels, skip_channels, bias=True), + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + Conv1d1x1(skip_channels, out_channels, bias=True), + ]) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + Tensor: Output tensor (B, 1, T) + + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/modules/parallel_wavegan/models/source.py b/modules/parallel_wavegan/models/source.py new file mode 100644 index 0000000000000000000000000000000000000000..cf741734587bd2040dde7e3a275a1456720c977c --- /dev/null +++ b/modules/parallel_wavegan/models/source.py @@ -0,0 +1,405 @@ +import torch +import numpy as np +import sys +import torch.nn.functional as torch_nn_func + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - + tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, + device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + + # generate sine waveforms + sine_waves = self._f02sine(f0_buf) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class PulseGen(torch.nn.Module): + """ Definition of Pulse train generator + + There are many ways to implement pulse generator. + Here, PulseGen is based on SinGen. For a perfect + """ + def __init__(self, samp_rate, pulse_amp = 0.1, + noise_std = 0.003, voiced_threshold = 0): + super(PulseGen, self).__init__() + self.pulse_amp = pulse_amp + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.noise_std = noise_std + self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \ + sine_amp=self.pulse_amp, noise_std=0, \ + voiced_threshold=self.voiced_threshold, \ + flag_for_pulse=True) + + def forward(self, f0): + """ Pulse train generator + pulse_train, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output pulse_train: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + + Note: self.l_sine doesn't make sure that the initial phase of + a voiced segment is np.pi, the first pulse in a voiced segment + may not be at the first time step within a voiced segment + """ + with torch.no_grad(): + sine_wav, uv, noise = self.l_sinegen(f0) + + # sine without additive noise + pure_sine = sine_wav - noise + + # step t corresponds to a pulse if + # sine[t] > sine[t+1] & sine[t] > sine[t-1] + # & sine[t-1], sine[t+1], and sine[t] are voiced + # or + # sine[t] is voiced, sine[t-1] is unvoiced + # we use torch.roll to simulate sine[t+1] and sine[t-1] + sine_1 = torch.roll(pure_sine, shifts=1, dims=1) + uv_1 = torch.roll(uv, shifts=1, dims=1) + uv_1[:, 0, :] = 0 + sine_2 = torch.roll(pure_sine, shifts=-1, dims=1) + uv_2 = torch.roll(uv, shifts=-1, dims=1) + uv_2[:, -1, :] = 0 + + loc = (pure_sine > sine_1) * (pure_sine > sine_2) \ + * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \ + + (uv_1 < 1) * (uv > 0) + + # pulse train without noise + pulse_train = pure_sine * loc + + # additive noise to pulse train + # note that noise from sinegen is zero in voiced regions + pulse_noise = torch.randn_like(pure_sine) * self.noise_std + + # with additive noise on pulse, and unvoiced regions + pulse_train += pulse_noise * loc + pulse_noise * (1 - uv) + return pulse_train, sine_wav, uv, pulse_noise + + +class SignalsConv1d(torch.nn.Module): + """ Filtering input signal with time invariant filter + Note: FIRFilter conducted filtering given fixed FIR weight + SignalsConv1d convolves two signals + Note: this is based on torch.nn.functional.conv1d + + """ + + def __init__(self): + super(SignalsConv1d, self).__init__() + + def forward(self, signal, system_ir): + """ output = forward(signal, system_ir) + + signal: (batchsize, length1, dim) + system_ir: (length2, dim) + + output: (batchsize, length1, dim) + """ + if signal.shape[-1] != system_ir.shape[-1]: + print("Error: SignalsConv1d expects shape:") + print("signal (batchsize, length1, dim)") + print("system_id (batchsize, length2, dim)") + print("But received signal: {:s}".format(str(signal.shape))) + print(" system_ir: {:s}".format(str(system_ir.shape))) + sys.exit(1) + padding_length = system_ir.shape[0] - 1 + groups = signal.shape[-1] + + # pad signal on the left + signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \ + (padding_length, 0)) + # prepare system impulse response as (dim, 1, length2) + # also flip the impulse response + ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \ + dims=[2]) + # convolute + output = torch_nn_func.conv1d(signal_pad, ir, groups=groups) + return output.permute(0, 2, 1) + + +class CyclicNoiseGen_v1(torch.nn.Module): + """ CyclicnoiseGen_v1 + Cyclic noise with a single parameter of beta. + Pytorch v1 implementation assumes f_t is also fixed + """ + + def __init__(self, samp_rate, + noise_std=0.003, voiced_threshold=0): + super(CyclicNoiseGen_v1, self).__init__() + self.samp_rate = samp_rate + self.noise_std = noise_std + self.voiced_threshold = voiced_threshold + + self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0, + noise_std=noise_std, + voiced_threshold=voiced_threshold) + self.l_conv = SignalsConv1d() + + def noise_decay(self, beta, f0mean): + """ decayed_noise = noise_decay(beta, f0mean) + decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate) + + beta: (dim=1) or (batchsize=1, 1, dim=1) + f0mean (batchsize=1, 1, dim=1) + + decayed_noise (batchsize=1, length, dim=1) + """ + with torch.no_grad(): + # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T + # truncate the noise when decayed by -40 dB + length = 4.6 * self.samp_rate / f0mean + length = length.int() + time_idx = torch.arange(0, length, device=beta.device) + time_idx = time_idx.unsqueeze(0).unsqueeze(2) + time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2]) + + noise = torch.randn(time_idx.shape, device=beta.device) + + # due to Pytorch implementation, use f0_mean as the f0 factor + decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate) + return noise * self.noise_std * decay + + def forward(self, f0s, beta): + """ Producde cyclic-noise + """ + # pulse train + pulse_train, sine_wav, uv, noise = self.l_pulse(f0s) + pure_pulse = pulse_train - noise + + # decayed_noise (length, dim=1) + if (uv < 1).all(): + # all unvoiced + cyc_noise = torch.zeros_like(sine_wav) + else: + f0mean = f0s[uv > 0].mean() + + decayed_noise = self.noise_decay(beta, f0mean)[0, :, :] + # convolute + cyc_noise = self.l_conv(pure_pulse, decayed_noise) + + # add noise in invoiced segments + cyc_noise = cyc_noise + noise * (1.0 - uv) + return cyc_noise, pulse_train, sine_wav, uv, noise + + +class SourceModuleCycNoise_v1(torch.nn.Module): + """ SourceModuleCycNoise_v1 + SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + + noise_std: std of Gaussian noise (default: 0.003) + voiced_threshold: threshold to set U/V given F0 (default: 0) + + cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0): + super(SourceModuleCycNoise_v1, self).__init__() + self.sampling_rate = sampling_rate + self.noise_std = noise_std + self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std, + voiced_threshod) + + def forward(self, f0_upsamped, beta): + """ + cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + # source for harmonic branch + cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.noise_std / 3 + return cyc, noise, uv + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +if __name__ == '__main__': + source = SourceModuleCycNoise_v1(24000) + x = torch.randn(16, 25600, 1) + + diff --git a/modules/parallel_wavegan/optimizers/__init__.py b/modules/parallel_wavegan/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e0c5932838281e912079e5784d84d43444a61a --- /dev/null +++ b/modules/parallel_wavegan/optimizers/__init__.py @@ -0,0 +1,2 @@ +from torch.optim import * # NOQA +from .radam import * # NOQA diff --git a/modules/parallel_wavegan/optimizers/radam.py b/modules/parallel_wavegan/optimizers/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..e805d7e34921bee436e1e7fd9e1f753c7609186b --- /dev/null +++ b/modules/parallel_wavegan/optimizers/radam.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +"""RAdam optimizer. + +This code is drived from https://github.com/LiyuanLucasLiu/RAdam. +""" + +import math +import torch + +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + """Rectified Adam optimizer.""" + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + """Initilize RAdam optimizer.""" + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + """Set state.""" + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + """Run one step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/modules/parallel_wavegan/stft_loss.py b/modules/parallel_wavegan/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..229e6c777dc9ec7f710842d1e648dba1189ec8b4 --- /dev/null +++ b/modules/parallel_wavegan/stft_loss.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" +import librosa +import torch + +from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", + use_mel_loss=False): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + self.use_mel_loss = use_mel_loss + self.mel_basis = None + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + if self.use_mel_loss: + if self.mel_basis is None: + self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T + x_mag = x_mag @ self.mel_basis + y_mag = y_mag @ self.mel_basis + + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window", + use_mel_loss=False): + """Initialize Multi resolution STFT loss module. + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)] + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/modules/parallel_wavegan/utils/__init__.py b/modules/parallel_wavegan/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fa95a020706b5412c3959fbf6e5980019c0d5f --- /dev/null +++ b/modules/parallel_wavegan/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * # NOQA diff --git a/modules/parallel_wavegan/utils/utils.py b/modules/parallel_wavegan/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d48a5ed28e8555d4b8cfb15fdee86426bbb9e368 --- /dev/null +++ b/modules/parallel_wavegan/utils/utils.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Utility functions.""" + +import fnmatch +import logging +import os +import sys + +import h5py +import numpy as np + + +def find_files(root_dir, query="*.wav", include_root_dir=True): + """Find files recursively. + + Args: + root_dir (str): Root root_dir to find. + query (str): Query to find. + include_root_dir (bool): If False, root_dir name is not included. + + Returns: + list: List of found filenames. + + """ + files = [] + for root, dirnames, filenames in os.walk(root_dir, followlinks=True): + for filename in fnmatch.filter(filenames, query): + files.append(os.path.join(root, filename)) + if not include_root_dir: + files = [file_.replace(root_dir + "/", "") for file_ in files] + + return files + + +def read_hdf5(hdf5_name, hdf5_path): + """Read hdf5 dataset. + + Args: + hdf5_name (str): Filename of hdf5 file. + hdf5_path (str): Dataset name in hdf5 file. + + Return: + any: Dataset values. + + """ + if not os.path.exists(hdf5_name): + logging.error(f"There is no such a hdf5 file ({hdf5_name}).") + sys.exit(1) + + hdf5_file = h5py.File(hdf5_name, "r") + + if hdf5_path not in hdf5_file: + logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") + sys.exit(1) + + hdf5_data = hdf5_file[hdf5_path][()] + hdf5_file.close() + + return hdf5_data + + +def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): + """Write dataset to hdf5. + + Args: + hdf5_name (str): Hdf5 dataset filename. + hdf5_path (str): Dataset path in hdf5. + write_data (ndarray): Data to write. + is_overwrite (bool): Whether to overwrite dataset. + + """ + # convert to numpy array + write_data = np.array(write_data) + + # check folder existence + folder_name, _ = os.path.split(hdf5_name) + if not os.path.exists(folder_name) and len(folder_name) != 0: + os.makedirs(folder_name) + + # check hdf5 existence + if os.path.exists(hdf5_name): + # if already exists, open with r+ mode + hdf5_file = h5py.File(hdf5_name, "r+") + # check dataset existence + if hdf5_path in hdf5_file: + if is_overwrite: + logging.warning("Dataset in hdf5 file already exists. " + "recreate dataset in hdf5.") + hdf5_file.__delitem__(hdf5_path) + else: + logging.error("Dataset in hdf5 file already exists. " + "if you want to overwrite, please set is_overwrite = True.") + hdf5_file.close() + sys.exit(1) + else: + # if not exists, open with w mode + hdf5_file = h5py.File(hdf5_name, "w") + + # write data to hdf5 + hdf5_file.create_dataset(hdf5_path, data=write_data) + hdf5_file.flush() + hdf5_file.close() + + +class HDF5ScpLoader(object): + """Loader class for a fests.scp file of hdf5 file. + + Examples: + key1 /some/path/a.h5:feats + key2 /some/path/b.h5:feats + key3 /some/path/c.h5:feats + key4 /some/path/d.h5:feats + ... + >>> loader = HDF5ScpLoader("hdf5.scp") + >>> array = loader["key1"] + + key1 /some/path/a.h5 + key2 /some/path/b.h5 + key3 /some/path/c.h5 + key4 /some/path/d.h5 + ... + >>> loader = HDF5ScpLoader("hdf5.scp", "feats") + >>> array = loader["key1"] + + """ + + def __init__(self, feats_scp, default_hdf5_path="feats"): + """Initialize HDF5 scp loader. + + Args: + feats_scp (str): Kaldi-style feats.scp file with hdf5 format. + default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. + + """ + self.default_hdf5_path = default_hdf5_path + with open(feats_scp) as f: + lines = [line.replace("\n", "") for line in f.readlines()] + self.data = {} + for line in lines: + key, value = line.split() + self.data[key] = value + + def get_path(self, key): + """Get hdf5 file path for a given key.""" + return self.data[key] + + def __getitem__(self, key): + """Get ndarray for a given key.""" + p = self.data[key] + if ":" in p: + return read_hdf5(*p.split(":")) + else: + return read_hdf5(p, self.default_hdf5_path) + + def __len__(self): + """Return the length of the scp file.""" + return len(self.data) + + def __iter__(self): + """Return the iterator of the scp file.""" + return iter(self.data) + + def keys(self): + """Return the keys of the scp file.""" + return self.data.keys() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..4289235c6fa00f28f44a26dba879031b8982ff49 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,118 @@ +absl-py==0.11.0 +alignment==1.0.10 +altgraph==0.17 +appdirs==1.4.4 +async-timeout==3.0.1 +audioread==2.1.9 +backcall==0.2.0 +blinker==1.4 +brotlipy==0.7.0 +cachetools==4.2.0 +certifi==2020.12.5 +cffi==1.14.4 +chardet==4.0.0 +click==7.1.2 +cycler==0.10.0 +Cython==0.29.21 +cytoolz==0.11.0 +decorator==4.4.2 +Distance==0.1.3 +einops==0.3.0 +et-xmlfile==1.0.1 +fsspec==0.8.4 +future==0.18.2 +g2p-en==2.1.0 +g2pM==0.1.2.5 +google-auth==1.24.0 +google-auth-oauthlib==0.4.2 +grpcio==1.34.0 +h5py==3.1.0 +horology==1.1.0 +httplib2==0.18.1 +idna==2.10 +imageio==2.9.0 +inflect==5.0.2 +ipdb==0.13.4 +ipython==7.19.0 +ipython-genutils==0.2.0 +jdcal==1.4.1 +jedi==0.17.2 +jieba==0.42.1 +jiwer==2.2.0 +joblib==1.0.0 +kiwisolver==1.3.1 +librosa==0.8.0 +llvmlite==0.31.0 +Markdown==3.3.3 +matplotlib==3.3.3 +miditoolkit==0.1.7 +mido==1.2.9 +music21==5.7.2 +networkx==2.5 +nltk==3.5 +numba==0.48.0 +numpy==1.19.4 +oauth2client==4.1.3 +oauthlib==3.1.0 +olefile==0.46 +packaging==20.7 +pandas==1.2.0 +parso==0.7.1 +patsy==0.5.1 +pexpect==4.8.0 +pickleshare==0.7.5 +Pillow==8.0.1 +pooch==1.3.0 +praat-parselmouth==0.3.3 +prompt-toolkit==3.0.8 +protobuf==3.13.0 +ptyprocess==0.6.0 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycparser==2.20 +pycwt==0.3.0a22 +Pygments==2.7.3 +PyInstaller==3.6 +PyJWT==1.7.1 +pyloudnorm==0.1.0 +pyparsing==2.4.7 +pypinyin==0.39.0 +PySocks==1.7.1 +python-dateutil==2.8.1 +python-Levenshtein==0.12.0 +pytorch-lightning==0.7.1 +pytz==2020.5 +PyWavelets==1.1.1 +pyworld==0.2.12 +PyYAML==5.3.1 +regex==2020.11.13 +requests==2.25.1 +requests-oauthlib==1.3.0 +resampy==0.2.2 +Resemblyzer==0.1.1.dev0 +rsa==4.6 +scikit-image==0.16.2 +scikit-learn==0.22.2.post1 +scipy==1.5.4 +six==1.15.0 +SoundFile==0.10.3.post1 +stopit==1.1.1 +tensorboard==2.4.0 +tensorboard-plugin-wit==1.7.0 +tensorboardX==2.1 +TextGrid==1.5 +threadpoolctl==2.1.0 +toolz==0.11.1 +torch==1.6.0 +torchaudio==0.6.0 +torchvision==0.7.0 +tqdm==4.54.1 +traitlets==5.0.5 +typing==3.7.4.3 +urllib3==1.26.2 +uuid==1.30 +wcwidth==0.2.5 +webencodings==0.5.1 +webrtcvad==2.0.10 +Werkzeug==1.0.1 +pretty-midi==0.2.9 diff --git a/tasks/base_task.py b/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..b74d25c85ce8a86865c5d5a09f3f92579ffb2074 --- /dev/null +++ b/tasks/base_task.py @@ -0,0 +1,360 @@ +import glob +import re +import subprocess +from datetime import datetime + +import matplotlib + +matplotlib.use('Agg') + +from utils.hparams import hparams, set_hparams +import random +import sys +import numpy as np +import torch.distributed as dist +from pytorch_lightning.loggers import TensorBoardLogger +from utils.pl_utils import LatestModelCheckpoint, BaseTrainer, data_loader, DDP +from torch import nn +import torch.utils.data +import utils +import logging +import os + +torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, shuffle): + super().__init__() + self.hparams = hparams + self.shuffle = shuffle + self.sort_by_len = hparams['sort_by_len'] + self.sizes = None + + @property + def _sizes(self): + return self.sizes + + def __getitem__(self, index): + raise NotImplementedError + + def collater(self, samples): + raise NotImplementedError + + def __len__(self): + return len(self._sizes) + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + size = min(self._sizes[index], hparams['max_frames']) + return size + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + if self.sort_by_len: + indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] + # 先random, 然后稳定排序, 保证排序后同长度的数据顺序是依照random permutation的 (被其随机打乱). + else: + indices = np.arange(len(self)) + return indices + + @property + def num_workers(self): + return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) + + +class BaseTask(nn.Module): + def __init__(self, *args, **kwargs): + # dataset configs + super(BaseTask, self).__init__(*args, **kwargs) + self.current_epoch = 0 + self.global_step = 0 + self.loaded_optimizer_states_dict = {} + self.trainer = None + self.logger = None + self.on_gpu = False + self.use_dp = False + self.use_ddp = False + self.example_input_array = None + + self.max_tokens = hparams['max_tokens'] + self.max_sentences = hparams['max_sentences'] + self.max_eval_tokens = hparams['max_eval_tokens'] + if self.max_eval_tokens == -1: + hparams['max_eval_tokens'] = self.max_eval_tokens = self.max_tokens + self.max_eval_sentences = hparams['max_eval_sentences'] + if self.max_eval_sentences == -1: + hparams['max_eval_sentences'] = self.max_eval_sentences = self.max_sentences + + self.model = None + self.training_losses_meter = None + + ########### + # Training, validation and testing + ########### + def build_model(self): + raise NotImplementedError + + def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True): + # This function is updated on 2021.12.13 + if current_model_name is None: + current_model_name = model_name + utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict) + + def on_epoch_start(self): + self.training_losses_meter = {'total_loss': utils.AvgrageMeter()} + + def _training_step(self, sample, batch_idx, optimizer_idx): + """ + + :param sample: + :param batch_idx: + :return: total loss: torch.Tensor, loss_log: dict + """ + raise NotImplementedError + + def training_step(self, sample, batch_idx, optimizer_idx=-1): + loss_ret = self._training_step(sample, batch_idx, optimizer_idx) + self.opt_idx = optimizer_idx + if loss_ret is None: + return {'loss': None} + total_loss, log_outputs = loss_ret + log_outputs = utils.tensors_to_scalars(log_outputs) + for k, v in log_outputs.items(): + if k not in self.training_losses_meter: + self.training_losses_meter[k] = utils.AvgrageMeter() + if not np.isnan(v): + self.training_losses_meter[k].update(v) + self.training_losses_meter['total_loss'].update(total_loss.item()) + + try: + log_outputs['lr'] = self.scheduler.get_lr() + if isinstance(log_outputs['lr'], list): + log_outputs['lr'] = log_outputs['lr'][0] + except: + pass + + # log_outputs['all_loss'] = total_loss.item() + progress_bar_log = log_outputs + tb_log = {f'tr/{k}': v for k, v in log_outputs.items()} + return { + 'loss': total_loss, + 'progress_bar': progress_bar_log, + 'log': tb_log + } + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): + optimizer.step() + optimizer.zero_grad() + if self.scheduler is not None: + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) + + def on_epoch_end(self): + loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()} + print(f"\n==============\n " + f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}" + f"\n==============\n") + + def validation_step(self, sample, batch_idx): + """ + + :param sample: + :param batch_idx: + :return: output: dict + """ + raise NotImplementedError + + def _validation_end(self, outputs): + """ + + :param outputs: + :return: loss_output: dict + """ + raise NotImplementedError + + def validation_end(self, outputs): + loss_output = self._validation_end(outputs) + print(f"\n==============\n " + f"valid results: {loss_output}" + f"\n==============\n") + return { + 'log': {f'val/{k}': v for k, v in loss_output.items()}, + 'val_loss': loss_output['total_loss'] + } + + def build_scheduler(self, optimizer): + raise NotImplementedError + + def build_optimizer(self, model): + raise NotImplementedError + + def configure_optimizers(self): + optm = self.build_optimizer(self.model) + self.scheduler = self.build_scheduler(optm) + return [optm] + + def test_start(self): + pass + + def test_step(self, sample, batch_idx): + return self.validation_step(sample, batch_idx) + + def test_end(self, outputs): + return self.validation_end(outputs) + + ########### + # Running configuration + ########### + + @classmethod + def start(cls): + set_hparams() + os.environ['MASTER_PORT'] = str(random.randint(15000, 30000)) + random.seed(hparams['seed']) + np.random.seed(hparams['seed']) + task = cls() + work_dir = hparams['work_dir'] + trainer = BaseTrainer(checkpoint_callback=LatestModelCheckpoint( + filepath=work_dir, + verbose=True, + monitor='val_loss', + mode='min', + num_ckpt_keep=hparams['num_ckpt_keep'], + save_best=hparams['save_best'], + period=1 if hparams['save_ckpt'] else 100000 + ), + logger=TensorBoardLogger( + save_dir=work_dir, + name='lightning_logs', + version='lastest' + ), + gradient_clip_val=hparams['clip_grad_norm'], + val_check_interval=hparams['val_check_interval'], + row_log_interval=hparams['log_interval'], + max_updates=hparams['max_updates'], + num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams[ + 'validate'] else 10000, + accumulate_grad_batches=hparams['accumulate_grad_batches']) + if not hparams['infer']: # train + t = datetime.now().strftime('%Y%m%d%H%M%S') + code_dir = f'{work_dir}/codes/{t}' + subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True) + for c in hparams['save_codes']: + subprocess.check_call(f'cp -r "{c}" "{code_dir}/"', shell=True) + print(f"| Copied codes to {code_dir}.") + trainer.checkpoint_callback.task = task + trainer.fit(task) + else: + trainer.test(task) + + def configure_ddp(self, model, device_ids): + model = DDP( + model, + device_ids=device_ids, + find_unused_parameters=True + ) + if dist.get_rank() != 0 and not hparams['debug']: + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + random.seed(hparams['seed']) + np.random.seed(hparams['seed']) + return model + + def training_end(self, *args, **kwargs): + return None + + def init_ddp_connection(self, proc_rank, world_size): + set_hparams(print_hparams=False) + # guarantees unique ports across jobs from same grid search + default_port = 12910 + # if user gave a port number, use that one instead + try: + default_port = os.environ['MASTER_PORT'] + except Exception: + os.environ['MASTER_PORT'] = str(default_port) + + # figure out the root node addr + root_node = '127.0.0.2' + root_node = self.trainer.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) + + @data_loader + def train_dataloader(self): + return None + + @data_loader + def test_dataloader(self): + return None + + @data_loader + def val_dataloader(self): + return None + + def on_load_checkpoint(self, checkpoint): + pass + + def on_save_checkpoint(self, checkpoint): + pass + + def on_sanity_check_start(self): + pass + + def on_train_start(self): + pass + + def on_train_end(self): + pass + + def on_batch_start(self, batch): + pass + + def on_batch_end(self): + pass + + def on_pre_performance_check(self): + pass + + def on_post_performance_check(self): + pass + + def on_before_zero_grad(self, optimizer): + pass + + def on_after_backward(self): + pass + + def backward(self, loss, optimizer): + loss.backward() + + def grad_norm(self, norm_type): + results = {} + total_norm = 0 + for name, p in self.named_parameters(): + if p.requires_grad: + try: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm ** norm_type + norm = param_norm ** (1 / norm_type) + + grad = round(norm.data.cpu().numpy().flatten()[0], 3) + results['grad_{}_norm_{}'.format(norm_type, name)] = grad + except Exception: + # this param had no grad + pass + + total_norm = total_norm ** (1. / norm_type) + grad = round(total_norm.data.cpu().numpy().flatten()[0], 3) + results['grad_{}_norm_total'.format(norm_type)] = grad + return results diff --git a/tasks/run.py b/tasks/run.py new file mode 100644 index 0000000000000000000000000000000000000000..82c7559cec873eebf7c2c0ab6554895e21de7e7c --- /dev/null +++ b/tasks/run.py @@ -0,0 +1,15 @@ +import importlib +from utils.hparams import set_hparams, hparams + + +def run_task(): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task_cls.start() + + +if __name__ == '__main__': + set_hparams() + run_task() diff --git a/tasks/tts/fs2.py b/tasks/tts/fs2.py new file mode 100644 index 0000000000000000000000000000000000000000..32fb54f5bda486ece04598673cff367f5d8844fa --- /dev/null +++ b/tasks/tts/fs2.py @@ -0,0 +1,512 @@ +import matplotlib + +matplotlib.use('Agg') + +from utils import audio +import matplotlib.pyplot as plt +from data_gen.tts.data_gen_utils import get_pitch +from tasks.tts.fs2_utils import FastSpeechDataset +from utils.cwt import cwt2f0 +from utils.pl_utils import data_loader +import os +from multiprocessing.pool import Pool +from tqdm import tqdm +from modules.fastspeech.tts_modules import mel2ph_to_dur +from utils.hparams import hparams +from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure +from utils.pitch_utils import denorm_f0 +from modules.fastspeech.fs2 import FastSpeech2 +from tasks.tts.tts import TtsTask +import torch +import torch.optim +import torch.utils.data +import torch.nn.functional as F +import utils +import torch.distributions +import numpy as np +from modules.commons.ssim import ssim + +class FastSpeech2Task(TtsTask): + def __init__(self): + super(FastSpeech2Task, self).__init__() + self.dataset_cls = FastSpeechDataset + self.mse_loss_fn = torch.nn.MSELoss() + mel_losses = hparams['mel_loss'].split("|") + self.loss_and_lambda = {} + for i, l in enumerate(mel_losses): + if l == '': + continue + if ':' in l: + l, lbd = l.split(":") + lbd = float(lbd) + else: + lbd = 1.0 + self.loss_and_lambda[l] = lbd + print("| Mel losses:", self.loss_and_lambda) + self.sil_ph = self.phone_encoder.sil_phonemes() + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls(hparams['train_set_name'], shuffle=True) + return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences, + endless=hparams['endless_ds']) + + @data_loader + def val_dataloader(self): + valid_dataset = self.dataset_cls(hparams['valid_set_name'], shuffle=False) + return self.build_dataloader(valid_dataset, False, self.max_eval_tokens, self.max_eval_sentences) + + @data_loader + def test_dataloader(self): + test_dataset = self.dataset_cls(hparams['test_set_name'], shuffle=False) + return self.build_dataloader(test_dataset, False, self.max_eval_tokens, + self.max_eval_sentences, batch_by_size=False) + + def build_tts_model(self): + self.model = FastSpeech2(self.phone_encoder) + + def build_model(self): + self.build_tts_model() + if hparams['load_ckpt'] != '': + self.load_ckpt(hparams['load_ckpt'], strict=True) + utils.print_arch(self.model) + return self.model + + def _training_step(self, sample, batch_idx, _): + loss_output = self.run_model(self.model, sample) + total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + loss_output['batch_size'] = sample['txt_tokens'].size()[0] + return total_loss, loss_output + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + mel_out = self.model.out2mel(model_out['mel_out']) + outputs = utils.tensors_to_scalars(outputs) + # if sample['mels'].shape[0] == 1: + # self.add_laplace_var(mel_out, sample['mels'], outputs) + if batch_idx < hparams['num_valid_plots']: + self.plot_mel(batch_idx, sample['mels'], mel_out) + self.plot_dur(batch_idx, sample, model_out) + if hparams['use_pitch_embed']: + self.plot_pitch(batch_idx, sample, model_out) + return outputs + + def _validation_end(self, outputs): + all_losses_meter = { + 'total_loss': utils.AvgrageMeter(), + } + for output in outputs: + n = output['nsamples'] + for k, v in output['losses'].items(): + if k not in all_losses_meter: + all_losses_meter[k] = utils.AvgrageMeter() + all_losses_meter[k].update(v, n) + all_losses_meter['total_loss'].update(output['total_loss'], n) + return {k: round(v.avg, 4) for k, v in all_losses_meter.items()} + + def run_model(self, model, sample, return_output=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=False) + + losses = {} + self.add_mel_loss(output['mel_out'], target, losses) + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + ############ + # losses + ############ + def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None): + if mel_mix_loss is None: + for loss_name, lbd in self.loss_and_lambda.items(): + if 'l1' == loss_name: + l = self.l1_loss(mel_out, target) + elif 'mse' == loss_name: + raise NotImplementedError + elif 'ssim' == loss_name: + l = self.ssim_loss(mel_out, target) + elif 'gdl' == loss_name: + raise NotImplementedError + losses[f'{loss_name}{postfix}'] = l * lbd + else: + raise NotImplementedError + + def l1_loss(self, decoder_output, target): + # decoder_output : B x T x n_mel + # target : B x T x n_mel + l1_loss = F.l1_loss(decoder_output, target, reduction='none') + weights = self.weights_nonzero_speech(target) + l1_loss = (l1_loss * weights).sum() / weights.sum() + return l1_loss + + def ssim_loss(self, decoder_output, target, bias=6.0): + # decoder_output : B x T x n_mel + # target : B x T x n_mel + assert decoder_output.shape == target.shape + weights = self.weights_nonzero_speech(target) + decoder_output = decoder_output[:, None] + bias + target = target[:, None] + bias + ssim_loss = 1 - ssim(decoder_output, target, size_average=False) + ssim_loss = (ssim_loss * weights).sum() / weights.sum() + return ssim_loss + + def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None): + """ + + :param dur_pred: [B, T], float, log scale + :param mel2ph: [B, T] + :param txt_tokens: [B, T] + :param losses: + :return: + """ + B, T = txt_tokens.shape + nonpadding = (txt_tokens != 0).float() + dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding + is_sil = torch.zeros_like(txt_tokens).bool() + for p in self.sil_ph: + is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) + is_sil = is_sil.float() # [B, T_txt] + + # phone duration loss + if hparams['dur_loss'] == 'mse': + losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') + losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() + dur_pred = (dur_pred.exp() - 1).clamp(min=0) + elif hparams['dur_loss'] == 'mog': + return NotImplementedError + elif hparams['dur_loss'] == 'crf': + losses['pdur'] = -self.model.dur_predictor.crf( + dur_pred, dur_gt.long().clamp(min=0, max=31), mask=nonpadding > 0, reduction='mean') + losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur'] + + # use linear scale for sent and word duration + if hparams['lambda_word_dur'] > 0: + word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long() + word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:] + word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:] + wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') + word_nonpadding = (word_dur_g > 0).float() + wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() + losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] + if hparams['lambda_sent_dur'] > 0: + sent_dur_p = dur_pred.sum(-1) + sent_dur_g = dur_gt.sum(-1) + sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') + losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] + + def add_pitch_loss(self, output, sample, losses): + if hparams['pitch_type'] == 'ph': + nonpadding = (sample['txt_tokens'] != 0).float() + pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss + losses['f0'] = (pitch_loss_fn(output['pitch_pred'][:, :, 0], sample['f0'], + reduction='none') * nonpadding).sum() \ + / nonpadding.sum() * hparams['lambda_f0'] + return + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + nonpadding = (mel2ph != 0).float() + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + cwt_pred = output['cwt'][:, :, :10] + f0_mean_pred = output['f0_mean'] + f0_std_pred = output['f0_std'] + losses['C'] = self.cwt_loss(cwt_pred, cwt_spec) * hparams['lambda_f0'] + if hparams['use_uv']: + assert output['cwt'].shape[-1] == 11 + uv_pred = output['cwt'][:, :, -1] + losses['uv'] = (F.binary_cross_entropy_with_logits(uv_pred, uv, reduction='none') * nonpadding) \ + .sum() / nonpadding.sum() * hparams['lambda_uv'] + losses['f0_mean'] = F.l1_loss(f0_mean_pred, f0_mean) * hparams['lambda_f0'] + losses['f0_std'] = F.l1_loss(f0_std_pred, f0_std) * hparams['lambda_f0'] + if hparams['cwt_add_f0_loss']: + f0_cwt_ = self.model.cwt2f0_norm(cwt_pred, f0_mean_pred, f0_std_pred, mel2ph) + self.add_f0_loss(f0_cwt_[:, :, None], f0, uv, losses, nonpadding=nonpadding) + elif hparams['pitch_type'] == 'frame': + self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) + + def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding): + assert p_pred[..., 0].shape == f0.shape + if hparams['use_uv']: + assert p_pred[..., 1].shape == uv.shape + losses['uv'] = (F.binary_cross_entropy_with_logits( + p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \ + / nonpadding.sum() * hparams['lambda_uv'] + nonpadding = nonpadding * (uv == 0).float() + + f0_pred = p_pred[:, :, 0] + if hparams['pitch_loss'] in ['l1', 'l2']: + pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss + losses['f0'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \ + / nonpadding.sum() * hparams['lambda_f0'] + elif hparams['pitch_loss'] == 'ssim': + return NotImplementedError + + def cwt_loss(self, cwt_p, cwt_g): + if hparams['cwt_loss'] == 'l1': + return F.l1_loss(cwt_p, cwt_g) + if hparams['cwt_loss'] == 'l2': + return F.mse_loss(cwt_p, cwt_g) + if hparams['cwt_loss'] == 'ssim': + return self.ssim_loss(cwt_p, cwt_g, 20) + + def add_energy_loss(self, energy_pred, energy, losses): + nonpadding = (energy != 0).float() + loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum() + loss = loss * hparams['lambda_energy'] + losses['e'] = loss + + + ############ + # validation plots + ############ + def plot_mel(self, batch_idx, spec, spec_out, name=None): + spec_cat = torch.cat([spec, spec_out], -1) + name = f'mel_{batch_idx}' if name is None else name + vmin = hparams['mel_vmin'] + vmax = hparams['mel_vmax'] + self.logger.experiment.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step) + + def plot_dur(self, batch_idx, sample, model_out): + T_txt = sample['txt_tokens'].shape[1] + dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0] + dur_pred = self.model.dur_predictor.out2dur(model_out['dur']).float() + txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy()) + txt = txt.split(" ") + self.logger.experiment.add_figure( + f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step) + + def plot_pitch(self, batch_idx, sample, model_out): + f0 = sample['f0'] + if hparams['pitch_type'] == 'ph': + mel2ph = sample['mel2ph'] + f0 = self.expand_f0_ph(f0, mel2ph) + f0_pred = self.expand_f0_ph(model_out['pitch_pred'][:, :, 0], mel2ph) + self.logger.experiment.add_figure( + f'f0_{batch_idx}', f0_to_figure(f0[0], None, f0_pred[0]), self.global_step) + return + f0 = denorm_f0(f0, sample['uv'], hparams) + if hparams['pitch_type'] == 'cwt': + # cwt + cwt_out = model_out['cwt'] + cwt_spec = cwt_out[:, :, :10] + cwt = torch.cat([cwt_spec, sample['cwt_spec']], -1) + self.logger.experiment.add_figure(f'cwt_{batch_idx}', spec_to_figure(cwt[0]), self.global_step) + # f0 + f0_pred = cwt2f0(cwt_spec, model_out['f0_mean'], model_out['f0_std'], hparams['cwt_scales']) + if hparams['use_uv']: + assert cwt_out.shape[-1] == 11 + uv_pred = cwt_out[:, :, -1] > 0 + f0_pred[uv_pred > 0] = 0 + f0_cwt = denorm_f0(sample['f0_cwt'], sample['uv'], hparams) + self.logger.experiment.add_figure( + f'f0_{batch_idx}', f0_to_figure(f0[0], f0_cwt[0], f0_pred[0]), self.global_step) + elif hparams['pitch_type'] == 'frame': + # f0 + uv_pred = model_out['pitch_pred'][:, :, 1] > 0 + pitch_pred = denorm_f0(model_out['pitch_pred'][:, :, 0], uv_pred, hparams) + self.logger.experiment.add_figure( + f'f0_{batch_idx}', f0_to_figure(f0[0], None, pitch_pred[0]), self.global_step) + + ############ + # infer + ############ + def test_step(self, sample, batch_idx): + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + txt_tokens = sample['txt_tokens'] + mel2ph, uv, f0 = None, None, None + ref_mels = None + if hparams['profile_infer']: + pass + else: + if hparams['use_gt_dur']: + mel2ph = sample['mel2ph'] + if hparams['use_gt_f0']: + f0 = sample['f0'] + uv = sample['uv'] + print('Here using gt f0!!') + if hparams.get('use_midi') is not None and hparams['use_midi']: + outputs = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True, + pitch_midi=sample['pitch_midi'], midi_dur=sample.get('midi_dur'), is_slur=sample.get('is_slur')) + else: + outputs = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True) + sample['outputs'] = self.model.out2mel(outputs['mel_out']) + sample['mel2ph_pred'] = outputs['mel2ph'] + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + sample['f0'] = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel + sample['f0_pred'] = self.pe(sample['outputs'])['f0_denorm_pred'] # pe predict from Pred mel + else: + sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams) + sample['f0_pred'] = outputs.get('f0_denorm') + + return self.after_infer(sample) + + def after_infer(self, predictions): + if self.saving_result_pool is None and not hparams['profile_infer']: + self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16)) + self.saving_results_futures = [] + predictions = utils.unpack_dict_to_list(predictions) + t = tqdm(predictions) + for num_predictions, prediction in enumerate(t): + for k, v in prediction.items(): + if type(v) is torch.Tensor: + prediction[k] = v.cpu().numpy() + + item_name = prediction.get('item_name') + text = prediction.get('text').replace(":", "%3A")[:80] + + # remove paddings + mel_gt = prediction["mels"] + mel_gt_mask = np.abs(mel_gt).sum(-1) > 0 + mel_gt = mel_gt[mel_gt_mask] + mel2ph_gt = prediction.get("mel2ph") + mel2ph_gt = mel2ph_gt[mel_gt_mask] if mel2ph_gt is not None else None + mel_pred = prediction["outputs"] + mel_pred_mask = np.abs(mel_pred).sum(-1) > 0 + mel_pred = mel_pred[mel_pred_mask] + mel_gt = np.clip(mel_gt, hparams['mel_vmin'], hparams['mel_vmax']) + mel_pred = np.clip(mel_pred, hparams['mel_vmin'], hparams['mel_vmax']) + + mel2ph_pred = prediction.get("mel2ph_pred") + if mel2ph_pred is not None: + if len(mel2ph_pred) > len(mel_pred_mask): + mel2ph_pred = mel2ph_pred[:len(mel_pred_mask)] + mel2ph_pred = mel2ph_pred[mel_pred_mask] + + f0_gt = prediction.get("f0") + f0_pred = prediction.get("f0_pred") + if f0_pred is not None: + f0_gt = f0_gt[mel_gt_mask] + if len(f0_pred) > len(mel_pred_mask): + f0_pred = f0_pred[:len(mel_pred_mask)] + f0_pred = f0_pred[mel_pred_mask] + + str_phs = None + if self.phone_encoder is not None and 'txt_tokens' in prediction: + str_phs = self.phone_encoder.decode(prediction['txt_tokens'], strip_padding=True) + gen_dir = os.path.join(hparams['work_dir'], + f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred) + if not hparams['profile_infer']: + os.makedirs(gen_dir, exist_ok=True) + os.makedirs(f'{gen_dir}/wavs', exist_ok=True) + os.makedirs(f'{gen_dir}/plot', exist_ok=True) + os.makedirs(os.path.join(hparams['work_dir'], 'P_mels_npy'), exist_ok=True) + os.makedirs(os.path.join(hparams['work_dir'], 'G_mels_npy'), exist_ok=True) + self.saving_results_futures.append( + self.saving_result_pool.apply_async(self.save_result, args=[ + wav_pred, mel_pred, 'P', item_name, text, gen_dir, str_phs, mel2ph_pred, f0_gt, f0_pred])) + + if mel_gt is not None and hparams['save_gt']: + wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt) + self.saving_results_futures.append( + self.saving_result_pool.apply_async(self.save_result, args=[ + wav_gt, mel_gt, 'G', item_name, text, gen_dir, str_phs, mel2ph_gt, f0_gt, f0_pred])) + if hparams['save_f0']: + import matplotlib.pyplot as plt + # f0_pred_, _ = get_pitch(wav_pred, mel_pred, hparams) + f0_pred_ = f0_pred + f0_gt_, _ = get_pitch(wav_gt, mel_gt, hparams) + fig = plt.figure() + plt.plot(f0_pred_, label=r'$f0_P$') + plt.plot(f0_gt_, label=r'$f0_G$') + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + # f0_midi = prediction.get("f0_midi") + # f0_midi = f0_midi[mel_gt_mask] + # plt.plot(f0_midi, label=r'$f0_M$') + pass + plt.legend() + plt.tight_layout() + plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png') + plt.close(fig) + + t.set_description( + f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}") + else: + if 'gen_wav_time' not in self.stats: + self.stats['gen_wav_time'] = 0 + self.stats['gen_wav_time'] += len(wav_pred) / hparams['audio_sample_rate'] + print('gen_wav_time: ', self.stats['gen_wav_time']) + + return {} + + @staticmethod + def save_result(wav_out, mel, prefix, item_name, text, gen_dir, str_phs=None, mel2ph=None, gt_f0=None, pred_f0=None): + item_name = item_name.replace('/', '-') + base_fn = f'[{item_name}][{prefix}]' + + if text is not None: + base_fn += text + base_fn += ('-' + hparams['exp_name']) + np.save(os.path.join(hparams['work_dir'], f'{prefix}_mels_npy', item_name), mel) + audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', hparams['audio_sample_rate'], + norm=hparams['out_wav_norm']) + fig = plt.figure(figsize=(14, 10)) + spec_vmin = hparams['mel_vmin'] + spec_vmax = hparams['mel_vmax'] + heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax) + fig.colorbar(heatmap) + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + gt_f0 = (gt_f0 - 100) / (800 - 100) * 80 * (gt_f0 > 0) + pred_f0 = (pred_f0 - 100) / (800 - 100) * 80 * (pred_f0 > 0) + plt.plot(pred_f0, c='white', linewidth=1, alpha=0.6) + plt.plot(gt_f0, c='red', linewidth=1, alpha=0.6) + else: + f0, _ = get_pitch(wav_out, mel, hparams) + f0 = (f0 - 100) / (800 - 100) * 80 * (f0 > 0) + plt.plot(f0, c='white', linewidth=1, alpha=0.6) + if mel2ph is not None and str_phs is not None: + decoded_txt = str_phs.split(" ") + dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy() + dur = [0] + list(np.cumsum(dur)) + for i in range(len(dur) - 1): + shift = (i % 20) + 1 + plt.text(dur[i], shift, decoded_txt[i]) + plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black') + plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black', + alpha=1, linewidth=1) + plt.tight_layout() + plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png', dpi=1000) + plt.close(fig) + + ############## + # utils + ############## + @staticmethod + def expand_f0_ph(f0, mel2ph): + f0 = denorm_f0(f0, None, hparams) + f0 = F.pad(f0, [1, 0]) + f0 = torch.gather(f0, 1, mel2ph) # [B, T_mel] + return f0 + + +if __name__ == '__main__': + FastSpeech2Task.start() diff --git a/tasks/tts/fs2_utils.py b/tasks/tts/fs2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..092550863d2fd72f008cc790bc6d950340e68182 --- /dev/null +++ b/tasks/tts/fs2_utils.py @@ -0,0 +1,173 @@ +import matplotlib + +matplotlib.use('Agg') + +import glob +import importlib +from utils.cwt import get_lf0_cwt +import os +import torch.optim +import torch.utils.data +from utils.indexed_datasets import IndexedDataset +from utils.pitch_utils import norm_interp_f0 +import numpy as np +from tasks.base_task import BaseDataset +import torch +import torch.optim +import torch.utils.data +import utils +import torch.distributions +from utils.hparams import hparams + + +class FastSpeechDataset(BaseDataset): + def __init__(self, prefix, shuffle=False): + super().__init__(shuffle) + self.data_dir = hparams['binary_data_dir'] + self.prefix = prefix + self.hparams = hparams + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + self.indexed_ds = None + # self.name2spk_id={} + + # pitch stats + f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + else: + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None + + if prefix == 'test': + if hparams['test_input_dir'] != '': + self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir']) + else: + if hparams['num_test_samples'] > 0: + self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids'] + self.sizes = [self.sizes[i] for i in self.avail_idxs] + + if hparams['pitch_type'] == 'cwt': + _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10)) + + def _get_item(self, index): + if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: + index = self.avail_idxs[index] + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + return self.indexed_ds[index] + + def __getitem__(self, index): + hparams = self.hparams + item = self._get_item(index) + max_frames = hparams['max_frames'] + spec = torch.Tensor(item['mel'])[:max_frames] + energy = (spec.exp() ** 2).sum(-1).sqrt() + mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None + f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) + phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']]) + pitch = torch.LongTensor(item.get("pitch"))[:max_frames] + # print(item.keys(), item['mel'].shape, spec.shape) + sample = { + "id": index, + "item_name": item['item_name'], + "text": item['txt'], + "txt_token": phone, + "mel": spec, + "pitch": pitch, + "energy": energy, + "f0": f0, + "uv": uv, + "mel2ph": mel2ph, + "mel_nonpadding": spec.abs().sum(-1) > 0, + } + if self.hparams['use_spk_embed']: + sample["spk_embed"] = torch.Tensor(item['spk_embed']) + if self.hparams['use_spk_id']: + sample["spk_id"] = item['spk_id'] + # sample['spk_id'] = 0 + # for key in self.name2spk_id.keys(): + # if key in item['item_name']: + # sample['spk_id'] = self.name2spk_id[key] + # break + if self.hparams['pitch_type'] == 'cwt': + cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames] + f0_mean = item.get('f0_mean', item.get('cwt_mean')) + f0_std = item.get('f0_std', item.get('cwt_std')) + sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std}) + elif self.hparams['pitch_type'] == 'ph': + f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0) + f0_phlevel_num = torch.zeros_like(phone).float().scatter_add( + 0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1) + sample["f0_ph"] = f0_phlevel_sum / f0_phlevel_num + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + id = torch.LongTensor([s['id'] for s in samples]) + item_names = [s['item_name'] for s in samples] + text = [s['text'] for s in samples] + txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0) + f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) + pitch = utils.collate_1d([s['pitch'] for s in samples]) + uv = utils.collate_1d([s['uv'] for s in samples]) + energy = utils.collate_1d([s['energy'] for s in samples], 0.0) + mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ + if samples[0]['mel2ph'] is not None else None + mels = utils.collate_2d([s['mel'] for s in samples], 0.0) + txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples]) + mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) + + batch = { + 'id': id, + 'item_name': item_names, + 'nsamples': len(samples), + 'text': text, + 'txt_tokens': txt_tokens, + 'txt_lengths': txt_lengths, + 'mels': mels, + 'mel_lengths': mel_lengths, + 'mel2ph': mel2ph, + 'energy': energy, + 'pitch': pitch, + 'f0': f0, + 'uv': uv, + } + + if self.hparams['use_spk_embed']: + spk_embed = torch.stack([s['spk_embed'] for s in samples]) + batch['spk_embed'] = spk_embed + if self.hparams['use_spk_id']: + spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) + batch['spk_ids'] = spk_ids + if self.hparams['pitch_type'] == 'cwt': + cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples]) + f0_mean = torch.Tensor([s['f0_mean'] for s in samples]) + f0_std = torch.Tensor([s['f0_std'] for s in samples]) + batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std}) + elif self.hparams['pitch_type'] == 'ph': + batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples]) + + return batch + + def load_test_inputs(self, test_input_dir, spk_id=0): + inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3') + sizes = [] + items = [] + + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + binarization_args = hparams['binarization_args'] + + for wav_fn in inp_wav_paths: + item_name = os.path.basename(wav_fn) + ph = txt = tg_fn = '' + wav_fn = wav_fn + encoder = None + item = binarizer_cls.process_item(item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args) + items.append(item) + sizes.append(item['len']) + return items, sizes diff --git a/tasks/tts/pe.py b/tasks/tts/pe.py new file mode 100644 index 0000000000000000000000000000000000000000..3880c80d0820c36e044c00bd38a07fd3cce73323 --- /dev/null +++ b/tasks/tts/pe.py @@ -0,0 +1,155 @@ +import matplotlib +matplotlib.use('Agg') + +import torch +import numpy as np +import os + +from tasks.base_task import BaseDataset +from tasks.tts.fs2 import FastSpeech2Task +from modules.fastspeech.pe import PitchExtractor +import utils +from utils.indexed_datasets import IndexedDataset +from utils.hparams import hparams +from utils.plot import f0_to_figure +from utils.pitch_utils import norm_interp_f0, denorm_f0 + + +class PeDataset(BaseDataset): + def __init__(self, prefix, shuffle=False): + super().__init__(shuffle) + self.data_dir = hparams['binary_data_dir'] + self.prefix = prefix + self.hparams = hparams + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + self.indexed_ds = None + + # pitch stats + f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + else: + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None + + if prefix == 'test': + if hparams['num_test_samples'] > 0: + self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids'] + self.sizes = [self.sizes[i] for i in self.avail_idxs] + + def _get_item(self, index): + if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: + index = self.avail_idxs[index] + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + return self.indexed_ds[index] + + def __getitem__(self, index): + hparams = self.hparams + item = self._get_item(index) + max_frames = hparams['max_frames'] + spec = torch.Tensor(item['mel'])[:max_frames] + # mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None + f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) + pitch = torch.LongTensor(item.get("pitch"))[:max_frames] + # print(item.keys(), item['mel'].shape, spec.shape) + sample = { + "id": index, + "item_name": item['item_name'], + "text": item['txt'], + "mel": spec, + "pitch": pitch, + "f0": f0, + "uv": uv, + # "mel2ph": mel2ph, + # "mel_nonpadding": spec.abs().sum(-1) > 0, + } + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + id = torch.LongTensor([s['id'] for s in samples]) + item_names = [s['item_name'] for s in samples] + text = [s['text'] for s in samples] + f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) + pitch = utils.collate_1d([s['pitch'] for s in samples]) + uv = utils.collate_1d([s['uv'] for s in samples]) + mels = utils.collate_2d([s['mel'] for s in samples], 0.0) + mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) + # mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ + # if samples[0]['mel2ph'] is not None else None + # mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0) + + batch = { + 'id': id, + 'item_name': item_names, + 'nsamples': len(samples), + 'text': text, + 'mels': mels, + 'mel_lengths': mel_lengths, + 'pitch': pitch, + # 'mel2ph': mel2ph, + # 'mel_nonpaddings': mel_nonpaddings, + 'f0': f0, + 'uv': uv, + } + return batch + + +class PitchExtractionTask(FastSpeech2Task): + def __init__(self): + super().__init__() + self.dataset_cls = PeDataset + + def build_tts_model(self): + self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers']) + + # def build_scheduler(self, optimizer): + # return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) + def _training_step(self, sample, batch_idx, _): + loss_output = self.run_model(self.model, sample) + total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + loss_output['batch_size'] = sample['mels'].size()[0] + return total_loss, loss_output + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + self.plot_pitch(batch_idx, model_out, sample) + return outputs + + def run_model(self, model, sample, return_output=False, infer=False): + f0 = sample['f0'] + uv = sample['uv'] + output = model(sample['mels']) + losses = {} + self.add_pitch_loss(output, sample, losses) + if not return_output: + return losses + else: + return losses, output + + def plot_pitch(self, batch_idx, model_out, sample): + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.logger.experiment.add_figure( + f'f0_{batch_idx}', + f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]), + self.global_step) + + def add_pitch_loss(self, output, sample, losses): + # mel2ph = sample['mel2ph'] # [B, T_s] + mel = sample['mels'] + f0 = sample['f0'] + uv = sample['uv'] + # nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \ + # else (sample['txt_tokens'] != 0).float() + nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings'] + # print(nonpadding[0][-8:], nonpadding.shape) + self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) \ No newline at end of file diff --git a/tasks/tts/tts.py b/tasks/tts/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..f803c1e738137cb1eca19a1943196abd2884c0a5 --- /dev/null +++ b/tasks/tts/tts.py @@ -0,0 +1,131 @@ +from multiprocessing.pool import Pool + +import matplotlib + +from utils.pl_utils import data_loader +from utils.training_utils import RSQRTSchedule +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from modules.fastspeech.pe import PitchExtractor + +matplotlib.use('Agg') +import os +import numpy as np +from tqdm import tqdm +import torch.distributed as dist + +from tasks.base_task import BaseTask +from utils.hparams import hparams +from utils.text_encoder import TokenTextEncoder +import json + +import torch +import torch.optim +import torch.utils.data +import utils + + + +class TtsTask(BaseTask): + def __init__(self, *args, **kwargs): + self.vocoder = None + self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) + self.padding_idx = self.phone_encoder.pad() + self.eos_idx = self.phone_encoder.eos() + self.seg_idx = self.phone_encoder.seg() + self.saving_result_pool = None + self.saving_results_futures = None + self.stats = {} + super().__init__(*args, **kwargs) + + def build_scheduler(self, optimizer): + return RSQRTSchedule(optimizer) + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + model.parameters(), + lr=hparams['lr']) + return optimizer + + def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, + required_batch_size_multiple=-1, endless=False, batch_by_size=True): + devices_cnt = torch.cuda.device_count() + if devices_cnt == 0: + devices_cnt = 1 + if required_batch_size_multiple == -1: + required_batch_size_multiple = devices_cnt + + def shuffle_batches(batches): + np.random.shuffle(batches) + return batches + + if max_tokens is not None: + max_tokens *= devices_cnt + if max_sentences is not None: + max_sentences *= devices_cnt + indices = dataset.ordered_indices() + if batch_by_size: + batch_sampler = utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + batch_sampler = [] + for i in range(0, len(indices), max_sentences): + batch_sampler.append(indices[i:i + max_sentences]) + + if shuffle: + batches = shuffle_batches(list(batch_sampler)) + if endless: + batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] + else: + batches = batch_sampler + if endless: + batches = [b for _ in range(1000) for b in batches] + num_workers = dataset.num_workers + if self.trainer.use_ddp: + num_replicas = dist.get_world_size() + rank = dist.get_rank() + batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] + return torch.utils.data.DataLoader(dataset, + collate_fn=dataset.collater, + batch_sampler=batches, + num_workers=num_workers, + pin_memory=False) + + def build_phone_encoder(self, data_dir): + phone_list_file = os.path.join(data_dir, 'phone_set.json') + + phone_list = json.load(open(phone_list_file)) + return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + model.parameters(), + lr=hparams['lr']) + return optimizer + + def test_start(self): + self.saving_result_pool = Pool(8) + self.saving_results_futures = [] + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + self.pe = PitchExtractor().cuda() + utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) + self.pe.eval() + def test_end(self, outputs): + self.saving_result_pool.close() + [f.get() for f in tqdm(self.saving_results_futures)] + self.saving_result_pool.join() + return {} + + ########## + # utils + ########## + def weights_nonzero_speech(self, target): + # target : B x T x mel + # Assign weight 1.0 to all labels except for padding (id=0). + dim = target.size(-1) + return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) + +if __name__ == '__main__': + TtsTask.start() diff --git a/usr/__init__.py b/usr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/usr/configs/base.yaml b/usr/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f5c13e63109eada7eaa55fd6b48fe88df5f51739 --- /dev/null +++ b/usr/configs/base.yaml @@ -0,0 +1,24 @@ +task_cls: usr.task.DiffFsTask +pitch_type: frame +timesteps: 100 +dilation_cycle_length: 1 +residual_layers: 20 +residual_channels: 256 +lr: 0.001 +decay_steps: 50000 +keep_bins: 80 +spec_min: [ ] +spec_max: [ ] + +content_cond_steps: [ ] # [ 0, 10000 ] +spk_cond_steps: [ ] # [ 0, 10000 ] +# train and eval +fs2_ckpt: '' +max_updates: 400000 +# max_updates: 200000 +use_gt_dur: true +use_gt_f0: true +gen_tgt_spk_id: -1 +max_sentences: 48 +num_sanity_val_steps: 1 +num_valid_plots: 1 diff --git a/usr/configs/lj_ds_beta6.yaml b/usr/configs/lj_ds_beta6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a0fcebbeede3983efb0e1466ac31cdc2373c8daa --- /dev/null +++ b/usr/configs/lj_ds_beta6.yaml @@ -0,0 +1,43 @@ +base_config: + - configs/tts/lj/fs2.yaml + - ./base.yaml +# spec_min and spec_max are calculated on the training set. +spec_min: [ -4.7574, -4.6783, -4.6431, -4.5832, -4.5390, -4.6771, -4.8089, -4.7672, + -4.5784, -4.7755, -4.7150, -4.8919, -4.8271, -4.7389, -4.6047, -4.7759, + -4.6799, -4.8201, -4.7823, -4.8262, -4.7857, -4.7545, -4.9358, -4.9733, + -5.1134, -5.1395, -4.9016, -4.8434, -5.0189, -4.8460, -5.0529, -4.9510, + -5.0217, -5.0049, -5.1831, -5.1445, -5.1015, -5.0281, -4.9887, -4.9916, + -4.9785, -4.9071, -4.9488, -5.0342, -4.9332, -5.0650, -4.8924, -5.0875, + -5.0483, -5.0848, -5.1809, -5.0677, -5.0015, -5.0792, -5.0636, -5.2413, + -5.1421, -5.1710, -5.3256, -5.0511, -5.1186, -5.0057, -5.0446, -5.1173, + -5.0325, -5.1085, -5.0053, -5.0755, -5.1176, -5.1004, -5.2153, -5.2757, + -5.3025, -5.2867, -5.2918, -5.3328, -5.2731, -5.2985, -5.2400, -5.2211 ] +spec_max: [ -0.5982, -0.0778, 0.1205, 0.2747, 0.4657, 0.5123, 0.5684, 0.7093, + 0.6461, 0.6420, 0.7316, 0.7715, 0.7681, 0.8349, 0.7815, 0.7591, + 0.7910, 0.7433, 0.7352, 0.6869, 0.6854, 0.6623, 0.5353, 0.6492, + 0.6909, 0.6106, 0.5761, 0.5936, 0.5638, 0.4054, 0.4545, 0.3589, + 0.3037, 0.3380, 0.1599, 0.2433, 0.2741, 0.2130, 0.1569, 0.1911, + 0.2324, 0.1586, 0.1221, 0.0341, -0.0558, 0.0553, -0.1153, -0.0933, + -0.1171, -0.0050, -0.1519, -0.1629, -0.0522, -0.0739, -0.2069, -0.2405, + -0.1244, -0.2116, -0.1361, -0.1575, -0.1442, 0.0513, -0.1567, -0.2000, + 0.0086, -0.0698, 0.1385, 0.0941, 0.1864, 0.1225, 0.2176, 0.2566, + 0.1670, 0.1007, 0.1444, 0.0888, 0.1998, 0.2414, 0.2932, 0.3047 ] + +task_cls: usr.diffspeech_task.DiffSpeechTask +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/0414_hifi_lj_1 +num_valid_plots: 10 +use_gt_dur: false +use_gt_f0: false +pitch_type: cwt +pitch_extractor: 'parselmouth' +max_updates: 160000 +lr: 0.001 +timesteps: 100 +K_step: 71 +diff_loss_type: l1 +diff_decoder_type: 'wavenet' +schedule_type: 'linear' +max_beta: 0.06 +fs2_ckpt: checkpoints/fs2_lj_1/model_ckpt_steps_150000.ckpt +save_gt: true \ No newline at end of file diff --git a/usr/configs/m4singer/base.yaml b/usr/configs/m4singer/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..98fc9dc92d3129d6a1a76fb4ec2de23ef87bfa8a --- /dev/null +++ b/usr/configs/m4singer/base.yaml @@ -0,0 +1,77 @@ +base_config: + - usr/configs/popcs_ds_beta6.yaml +binarizer_cls: data_gen.singing.binarize.M4SingerBinarizer + +raw_data_dir: 'data/raw/m4singer' +processed_data_dir: 'xxx' +binary_data_dir: 'data/binary/m4singer' +datasets: [ + 'm4singer', +] +test_prefixes: [ + 'Alto-2#岁月神偷', + 'Alto-2#奇妙能力歌', + 'Tenor-1#一千年以后', + 'Tenor-1#童话', + 'Tenor-2#消愁', + 'Tenor-2#一荤一素', + 'Soprano-1#念奴娇赤壁怀古', + 'Soprano-1#问春', +] +num_spk: 20 +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/m4singer_hifigan +pe_enable: true +pe_ckpt: 'checkpoints/m4singer_pe' + +mel_vmin: -6. +mel_vmax: 1.5 +wav2spec_eps: 1e-6 +audio_sample_rate: 24000 +hop_size: 128 # Hop size. +fft_size: 512 # FFT size. +win_size: 512 # FFT size. +fmin: 30 +fmax: 12000 +min_level_db: -120 + +use_pitch_embed: true +use_spk_embed: false +use_spk_id: true +use_midi: true +use_gt_f0: false +use_gt_dur: false + +lambda_f0: 1.0 +lambda_uv: 1.0 +#lambda_energy: 0.1 +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_word_dur: 1.0 +predictor_grad: 0.1 +hidden_size: 256 + +binarization_args: + with_wav: false + with_spk_embed: true + with_align: true + +fs2_ckpt: '' + +use_nsf: true + +# config for experiments +max_frames: 5000 +max_tokens: 40000 +max_sentences: 12 +predictor_layers: 5 +rel_pos: true +dur_predictor_layers: 5 +dur_predictor_kernel: 3 + +num_valid_plots: 10 +save_gt: true + +spec_max: [-0.3894500136375427, -0.3796464204788208, -0.2914905250072479, -0.15550297498703003, -0.08502643555402756, 0.10698417574167252, -0.0739326998591423, -0.0541548952460289, 0.15501998364925385, 0.06483431905508041, 0.03054228238761425, -0.013737732544541359, -0.004876468330621719, 0.04368264228105545, 0.13329921662807465, 0.16471388936042786, 0.04605761915445328, -0.05680707097053528, 0.0542571023106575, -0.0076539707370102406, -0.00953489076346159, -0.04434828832745552, 0.001293870504014194, -0.12238839268684387, 0.06418416649103165, 0.02843189612030983, 0.08505241572856903, 0.07062800228595734, 0.00120724702719599, -0.07675088942050934, 0.03785804659128189, 0.04890783503651619, -0.06888376921415329, -0.0839693546295166, -0.17545585334300995, -0.2911079525947571, -0.4238220453262329, -0.262084037065506, -0.3002263605594635, -0.3845032751560211, -0.3906497061252594, -0.6550108790397644, -0.7810799479484558, -0.7503029704093933, -0.7995198965072632, -0.8092347383499146, -0.6196113228797913, -0.6684317588806152, -0.7735874056816101, -0.8324533104896545, -0.9601566791534424, -0.955253541469574, -0.748817503452301, -0.9106167554855347, -0.9707801342010498, -1.053107500076294, -1.0448424816131592, -1.1082794666290283, -1.1296544075012207, -1.071642279624939, -1.1003081798553467, -1.166810154914856, -1.1408926248550415, -1.1330615282058716, -1.1167492866516113, -1.0716774463653564, -1.035891056060791, -1.0092483758926392, -0.9675999879837036, -0.938962996006012, -1.0120564699172974, -0.9777995347976685, -1.029313564300537, -0.9459163546562195, -0.8519706130027771, -0.7751091122627258, -0.7933766841888428, -0.9019735455513, -0.9983296990394592, -1.505873441696167] +spec_min: [-6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0, -6.0] + diff --git a/usr/configs/m4singer/diff.yaml b/usr/configs/m4singer/diff.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fb890454119bdfab8e5dcda8070e2b96d26fa65a --- /dev/null +++ b/usr/configs/m4singer/diff.yaml @@ -0,0 +1,33 @@ +base_config: + - usr/configs/m4singer/base.yaml + +use_midi: true # for midi exp +use_gt_f0: false # for midi exp +use_gt_dur: false # for further midi exp +lambda_ph_dur: 1.0 +lambda_sent_dur: 1.0 +lambda_word_dur: 1.0 +predictor_grad: 0.1 + +fs2_ckpt: 'checkpoints/m4singer_fs2_e2e' +task_cls: usr.diffsinger_task.DiffSingerMIDITask + +# for diffusion schedule +timesteps: 1000 +K_step: 1000 +max_beta: 0.02 +max_tokens: 36000 +max_updates: 900000 +max_sentences: 28 +gaussian_start: True +pndm_speedup: 5 + +use_pitch_embed: false +decay_steps: 100000 +lambda_f0: 0. +lambda_uv: 0. +dilation_cycle_length: 4 +residual_layers: 20 +residual_channels: 256 +rel_pos: true +pe_enable: true diff --git a/usr/configs/m4singer/fs2.yaml b/usr/configs/m4singer/fs2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..060603d7aaa42ace4ff73cbcf3fff5bfdbf3afb9 --- /dev/null +++ b/usr/configs/m4singer/fs2.yaml @@ -0,0 +1,16 @@ +base_config: + - configs/singing/fs2.yaml + - usr/configs/m4singer/base.yaml +task_cls: usr.diffsinger_task.AuxDecoderMIDITask + +max_frames: 5000 +max_tokens: 40000 +rel_pos: true + +use_pitch_embed: false +num_valid_plots: 10 +max_updates: 320000 +lr: 1 +max_sentences: 12 +save_gt: true +pe_enable: true diff --git a/usr/configs/popcs_ds_beta6.yaml b/usr/configs/popcs_ds_beta6.yaml new file mode 100644 index 0000000000000000000000000000000000000000..699e9f7d1d1efee2885b8a3b57eb8080473b073a --- /dev/null +++ b/usr/configs/popcs_ds_beta6.yaml @@ -0,0 +1,70 @@ +base_config: + - configs/tts/fs2.yaml + - configs/singing/base.yaml + - ./base.yaml + +audio_sample_rate: 24000 +hop_size: 128 # Hop size. +fft_size: 512 # FFT size. +win_size: 512 # FFT size. +fmin: 30 +fmax: 12000 +min_level_db: -120 + +binarization_args: + with_wav: true + with_spk_embed: false + with_align: true +raw_data_dir: 'data/raw/popcs' +processed_data_dir: 'data/processed/popcs' +binary_data_dir: 'data/binary/popcs-pmf0' +num_spk: 1 +datasets: [ + 'popcs', +] +test_prefixes: [ + 'popcs-说散就散', + 'popcs-隐形的翅膀', +] + +spec_min: [-6.8276, -7.0270, -6.8142, -7.1429, -7.6669, -7.6000, -7.1148, -6.9640, + -6.8414, -6.6596, -6.6880, -6.7439, -6.7986, -7.4940, -7.7845, -7.6586, + -6.9288, -6.7639, -6.9118, -6.8246, -6.7183, -7.1769, -6.9794, -7.4513, + -7.3422, -7.5623, -6.9610, -6.8158, -6.9595, -6.8403, -6.5688, -6.6356, + -7.0209, -6.5002, -6.7819, -6.5232, -6.6927, -6.5701, -6.5531, -6.7069, + -6.6462, -6.4523, -6.5954, -6.4264, -6.4487, -6.7070, -6.4025, -6.3042, + -6.4008, -6.3857, -6.3903, -6.3094, -6.2491, -6.3518, -6.3566, -6.4168, + -6.2481, -6.3624, -6.2858, -6.2575, -6.3638, -6.4520, -6.1835, -6.2754, + -6.1253, -6.1645, -6.0638, -6.1262, -6.0710, -6.1039, -6.4428, -6.1363, + -6.1054, -6.1252, -6.1797, -6.0235, -6.0758, -5.9453, -6.0213, -6.0446] +spec_max: [ 0.2645, 0.0583, -0.2344, -0.0184, 0.1227, 0.1533, 0.1103, 0.1212, + 0.2421, 0.1809, 0.2134, 0.3161, 0.3301, 0.3289, 0.2667, 0.2421, + 0.2581, 0.2600, 0.1394, 0.1907, 0.1082, 0.1474, 0.1680, 0.2550, + 0.1057, 0.0826, 0.0423, 0.1203, -0.0701, -0.0056, 0.0477, -0.0639, + -0.0272, -0.0728, -0.1648, -0.0855, -0.2652, -0.1998, -0.1547, -0.2167, + -0.4181, -0.5463, -0.4161, -0.4733, -0.6518, -0.5387, -0.4290, -0.4191, + -0.4151, -0.3042, -0.3810, -0.4160, -0.4496, -0.2847, -0.4676, -0.4658, + -0.4931, -0.4885, -0.5547, -0.5481, -0.6948, -0.7968, -0.8455, -0.8392, + -0.8770, -0.9520, -0.8749, -0.7297, -0.8374, -0.8667, -0.7157, -0.9035, + -0.9219, -0.8801, -0.9298, -0.9009, -0.9604, -1.0537, -1.0781, -1.3766] + +task_cls: usr.diffsinger_task.DiffSingerTask +#vocoder: usr.singingvocoder.highgan.HighGAN +#vocoder_ckpt: checkpoints/h_2_model/checkpoint-530000steps.pkl +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/0109_hifigan_bigpopcs_hop128 + +pitch_extractor: 'parselmouth' +# config for experiments +use_spk_embed: false +num_valid_plots: 10 +max_updates: 160000 +lr: 0.001 +timesteps: 100 +K_step: 51 +diff_loss_type: l1 +diff_decoder_type: 'wavenet' +schedule_type: 'linear' +max_beta: 0.06 +fs2_ckpt: '' +use_nsf: true \ No newline at end of file diff --git a/usr/configs/popcs_ds_beta6_offline.yaml b/usr/configs/popcs_ds_beta6_offline.yaml new file mode 100644 index 0000000000000000000000000000000000000000..84d15d9f5e56043dca851c4a3e0ff21bc2babfb3 --- /dev/null +++ b/usr/configs/popcs_ds_beta6_offline.yaml @@ -0,0 +1,12 @@ +base_config: + - ./popcs_ds_beta6.yaml + +fs2_ckpt: checkpoints/popcs_fs2_pmf0_1230/model_ckpt_steps_160000.ckpt # to be infer +num_valid_plots: 0 +task_cls: usr.diffsinger_task.DiffSingerOfflineTask + +# tmp: +#pe_enable: true +#pe_ckpt: '' +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/0109_hifigan_bigpopcs_hop128 \ No newline at end of file diff --git a/usr/configs/popcs_fs2.yaml b/usr/configs/popcs_fs2.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a6b08e9a79049934d5efc30ce6eba37117467019 --- /dev/null +++ b/usr/configs/popcs_fs2.yaml @@ -0,0 +1,44 @@ +base_config: + - configs/singing/fs2.yaml + +audio_sample_rate: 24000 +hop_size: 128 # Hop size. +fft_size: 512 # FFT size. +win_size: 512 # FFT size. +fmin: 30 +fmax: 12000 +min_level_db: -120 + +binarization_args: + with_wav: true + with_spk_embed: false + with_align: true +raw_data_dir: 'data/raw/popcs' +processed_data_dir: 'data/processed/popcs' +binary_data_dir: 'data/binary/popcs-pmf0' +num_spk: 1 +datasets: [ + 'popcs', +] +test_prefixes: [ + 'popcs-说散就散', + 'popcs-隐形的翅膀', +] + +task_cls: tasks.tts.fs2.FastSpeech2Task +#vocoder: usr.singingvocoder.highgan.HighGAN +#vocoder_ckpt: checkpoints/h_2_model/checkpoint-530000steps.pkl +vocoder: vocoders.hifigan.HifiGAN +vocoder_ckpt: checkpoints/0109_hifigan_bigpopcs_hop128 +use_nsf: true + +# config for experiments +max_tokens: 18000 +use_spk_embed: false +num_valid_plots: 10 +max_updates: 160000 +save_gt: true + +# tmp: +#pe_enable: true +#pe_ckpt: '' \ No newline at end of file diff --git a/usr/diff/candidate_decoder.py b/usr/diff/candidate_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..133a51a61942027c255841e2638e296238c07a30 --- /dev/null +++ b/usr/diff/candidate_decoder.py @@ -0,0 +1,96 @@ +from modules.fastspeech.tts_modules import FastspeechDecoder +# from modules.fastspeech.fast_tacotron import DecoderRNN +# from modules.fastspeech.speedy_speech.speedy_speech import ConvBlocks +# from modules.fastspeech.conformer.conformer import ConformerDecoder +import torch +from torch.nn import functional as F +import torch.nn as nn +import math +from utils.hparams import hparams +from .diffusion import Mish +Linear = nn.Linear + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +class FFT(FastspeechDecoder): + def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None): + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads) + dim = hparams['residual_channels'] + self.input_projection = Conv1d(hparams['audio_num_mel_bins'], dim, 1) + self.diffusion_embedding = SinusoidalPosEmb(dim) + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.get_mel_out = Linear(hparams['hidden_size'], 80, bias=True) + self.get_decode_inp = Linear(hparams['hidden_size'] + dim + dim, + hparams['hidden_size']) # hs + dim + 80 -> hs + + def forward(self, spec, diffusion_step, cond, padding_mask=None, attn_mask=None, return_hiddens=False): + """ + :param spec: [B, 1, 80, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec[:, 0] + x = self.input_projection(x).permute([0, 2, 1]) # [B, T, residual_channel] + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) # [B, dim] + cond = cond.permute([0, 2, 1]) # [B, T, M] + + seq_len = cond.shape[1] # [T_mel] + time_embed = diffusion_step[:, None, :] # [B, 1, dim] + time_embed = time_embed.repeat([1, seq_len, 1]) # # [B, T, dim] + + decoder_inp = torch.cat([x, cond, time_embed], dim=-1) # [B, T, dim + H + dim] + decoder_inp = self.get_decode_inp(decoder_inp) # [B, T, H] + x = decoder_inp + + ''' + Required x: [B, T, C] + :return: [B, T, C] or [L, B, T, C] + ''' + padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1] + if self.use_pos_embed: + positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) * nonpadding_mask_TB + hiddens = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB + hiddens.append(x) + if self.use_last_norm: + x = self.layer_norm(x) * nonpadding_mask_TB + if return_hiddens: + x = torch.stack(hiddens, 0) # [L, T, B, C] + x = x.transpose(1, 2) # [L, B, T, C] + else: + x = x.transpose(0, 1) # [B, T, C] + + x = self.get_mel_out(x).permute([0, 2, 1]) # [B, 80, T] + return x[:, None, :, :] \ No newline at end of file diff --git a/usr/diff/diffusion.py b/usr/diff/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..c30976ab258feff830c2fa1a2d70876cb1d76eda --- /dev/null +++ b/usr/diff/diffusion.py @@ -0,0 +1,334 @@ +import math +import random +from functools import partial +from inspect import isfunction +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from modules.diffsinger_midi.fs2 import FastSpeech2MIDI +from utils.hparams import hparams + + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class Upsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(dim, dim_out, 3, padding=1), + nn.GroupNorm(groups, dim_out), + Mish() + ) + + def forward(self, x): + return self.block(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim, groups=8): + super().__init__() + self.mlp = nn.Sequential( + Mish(), + nn.Linear(time_emb_dim, dim_out) + ) + + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb): + h = self.block1(x) + h += self.mlp(time_emb)[:, :, None, None] + h = self.block2(h) + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, + timesteps=1000, loss_type='l1', betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + if hparams.get('use_midi') is not None and hparams['use_midi']: + self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims) + else: + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.fs2.decoder = None + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond, clip_denoised: bool): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, nonpadding=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if self.loss_type == 'l1': + if nonpadding is not None: + loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() + else: + # print('are you sure w/o nonpadding?') + loss = (noise - x_recon).abs().mean() + + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=infer) + cond = ret['decoder_inp'].transpose(1, 2) + if not infer: + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + nonpadding = (mel2ph != 0).float() + ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding) + else: + t = self.num_timesteps + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + ret['mel_out'] = self.denorm_spec(x) + + return ret + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + + def out2mel(self, x): + return x diff --git a/usr/diff/net.py b/usr/diff/net.py new file mode 100644 index 0000000000000000000000000000000000000000..b8811115eafb4f27165cf4d89c67c0d9455aac9d --- /dev/null +++ b/usr/diff/net.py @@ -0,0 +1,130 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from math import sqrt + +from .diffusion import Mish +from utils.hparams import hparams + +Linear = nn.Linear +ConvTranspose2d = nn.ConvTranspose2d + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + def override(self, attrs): + if isinstance(attrs, dict): + self.__dict__.update(**attrs) + elif isinstance(attrs, (list, tuple, set)): + for attr in attrs: + self.override(attr) + elif attrs is not None: + raise NotImplementedError + return self + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +@torch.jit.script +def silu(x): + return x * torch.sigmoid(x) + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / sqrt(2.0), skip + + +class DiffNet(nn.Module): + def __init__(self, in_dims=80): + super().__init__() + self.params = params = AttrDict( + # Model params + encoder_hidden=hparams['hidden_size'], + residual_layers=hparams['residual_layers'], + residual_channels=hparams['residual_channels'], + dilation_cycle_length=hparams['dilation_cycle_length'], + ) + self.input_projection = Conv1d(in_dims, params.residual_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels) + dim = params.residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length)) + for i in range(params.residual_layers) + ]) + self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1) + self.output_projection = Conv1d(params.residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + + :param spec: [B, 1, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec[:, 0] + x = self.input_projection(x) # x [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer_id, layer in enumerate(self.residual_layers): + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x[:, None, :, :] diff --git a/usr/diff/shallow_diffusion_tts.py b/usr/diff/shallow_diffusion_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..835c57efffae63df1a70165d8fb10e507070435a --- /dev/null +++ b/usr/diff/shallow_diffusion_tts.py @@ -0,0 +1,320 @@ +import math +import random +from collections import deque +from functools import partial +from inspect import isfunction +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from modules.diffsinger_midi.fs2 import FastSpeech2MIDI +from utils.hparams import hparams + + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, + timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + if hparams.get('use_midi') is not None and hparams['use_midi']: + self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims) + else: + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + if 'schedule_type' in hparams.keys(): + betas = beta_schedule[hparams['schedule_type']](timesteps) + else: + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.K_step = K_step + self.loss_type = loss_type + + self.noise_list = deque(maxlen=4) + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond, clip_denoised: bool): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False): + """ + Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778). + """ + + def get_x_pred(x, noise_t, t): + a_t = extract(self.alphas_cumprod, t, x.shape) + a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape) + a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt() + + x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t) + x_pred = x + x_delta + + return x_pred + + noise_list = self.noise_list + noise_pred = self.denoise_fn(x, t, cond=cond) + + if len(noise_list) == 0: + x_pred = get_x_pred(x, noise_pred, t) + noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond) + noise_pred_prime = (noise_pred + noise_pred_prev) / 2 + elif len(noise_list) == 1: + noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2 + elif len(noise_list) == 2: + noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12 + elif len(noise_list) >= 3: + noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24 + + x_prev = get_x_pred(x, noise_pred_prime, t) + noise_list.append(noise_pred) + + return x_prev + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, nonpadding=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if self.loss_type == 'l1': + if nonpadding is not None: + loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() + else: + # print('are you sure w/o nonpadding?') + loss = (noise - x_recon).abs().mean() + + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=(not infer), infer=infer, **kwargs) + cond = ret['decoder_inp'].transpose(1, 2) + + if not infer: + t = torch.randint(0, self.K_step, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + ret['diff_loss'] = self.p_losses(x, t, cond) + # nonpadding = (mel2ph != 0).float() + # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding) + else: + ret['fs2_mel'] = ret['mel_out'] + fs2_mels = ret['mel_out'] + t = self.K_step + fs2_mels = self.norm_spec(fs2_mels) + fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :] + + x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long()) + if hparams.get('gaussian_start') is not None and hparams['gaussian_start']: + print('===> gaussion start.') + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + + if hparams.get('pndm_speedup'): + self.noise_list = deque(maxlen=4) + iteration_interval = hparams['pndm_speedup'] + for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step', + total=t // iteration_interval): + x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval, + cond) + else: + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + if mel2ph is not None: # for singing + ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None]) + else: + ret['mel_out'] = self.denorm_spec(x) + return ret + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + + def out2mel(self, x): + return x + + +class OfflineGaussianDiffusion(GaussianDiffusion): + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs): + b, *_, device = *txt_tokens.shape, txt_tokens.device + + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=True, **kwargs) + cond = ret['decoder_inp'].transpose(1, 2) + fs2_mels = ref_mels[1] + ref_mels = ref_mels[0] + + if not infer: + t = torch.randint(0, self.K_step, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + ret['diff_loss'] = self.p_losses(x, t, cond) + else: + t = self.K_step + fs2_mels = self.norm_spec(fs2_mels) + fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :] + + x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long()) + + if hparams.get('gaussian_start') is not None and hparams['gaussian_start']: + print('===> gaussion start.') + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + ret['mel_out'] = self.denorm_spec(x) + return ret diff --git a/usr/diffsinger_task.py b/usr/diffsinger_task.py new file mode 100644 index 0000000000000000000000000000000000000000..aa500a7c410d5c86bdd7fa339f73dc49817ada3b --- /dev/null +++ b/usr/diffsinger_task.py @@ -0,0 +1,490 @@ +import torch + +import utils +from utils.hparams import hparams +from .diff.net import DiffNet +from .diff.shallow_diffusion_tts import GaussianDiffusion, OfflineGaussianDiffusion +from .diffspeech_task import DiffSpeechTask +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from modules.fastspeech.pe import PitchExtractor +from modules.fastspeech.fs2 import FastSpeech2 +from modules.diffsinger_midi.fs2 import FastSpeech2MIDI +from modules.fastspeech.tts_modules import mel2ph_to_dur + +from usr.diff.candidate_decoder import FFT +from utils.pitch_utils import denorm_f0 +from tasks.tts.fs2_utils import FastSpeechDataset +from tasks.tts.fs2 import FastSpeech2Task + +import numpy as np +import os +import torch.nn.functional as F + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), + 'fft': lambda hp: FFT( + hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']), +} + + +class DiffSingerTask(DiffSpeechTask): + def __init__(self): + super(DiffSingerTask, self).__init__() + self.dataset_cls = FastSpeechDataset + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + self.pe = PitchExtractor().cuda() + utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) + self.pe.eval() + + def build_tts_model(self): + # import torch + # from tqdm import tqdm + # v_min = torch.ones([80]) * 100 + # v_max = torch.ones([80]) * -100 + # for i, ds in enumerate(tqdm(self.dataset_cls('train'))): + # v_max = torch.max(torch.max(ds['mel'].reshape(-1, 80), 0)[0], v_max) + # v_min = torch.min(torch.min(ds['mel'].reshape(-1, 80), 0)[0], v_min) + # if i % 100 == 0: + # print(i, v_min, v_max) + # print('final', v_min, v_max) + mel_bins = hparams['audio_num_mel_bins'] + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + K_step=hparams['K_step'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + if hparams['fs2_ckpt'] != '': + utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) + # self.model.fs2.decoder = None + #for k, v in self.model.fs2.named_parameters(): + # v.requires_grad = False + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + target = sample['mels'] # [B, T_s, 80] + energy = sample['energy'] + # fs2_mel = sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) + + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel + pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel + else: + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + pred_f0 = model_out.get('f0_denorm') + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') + self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'], name=f'fs2mel_{batch_idx}') + return outputs + + +class ShallowDiffusionOfflineDataset(FastSpeechDataset): + def __getitem__(self, index): + sample = super(ShallowDiffusionOfflineDataset, self).__getitem__(index) + item = self._get_item(index) + + if self.prefix != 'train' and hparams['fs2_ckpt'] != '': + fs2_ckpt = os.path.dirname(hparams['fs2_ckpt']) + item_name = item['item_name'] + fs2_mel = torch.Tensor(np.load(f'{fs2_ckpt}/P_mels_npy/{item_name}.npy')) # ~M generated by FFT-singer. + sample['fs2_mel'] = fs2_mel + return sample + + def collater(self, samples): + batch = super(ShallowDiffusionOfflineDataset, self).collater(samples) + if self.prefix != 'train' and hparams['fs2_ckpt'] != '': + batch['fs2_mels'] = utils.collate_2d([s['fs2_mel'] for s in samples], 0.0) + return batch + + +class DiffSingerOfflineTask(DiffSingerTask): + def __init__(self): + super(DiffSingerOfflineTask, self).__init__() + self.dataset_cls = ShallowDiffusionOfflineDataset + + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + self.model = OfflineGaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + K_step=hparams['K_step'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + # if hparams['fs2_ckpt'] != '': + # utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) + # self.model.fs2.decoder = None + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + fs2_mel = None #sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=[target, fs2_mel], f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + # self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + # if hparams['use_pitch_embed']: + # self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + target = sample['mels'] # [B, T_s, 80] + energy = sample['energy'] + # fs2_mel = sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + fs2_mel = sample['fs2_mels'] + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, + ref_mels=[None, fs2_mel], infer=True) + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel + pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel + else: + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + pred_f0 = model_out.get('f0_denorm') + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') + self.plot_mel(batch_idx, sample['mels'], fs2_mel, name=f'fs2mel_{batch_idx}') + return outputs + + def test_step(self, sample, batch_idx): + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + txt_tokens = sample['txt_tokens'] + energy = sample['energy'] + if hparams['profile_infer']: + pass + else: + mel2ph, uv, f0 = None, None, None + if hparams['use_gt_dur']: + mel2ph = sample['mel2ph'] + if hparams['use_gt_f0']: + f0 = sample['f0'] + uv = sample['uv'] + fs2_mel = sample['fs2_mels'] + outputs = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=[None, fs2_mel], energy=energy, + infer=True) + sample['outputs'] = self.model.out2mel(outputs['mel_out']) + sample['mel2ph_pred'] = outputs['mel2ph'] + + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + sample['f0'] = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel + sample['f0_pred'] = self.pe(sample['outputs'])['f0_denorm_pred'] # pe predict from Pred mel + else: + sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams) + sample['f0_pred'] = outputs.get('f0_denorm') + return self.after_infer(sample) + + +class MIDIDataset(FastSpeechDataset): + def __getitem__(self, index): + sample = super(MIDIDataset, self).__getitem__(index) + item = self._get_item(index) + sample['f0_midi'] = torch.FloatTensor(item['f0_midi']) + sample['pitch_midi'] = torch.LongTensor(item['pitch_midi'])[:hparams['max_frames']] + + return sample + + def collater(self, samples): + batch = super(MIDIDataset, self).collater(samples) + batch['f0_midi'] = utils.collate_1d([s['f0_midi'] for s in samples], 0.0) + batch['pitch_midi'] = utils.collate_1d([s['pitch_midi'] for s in samples], 0) + # print((batch['pitch_midi'] == f0_to_coarse(batch['f0_midi'])).all()) + return batch + + +class M4SingerDataset(FastSpeechDataset): + def __getitem__(self, index): + sample = super(M4SingerDataset, self).__getitem__(index) + item = self._get_item(index) + sample['pitch_midi'] = torch.LongTensor(item['pitch_midi']) + sample['midi_dur'] = torch.FloatTensor(item['midi_dur']) + sample['is_slur'] = torch.LongTensor(item['is_slur']) + sample['word_boundary'] = torch.LongTensor(item['word_boundary']) + return sample + + def collater(self, samples): + batch = super(M4SingerDataset, self).collater(samples) + batch['pitch_midi'] = utils.collate_1d([s['pitch_midi'] for s in samples], 0) + batch['midi_dur'] = utils.collate_1d([s['midi_dur'] for s in samples], 0) + batch['is_slur'] = utils.collate_1d([s['is_slur'] for s in samples], 0) + batch['word_boundary'] = utils.collate_1d([s['word_boundary'] for s in samples], 0) + return batch + + +class DiffSingerMIDITask(DiffSingerTask): + def __init__(self): + super(DiffSingerMIDITask, self).__init__() + # self.dataset_cls = MIDIDataset + self.dataset_cls = M4SingerDataset + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s] + mel2ph = sample['mel2ph'] + if hparams.get('switch_midi2f0_step') is not None and self.global_step > hparams['switch_midi2f0_step']: + f0 = None + uv = None + else: + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer, pitch_midi=sample['pitch_midi'], + midi_dur=sample.get('midi_dur'), is_slur=sample.get('is_slur')) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, sample['word_boundary'], losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + target = sample['mels'] # [B, T_s, 80] + energy = sample['energy'] + # fs2_mel = sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + + outputs['losses'] = {} + + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx % 20 == 0 and batch_idx // 20 < hparams['num_valid_plots']: + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=None, uv=None, energy=energy, ref_mels=None, infer=True, + pitch_midi=sample['pitch_midi'], midi_dur=sample.get('midi_dur'), is_slur=sample.get('is_slur')) + + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + gt_f0 = self.pe(sample['mels'])['f0_denorm_pred'] # pe predict from GT mel + pred_f0 = self.pe(model_out['mel_out'])['f0_denorm_pred'] # pe predict from Pred mel + else: + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + pred_f0 = model_out.get('f0_denorm') + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=pred_f0) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out'], name=f'diffmel_{batch_idx}') + self.plot_mel(batch_idx, sample['mels'], model_out['fs2_mel'], name=f'fs2mel_{batch_idx}') + if hparams['use_pitch_embed']: + self.plot_pitch(batch_idx, sample, model_out) + return outputs + + def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, wdb, losses=None): + """ + :param dur_pred: [B, T], float, log scale + :param mel2ph: [B, T] + :param txt_tokens: [B, T] + :param losses: + :return: + """ + B, T = txt_tokens.shape + nonpadding = (txt_tokens != 0).float() + dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding + is_sil = torch.zeros_like(txt_tokens).bool() + for p in self.sil_ph: + is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) + is_sil = is_sil.float() # [B, T_txt] + + # phone duration loss + if hparams['dur_loss'] == 'mse': + losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') + losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() + dur_pred = (dur_pred.exp() - 1).clamp(min=0) + else: + raise NotImplementedError + + # use linear scale for sent and word duration + if hparams['lambda_word_dur'] > 0: + idx = F.pad(wdb.cumsum(axis=1), (1, 0))[:, :-1] + # word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_(1, idx, midi_dur) # midi_dur can be implied by add gt-ph_dur + word_dur_p = dur_pred.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_pred) + word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_gt) + wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') + word_nonpadding = (word_dur_g > 0).float() + wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() + losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] + if hparams['lambda_sent_dur'] > 0: + sent_dur_p = dur_pred.sum(-1) + sent_dur_g = dur_gt.sum(-1) + sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') + losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] + + +class AuxDecoderMIDITask(FastSpeech2Task): + def __init__(self): + super().__init__() + # self.dataset_cls = MIDIDataset + self.dataset_cls = M4SingerDataset + + def build_tts_model(self): + if hparams.get('use_midi') is not None and hparams['use_midi']: + self.model = FastSpeech2MIDI(self.phone_encoder) + else: + self.model = FastSpeech2(self.phone_encoder) + + def run_model(self, model, sample, return_output=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=False, pitch_midi=sample['pitch_midi'], + midi_dur=sample.get('midi_dur'), is_slur=sample.get('is_slur')) + + losses = {} + self.add_mel_loss(output['mel_out'], target, losses) + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, sample['word_boundary'], losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, wdb, losses=None): + """ + :param dur_pred: [B, T], float, log scale + :param mel2ph: [B, T] + :param txt_tokens: [B, T] + :param losses: + :return: + """ + B, T = txt_tokens.shape + nonpadding = (txt_tokens != 0).float() + dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding + is_sil = torch.zeros_like(txt_tokens).bool() + for p in self.sil_ph: + is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) + is_sil = is_sil.float() # [B, T_txt] + + # phone duration loss + if hparams['dur_loss'] == 'mse': + losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') + losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() + dur_pred = (dur_pred.exp() - 1).clamp(min=0) + else: + raise NotImplementedError + + # use linear scale for sent and word duration + if hparams['lambda_word_dur'] > 0: + idx = F.pad(wdb.cumsum(axis=1), (1, 0))[:, :-1] + # word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_(1, idx, midi_dur) # midi_dur can be implied by add gt-ph_dur + word_dur_p = dur_pred.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_pred) + word_dur_g = dur_gt.new_zeros([B, idx.max() + 1]).scatter_add(1, idx, dur_gt) + wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') + word_nonpadding = (word_dur_g > 0).float() + wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() + losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] + if hparams['lambda_sent_dur'] > 0: + sent_dur_p = dur_pred.sum(-1) + sent_dur_g = dur_gt.sum(-1) + sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') + losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + mel_out = self.model.out2mel(model_out['mel_out']) + outputs = utils.tensors_to_scalars(outputs) + # if sample['mels'].shape[0] == 1: + # self.add_laplace_var(mel_out, sample['mels'], outputs) + if batch_idx < hparams['num_valid_plots']: + self.plot_mel(batch_idx, sample['mels'], mel_out) + self.plot_dur(batch_idx, sample, model_out) + if hparams['use_pitch_embed']: + self.plot_pitch(batch_idx, sample, model_out) + return outputs \ No newline at end of file diff --git a/usr/diffspeech_task.py b/usr/diffspeech_task.py new file mode 100644 index 0000000000000000000000000000000000000000..05c313f94d07e91a94996a30bedd27b28c8cb04a --- /dev/null +++ b/usr/diffspeech_task.py @@ -0,0 +1,122 @@ +import torch + +import utils +from utils.hparams import hparams +from .diff.net import DiffNet +from .diff.shallow_diffusion_tts import GaussianDiffusion +from .task import DiffFsTask +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from utils.pitch_utils import denorm_f0 +from tasks.tts.fs2_utils import FastSpeechDataset + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class DiffSpeechTask(DiffFsTask): + def __init__(self): + super(DiffSpeechTask, self).__init__() + self.dataset_cls = FastSpeechDataset + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + K_step=hparams['K_step'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + if hparams['fs2_ckpt'] != '': + utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) + # self.model.fs2.decoder = None + for k, v in self.model.fs2.named_parameters(): + if not 'predictor' in k: + v.requires_grad = False + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + filter(lambda p: p.requires_grad, model.parameters()), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + weight_decay=hparams['weight_decay']) + return optimizer + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s] + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + # fs2_mel = sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + # model_out = self.model( + # txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, infer=True) + # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}') + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm')) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs + + ############ + # validation plots + ############ + def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): + gt_wav = gt_wav[0].cpu().numpy() + wav_out = wav_out[0].cpu().numpy() + gt_f0 = gt_f0[0].cpu().numpy() + f0 = f0[0].cpu().numpy() + if is_mel: + gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) + wav_out = self.vocoder.spec2wav(wav_out, f0=f0) + self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + diff --git a/usr/task.py b/usr/task.py new file mode 100644 index 0000000000000000000000000000000000000000..3e34db228938351dfe65f6b5da1318457b00be20 --- /dev/null +++ b/usr/task.py @@ -0,0 +1,84 @@ +import torch + +import utils +from .diff.diffusion import GaussianDiffusion +from .diff.net import DiffNet +from tasks.tts.fs2 import FastSpeech2Task +from utils.hparams import hparams + + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class DiffFsTask(FastSpeech2Task): + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def _training_step(self, sample, batch_idx, _): + log_outputs = self.run_model(self.model, sample) + total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + log_outputs['batch_size'] = sample['txt_tokens'].size()[0] + log_outputs['lr'] = self.scheduler.get_lr()[0] + return total_loss, log_outputs + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + _, model_out = self.run_model(self.model, sample, return_output=True, infer=True) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs + + def build_scheduler(self, optimizer): + return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx): + if optimizer is None: + return + optimizer.step() + optimizer.zero_grad() + if self.scheduler is not None: + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea5c5a67e038c2213247dfb905942882c090a77 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,250 @@ +import glob +import logging +import re +import time +from collections import defaultdict +import os +import sys +import shutil +import types +import numpy as np +import torch +import torch.nn.functional as F +import torch.distributed as dist +from torch import nn + + +def tensors_to_scalars(metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + if type(v) is dict: + v = tensors_to_scalars(v) + new_metrics[k] = v + return new_metrics + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + dst[0] = shift_id + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): + """Convert a list of 2d tensors into a padded 3d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +def batch_by_size( + indices, num_tokens_fn, max_tokens=None, max_sentences=None, + required_batch_size_multiple=1, distributed=False +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + max_tokens = max_tokens if max_tokens is not None else sys.maxsize + max_sentences = max_sentences if max_sentences is not None else sys.maxsize + bsz_mult = required_batch_size_multiple + + if isinstance(indices, types.GeneratorType): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + assert sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def unpack_dict_to_list(samples): + samples_ = [] + bsz = samples.get('outputs').size(0) + for i in range(bsz): + res = {} + for k, v in samples.items(): + try: + res[k] = v[i] + except: + pass + samples_.append(res) + return samples_ + + +def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True): + if os.path.isfile(ckpt_base_dir): + base_dir = os.path.dirname(ckpt_base_dir) + checkpoint_path = [ckpt_base_dir] + else: + base_dir = ckpt_base_dir + checkpoint_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0])) + if len(checkpoint_path) > 0: + checkpoint_path = checkpoint_path[-1] + state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"] + state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items() + if k.startswith(f'{prefix_in_ckpt}.')} + if not strict: + cur_model_state_dict = cur_model.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print("| Unmatched keys: ", key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + cur_model.load_state_dict(state_dict, strict=strict) + print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.") + else: + e_msg = f"| ckpt not found in {base_dir}." + if force: + assert False, e_msg + else: + print(e_msg) + + +def remove_padding(x, padding_idx=0): + if x is None: + return None + assert len(x.shape) in [1, 2] + if len(x.shape) == 2: # [T, H] + return x[np.abs(x).sum(-1) != padding_idx] + elif len(x.shape) == 1: # [T] + return x[x != padding_idx] + + +class Timer: + timer_map = {} + + def __init__(self, name, print_time=False): + if name not in Timer.timer_map: + Timer.timer_map[name] = 0 + self.name = name + self.print_time = print_time + + def __enter__(self): + self.t = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + Timer.timer_map[self.name] += time.time() - self.t + if self.print_time: + print(self.name, Timer.timer_map[self.name]) + + +def print_arch(model, model_name='model'): + print(f"| {model_name} Arch: ", model) + num_params(model, model_name=model_name) + + +def num_params(model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters diff --git a/utils/audio.py b/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..aba7ab926cf793d085bbdc70c97f376001183fe1 --- /dev/null +++ b/utils/audio.py @@ -0,0 +1,56 @@ +import subprocess +import matplotlib + +matplotlib.use('Agg') +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile + + +def save_wav(wav, path, sr, norm=False): + if norm: + wav = wav / np.abs(wav).max() + wav *= 32767 + # proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + + +def get_hop_size(hparams): + hop_size = hparams['hop_size'] + if hop_size is None: + assert hparams['frame_shift_ms'] is not None + hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate']) + return hop_size + + +########################################################################################### +def _stft(y, hparams): + return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams), + win_length=hparams['win_size'], pad_mode='constant') + + +def _istft(y, hparams): + return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size']) + + +def librosa_pad_lr(x, fsize, fshift, pad_sides=1): + '''compute right padding (final frame) or both sides padding (first and final frames) + ''' + assert pad_sides in (1, 2) + # return int(fsize // 2) + pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] + if pad_sides == 1: + return 0, pad + else: + return pad // 2, pad // 2 + pad % 2 + + +# Conversions +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def normalize(S, hparams): + return (S - hparams['min_level_db']) / -hparams['min_level_db'] diff --git a/utils/cwt.py b/utils/cwt.py new file mode 100644 index 0000000000000000000000000000000000000000..1a08461b9e422aac614438e6240b7355b8e4bb2c --- /dev/null +++ b/utils/cwt.py @@ -0,0 +1,146 @@ +import librosa +import numpy as np +from pycwt import wavelet +from scipy.interpolate import interp1d + + +def load_wav(wav_file, sr): + wav, _ = librosa.load(wav_file, sr=sr, mono=True) + return wav + + +def convert_continuos_f0(f0): + '''CONVERT F0 TO CONTINUOUS F0 + Args: + f0 (ndarray): original f0 sequence with the shape (T) + Return: + (ndarray): continuous f0 with the shape (T) + ''' + # get uv information as binary + f0 = np.copy(f0) + uv = np.float32(f0 != 0) + + # get start and end of f0 + if (f0 == 0).all(): + print("| all of the f0 values are 0.") + return uv, f0 + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + + # padding start and end of f0 sequence + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nz_frames = np.where(f0 != 0)[0] + + # perform linear interpolation + f = interp1d(nz_frames, f0[nz_frames]) + cont_f0 = f(np.arange(0, f0.shape[0])) + + return uv, cont_f0 + + +def get_cont_lf0(f0, frame_period=5.0): + uv, cont_f0_lpf = convert_continuos_f0(f0) + # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20) + cont_lf0_lpf = np.log(cont_f0_lpf) + return uv, cont_lf0_lpf + + +def get_lf0_cwt(lf0): + ''' + input: + signal of shape (N) + output: + Wavelet_lf0 of shape(10, N), scales of shape(10) + ''' + mother = wavelet.MexicanHat() + dt = 0.005 + dj = 1 + s0 = dt * 2 + J = 9 + + Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother) + # Wavelet.shape => (J + 1, len(lf0)) + Wavelet_lf0 = np.real(Wavelet_lf0).T + return Wavelet_lf0, scales + + +def norm_scale(Wavelet_lf0): + Wavelet_lf0_norm = np.zeros((Wavelet_lf0.shape[0], Wavelet_lf0.shape[1])) + mean = Wavelet_lf0.mean(0)[None, :] + std = Wavelet_lf0.std(0)[None, :] + Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std + return Wavelet_lf0_norm, mean, std + + +def normalize_cwt_lf0(f0, mean, std): + uv, cont_lf0_lpf = get_cont_lf0(f0) + cont_lf0_norm = (cont_lf0_lpf - mean) / std + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm) + Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0) + + return Wavelet_lf0_norm + + +def get_lf0_cwt_norm(f0s, mean, std): + uvs = list() + cont_lf0_lpfs = list() + cont_lf0_lpf_norms = list() + Wavelet_lf0s = list() + Wavelet_lf0s_norm = list() + scaless = list() + + means = list() + stds = list() + for f0 in f0s: + uv, cont_lf0_lpf = get_cont_lf0(f0) + cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std + + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) # [560,10] + Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0) # [560,10],[1,10],[1,10] + + Wavelet_lf0s_norm.append(Wavelet_lf0_norm) + uvs.append(uv) + cont_lf0_lpfs.append(cont_lf0_lpf) + cont_lf0_lpf_norms.append(cont_lf0_lpf_norm) + Wavelet_lf0s.append(Wavelet_lf0) + scaless.append(scales) + means.append(mean_scale) + stds.append(std_scale) + + return Wavelet_lf0s_norm, scaless, means, stds + + +def inverse_cwt_torch(Wavelet_lf0, scales): + import torch + b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5)) + lf0_rec = Wavelet_lf0 * b + lf0_rec_sum = lf0_rec.sum(-1) + lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True) + return lf0_rec_sum + + +def inverse_cwt(Wavelet_lf0, scales): + b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5)) + lf0_rec = Wavelet_lf0 * b + lf0_rec_sum = lf0_rec.sum(-1) + lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True) + return lf0_rec_sum + + +def cwt2f0(cwt_spec, mean, std, cwt_scales): + assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3 + import torch + if isinstance(cwt_spec, torch.Tensor): + f0 = inverse_cwt_torch(cwt_spec, cwt_scales) + f0 = f0 * std[:, None] + mean[:, None] + f0 = f0.exp() # [B, T] + else: + f0 = inverse_cwt(cwt_spec, cwt_scales) + f0 = f0 * std[:, None] + mean[:, None] + f0 = np.exp(f0) # [B, T] + return f0 diff --git a/utils/hparams.py b/utils/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..920cee4d730025580d696a2365a7f208e71eff63 --- /dev/null +++ b/utils/hparams.py @@ -0,0 +1,122 @@ +import argparse +import os +import yaml + +global_print_hparams = True +hparams = {} + + +class Args: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + self.__setattr__(k, v) + + +def override_config(old_config: dict, new_config: dict): + for k, v in new_config.items(): + if isinstance(v, dict) and k in old_config: + override_config(old_config[k], new_config[k]) + else: + old_config[k] = v + + +def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): + if config == '': + parser = argparse.ArgumentParser(description='neural music') + parser.add_argument('--config', type=str, default='', + help='location of the data corpus') + parser.add_argument('--exp_name', type=str, default='', help='exp_name') + parser.add_argument('--hparams', type=str, default='', + help='location of the data corpus') + parser.add_argument('--infer', action='store_true', help='infer') + parser.add_argument('--validate', action='store_true', help='validate') + parser.add_argument('--reset', action='store_true', help='reset hparams') + parser.add_argument('--debug', action='store_true', help='debug') + args, unknown = parser.parse_known_args() + else: + args = Args(config=config, exp_name=exp_name, hparams=hparams_str, + infer=False, validate=False, reset=False, debug=False) + args_work_dir = '' + if args.exp_name != '': + args.work_dir = args.exp_name + args_work_dir = f'checkpoints/{args.work_dir}' + + config_chains = [] + loaded_config = set() + + def load_config(config_fn): # deep first + with open(config_fn) as f: + hparams_ = yaml.safe_load(f) + loaded_config.add(config_fn) + if 'base_config' in hparams_: + ret_hparams = {} + if not isinstance(hparams_['base_config'], list): + hparams_['base_config'] = [hparams_['base_config']] + for c in hparams_['base_config']: + if c not in loaded_config: + if c.startswith('.'): + c = f'{os.path.dirname(config_fn)}/{c}' + c = os.path.normpath(c) + override_config(ret_hparams, load_config(c)) + override_config(ret_hparams, hparams_) + else: + ret_hparams = hparams_ + config_chains.append(config_fn) + return ret_hparams + + global hparams + assert args.config != '' or args_work_dir != '' + saved_hparams = {} + if args_work_dir != 'checkpoints/': + ckpt_config_path = f'{args_work_dir}/config.yaml' + if os.path.exists(ckpt_config_path): + try: + with open(ckpt_config_path) as f: + saved_hparams.update(yaml.safe_load(f)) + except: + pass + if args.config == '': + args.config = ckpt_config_path + + hparams_ = {} + + hparams_.update(load_config(args.config)) + + if not args.reset: + hparams_.update(saved_hparams) + hparams_['work_dir'] = args_work_dir + + if args.hparams != "": + for new_hparam in args.hparams.split(","): + k, v = new_hparam.split("=") + if v in ['True', 'False'] or type(hparams_[k]) == bool: + hparams_[k] = eval(v) + else: + hparams_[k] = type(hparams_[k])(v) + + if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: + os.makedirs(hparams_['work_dir'], exist_ok=True) + with open(ckpt_config_path, 'w') as f: + yaml.safe_dump(hparams_, f) + + hparams_['infer'] = args.infer + hparams_['debug'] = args.debug + hparams_['validate'] = args.validate + global global_print_hparams + if global_hparams: + hparams.clear() + hparams.update(hparams_) + + if print_hparams and global_print_hparams and global_hparams: + print('| Hparams chains: ', config_chains) + print('| Hparams: ') + for i, (k, v) in enumerate(sorted(hparams_.items())): + print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") + print("") + global_print_hparams = False + # print(hparams_.keys()) + if hparams.get('exp_name') is None: + hparams['exp_name'] = args.exp_name + if hparams_.get('exp_name') is None: + hparams_['exp_name'] = args.exp_name + return hparams_ diff --git a/utils/indexed_datasets.py b/utils/indexed_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e15632be30d6296a3c9aa80a1f351058003698b3 --- /dev/null +++ b/utils/indexed_datasets.py @@ -0,0 +1,71 @@ +import pickle +from copy import deepcopy + +import numpy as np + + +class IndexedDataset: + def __init__(self, path, num_cache=1): + super().__init__() + self.path = path + self.data_file = None + self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] + self.data_file = open(f"{path}.data", 'rb', buffering=-1) + self.cache = [] + self.num_cache = num_cache + + def check_index(self, i): + if i < 0 or i >= len(self.data_offsets) - 1: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + def __getitem__(self, i): + self.check_index(i) + if self.num_cache > 0: + for c in self.cache: + if c[0] == i: + return c[1] + self.data_file.seek(self.data_offsets[i]) + b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) + item = pickle.loads(b) + if self.num_cache > 0: + self.cache = [(i, deepcopy(item))] + self.cache[:-1] + return item + + def __len__(self): + return len(self.data_offsets) - 1 + +class IndexedDatasetBuilder: + def __init__(self, path): + self.path = path + self.out_file = open(f"{path}.data", 'wb') + self.byte_offsets = [0] + + def add_item(self, item): + s = pickle.dumps(item) + bytes = self.out_file.write(s) + self.byte_offsets.append(self.byte_offsets[-1] + bytes) + + def finalize(self): + self.out_file.close() + np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) + + +if __name__ == "__main__": + import random + from tqdm import tqdm + ds_path = '/tmp/indexed_ds_example' + size = 100 + items = [{"a": np.random.normal(size=[10000, 10]), + "b": np.random.normal(size=[10000, 10])} for i in range(size)] + builder = IndexedDatasetBuilder(ds_path) + for i in tqdm(range(size)): + builder.add_item(items[i]) + builder.finalize() + ds = IndexedDataset(ds_path) + for i in tqdm(range(10000)): + idx = random.randint(0, size - 1) + assert (ds[idx]['a'] == items[idx]['a']).all() diff --git a/utils/multiprocess_utils.py b/utils/multiprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..24876c4ca777f09d1c1e1b75674cd7aaf37a75a6 --- /dev/null +++ b/utils/multiprocess_utils.py @@ -0,0 +1,47 @@ +import os +import traceback +from multiprocessing import Queue, Process + + +def chunked_worker(worker_id, map_func, args, results_queue=None, init_ctx_func=None): + ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None + for job_idx, arg in args: + try: + if ctx is not None: + res = map_func(*arg, ctx=ctx) + else: + res = map_func(*arg) + results_queue.put((job_idx, res)) + except: + traceback.print_exc() + results_queue.put((job_idx, None)) + +def chunked_multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, q_max_size=1000): + args = zip(range(len(args)), args) + args = list(args) + n_jobs = len(args) + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + results_queues = [] + if ordered: + for i in range(num_workers): + results_queues.append(Queue(maxsize=q_max_size // num_workers)) + else: + results_queue = Queue(maxsize=q_max_size) + for i in range(num_workers): + results_queues.append(results_queue) + workers = [] + for i in range(num_workers): + args_worker = args[i::num_workers] + p = Process(target=chunked_worker, args=( + i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True) + workers.append(p) + p.start() + for n_finished in range(n_jobs): + results_queue = results_queues[n_finished % num_workers] + job_idx, res = results_queue.get() + assert job_idx == n_finished or not ordered, (job_idx, n_finished) + yield res + for w in workers: + w.join() + w.close() diff --git a/utils/pitch_utils.py b/utils/pitch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7fd166abd3a03bac5909e498669b482447435cf --- /dev/null +++ b/utils/pitch_utils.py @@ -0,0 +1,76 @@ +######### +# world +########## +import librosa +import numpy as np +import torch + +gamma = 0 +mcepInput = 3 # 0 for dB, 3 for magnitude +alpha = 0.45 +en_floor = 10 ** (-80 / 20) +FFT_SIZE = 2048 + + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) + return f0_coarse + + +def norm_f0(f0, uv, hparams): + is_torch = isinstance(f0, torch.Tensor) + if hparams['pitch_norm'] == 'standard': + f0 = (f0 - hparams['f0_mean']) / hparams['f0_std'] + if hparams['pitch_norm'] == 'log': + f0 = torch.log2(f0) if is_torch else np.log2(f0) + if uv is not None and hparams['use_uv']: + f0[uv > 0] = 0 + return f0 + + +def norm_interp_f0(f0, hparams): + is_torch = isinstance(f0, torch.Tensor) + if is_torch: + device = f0.device + f0 = f0.data.cpu().numpy() + uv = f0 == 0 + f0 = norm_f0(f0, uv, hparams) + if sum(uv) == len(f0): + f0[uv] = 0 + elif sum(uv) > 0: + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + uv = torch.FloatTensor(uv) + f0 = torch.FloatTensor(f0) + if is_torch: + f0 = f0.to(device) + return f0, uv + + +def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None): + if hparams['pitch_norm'] == 'standard': + f0 = f0 * hparams['f0_std'] + hparams['f0_mean'] + if hparams['pitch_norm'] == 'log': + f0 = 2 ** f0 + if min is not None: + f0 = f0.clamp(min=min) + if max is not None: + f0 = f0.clamp(max=max) + if uv is not None and hparams['use_uv']: + f0[uv > 0] = 0 + if pitch_padding is not None: + f0[pitch_padding] = 0 + return f0 diff --git a/utils/pl_utils.py b/utils/pl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..adab6edb140bcc29272f09b31dbfe1727dc5e68e --- /dev/null +++ b/utils/pl_utils.py @@ -0,0 +1,1619 @@ +import matplotlib +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel + +matplotlib.use('Agg') +import glob +import itertools +import subprocess +import threading +import traceback + +from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.callbacks import ModelCheckpoint + +from functools import wraps +from torch.cuda._utils import _get_device_index +import numpy as np +import torch.optim +import torch.utils.data +import copy +import logging +import os +import re +import sys +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm +from torch.optim.optimizer import Optimizer + + +def get_a_var(obj): # pragma: no cover + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, list) or isinstance(obj, tuple): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + + +def data_loader(fn): + """ + Decorator to make any fx with this use the lazy property + :param fn: + :return: + """ + + wraps(fn) + attr_name = '_lazy_' + fn.__name__ + + def _get_data_loader(self): + try: + value = getattr(self, attr_name) + except AttributeError: + try: + value = fn(self) # Lazy evaluation, done only once. + if ( + value is not None and + not isinstance(value, list) and + fn.__name__ in ['test_dataloader', 'val_dataloader'] + ): + value = [value] + except AttributeError as e: + # Guard against AttributeError suppression. (Issue #142) + traceback.print_exc() + error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) + raise RuntimeError(error) from e + setattr(self, attr_name, value) # Memoize evaluation. + return value + + return _get_data_loader + + +def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover + r"""Applies each `module` in :attr:`modules` in parallel on arguments + contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) + on each of :attr:`devices`. + + Args: + modules (Module): modules to be parallelized + inputs (tensor): inputs to the modules + devices (list of int or torch.device): CUDA devices + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + devices = list(map(lambda x: _get_device_index(x, True), devices)) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, kwargs, device=None): + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + + # --------------- + # CHANGE + if module.training: + output = module.training_step(*input, **kwargs) + + elif module.testing: + output = module.test_step(*input, **kwargs) + + else: + output = module.validation_step(*input, **kwargs) + # --------------- + + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + # make sure each module knows what training state it's in... + # fixes weird bug where copies are out of sync + root_m = modules[0] + for m in modules[1:]: + m.training = root_m.training + m.testing = root_m.testing + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in + enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs + + +def _find_tensors(obj): # pragma: no cover + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + def forward(self, *inputs, **kwargs): # pragma: no cover + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + # -------------- + # LIGHTNING MOD + # -------------- + # normal + # output = self.module(*inputs[0], **kwargs[0]) + # lightning + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + # normal + output = self.module(*inputs, **kwargs) + + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + return output + + +class DP(DataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + for t in itertools.chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + # lightning + if self.module.training: + return self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + return self.module.test_step(*inputs[0], **kwargs[0]) + else: + return self.module.validation_step(*inputs[0], **kwargs[0]) + + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + +class GradientAccumulationScheduler: + def __init__(self, scheduling: dict): + if scheduling == {}: # empty dict error + raise TypeError("Empty dict cannot be interpreted correct") + + for key in scheduling.keys(): + if not isinstance(key, int) or not isinstance(scheduling[key], int): + raise TypeError("All epoches and accumulation factor must be integers") + + minimal_epoch = min(scheduling.keys()) + if minimal_epoch < 1: + msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + raise IndexError(msg) + elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor + scheduling.update({1: 1}) + + self.scheduling = scheduling + self.epochs = sorted(scheduling.keys()) + + def on_epoch_begin(self, epoch, trainer): + epoch += 1 # indexing epochs from 1 + for i in reversed(range(len(self.epochs))): + if epoch >= self.epochs[i]: + trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + break + + +class LatestModelCheckpoint(ModelCheckpoint): + def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5, + save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True): + super(ModelCheckpoint, self).__init__() + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + os.makedirs(filepath, exist_ok=True) + self.num_ckpt_keep = num_ckpt_keep + self.save_best = save_best + self.save_weights_only = save_weights_only + self.period = period + self.epochs_since_last_check = 0 + self.prefix = prefix + self.best_k_models = {} + # {filename: monitor} + self.kth_best_model = '' + self.save_top_k = 1 + self.task = None + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + self.mode = 'min' + elif mode == 'max': + self.monitor_op = np.greater + self.best = -np.Inf + self.mode = 'max' + else: + if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): + self.monitor_op = np.greater + self.best = -np.Inf + self.mode = 'max' + else: + self.monitor_op = np.less + self.best = np.Inf + self.mode = 'min' + if os.path.exists(f'{self.filepath}/best_valid.npy'): + self.best = np.load(f'{self.filepath}/best_valid.npy')[0] + + def get_all_ckpts(self): + return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'), + key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + self.epochs_since_last_check += 1 + best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt' + if self.epochs_since_last_check >= self.period: + self.epochs_since_last_check = 0 + filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt' + if self.verbose > 0: + logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}') + self._save_model(filepath) + for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]: + subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True) + if self.verbose > 0: + logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') + current = logs.get(self.monitor) + if current is not None and self.save_best: + if self.monitor_op(current, self.best): + self.best = current + if self.verbose > 0: + logging.info( + f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached' + f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {best_filepath} as top 1') + self._save_model(best_filepath) + np.save(f'{self.filepath}/best_valid.npy', [self.best]) + + +class BaseTrainer: + def __init__( + self, + logger=True, + checkpoint_callback=True, + default_save_path=None, + gradient_clip_val=0, + process_position=0, + gpus=-1, + log_gpu_memory=None, + show_progress_bar=True, + track_grad_norm=-1, + check_val_every_n_epoch=1, + accumulate_grad_batches=1, + max_updates=1000, + min_epochs=1, + val_check_interval=1.0, + log_save_interval=100, + row_log_interval=10, + print_nan_grads=False, + weights_summary='full', + num_sanity_val_steps=5, + resume_from_checkpoint=None, + ): + self.log_gpu_memory = log_gpu_memory + self.gradient_clip_val = gradient_clip_val + self.check_val_every_n_epoch = check_val_every_n_epoch + self.track_grad_norm = track_grad_norm + self.on_gpu = True if (gpus and torch.cuda.is_available()) else False + self.process_position = process_position + self.weights_summary = weights_summary + self.max_updates = max_updates + self.min_epochs = min_epochs + self.num_sanity_val_steps = num_sanity_val_steps + self.print_nan_grads = print_nan_grads + self.resume_from_checkpoint = resume_from_checkpoint + self.default_save_path = default_save_path + + # training bookeeping + self.total_batch_idx = 0 + self.running_loss = [] + self.avg_loss = 0 + self.batch_idx = 0 + self.tqdm_metrics = {} + self.callback_metrics = {} + self.num_val_batches = 0 + self.num_training_batches = 0 + self.num_test_batches = 0 + self.get_train_dataloader = None + self.get_test_dataloaders = None + self.get_val_dataloaders = None + self.is_iterable_train_dataloader = False + + # training state + self.model = None + self.testing = False + self.disable_validation = False + self.lr_schedulers = [] + self.optimizers = None + self.global_step = 0 + self.current_epoch = 0 + self.total_batches = 0 + + # configure checkpoint callback + self.checkpoint_callback = checkpoint_callback + self.checkpoint_callback.save_function = self.save_checkpoint + self.weights_save_path = self.checkpoint_callback.filepath + + # accumulated grads + self.configure_accumulated_gradients(accumulate_grad_batches) + + # allow int, string and gpu list + self.data_parallel_device_ids = [ + int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] + if len(self.data_parallel_device_ids) == 0: + self.root_gpu = None + self.on_gpu = False + else: + self.root_gpu = self.data_parallel_device_ids[0] + self.on_gpu = True + + # distributed backend choice + self.use_ddp = False + self.use_dp = False + self.single_gpu = False + self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp' + self.set_distributed_mode(self.distributed_backend) + + self.proc_rank = 0 + self.world_size = 1 + self.node_rank = 0 + + # can't init progress bar here because starting a new process + # means the progress_bar won't survive pickling + self.show_progress_bar = show_progress_bar + + # logging + self.log_save_interval = log_save_interval + self.val_check_interval = val_check_interval + self.logger = logger + self.logger.rank = 0 + self.row_log_interval = row_log_interval + + @property + def num_gpus(self): + gpus = self.data_parallel_device_ids + if gpus is None: + return 0 + else: + return len(gpus) + + @property + def data_parallel(self): + return self.use_dp or self.use_ddp + + def get_model(self): + is_dp_module = isinstance(self.model, (DDP, DP)) + model = self.model.module if is_dp_module else self.model + return model + + # ----------------------------- + # MODEL TRAINING + # ----------------------------- + def fit(self, model): + if self.use_ddp: + mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) + else: + model.model = model.build_model() + if not self.testing: + self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + if self.use_dp: + model.cuda(self.root_gpu) + model = DP(model, device_ids=self.data_parallel_device_ids) + elif self.single_gpu: + model.cuda(self.root_gpu) + self.run_pretrain_routine(model) + return 1 + + def init_optimizers(self, optimizers): + + # single optimizer + if isinstance(optimizers, Optimizer): + return [optimizers], [] + + # two lists + elif len(optimizers) == 2 and isinstance(optimizers[0], list): + optimizers, lr_schedulers = optimizers + return optimizers, lr_schedulers + + # single list or tuple + elif isinstance(optimizers, list) or isinstance(optimizers, tuple): + return optimizers[0], [] + + def run_pretrain_routine(self, model): + """Sanity check a few things before starting actual training. + + :param model: + """ + ref_model = model + if self.data_parallel: + ref_model = model.module + + # give model convenience properties + ref_model.trainer = self + + # set local properties on the model + self.copy_trainer_model_properties(ref_model) + + # link up experiment object + if self.logger is not None: + ref_model.logger = self.logger + self.logger.save() + + if self.use_ddp: + dist.barrier() + + # set up checkpoint callback + # self.configure_checkpoint_callback() + + # transfer data loaders from model + self.get_dataloaders(ref_model) + + # track model now. + # if cluster resets state, the model will update with the saved weights + self.model = model + + # restore training and model before hpc call + self.restore_weights(model) + + # when testing requested only run test and return + if self.testing: + self.run_evaluation(test=True) + return + + # check if we should run validation during training + self.disable_validation = self.num_val_batches == 0 + + # run tiny validation (if validation defined) + # to make sure program won't crash during val + ref_model.on_sanity_check_start() + ref_model.on_train_start() + if not self.disable_validation and self.num_sanity_val_steps > 0: + # init progress bars for validation sanity check + pbar = tqdm.tqdm(desc='Validation sanity check', + total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), + leave=False, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') + self.main_progress_bar = pbar + # dummy validation progress bar + self.val_progress_bar = tqdm.tqdm(disable=True) + + self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) + + # close progress bars + self.main_progress_bar.close() + self.val_progress_bar.close() + + # init progress bar + pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', + file=sys.stdout) + self.main_progress_bar = pbar + + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + + # CORE TRAINING LOOP + self.train() + + def test(self, model): + self.testing = True + self.fit(model) + + @property + def training_tqdm_dict(self): + tqdm_dict = { + 'step': '{}'.format(self.global_step), + } + tqdm_dict.update(self.tqdm_metrics) + return tqdm_dict + + # -------------------- + # restore ckpt + # -------------------- + def restore_weights(self, model): + """ + To restore weights we have two cases. + First, attempt to restore hpc weights. If successful, don't restore + other weights. + + Otherwise, try to restore actual weights + :param model: + :return: + """ + # clear cache before restore + if self.on_gpu: + torch.cuda.empty_cache() + + if self.resume_from_checkpoint is not None: + self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) + else: + # restore weights if same exp version + self.restore_state_if_checkpoint_exists(model) + + # wait for all models to restore weights + if self.use_ddp: + # wait for all processes to catch up + dist.barrier() + + # clear cache after restore + if self.on_gpu: + torch.cuda.empty_cache() + + def restore_state_if_checkpoint_exists(self, model): + did_restore = False + + # do nothing if there's not dir or callback + no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback) + if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): + return did_restore + + # restore trainer state and model if there is a weight for this experiment + last_steps = -1 + last_ckpt_name = None + + # find last epoch + checkpoints = os.listdir(self.checkpoint_callback.filepath) + for name in checkpoints: + if '.ckpt' in name and not name.endswith('part'): + if 'steps_' in name: + steps = name.split('steps_')[1] + steps = int(re.sub('[^0-9]', '', steps)) + + if steps > last_steps: + last_steps = steps + last_ckpt_name = name + + # restore last checkpoint + if last_ckpt_name is not None: + last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) + self.restore(last_ckpt_path, self.on_gpu) + logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}') + did_restore = True + + return did_restore + + def restore(self, checkpoint_path, on_gpu): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # load model state + model = self.get_model() + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict'], strict=False) + if on_gpu: + model.cuda(self.root_gpu) + # load training state (affects trainer only) + self.restore_training_state(checkpoint) + model.global_step = self.global_step + del checkpoint + + try: + if dist.is_initialized() and dist.get_rank() > 0: + return + except Exception as e: + print(e) + return + + def restore_training_state(self, checkpoint): + """ + Restore trainer state. + Model will get its change to update + :param checkpoint: + :return: + """ + if self.checkpoint_callback is not None and self.checkpoint_callback is not False: + self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] + + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['epoch'] + + if self.testing: + return + + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + if optimizer is None: + return + optimizer.load_state_dict(opt_state) + + # move optimizer to GPU 1 weight at a time + # avoids OOM + if self.root_gpu is not None: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.root_gpu) + + # restore the lr schedulers + if 'lr_schedulers' in checkpoint: + lr_schedulers = checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): + scheduler.load_state_dict(lrs_state) + + # -------------------- + # MODEL SAVE CHECKPOINT + # -------------------- + def _atomic_save(self, checkpoint, filepath): + """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. + + This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once + saving is finished. + + Args: + checkpoint (object): The object to save. + Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` + accepts. + filepath (str|pathlib.Path): The path to which the checkpoint will be saved. + This points to the file that the checkpoint will be stored in. + """ + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + + def save_checkpoint(self, filepath): + checkpoint = self.dump_checkpoint() + self._atomic_save(checkpoint, filepath) + + def dump_checkpoint(self): + + checkpoint = { + 'epoch': self.current_epoch, + 'global_step': self.global_step + } + + if self.checkpoint_callback is not None and self.checkpoint_callback is not False: + checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best + + # save optimizers + optimizer_states = [] + for i, optimizer in enumerate(self.optimizers): + if optimizer is not None: + optimizer_states.append(optimizer.state_dict()) + + checkpoint['optimizer_states'] = optimizer_states + + # save lr schedulers + lr_schedulers = [] + for i, scheduler in enumerate(self.lr_schedulers): + lr_schedulers.append(scheduler.state_dict()) + + checkpoint['lr_schedulers'] = lr_schedulers + + # add the hparams and state_dict from the model + model = self.get_model() + checkpoint['state_dict'] = model.state_dict() + # give the model a chance to add a few things + model.on_save_checkpoint(checkpoint) + + return checkpoint + + def copy_trainer_model_properties(self, model): + if isinstance(model, DP): + ref_model = model.module + elif isinstance(model, DDP): + ref_model = model.module + else: + ref_model = model + + for m in [model, ref_model]: + m.trainer = self + m.on_gpu = self.on_gpu + m.use_dp = self.use_dp + m.use_ddp = self.use_ddp + m.testing = self.testing + m.single_gpu = self.single_gpu + + def transfer_batch_to_gpu(self, batch, gpu_id): + # base case: object can be directly moved using `cuda` or `to` + if callable(getattr(batch, 'cuda', None)): + return batch.cuda(gpu_id, non_blocking=True) + + elif callable(getattr(batch, 'to', None)): + return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + + # when list + elif isinstance(batch, list): + for i, x in enumerate(batch): + batch[i] = self.transfer_batch_to_gpu(x, gpu_id) + return batch + + # when tuple + elif isinstance(batch, tuple): + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = self.transfer_batch_to_gpu(x, gpu_id) + return tuple(batch) + + # when dict + elif isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = self.transfer_batch_to_gpu(v, gpu_id) + + return batch + + # nothing matches, return the value as is without transform + return batch + + def set_distributed_mode(self, distributed_backend): + # skip for CPU + if self.num_gpus == 0: + return + + # single GPU case + # in single gpu case we allow ddp so we can train on multiple + # nodes, 1 gpu per node + elif self.num_gpus == 1: + self.single_gpu = True + self.use_dp = False + self.use_ddp = False + self.root_gpu = 0 + self.data_parallel_device_ids = [0] + else: + if distributed_backend is not None: + self.use_dp = distributed_backend == 'dp' + self.use_ddp = distributed_backend == 'ddp' + elif distributed_backend is None: + self.use_dp = True + self.use_ddp = False + + logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}') + + def ddp_train(self, gpu_idx, model): + """ + Entry point into a DP thread + :param gpu_idx: + :param model: + :param cluster_obj: + :return: + """ + # otherwise default to node rank 0 + self.node_rank = 0 + + # show progressbar only on progress_rank 0 + self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0 + + # determine which process we are and world size + if self.use_ddp: + self.proc_rank = self.node_rank * self.num_gpus + gpu_idx + self.world_size = self.num_gpus + + # let the exp know the rank to avoid overwriting logs + if self.logger is not None: + self.logger.rank = self.proc_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self + model.init_ddp_connection(self.proc_rank, self.world_size) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + model.model = model.build_model() + if not self.testing: + self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + + # MODEL + # copy model to each gpu + if self.distributed_backend == 'ddp': + torch.cuda.set_device(gpu_idx) + model.cuda(gpu_idx) + + # set model properties before going into wrapper + self.copy_trainer_model_properties(model) + + # override root GPU + self.root_gpu = gpu_idx + + if self.distributed_backend == 'ddp': + device_ids = [gpu_idx] + else: + device_ids = None + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # continue training routine + self.run_pretrain_routine(model) + + def resolve_root_node_address(self, root_node): + if '[' in root_node: + name = root_node.split('[')[0] + number = root_node.split(',')[0] + if '-' in number: + number = number.split('-')[0] + + number = re.sub('[^0-9]', '', number) + root_node = name + number + + return root_node + + def log_metrics(self, metrics, grad_norm_dic, step=None): + """Logs the metric dict passed in. + + :param metrics: + :param grad_norm_dic: + """ + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + + # add norms + metrics.update(grad_norm_dic) + + # turn all tensors to scalars + scalar_metrics = self.metrics_to_scalars(metrics) + + step = step if step is not None else self.global_step + # log actual metrics + if self.proc_rank == 0 and self.logger is not None: + self.logger.log_metrics(scalar_metrics, step=step) + self.logger.save() + + def add_tqdm_metrics(self, metrics): + for k, v in metrics.items(): + if type(v) is torch.Tensor: + v = v.item() + + self.tqdm_metrics[k] = v + + def metrics_to_scalars(self, metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + if type(v) is dict: + v = self.metrics_to_scalars(v) + + new_metrics[k] = v + + return new_metrics + + def process_output(self, output, train=False): + """Reduces output according to the training mode. + + Separates loss from logging and tqdm metrics + :param output: + :return: + """ + # --------------- + # EXTRACT CALLBACK KEYS + # --------------- + # all keys not progress_bar or log are candidates for callbacks + callback_metrics = {} + for k, v in output.items(): + if k not in ['progress_bar', 'log', 'hiddens']: + callback_metrics[k] = v + + if train and self.use_dp: + num_gpus = self.num_gpus + callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) + + for k, v in callback_metrics.items(): + if isinstance(v, torch.Tensor): + callback_metrics[k] = v.item() + + # --------------- + # EXTRACT PROGRESS BAR KEYS + # --------------- + try: + progress_output = output['progress_bar'] + + # reduce progress metrics for tqdm when using dp + if train and self.use_dp: + num_gpus = self.num_gpus + progress_output = self.reduce_distributed_output(progress_output, num_gpus) + + progress_bar_metrics = progress_output + except Exception: + progress_bar_metrics = {} + + # --------------- + # EXTRACT LOGGING KEYS + # --------------- + # extract metrics to log to experiment + try: + log_output = output['log'] + + # reduce progress metrics for tqdm when using dp + if train and self.use_dp: + num_gpus = self.num_gpus + log_output = self.reduce_distributed_output(log_output, num_gpus) + + log_metrics = log_output + except Exception: + log_metrics = {} + + # --------------- + # EXTRACT LOSS + # --------------- + # if output dict doesn't have the keyword loss + # then assume the output=loss if scalar + loss = None + if train: + try: + loss = output['loss'] + except Exception: + if type(output) is torch.Tensor: + loss = output + else: + raise RuntimeError( + 'No `loss` value in the dictionary returned from `model.training_step()`.' + ) + + # when using dp need to reduce the loss + if self.use_dp: + loss = self.reduce_distributed_output(loss, self.num_gpus) + + # --------------- + # EXTRACT HIDDEN + # --------------- + hiddens = output.get('hiddens') + + # use every metric passed in as a candidate for callback + callback_metrics.update(progress_bar_metrics) + callback_metrics.update(log_metrics) + + # convert tensors to numpy + for k, v in callback_metrics.items(): + if isinstance(v, torch.Tensor): + callback_metrics[k] = v.item() + + return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens + + def reduce_distributed_output(self, output, num_gpus): + if num_gpus <= 1: + return output + + # when using DP, we get one output per gpu + # average outputs and return + if type(output) is torch.Tensor: + return output.mean() + + for k, v in output.items(): + # recurse on nested dics + if isinstance(output[k], dict): + output[k] = self.reduce_distributed_output(output[k], num_gpus) + + # do nothing when there's a scalar + elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: + pass + + # reduce only metrics that have the same number of gpus + elif output[k].size(0) == num_gpus: + reduced = torch.mean(output[k]) + output[k] = reduced + return output + + def clip_gradients(self): + if self.gradient_clip_val > 0: + model = self.get_model() + torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) + + def print_nan_gradients(self): + model = self.get_model() + for param in model.parameters(): + if (param.grad is not None) and torch.isnan(param.grad.float()).any(): + logging.info(param, param.grad) + + def configure_accumulated_gradients(self, accumulate_grad_batches): + self.accumulate_grad_batches = None + + if isinstance(accumulate_grad_batches, dict): + self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) + elif isinstance(accumulate_grad_batches, int): + schedule = {1: accumulate_grad_batches} + self.accumulation_scheduler = GradientAccumulationScheduler(schedule) + else: + raise TypeError("Gradient accumulation supports only int and dict types") + + def get_dataloaders(self, model): + if not self.testing: + self.init_train_dataloader(model) + self.init_val_dataloader(model) + else: + self.init_test_dataloader(model) + + if self.use_ddp: + dist.barrier() + if not self.testing: + self.get_train_dataloader() + self.get_val_dataloaders() + else: + self.get_test_dataloaders() + + def init_train_dataloader(self, model): + self.fisrt_epoch = True + self.get_train_dataloader = model.train_dataloader + if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader): + self.num_training_batches = len(self.get_train_dataloader()) + self.num_training_batches = int(self.num_training_batches) + else: + self.num_training_batches = float('inf') + self.is_iterable_train_dataloader = True + if isinstance(self.val_check_interval, int): + self.val_check_batch = self.val_check_interval + else: + self._percent_range_check('val_check_interval') + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) + self.val_check_batch = max(1, self.val_check_batch) + + def init_val_dataloader(self, model): + self.get_val_dataloaders = model.val_dataloader + self.num_val_batches = 0 + if self.get_val_dataloaders() is not None: + if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader): + self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) + self.num_val_batches = int(self.num_val_batches) + else: + self.num_val_batches = float('inf') + + def init_test_dataloader(self, model): + self.get_test_dataloaders = model.test_dataloader + if self.get_test_dataloaders() is not None: + if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader): + self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) + self.num_test_batches = int(self.num_test_batches) + else: + self.num_test_batches = float('inf') + + def evaluate(self, model, dataloaders, max_batches, test=False): + """Run evaluation code. + + :param model: PT model + :param dataloaders: list of PT dataloaders + :param max_batches: Scalar + :param test: boolean + :return: + """ + # enable eval mode + model.zero_grad() + model.eval() + + # copy properties for forward overrides + self.copy_trainer_model_properties(model) + + # disable gradients to save memory + torch.set_grad_enabled(False) + + if test: + self.get_model().test_start() + # bookkeeping + outputs = [] + + # run training + for dataloader_idx, dataloader in enumerate(dataloaders): + dl_outputs = [] + for batch_idx, batch in enumerate(dataloader): + + if batch is None: # pragma: no cover + continue + + # stop short when on fast_dev_run (sets max_batch=1) + if batch_idx >= max_batches: + break + + # ----------------- + # RUN EVALUATION STEP + # ----------------- + output = self.evaluation_forward(model, + batch, + batch_idx, + dataloader_idx, + test) + + # track outputs for collation + dl_outputs.append(output) + + # batch done + if test: + self.test_progress_bar.update(1) + else: + self.val_progress_bar.update(1) + outputs.append(dl_outputs) + + # with a single dataloader don't pass an array + if len(dataloaders) == 1: + outputs = outputs[0] + + # give model a chance to do something with the outputs (and method defined) + model = self.get_model() + if test: + eval_results_ = model.test_end(outputs) + else: + eval_results_ = model.validation_end(outputs) + eval_results = eval_results_ + + # enable train mode again + model.train() + + # enable gradients to save memory + torch.set_grad_enabled(True) + + return eval_results + + def run_evaluation(self, test=False): + # when testing make sure user defined a test step + model = self.get_model() + model.on_pre_performance_check() + + # select dataloaders + if test: + dataloaders = self.get_test_dataloaders() + max_batches = self.num_test_batches + else: + # val + dataloaders = self.get_val_dataloaders() + max_batches = self.num_val_batches + + # init validation or test progress bar + # main progress bar will already be closed when testing so initial position is free + position = 2 * self.process_position + (not test) + desc = 'Testing' if test else 'Validating' + pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch', file=sys.stdout) + setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) + + # run evaluation + eval_results = self.evaluate(self.model, + dataloaders, + max_batches, + test) + if eval_results is not None: + _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( + eval_results) + + # add metrics to prog bar + self.add_tqdm_metrics(prog_bar_metrics) + + # log metrics + self.log_metrics(log_metrics, {}) + + # track metrics for callbacks + self.callback_metrics.update(callback_metrics) + + # hook + model.on_post_performance_check() + + # add model specific metrics + tqdm_metrics = self.training_tqdm_dict + if not test: + self.main_progress_bar.set_postfix(**tqdm_metrics) + + # close progress bar + if test: + self.test_progress_bar.close() + else: + self.val_progress_bar.close() + + # model checkpointing + if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: + self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, + logs=self.callback_metrics) + + def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] + + if test and len(self.get_test_dataloaders()) > 1: + args.append(dataloader_idx) + + elif not test and len(self.get_val_dataloaders()) > 1: + args.append(dataloader_idx) + + # handle DP, DDP forward + if self.use_ddp or self.use_dp: + output = model(*args) + return output + + # single GPU + if self.single_gpu: + # for single GPU put inputs on gpu manually + root_gpu = 0 + if isinstance(self.data_parallel_device_ids, list): + root_gpu = self.data_parallel_device_ids[0] + batch = self.transfer_batch_to_gpu(batch, root_gpu) + args[0] = batch + + # CPU + if test: + output = model.test_step(*args) + else: + output = model.validation_step(*args) + + return output + + def train(self): + model = self.get_model() + # run all epochs + for epoch in range(self.current_epoch, 1000000): + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): + self.get_train_dataloader().sampler.set_epoch(epoch) + + # get model + model = self.get_model() + + # update training progress in trainer and model + model.current_epoch = epoch + self.current_epoch = epoch + + total_val_batches = 0 + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch + + # total batches includes multiple val checks + self.total_batches = self.num_training_batches + total_val_batches + self.batch_loss_value = 0 # accumulated grads + + if self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + num_iterations = None + else: + num_iterations = self.total_batches + + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) + + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin(epoch, self) + + # ----------------- + # RUN TNG EPOCH + # ----------------- + self.run_training_epoch() + + # update LR schedulers + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step(epoch=self.current_epoch) + + self.main_progress_bar.close() + + model.on_train_end() + + if self.logger is not None: + self.logger.finalize("success") + + def run_training_epoch(self): + # before epoch hook + if self.is_function_implemented('on_epoch_start'): + model = self.get_model() + model.on_epoch_start() + + # run epoch + for batch_idx, batch in enumerate(self.get_train_dataloader()): + # stop epoch if we limited the number of training batches + if batch_idx >= self.num_training_batches: + break + + self.batch_idx = batch_idx + + model = self.get_model() + model.global_step = self.global_step + + # --------------- + # RUN TRAIN STEP + # --------------- + output = self.run_training_batch(batch, batch_idx) + batch_result, grad_norm_dic, batch_step_metrics = output + + # when returning -1 from train_step, we end epoch early + early_stop_epoch = batch_result == -1 + + # --------------- + # RUN VAL STEP + # --------------- + should_check_val = ( + not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch) + self.fisrt_epoch = False + + if should_check_val: + self.run_evaluation(test=self.testing) + + # when logs should be saved + should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch + if should_save_log: + if self.proc_rank == 0 and self.logger is not None: + self.logger.save() + + # when metrics should be logged + should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch + if should_log_metrics: + # logs user requested information to logger + self.log_metrics(batch_step_metrics, grad_norm_dic) + + self.global_step += 1 + self.total_batch_idx += 1 + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if early_stop_epoch: + break + if self.global_step > self.max_updates: + print("| Training end..") + exit() + + # epoch end hook + if self.is_function_implemented('on_epoch_end'): + model = self.get_model() + model.on_epoch_end() + + def run_training_batch(self, batch, batch_idx): + # track grad norms + grad_norm_dic = {} + + # track all metrics for callbacks + all_callback_metrics = [] + + # track metrics to log + all_log_metrics = [] + + if batch is None: + return 0, grad_norm_dic, {} + + # hook + if self.is_function_implemented('on_batch_start'): + model_ref = self.get_model() + response = model_ref.on_batch_start(batch) + + if response == -1: + return -1, grad_norm_dic, {} + + splits = [batch] + self.hiddens = None + for split_idx, split_batch in enumerate(splits): + self.split_idx = split_idx + + # call training_step once per optimizer + for opt_idx, optimizer in enumerate(self.optimizers): + if optimizer is None: + continue + # make sure only the gradients of the current optimizer's paramaters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if len(self.optimizers) > 1: + for param in self.get_model().parameters(): + param.requires_grad = False + for group in optimizer.param_groups: + for param in group['params']: + param.requires_grad = True + + # wrap the forward step in a closure so second order methods work + def optimizer_closure(): + # forward pass + output = self.training_forward( + split_batch, batch_idx, opt_idx, self.hiddens) + + closure_loss = output[0] + progress_bar_metrics = output[1] + log_metrics = output[2] + callback_metrics = output[3] + self.hiddens = output[4] + if closure_loss is None: + return None + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = closure_loss / self.accumulate_grad_batches + + # backward pass + model_ref = self.get_model() + if closure_loss.requires_grad: + model_ref.backward(closure_loss, optimizer) + + # track metrics for callbacks + all_callback_metrics.append(callback_metrics) + + # track progress bar metrics + self.add_tqdm_metrics(progress_bar_metrics) + all_log_metrics.append(log_metrics) + + # insert after step hook + if self.is_function_implemented('on_after_backward'): + model_ref = self.get_model() + model_ref.on_after_backward() + + return closure_loss + + # calculate loss + loss = optimizer_closure() + if loss is None: + continue + + # nan grads + if self.print_nan_grads: + self.print_nan_gradients() + + # track total loss for logging (avoid mem leaks) + self.batch_loss_value += loss.item() + + # gradient update with accumulated gradients + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + + # track gradient norms when requested + if batch_idx % self.row_log_interval == 0: + if self.track_grad_norm > 0: + model = self.get_model() + grad_norm_dic = model.grad_norm( + self.track_grad_norm) + + # clip gradients + self.clip_gradients() + + # calls .step(), .zero_grad() + # override function to modify this behavior + model = self.get_model() + model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx) + + # calculate running loss for display + self.running_loss.append(self.batch_loss_value) + self.batch_loss_value = 0 + self.avg_loss = np.mean(self.running_loss[-100:]) + + # activate batch end hook + if self.is_function_implemented('on_batch_end'): + model = self.get_model() + model.on_batch_end() + + # update progress bar + self.main_progress_bar.update(1) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + + # collapse all metrics into one dict + all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + + # track all metrics for callbacks + self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) + + return 0, grad_norm_dic, all_log_metrics + + def training_forward(self, batch, batch_idx, opt_idx, hiddens): + """ + Handle forward for each training case (distributed, single gpu, etc...) + :param batch: + :param batch_idx: + :return: + """ + # --------------- + # FORWARD + # --------------- + # enable not needing to add opt_idx to training_step + args = [batch, batch_idx, opt_idx] + + # distributed forward + if self.use_ddp or self.use_dp: + output = self.model(*args) + # single GPU forward + elif self.single_gpu: + gpu_id = 0 + if isinstance(self.data_parallel_device_ids, list): + gpu_id = self.data_parallel_device_ids[0] + batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) + args[0] = batch + output = self.model.training_step(*args) + # CPU forward + else: + output = self.model.training_step(*args) + + # allow any mode to define training_end + model_ref = self.get_model() + output_ = model_ref.training_end(output) + if output_ is not None: + output = output_ + + # format and reduce outputs accordingly + output = self.process_output(output, train=True) + + return output + + # --------------- + # Utils + # --------------- + def is_function_implemented(self, f_name): + model = self.get_model() + f_op = getattr(model, f_name, None) + return callable(f_op) + + def _percent_range_check(self, name): + value = getattr(self, name) + msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." + if name == "val_check_interval": + msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." + + if not 0. <= value <= 1.: + raise ValueError(msg) diff --git a/utils/plot.py b/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..bdca62a8cd80869c707890cd9febd39966cd3658 --- /dev/null +++ b/utils/plot.py @@ -0,0 +1,56 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + +LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] + + +def spec_to_figure(spec, vmin=None, vmax=None): + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + fig = plt.figure(figsize=(12, 6)) + plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + return fig + + +def spec_f0_to_figure(spec, f0s, figsize=None): + max_y = spec.shape[1] + if isinstance(spec, torch.Tensor): + spec = spec.detach().cpu().numpy() + f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()} + f0s = {k: f0 / 10 for k, f0 in f0s.items()} + fig = plt.figure(figsize=(12, 6) if figsize is None else figsize) + plt.pcolor(spec.T) + for i, (k, f0) in enumerate(f0s.items()): + plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8) + plt.legend() + return fig + + +def dur_to_figure(dur_gt, dur_pred, txt): + dur_gt = dur_gt.long().cpu().numpy() + dur_pred = dur_pred.long().cpu().numpy() + dur_gt = np.cumsum(dur_gt) + dur_pred = np.cumsum(dur_pred) + fig = plt.figure(figsize=(12, 6)) + for i in range(len(dur_gt)): + shift = (i % 8) + 1 + plt.text(dur_gt[i], shift, txt[i]) + plt.text(dur_pred[i], 10 + shift, txt[i]) + plt.vlines(dur_gt[i], 0, 10, colors='b') # blue is gt + plt.vlines(dur_pred[i], 10, 20, colors='r') # red is pred + return fig + + +def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None): + fig = plt.figure() + f0_gt = f0_gt.cpu().numpy() + plt.plot(f0_gt, color='r', label='gt') + if f0_cwt is not None: + f0_cwt = f0_cwt.cpu().numpy() + plt.plot(f0_cwt, color='b', label='cwt') + if f0_pred is not None: + f0_pred = f0_pred.cpu().numpy() + plt.plot(f0_pred, color='green', label='pred') + plt.legend() + return fig diff --git a/utils/text_encoder.py b/utils/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e0758abc7b4e1f452481cba9715df08ceab543 --- /dev/null +++ b/utils/text_encoder.py @@ -0,0 +1,304 @@ +import re +import six +from six.moves import range # pylint: disable=redefined-builtin + +PAD = "" +EOS = "" +UNK = "" +SEG = "|" +RESERVED_TOKENS = [PAD, EOS, UNK] +NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 +UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 + +if six.PY2: + RESERVED_TOKENS_BYTES = RESERVED_TOKENS +else: + RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] + +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") +_ESCAPE_CHARS = set(u"\\_u;0123456789") + + +def strip_ids(ids, ids_to_strip): + """Strip ids_to_strip from the end ids.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids + + +class TextEncoder(object): + """Base class for converting from ints to/from human readable strings.""" + + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids + + @property + def num_reserved_ids(self): + return self._num_reserved_ids + + def encode(self, s): + """Transform a human-readable string into a sequence of int ids. + + The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, + num_reserved_ids) are reserved. + + EOS is not appended. + + Args: + s: human-readable string to be converted. + + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int ids into a human-readable string. + + EOS is not expected in ids. + + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens + (EOS and PAD). + + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + """Transform a sequence of int ids into a their string versions. + + This method supports transforming individual input/output ids to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. + + Args: + ids: list of integers to be converted. + + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] + + @property + def vocab_size(self): + raise NotImplementedError() + + +class ByteTextEncoder(TextEncoder): + """Encodes each byte to an id. For 8-bit strings only.""" + + def encode(self, s): + numres = self._num_reserved_ids + if six.PY2: + if isinstance(s, unicode): + s = s.encode("utf-8") + return [ord(c) + numres for c in s] + # Python3: explicitly convert to UTF-8 + return [c + numres for c in s.encode("utf-8")] + + def decode(self, ids, strip_extraneous=False): + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + if six.PY2: + return "".join(decoded_ids) + # Python3: join byte arrays and then decode string + return b"".join(decoded_ids).decode("utf-8", "replace") + + def decode_list(self, ids): + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return decoded_ids + + @property + def vocab_size(self): + return 2**8 + self._num_reserved_ids + + +class ByteTextEncoderWithEos(ByteTextEncoder): + """Encodes each byte to an id and appends the EOS token.""" + + def encode(self, s): + return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID] + + +class TokenTextEncoder(TextEncoder): + """Encoder based on a user-supplied vocabulary (file or list).""" + + def __init__(self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when + encoding will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) + self.pad_index = self._token_to_id[PAD] + self.eos_index = self._token_to_id[EOS] + self.unk_index = self._token_to_id[UNK] + self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index + + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + sentence = s + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [t if t in self._token_to_id else self._replace_oov + for t in tokens] + ret = [self._token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret + + def decode(self, ids, strip_eos=False, strip_padding=False): + if strip_padding and self.pad() in list(ids): + pad_pos = list(ids).index(self.pad()) + ids = ids[:pad_pos] + if strip_eos and self.eos() in list(ids): + eos_pos = list(ids).index(self.eos()) + ids = ids[:eos_pos] + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] + + @property + def vocab_size(self): + return len(self._id_to_token) + + def __len__(self): + return self.vocab_size + + def _safe_id_to_token(self, idx): + return self._id_to_token.get(idx, "ID_%d" % idx) + + def _init_vocab_from_file(self, filename): + """Load vocab from a file. + + Args: + filename: The file to load vocabulary from. + """ + with open(filename) as f: + tokens = [token.strip() for token in f.readlines()] + + def token_gen(): + for token in tokens: + yield token + + self._init_vocab(token_gen(), add_reserved_tokens=False) + + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + + Args: + vocab_list: A list of tokens. + """ + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" + + self._id_to_token = {} + non_reserved_start_index = 0 + + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) + + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index)) + + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict((v, k) + for k, v in six.iteritems(self._id_to_token)) + + def pad(self): + return self.pad_index + + def eos(self): + return self.eos_index + + def unk(self): + return self.unk_index + + def seg(self): + return self.seg_index + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with open(filename, "w") as f: + for i in range(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") + + def sil_phonemes(self): + return [p for p in self._id_to_token.values() if not p[0].isalpha()] diff --git a/utils/text_norm.py b/utils/text_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d0973cebc91e0525aeb6657e70012a1d37b5e6ff --- /dev/null +++ b/utils/text_norm.py @@ -0,0 +1,790 @@ +# coding=utf-8 +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import sys, os, argparse, codecs, string, re + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + +# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CHINESE_PUNC_STOP = '!?。。' +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit(power=index + 1, + simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit(power=index + 8, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit(power=(index + 2) * 4, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit(power=pow(2, index + 3), + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + else: + raise ValueError( + 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', lambda x, + y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * + pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, + traditional=False, alt_zero=False, alt_one=False, alt_two=True, + use_zeros=True, use_units=True): + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed( + system.units) if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +# ================================================================================ # +# NSW Normalizer +# ================================================================================ # +class NSWNormalizer: + def __init__(self, raw_text): + self.raw_text = '^' + raw_text + '$' + self.norm_text = '' + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + '2' + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self, remove_punc=True): + text = self.raw_text + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + self.norm_text = text + self._particular() + + text = self.norm_text.lstrip('^').rstrip('$') + if remove_punc: + # Punctuations removal + old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations + new_chars = ' ' * len(old_chars) + del_chars = '' + text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) + return text + + +def nsw_test_case(raw_text): + print('I:' + raw_text) + print('O:' + NSWNormalizer(raw_text).normalize()) + print('') + + +def nsw_test(): + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('手机:+86 19859213959或15659451527。') + nsw_test_case('分数:32477/76391。') + nsw_test_case('百分数:80.03%。') + nsw_test_case('编号:31520181154418。') + nsw_test_case('纯数:2983.07克或12345.60米。') + nsw_test_case('日期:1999年2月20日或09年3月15号。') + nsw_test_case('金钱:12块5,34.5元,20.1万') + nsw_test_case('特殊:O2O或B2C。') + nsw_test_case('3456万吨') + nsw_test_case('2938个') + nsw_test_case('938') + nsw_test_case('今天吃了115个小笼包231个馒头') + nsw_test_case('有62%的概率') + + +if __name__ == '__main__': + # nsw_test() + + p = argparse.ArgumentParser() + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + args = p.parse_args() + + ifile = codecs.open(args.ifile, 'r', 'utf8') + ofile = codecs.open(args.ofile, 'w+', 'utf8') + + n = 0 + for l in ifile: + key = '' + text = '' + if args.has_key: + cols = l.split(maxsplit=1) + key = cols[0] + if len(cols) == 2: + text = cols[1] + else: + text = '' + else: + text = l + + # cases + if args.to_upper and args.to_lower: + sys.stderr.write('text norm: to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # NSW(Non-Standard-Word) normalization + text = NSWNormalizer(text).normalize() + + # + if args.has_key: + ofile.write(key + '\t' + text) + else: + ofile.write(text) + + n += 1 + if n % args.log_interval == 0: + sys.stderr.write("text norm: {} lines done.\n".format(n)) + + sys.stderr.write("text norm: {} lines done in total.\n".format(n)) + + ifile.close() + ofile.close() diff --git a/utils/training_utils.py b/utils/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..409b15388790b1aadb24632313bdd1f41b4b06ac --- /dev/null +++ b/utils/training_utils.py @@ -0,0 +1,27 @@ +from utils.hparams import hparams + + +class RSQRTSchedule(object): + def __init__(self, optimizer): + super().__init__() + self.optimizer = optimizer + self.constant_lr = hparams['lr'] + self.warmup_updates = hparams['warmup_updates'] + self.hidden_size = hparams['hidden_size'] + self.lr = hparams['lr'] + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + warmup = min(num_updates / self.warmup_updates, 1.0) + rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 + rsqrt_hidden = self.hidden_size ** -0.5 + self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + def get_lr(self): + return self.optimizer.param_groups[0]['lr'] diff --git a/utils/tts_utils.py b/utils/tts_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5df7c177f934f9c07865eb4f5b053aff9620696d --- /dev/null +++ b/utils/tts_utils.py @@ -0,0 +1,24 @@ +import torch +import torch.nn.functional as F +from collections import defaultdict + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + +def fill_with_neg_inf2(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(-1e8).type_as(t) + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) diff --git a/vocoders/__init__.py b/vocoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..66c318857ce48048437dede7072901ad6471b8fc --- /dev/null +++ b/vocoders/__init__.py @@ -0,0 +1 @@ +from vocoders import hifigan diff --git a/vocoders/base_vocoder.py b/vocoders/base_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe49a9e4f790ecdc5e76d60a23f96602b59fc48d --- /dev/null +++ b/vocoders/base_vocoder.py @@ -0,0 +1,39 @@ +import importlib +VOCODERS = {} + + +def register_vocoder(cls): + VOCODERS[cls.__name__.lower()] = cls + VOCODERS[cls.__name__] = cls + return cls + + +def get_vocoder_cls(hparams): + if hparams['vocoder'] in VOCODERS: + return VOCODERS[hparams['vocoder']] + else: + vocoder_cls = hparams['vocoder'] + pkg = ".".join(vocoder_cls.split(".")[:-1]) + cls_name = vocoder_cls.split(".")[-1] + vocoder_cls = getattr(importlib.import_module(pkg), cls_name) + return vocoder_cls + + +class BaseVocoder: + def spec2wav(self, mel): + """ + + :param mel: [T, 80] + :return: wav: [T'] + """ + + raise NotImplementedError + + @staticmethod + def wav2spec(wav_fn): + """ + + :param wav_fn: str + :return: wav, mel: [T, 80] + """ + raise NotImplementedError diff --git a/vocoders/hifigan.py b/vocoders/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..810d3c931b556387f8a2e85537f4964add1e76b0 --- /dev/null +++ b/vocoders/hifigan.py @@ -0,0 +1,76 @@ +import glob +import json +import os +import re + +import librosa +import torch + +import utils +from modules.hifigan.hifigan import HifiGanGenerator +from utils.hparams import hparams, set_hparams +from vocoders.base_vocoder import register_vocoder +from vocoders.pwg import PWG +from vocoders.vocoder_utils import denoise + + +def load_model(config_path, checkpoint_path): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_dict = torch.load(checkpoint_path, map_location="cpu") + if '.yaml' in config_path: + config = set_hparams(config_path, global_hparams=False) + state = ckpt_dict["state_dict"]["model_gen"] + elif '.json' in config_path: + config = json.load(open(config_path, 'r')) + state = ckpt_dict["generator"] + + model = HifiGanGenerator(config) + model.load_state_dict(state, strict=True) + model.remove_weight_norm() + model = model.eval().to(device) + print(f"| Loaded model parameters from {checkpoint_path}.") + print(f"| HifiGAN device: {device}.") + return model, config, device + + +total_time = 0 + + +@register_vocoder +class HifiGAN(PWG): + def __init__(self): + base_dir = hparams['vocoder_ckpt'] + config_path = f'{base_dir}/config.yaml' + if os.path.exists(config_path): + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load HifiGAN: ', ckpt) + self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) + else: + config_path = f'{base_dir}/config.json' + ckpt = f'{base_dir}/generator_v1' + if os.path.exists(config_path): + self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) + + def spec2wav(self, mel, **kwargs): + device = self.device + with torch.no_grad(): + c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device) + with utils.Timer('hifigan', print_time=hparams['profile_infer']): + f0 = kwargs.get('f0') + if f0 is not None and hparams.get('use_nsf'): + f0 = torch.FloatTensor(f0[None, :]).to(device) + y = self.model(c, f0).view(-1) + else: + y = self.model(c).view(-1) + wav_out = y.cpu().numpy() + if hparams.get('vocoder_denoise_c', 0.0) > 0: + wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c']) + return wav_out + + # @staticmethod + # def wav2spec(wav_fn, **kwargs): + # wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate']) + # wav_torch = torch.FloatTensor(wav)[None, :] + # mel = mel_spectrogram(wav_torch, hparams).numpy()[0] + # return wav, mel.T diff --git a/vocoders/pwg.py b/vocoders/pwg.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9b6891ab2ba5cb413eeca97a41534e5db129d5 --- /dev/null +++ b/vocoders/pwg.py @@ -0,0 +1,137 @@ +import glob +import re +import librosa +import torch +import yaml +from sklearn.preprocessing import StandardScaler +from torch import nn +from modules.parallel_wavegan.models import ParallelWaveGANGenerator +from modules.parallel_wavegan.utils import read_hdf5 +from utils.hparams import hparams +from utils.pitch_utils import f0_to_coarse +from vocoders.base_vocoder import BaseVocoder, register_vocoder +import numpy as np + + +def load_pwg_model(config_path, checkpoint_path, stats_path): + # load config + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # setup + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = ParallelWaveGANGenerator(**config["generator_params"]) + + ckpt_dict = torch.load(checkpoint_path, map_location="cpu") + if 'state_dict' not in ckpt_dict: # official vocoder + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"]) + scaler = StandardScaler() + if config["format"] == "hdf5": + scaler.mean_ = read_hdf5(stats_path, "mean") + scaler.scale_ = read_hdf5(stats_path, "scale") + elif config["format"] == "npy": + scaler.mean_ = np.load(stats_path)[0] + scaler.scale_ = np.load(stats_path)[1] + else: + raise ValueError("support only hdf5 or npy format.") + else: # custom PWG vocoder + fake_task = nn.Module() + fake_task.model_gen = model + fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False) + scaler = None + + model.remove_weight_norm() + model = model.eval().to(device) + print(f"| Loaded model parameters from {checkpoint_path}.") + print(f"| PWG device: {device}.") + return model, scaler, config, device + + +@register_vocoder +class PWG(BaseVocoder): + def __init__(self): + if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model + base_dir = 'wavegan_pretrained' + ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl') + ckpt = sorted(ckpts, key= + lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1] + config_path = f'{base_dir}/config.yaml' + print('| load PWG: ', ckpt) + self.model, self.scaler, self.config, self.device = load_pwg_model( + config_path=config_path, + checkpoint_path=ckpt, + stats_path=f'{base_dir}/stats.h5', + ) + else: + base_dir = hparams['vocoder_ckpt'] + print(base_dir) + config_path = f'{base_dir}/config.yaml' + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load PWG: ', ckpt) + self.scaler = None + self.model, _, self.config, self.device = load_pwg_model( + config_path=config_path, + checkpoint_path=ckpt, + stats_path=f'{base_dir}/stats.h5', + ) + + def spec2wav(self, mel, **kwargs): + # start generation + config = self.config + device = self.device + pad_size = (config["generator_params"]["aux_context_window"], + config["generator_params"]["aux_context_window"]) + c = mel + if self.scaler is not None: + c = self.scaler.transform(c) + + with torch.no_grad(): + z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device) + c = np.pad(c, (pad_size, (0, 0)), "edge") + c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device) + p = kwargs.get('f0') + if p is not None: + p = f0_to_coarse(p) + p = np.pad(p, (pad_size,), "edge") + p = torch.LongTensor(p[None, :]).to(device) + y = self.model(z, c, p).view(-1) + wav_out = y.cpu().numpy() + return wav_out + + @staticmethod + def wav2spec(wav_fn, return_linear=False): + from data_gen.tts.data_gen_utils import process_utterance + res = process_utterance( + wav_fn, fft_size=hparams['fft_size'], + hop_size=hparams['hop_size'], + win_length=hparams['win_size'], + num_mels=hparams['audio_num_mel_bins'], + fmin=hparams['fmin'], + fmax=hparams['fmax'], + sample_rate=hparams['audio_sample_rate'], + loud_norm=hparams['loud_norm'], + min_level_db=hparams['min_level_db'], + return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10))) + if return_linear: + return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft] + else: + return res[0], res[1].T + + @staticmethod + def wav2mfcc(wav_fn): + fft_size = hparams['fft_size'] + hop_size = hparams['hop_size'] + win_length = hparams['win_size'] + sample_rate = hparams['audio_sample_rate'] + wav, _ = librosa.core.load(wav_fn, sr=sample_rate) + mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, + n_fft=fft_size, hop_length=hop_size, + win_length=win_length, pad_mode="constant", power=1.0) + mfcc_delta = librosa.feature.delta(mfcc, order=1) + mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) + mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T + return mfcc diff --git a/vocoders/vocoder_utils.py b/vocoders/vocoder_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..98792f79b1ccdf9a136e8b9526b8e2ea65a4c004 --- /dev/null +++ b/vocoders/vocoder_utils.py @@ -0,0 +1,15 @@ +import librosa + +from utils.hparams import hparams +import numpy as np + + +def denoise(wav, v=0): + spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'], + win_length=hparams['win_size'], pad_mode='constant') + spec_m = np.abs(spec) + spec_m = np.clip(spec_m - v, a_min=0, a_max=None) + spec_a = np.angle(spec) + + return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'], + win_length=hparams['win_size'])