diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..8693ea7bdbee695fd0c934b72ea9485f5ffe7147 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,32 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8ca5096f34244d1078a811f3c6d21d19eb9d3a44 --- /dev/null +++ b/.gitignore @@ -0,0 +1,151 @@ +### Project ignore + +/ParallelWaveGAN +/wavegan_pretrained* +/pretrained_models +rsync +.idea +.DS_Store +bak +tmp +*.tar.gz +# mfa and kaldi +kaldi_align/exp +mfa +montreal-forced-aligner +mos +nbs +/configs_usr/* +!/configs_usr/.gitkeep +/fast_transformers +/rnnoise +/usr/* +!/usr/.gitkeep + +# Created by .ignore support plugin (hsz.mobi) +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ +将删除 datasets/remi/test/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..563c90c56cc4bea1f10c6307123319d92441b556 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Jinglin Liu + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7db9b17d3c23482bd97236d4cd83de574d8f1561 --- /dev/null +++ b/README.md @@ -0,0 +1,10 @@ +--- +title: ProDiff +emoji: 🤗 +colorFrom: yellow +colorTo: orange +sdk: gradio +app_file: "inference/gradio/infer.py" +pinned: false +--- + diff --git a/checkpoints/FastDiff/config.yaml b/checkpoints/FastDiff/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6cf5325101f782ab41eff945a16ab23a32f73d65 --- /dev/null +++ b/checkpoints/FastDiff/config.yaml @@ -0,0 +1,149 @@ +N: '' +T: 1000 +accumulate_grad_batches: 1 +amp: false +audio_channels: 1 +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +aux_context_window: 0 +beta_0: 1.0e-06 +beta_T: 0.01 +binarization_args: + reset_phone_dict: true + reset_word_dict: true + shuffle: false + trim_eos_bos: false + with_align: false + with_f0: false + with_f0cwt: false + with_linear: false + with_spk_embed: false + with_spk_id: true + with_txt: false + with_wav: true + with_word: false +binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer +binary_data_dir: data/binary/LJSpeech +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +clip_grad_value: 0 +cond_channels: 80 +debug: false +dec_ffn_kernel_size: 9 +dec_layers: 4 +dict_dir: '' +diffusion_step_embed_dim_in: 128 +diffusion_step_embed_dim_mid: 512 +diffusion_step_embed_dim_out: 512 +disc_start_steps: 40000 +discriminator_grad_norm: 1 +dropout: 0.0 +ds_workers: 1 +enc_ffn_kernel_size: 9 +enc_layers: 4 +endless_ds: true +eval_max_batches: -1 +ffn_act: gelu +ffn_padding: SAME +fft_size: 1024 +fmax: 7600 +fmin: 80 +frames_multiple: 1 +gen_dir_name: '' +generator_grad_norm: 10 +griffin_lim_iters: 60 +hidden_size: 256 +hop_size: 256 +infer: false +inner_channels: 32 +kpnet_conv_size: 3 +kpnet_hidden_channels: 64 +load_ckpt: '' +loud_norm: false +lr: 2e-4 +lvc_kernel_size: 3 +lvc_layers_each_block: 4 +max_epochs: 1000 +max_frames: 1548 +max_input_tokens: 1550 +max_samples: 25600 +max_sentences: 20 +max_tokens: 30000 +max_updates: 1000000 +max_valid_sentences: 1 +max_valid_tokens: 60000 +mel_loss: l1 +mel_vmax: 1.5 +mel_vmin: -6 +mfa_version: 2 +min_frames: 0 +min_level_db: -100 +noise_schedule: '' +num_ckpt_keep: 3 +num_heads: 2 +num_mels: 80 +num_sanity_val_steps: -1 +num_spk: 400 +num_test_samples: 0 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pitch_extractor: parselmouth +pre_align_args: + allow_no_txt: false + denoise: false + nsample_per_mfa_group: 1000 + sox_resample: false + sox_to_wav: false + trim_sil: false + txt_processor: en + use_tone: true +pre_align_cls: egs.datasets.audio.pre_align.PreAlign +print_nan_grads: false +processed_data_dir: data/processed/LJSpeech +profile_infer: false +raw_data_dir: data/raw/LJSpeech-1.1 +ref_level_db: 20 +rename_tmux: true +resume_from_checkpoint: 0 +save_best: true +save_codes: [] +save_f0: false +save_gt: true +scheduler: rsqrt +seed: 1234 +sort_by_len: true +task_cls: modules.FastDiff.task.FastDiff.FastDiffTask +tb_log_interval: 100 +test_ids: [] +test_input_dir: '' +test_mel_dir: '' +test_num: 100 +test_set_name: test +train_set_name: train +train_sets: '' +upsample_ratios: +- 8 +- 8 +- 4 +use_pitch_embed: false +use_spk_embed: false +use_spk_id: false +use_split_spk_id: false +use_wav: true +use_weight_norm: true +use_word_input: false +val_check_interval: 2000 +valid_infer_interval: 10000 +valid_monitor_key: val_loss +valid_monitor_mode: min +valid_set_name: valid +vocoder_denoise_c: 0.0 +warmup_updates: 8000 +weight_decay: 0 +win_length: null +win_size: 1024 +window: hann +word_size: 30000 +work_dir: checkpoints/FastDiff diff --git a/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt b/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..226e9f776d4950cf1711a08420a51c0c1aa0c526 --- /dev/null +++ b/checkpoints/FastDiff/model_ckpt_steps_500000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ee7b6022e525c71a6025b41eeeafff9d6186b52cba76b580d6986bc8674902f3 +size 183951271 diff --git a/checkpoints/ProDiff/config.yaml b/checkpoints/ProDiff/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..27aa0559523fb9aa3fc0c92ef77a07ab9eece19f --- /dev/null +++ b/checkpoints/ProDiff/config.yaml @@ -0,0 +1,205 @@ +accumulate_grad_batches: 1 +amp: false +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +base_config: +- ./base.yaml +binarization_args: + reset_phone_dict: true + reset_word_dict: true + shuffle: false + trim_eos_bos: false + trim_sil: false + with_align: true + with_f0: true + with_f0cwt: false + with_linear: false + with_spk_embed: false + with_spk_id: true + with_txt: true + with_wav: false + with_word: true +binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer +binary_data_dir: data/binary/LJSpeech +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +clip_grad_value: 0 +conv_use_pos: false +cwt_add_f0_loss: false +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_std_scale: 0.8 +debug: false +dec_dilations: +- 1 +- 1 +- 1 +- 1 +dec_ffn_kernel_size: 9 +dec_inp_add_noise: false +dec_kernel_size: 5 +dec_layers: 4 +dec_num_heads: 2 +decoder_rnn_dim: 0 +decoder_type: fft +dict_dir: '' +diff_decoder_type: wavenet +diff_loss_type: l1 +dilation_cycle_length: 1 +dropout: 0.1 +ds_workers: 2 +dur_enc_hidden_stride_kernel: +- 0,2,3 +- 0,2,3 +- 0,1,3 +dur_loss: mse +dur_predictor_kernel: 3 +dur_predictor_layers: 2 +enc_dec_norm: ln +enc_dilations: +- 1 +- 1 +- 1 +- 1 +enc_ffn_kernel_size: 9 +enc_kernel_size: 5 +enc_layers: 4 +encoder_K: 8 +encoder_type: fft +endless_ds: true +ffn_act: gelu +ffn_hidden_size: 1024 +ffn_padding: SAME +fft_size: 1024 +fmax: 7600 +fmin: 80 +frames_multiple: 1 +gen_dir_name: '' +gen_tgt_spk_id: -1 +griffin_lim_iters: 60 +hidden_size: 256 +hop_size: 256 +infer: false +keep_bins: 80 +lambda_commit: 0.25 +lambda_energy: 0.1 +lambda_f0: 1.0 +lambda_ph_dur: 0.1 +lambda_sent_dur: 1.0 +lambda_uv: 1.0 +lambda_word_dur: 1.0 +layers_in_block: 2 +load_ckpt: '' +loud_norm: false +lr: 1.0 +max_beta: 0.06 +max_epochs: 1000 +max_frames: 1548 +max_input_tokens: 1550 +max_sentences: 48 +max_tokens: 32000 +max_updates: 200000 +max_valid_sentences: 1 +max_valid_tokens: 60000 +mel_loss: ssim:0.5|l1:0.5 +mel_vmax: 1.5 +mel_vmin: -6 +min_frames: 0 +min_level_db: -100 +num_ckpt_keep: 3 +num_heads: 2 +num_sanity_val_steps: -1 +num_spk: 1 +num_test_samples: 0 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pitch_ar: false +pitch_embed_type: 0 +pitch_enc_hidden_stride_kernel: +- 0,2,5 +- 0,2,5 +- 0,2,5 +pitch_extractor: parselmouth +pitch_loss: l1 +pitch_norm: standard +pitch_ssim_win: 11 +pitch_type: frame +pre_align_args: + allow_no_txt: false + denoise: false + sox_resample: false + sox_to_wav: false + trim_sil: false + txt_processor: en + use_tone: true +pre_align_cls: '' +predictor_dropout: 0.5 +predictor_grad: 0.1 +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 2 +pretrain_fs_ckpt: '' +print_nan_grads: false +processed_data_dir: data/processed/LJSpeech +profile_infer: false +raw_data_dir: data/raw/LJSpeech +ref_hidden_stride_kernel: +- 0,3,5 +- 0,3,5 +- 0,2,5 +- 0,2,5 +- 0,2,5 +ref_level_db: 20 +ref_norm_layer: bn +rename_tmux: true +residual_channels: 256 +residual_layers: 20 +resume_from_checkpoint: 0 +save_best: true +save_codes: [] +save_f0: false +save_gt: true +schedule_type: vpsde +scheduler: rsqrt +seed: 1234 +sil_add_noise: false +sort_by_len: true +spec_max: [] +spec_min: [] +task_cls: modules.ProDiff.task.ProDiff_task.ProDiff_Task +tb_log_interval: 100 +teacher_ckpt: checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt +test_ids: [] +test_input_dir: '' +test_num: 100 +test_set_name: test +timesteps: 4 +train_set_name: train +train_sets: '' +use_cond_disc: true +use_energy_embed: true +use_gt_dur: true +use_gt_f0: true +use_pitch_embed: true +use_pos_embed: true +use_ref_enc: false +use_spk_embed: false +use_spk_id: false +use_split_spk_id: false +use_uv: true +use_var_enc: false +val_check_interval: 2000 +valid_infer_interval: 10000 +valid_monitor_key: val_loss +valid_monitor_mode: min +valid_set_name: valid +var_enc_vq_codes: 64 +vocoder_denoise_c: 0.0 +warmup_updates: 2000 +weight_decay: 0 +win_size: 1024 +word_size: 30000 +work_dir: checkpoints/ProDiff diff --git a/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt b/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..2e8e845b8cd31da66a969c32db251f7bd55af470 --- /dev/null +++ b/checkpoints/ProDiff/model_ckpt_steps_200000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8cc8aad355c297b010e2c362341f736b3477744af76e02f6c9965409a7e9113a +size 349055740 diff --git a/checkpoints/ProDiff_Teacher/config.yaml b/checkpoints/ProDiff_Teacher/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd6983c1d8889954b0a1dfdf34c66f72a4407bbf --- /dev/null +++ b/checkpoints/ProDiff_Teacher/config.yaml @@ -0,0 +1,205 @@ +accumulate_grad_batches: 1 +amp: false +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +base_config: +- ./base.yaml +binarization_args: + reset_phone_dict: true + reset_word_dict: true + shuffle: false + trim_eos_bos: false + trim_sil: false + with_align: true + with_f0: true + with_f0cwt: false + with_linear: false + with_spk_embed: false + with_spk_id: true + with_txt: true + with_wav: false + with_word: true +binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer +binary_data_dir: data/binary/LJSpeech +check_val_every_n_epoch: 10 +clip_grad_norm: 1 +clip_grad_value: 0 +conv_use_pos: false +cwt_add_f0_loss: false +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_std_scale: 0.8 +debug: false +dec_dilations: +- 1 +- 1 +- 1 +- 1 +dec_ffn_kernel_size: 9 +dec_inp_add_noise: false +dec_kernel_size: 5 +dec_layers: 4 +dec_num_heads: 2 +decoder_rnn_dim: 0 +decoder_type: fft +dict_dir: '' +diff_decoder_type: wavenet +diff_loss_type: l1 +dilation_cycle_length: 1 +dropout: 0.1 +ds_workers: 2 +dur_enc_hidden_stride_kernel: +- 0,2,3 +- 0,2,3 +- 0,1,3 +dur_loss: mse +dur_predictor_kernel: 3 +dur_predictor_layers: 2 +enc_dec_norm: ln +enc_dilations: +- 1 +- 1 +- 1 +- 1 +enc_ffn_kernel_size: 9 +enc_kernel_size: 5 +enc_layers: 4 +encoder_K: 8 +encoder_type: fft +endless_ds: true +ffn_act: gelu +ffn_hidden_size: 1024 +ffn_padding: SAME +fft_size: 1024 +fmax: 7600 +fmin: 80 +frames_multiple: 1 +gen_dir_name: '' +gen_tgt_spk_id: -1 +griffin_lim_iters: 60 +hidden_size: 256 +hop_size: 256 +infer: false +keep_bins: 80 +lambda_commit: 0.25 +lambda_energy: 0.1 +lambda_f0: 1.0 +lambda_ph_dur: 0.1 +lambda_sent_dur: 1.0 +lambda_uv: 1.0 +lambda_word_dur: 1.0 +layers_in_block: 2 +load_ckpt: '' +loud_norm: false +lr: 1.0 +max_beta: 0.06 +max_epochs: 1000 +max_frames: 1548 +max_input_tokens: 1550 +max_sentences: 48 +max_tokens: 32000 +max_updates: 200000 +max_valid_sentences: 1 +max_valid_tokens: 60000 +mel_loss: ssim:0.5|l1:0.5 +mel_vmax: 1.5 +mel_vmin: -6 +min_frames: 0 +min_level_db: -100 +num_ckpt_keep: 3 +num_heads: 2 +num_sanity_val_steps: -1 +num_spk: 1 +num_test_samples: 20 +num_valid_plots: 10 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +out_wav_norm: false +pitch_ar: false +pitch_embed_type: 0 +pitch_enc_hidden_stride_kernel: +- 0,2,5 +- 0,2,5 +- 0,2,5 +pitch_extractor: parselmouth +pitch_loss: l1 +pitch_norm: standard +pitch_ssim_win: 11 +pitch_type: frame +pre_align_args: + allow_no_txt: false + denoise: false + sox_resample: false + sox_to_wav: false + trim_sil: false + txt_processor: en + use_tone: true +pre_align_cls: egs.datasets.audio.lj.pre_align.LJPreAlign +predictor_dropout: 0.5 +predictor_grad: 0.1 +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 2 +pretrain_fs_ckpt: '' +print_nan_grads: false +processed_data_dir: data/processed/LJSpeech +profile_infer: false +raw_data_dir: data/raw/LJSpeech +ref_hidden_stride_kernel: +- 0,3,5 +- 0,3,5 +- 0,2,5 +- 0,2,5 +- 0,2,5 +ref_level_db: 20 +ref_norm_layer: bn +rename_tmux: true +residual_channels: 256 +residual_layers: 20 +resume_from_checkpoint: 0 +save_best: true +save_codes: [] +save_f0: false +save_gt: true +schedule_type: vpsde +scheduler: rsqrt +seed: 1234 +sil_add_noise: false +sort_by_len: true +spec_max: [] +spec_min: [] +task_cls: modules.ProDiff.task.ProDiff_teacher_task.ProDiff_teacher_Task +tb_log_interval: 100 +test_ids: [] +test_input_dir: '' +test_num: 100 +test_set_name: test +timescale: 1 +timesteps: 4 +train_set_name: train +train_sets: '' +use_cond_disc: true +use_energy_embed: true +use_gt_dur: true +use_gt_f0: true +use_pitch_embed: true +use_pos_embed: true +use_ref_enc: false +use_spk_embed: false +use_spk_id: false +use_split_spk_id: false +use_uv: true +use_var_enc: false +val_check_interval: 2000 +valid_infer_interval: 10000 +valid_monitor_key: val_loss +valid_monitor_mode: min +valid_set_name: valid +var_enc_vq_codes: 64 +vocoder_denoise_c: 0.0 +warmup_updates: 2000 +weight_decay: 0 +win_size: 1024 +word_size: 30000 +work_dir: checkpoints/ProDiff_Teacher1 diff --git a/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt b/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..a3ca44b2160a0ff0ade91e6e52a6c06b87f0d1f7 --- /dev/null +++ b/checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d3d02a215431c69dd54c1413b9a02cdc32795e2039ad9be857b12e85c470eea +size 342252871 diff --git a/data/binary/LJSpeech/phone_set.json b/data/binary/LJSpeech/phone_set.json new file mode 100644 index 0000000000000000000000000000000000000000..1d097037b62382bae4893512cc291d8699c0b049 --- /dev/null +++ b/data/binary/LJSpeech/phone_set.json @@ -0,0 +1 @@ +["!", ",", ".", ":", ";", "", "", "?", "AA0", "AA1", "AA2", "AE0", "AE1", "AE2", "AH0", "AH1", "AH2", "AO0", "AO1", "AO2", "AW0", "AW1", "AW2", "AY0", "AY1", "AY2", "B", "CH", "D", "DH", "EH0", "EH1", "EH2", "ER0", "ER1", "ER2", "EY0", "EY1", "EY2", "F", "G", "HH", "IH0", "IH1", "IH2", "IY0", "IY1", "IY2", "JH", "K", "L", "M", "N", "NG", "OW0", "OW1", "OW2", "OY0", "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH0", "UH1", "UH2", "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH", "|"] \ No newline at end of file diff --git a/data/binary/LJSpeech/spk_map.json b/data/binary/LJSpeech/spk_map.json new file mode 100644 index 0000000000000000000000000000000000000000..15bba8f120494a14ecd1308c6f534ec7e4322391 --- /dev/null +++ b/data/binary/LJSpeech/spk_map.json @@ -0,0 +1 @@ +{"SPK1": 0} \ No newline at end of file diff --git a/data/binary/LJSpeech/train_f0s_mean_std.npy b/data/binary/LJSpeech/train_f0s_mean_std.npy new file mode 100644 index 0000000000000000000000000000000000000000..42b6fc952934d7e7aedc61c7975ef437c4981d08 --- /dev/null +++ b/data/binary/LJSpeech/train_f0s_mean_std.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8790d5a84d77143690ae71a1f1e7fc81359e69ead263dc440366f2164c739efd +size 144 diff --git a/data_gen/tts/base_binarizer.py b/data_gen/tts/base_binarizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b30a20c1cdc3403214ff527d68a50806befafeb9 --- /dev/null +++ b/data_gen/tts/base_binarizer.py @@ -0,0 +1,224 @@ +import os +os.environ["OMP_NUM_THREADS"] = "1" + +from utils.multiprocess_utils import chunked_multiprocess_run +import random +import traceback +import json +from resemblyzer import VoiceEncoder +from tqdm import tqdm +from data_gen.tts.data_gen_utils import get_mel2ph, get_pitch, build_phone_encoder +from utils.hparams import set_hparams, hparams +import numpy as np +from utils.indexed_datasets import IndexedDatasetBuilder +from vocoders.base_vocoder import VOCODERS +import pandas as pd + + +class BinarizationError(Exception): + pass + + +class BaseBinarizer: + def __init__(self, processed_data_dir=None): + if processed_data_dir is None: + processed_data_dir = hparams['processed_data_dir'] + self.processed_data_dirs = processed_data_dir.split(",") + self.binarization_args = hparams['binarization_args'] + self.pre_align_args = hparams['pre_align_args'] + self.forced_align = self.pre_align_args['forced_align'] + tg_dir = None + if self.forced_align == 'mfa': + tg_dir = 'mfa_outputs' + if self.forced_align == 'kaldi': + tg_dir = 'kaldi_outputs' + self.item2txt = {} + self.item2ph = {} + self.item2wavfn = {} + self.item2tgfn = {} + self.item2spk = {} + for ds_id, processed_data_dir in enumerate(self.processed_data_dirs): + self.meta_df = pd.read_csv(f"{processed_data_dir}/metadata_phone.csv", dtype=str) + for r_idx, r in self.meta_df.iterrows(): + item_name = raw_item_name = r['item_name'] + if len(self.processed_data_dirs) > 1: + item_name = f'ds{ds_id}_{item_name}' + self.item2txt[item_name] = r['txt'] + self.item2ph[item_name] = r['ph'] + self.item2wavfn[item_name] = os.path.join(hparams['raw_data_dir'], 'wavs', os.path.basename(r['wav_fn']).split('_')[1]) + self.item2spk[item_name] = r.get('spk', 'SPK1') + if len(self.processed_data_dirs) > 1: + self.item2spk[item_name] = f"ds{ds_id}_{self.item2spk[item_name]}" + if tg_dir is not None: + self.item2tgfn[item_name] = f"{processed_data_dir}/{tg_dir}/{raw_item_name}.TextGrid" + self.item_names = sorted(list(self.item2txt.keys())) + if self.binarization_args['shuffle']: + random.seed(1234) + random.shuffle(self.item_names) + + @property + def train_item_names(self): + return self.item_names[hparams['test_num']+hparams['valid_num']:] + + @property + def valid_item_names(self): + return self.item_names[0: hparams['test_num']+hparams['valid_num']] # + + @property + def test_item_names(self): + return self.item_names[0: hparams['test_num']] # Audios for MOS testing are in 'test_ids' + + def build_spk_map(self): + spk_map = set() + for item_name in self.item_names: + spk_name = self.item2spk[item_name] + spk_map.add(spk_name) + spk_map = {x: i for i, x in enumerate(sorted(list(spk_map)))} + assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map) + return spk_map + + def item_name2spk_id(self, item_name): + return self.spk_map[self.item2spk[item_name]] + + def _phone_encoder(self): + ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" + ph_set = [] + if hparams['reset_phone_dict'] or not os.path.exists(ph_set_fn): + for processed_data_dir in self.processed_data_dirs: + ph_set += [x.split(' ')[0] for x in open(f'{processed_data_dir}/dict.txt').readlines()] + ph_set = sorted(set(ph_set)) + json.dump(ph_set, open(ph_set_fn, 'w')) + else: + ph_set = json.load(open(ph_set_fn, 'r')) + print("| phone set: ", ph_set) + return build_phone_encoder(hparams['binary_data_dir']) + + def meta_data(self, prefix): + if prefix == 'valid': + item_names = self.valid_item_names + elif prefix == 'test': + item_names = self.test_item_names + else: + item_names = self.train_item_names + for item_name in item_names: + ph = self.item2ph[item_name] + txt = self.item2txt[item_name] + tg_fn = self.item2tgfn.get(item_name) + wav_fn = self.item2wavfn[item_name] + spk_id = self.item_name2spk_id(item_name) + yield item_name, ph, txt, tg_fn, wav_fn, spk_id + + def process(self): + os.makedirs(hparams['binary_data_dir'], exist_ok=True) + self.spk_map = self.build_spk_map() + print("| spk_map: ", self.spk_map) + spk_map_fn = f"{hparams['binary_data_dir']}/spk_map.json" + json.dump(self.spk_map, open(spk_map_fn, 'w')) + + self.phone_encoder = self._phone_encoder() + self.process_data('valid') + self.process_data('test') + self.process_data('train') + + def process_data(self, prefix): + data_dir = hparams['binary_data_dir'] + args = [] + builder = IndexedDatasetBuilder(f'{data_dir}/{prefix}') + lengths = [] + f0s = [] + total_sec = 0 + if self.binarization_args['with_spk_embed']: + voice_encoder = VoiceEncoder().cuda() + + meta_data = list(self.meta_data(prefix)) + for m in meta_data: + args.append(list(m) + [self.phone_encoder, self.binarization_args]) + num_workers = int(os.getenv('N_PROC', os.cpu_count() // 3)) + for f_id, (_, item) in enumerate( + zip(tqdm(meta_data), chunked_multiprocess_run(self.process_item, args, num_workers=num_workers))): + if item is None: + continue + item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \ + if self.binarization_args['with_spk_embed'] else None + if not self.binarization_args['with_wav'] and 'wav' in item: + print("del wav") + del item['wav'] + builder.add_item(item) + lengths.append(item['len']) + total_sec += item['sec'] + if item.get('f0') is not None: + f0s.append(item['f0']) + builder.finalize() + np.save(f'{data_dir}/{prefix}_lengths.npy', lengths) + if len(f0s) > 0: + f0s = np.concatenate(f0s, 0) + f0s = f0s[f0s != 0] + np.save(f'{data_dir}/{prefix}_f0s_mean_std.npy', [np.mean(f0s).item(), np.std(f0s).item()]) + print(f"| {prefix} total duration: {total_sec:.3f}s") + + @classmethod + def process_item(cls, item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args): + if hparams['vocoder'] in VOCODERS: + wav, mel = VOCODERS[hparams['vocoder']].wav2spec(wav_fn) + else: + wav, mel = VOCODERS[hparams['vocoder'].split('.')[-1]].wav2spec(wav_fn) + res = { + 'item_name': item_name, 'txt': txt, 'ph': ph, 'mel': mel, 'wav': wav, 'wav_fn': wav_fn, + 'sec': len(wav) / hparams['audio_sample_rate'], 'len': mel.shape[0], 'spk_id': spk_id + } + try: + if binarization_args['with_f0']: + cls.get_pitch(wav, mel, res) + if binarization_args['with_f0cwt']: + cls.get_f0cwt(res['f0'], res) + if binarization_args['with_txt']: + try: + phone_encoded = res['phone'] = encoder.encode(ph) + except: + traceback.print_exc() + raise BinarizationError(f"Empty phoneme") + if binarization_args['with_align']: + cls.get_align(tg_fn, ph, mel, phone_encoded, res) + except BinarizationError as e: + print(f"| Skip item ({e}). item_name: {item_name}, wav_fn: {wav_fn}") + return None + return res + + @staticmethod + def get_align(tg_fn, ph, mel, phone_encoded, res): + if tg_fn is not None and os.path.exists(tg_fn): + mel2ph, dur = get_mel2ph(tg_fn, ph, mel, hparams) + else: + raise BinarizationError(f"Align not found") + if mel2ph.max() - 1 >= len(phone_encoded): + raise BinarizationError( + f"Align does not match: mel2ph.max() - 1: {mel2ph.max() - 1}, len(phone_encoded): {len(phone_encoded)}") + res['mel2ph'] = mel2ph + res['dur'] = dur + + @staticmethod + def get_pitch(wav, mel, res): + f0, pitch_coarse = get_pitch(wav, mel, hparams) + if sum(f0) == 0: + raise BinarizationError("Empty f0") + res['f0'] = f0 + res['pitch'] = pitch_coarse + + @staticmethod + def get_f0cwt(f0, res): + from utils.cwt import get_cont_lf0, get_lf0_cwt + uv, cont_lf0_lpf = get_cont_lf0(f0) + logf0s_mean_org, logf0s_std_org = np.mean(cont_lf0_lpf), np.std(cont_lf0_lpf) + cont_lf0_lpf_norm = (cont_lf0_lpf - logf0s_mean_org) / logf0s_std_org + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) + if np.any(np.isnan(Wavelet_lf0)): + raise BinarizationError("NaN CWT") + res['cwt_spec'] = Wavelet_lf0 + res['cwt_scales'] = scales + res['f0_mean'] = logf0s_mean_org + res['f0_std'] = logf0s_std_org + + +if __name__ == "__main__": + set_hparams() + BaseBinarizer().process() diff --git a/data_gen/tts/base_preprocess.py b/data_gen/tts/base_preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..6c0b2cda06076d32b4eda800b134415e20d0f730 --- /dev/null +++ b/data_gen/tts/base_preprocess.py @@ -0,0 +1,245 @@ +import json +import os +import random +import re +import traceback +from collections import Counter +from functools import partial + +import librosa +from tqdm import tqdm +from data_gen.tts.txt_processors.base_text_processor import get_txt_processor_cls +from data_gen.tts.wav_processors.base_processor import get_wav_processor_cls +from utils.hparams import hparams +from utils.multiprocess_utils import multiprocess_run_tqdm +from utils.os_utils import link_file, move_file, remove_file +from data_gen.tts.data_gen_utils import is_sil_phoneme, build_token_encoder + + +class BasePreprocessor: + def __init__(self): + self.preprocess_args = hparams['preprocess_args'] + txt_processor = self.preprocess_args['txt_processor'] + self.txt_processor = get_txt_processor_cls(txt_processor) + self.raw_data_dir = hparams['raw_data_dir'] + self.processed_dir = hparams['processed_data_dir'] + self.spk_map_fn = f"{self.processed_dir}/spk_map.json" + + def meta_data(self): + """ + :return: {'item_name': Str, 'wav_fn': Str, 'txt': Str, 'spk_name': Str, 'txt_loader': None or Func} + """ + raise NotImplementedError + + def process(self): + processed_dir = self.processed_dir + wav_processed_tmp_dir = f'{processed_dir}/processed_tmp' + remove_file(wav_processed_tmp_dir) + os.makedirs(wav_processed_tmp_dir, exist_ok=True) + wav_processed_dir = f'{processed_dir}/{self.wav_processed_dirname}' + remove_file(wav_processed_dir) + os.makedirs(wav_processed_dir, exist_ok=True) + + meta_data = list(tqdm(self.meta_data(), desc='Load meta data')) + item_names = [d['item_name'] for d in meta_data] + assert len(item_names) == len(set(item_names)), 'Key `item_name` should be Unique.' + + # preprocess data + phone_list = [] + word_list = [] + spk_names = set() + process_item = partial(self.preprocess_first_pass, + txt_processor=self.txt_processor, + wav_processed_dir=wav_processed_dir, + wav_processed_tmp=wav_processed_tmp_dir, + preprocess_args=self.preprocess_args) + items = [] + args = [{ + 'item_name': item_raw['item_name'], + 'txt_raw': item_raw['txt'], + 'wav_fn': item_raw['wav_fn'], + 'txt_loader': item_raw.get('txt_loader'), + 'others': item_raw.get('others', None) + } for item_raw in meta_data] + for item_, (item_id, item) in zip(meta_data, multiprocess_run_tqdm(process_item, args, desc='Preprocess')): + if item is not None: + item_.update(item) + item = item_ + if 'txt_loader' in item: + del item['txt_loader'] + item['id'] = item_id + item['spk_name'] = item.get('spk_name', '') + item['others'] = item.get('others', None) + phone_list += item['ph'].split(" ") + word_list += item['word'].split(" ") + spk_names.add(item['spk_name']) + items.append(item) + + # add encoded tokens + ph_encoder, word_encoder = self._phone_encoder(phone_list), self._word_encoder(word_list) + spk_map = self.build_spk_map(spk_names) + args = [{ + 'ph': item['ph'], 'word': item['word'], 'spk_name': item['spk_name'], + 'word_encoder': word_encoder, 'ph_encoder': ph_encoder, 'spk_map': spk_map + } for item in items] + for idx, item_new_kv in multiprocess_run_tqdm(self.preprocess_second_pass, args, desc='Add encoded tokens'): + items[idx].update(item_new_kv) + + # build mfa data + if self.preprocess_args['use_mfa']: + mfa_dict = set() + mfa_input_dir = f'{processed_dir}/mfa_inputs' + remove_file(mfa_input_dir) + # group MFA inputs for better parallelism + mfa_groups = [i // self.preprocess_args['nsample_per_mfa_group'] for i in range(len(items))] + if self.preprocess_args['mfa_group_shuffle']: + random.seed(hparams['seed']) + random.shuffle(mfa_groups) + args = [{ + 'item': item, 'mfa_input_dir': mfa_input_dir, + 'mfa_group': mfa_group, 'wav_processed_tmp': wav_processed_tmp_dir, + 'preprocess_args': self.preprocess_args + } for item, mfa_group in zip(items, mfa_groups)] + for i, (ph_gb_word_nosil, new_wav_align_fn) in multiprocess_run_tqdm( + self.build_mfa_inputs, args, desc='Build MFA data'): + items[i]['wav_align_fn'] = new_wav_align_fn + for w in ph_gb_word_nosil.split(" "): + mfa_dict.add(f"{w} {w.replace('_', ' ')}") + mfa_dict = sorted(mfa_dict) + with open(f'{processed_dir}/mfa_dict.txt', 'w') as f: + f.writelines([f'{l}\n' for l in mfa_dict]) + with open(f"{processed_dir}/{self.meta_csv_filename}.json", 'w') as f: + f.write(re.sub(r'\n\s+([\d+\]])', r'\1', json.dumps(items, ensure_ascii=False, sort_keys=False, indent=1))) + remove_file(wav_processed_tmp_dir) + + @classmethod + def preprocess_first_pass(cls, item_name, txt_raw, txt_processor, + wav_fn, wav_processed_dir, wav_processed_tmp, + preprocess_args, txt_loader=None, others=None): + try: + if txt_loader is not None: + txt_raw = txt_loader(txt_raw) + ph, txt, word, ph2word, ph_gb_word = cls.txt_to_ph(txt_processor, txt_raw, preprocess_args) + wav_fn, wav_align_fn = cls.process_wav( + item_name, wav_fn, + hparams['processed_data_dir'], + wav_processed_tmp, preprocess_args) + + # wav for binarization + ext = os.path.splitext(wav_fn)[1] + os.makedirs(wav_processed_dir, exist_ok=True) + new_wav_fn = f"{wav_processed_dir}/{item_name}{ext}" + move_link_func = move_file if os.path.dirname(wav_fn) == wav_processed_tmp else link_file + move_link_func(wav_fn, new_wav_fn) + return { + 'txt': txt, 'txt_raw': txt_raw, 'ph': ph, + 'word': word, 'ph2word': ph2word, 'ph_gb_word': ph_gb_word, + 'wav_fn': new_wav_fn, 'wav_align_fn': wav_align_fn, + 'others': others + } + except: + traceback.print_exc() + print(f"| Error is caught. item_name: {item_name}.") + return None + + @staticmethod + def txt_to_ph(txt_processor, txt_raw, preprocess_args): + txt_struct, txt = txt_processor.process(txt_raw, preprocess_args) + ph = [p for w in txt_struct for p in w[1]] + return " ".join(ph), txt + + @staticmethod + def process_wav(item_name, wav_fn, processed_dir, wav_processed_tmp, preprocess_args): + processors = [get_wav_processor_cls(v) for v in preprocess_args['wav_processors']] + processors = [k() for k in processors if k is not None] + if len(processors) >= 1: + sr_file = librosa.core.get_samplerate(wav_fn) + output_fn_for_align = None + ext = os.path.splitext(wav_fn)[1] + input_fn = f"{wav_processed_tmp}/{item_name}{ext}" + link_file(wav_fn, input_fn) + for p in processors: + outputs = p.process(input_fn, sr_file, wav_processed_tmp, processed_dir, item_name, preprocess_args) + if len(outputs) == 3: + input_fn, sr, output_fn_for_align = outputs + else: + input_fn, sr = outputs + return input_fn, output_fn_for_align + else: + return wav_fn, wav_fn + + def _phone_encoder(self, ph_set): + ph_set_fn = f"{self.processed_dir}/phone_set.json" + if self.preprocess_args['reset_phone_dict'] or not os.path.exists(ph_set_fn): + ph_set = sorted(set(ph_set)) + json.dump(ph_set, open(ph_set_fn, 'w'), ensure_ascii=False) + print("| Build phone set: ", ph_set) + else: + ph_set = json.load(open(ph_set_fn, 'r')) + print("| Load phone set: ", ph_set) + return build_token_encoder(ph_set_fn) + + def _word_encoder(self, word_set): + word_set_fn = f"{self.processed_dir}/word_set.json" + if self.preprocess_args['reset_word_dict']: + word_set = Counter(word_set) + total_words = sum(word_set.values()) + word_set = word_set.most_common(hparams['word_dict_size']) + num_unk_words = total_words - sum([x[1] for x in word_set]) + word_set = ['', ''] + [x[0] for x in word_set] + word_set = sorted(set(word_set)) + json.dump(word_set, open(word_set_fn, 'w'), ensure_ascii=False) + print(f"| Build word set. Size: {len(word_set)}, #total words: {total_words}," + f" #unk_words: {num_unk_words}, word_set[:10]:, {word_set[:10]}.") + else: + word_set = json.load(open(word_set_fn, 'r')) + print("| Load word set. Size: ", len(word_set), word_set[:10]) + return build_token_encoder(word_set_fn) + + @classmethod + def preprocess_second_pass(cls, word, ph, spk_name, word_encoder, ph_encoder, spk_map): + word_token = word_encoder.encode(word) + ph_token = ph_encoder.encode(ph) + spk_id = spk_map[spk_name] + return {'word_token': word_token, 'ph_token': ph_token, 'spk_id': spk_id} + + def build_spk_map(self, spk_names): + spk_map = {x: i for i, x in enumerate(sorted(list(spk_names)))} + assert len(spk_map) == 0 or len(spk_map) <= hparams['num_spk'], len(spk_map) + print(f"| Number of spks: {len(spk_map)}, spk_map: {spk_map}") + json.dump(spk_map, open(self.spk_map_fn, 'w'), ensure_ascii=False) + return spk_map + + @classmethod + def build_mfa_inputs(cls, item, mfa_input_dir, mfa_group, wav_processed_tmp, preprocess_args): + item_name = item['item_name'] + wav_align_fn = item['wav_align_fn'] + ph_gb_word = item['ph_gb_word'] + ext = os.path.splitext(wav_align_fn)[1] + mfa_input_group_dir = f'{mfa_input_dir}/{mfa_group}' + os.makedirs(mfa_input_group_dir, exist_ok=True) + new_wav_align_fn = f"{mfa_input_group_dir}/{item_name}{ext}" + move_link_func = move_file if os.path.dirname(wav_align_fn) == wav_processed_tmp else link_file + move_link_func(wav_align_fn, new_wav_align_fn) + ph_gb_word_nosil = " ".join(["_".join([p for p in w.split("_") if not is_sil_phoneme(p)]) + for w in ph_gb_word.split(" ") if not is_sil_phoneme(w)]) + with open(f'{mfa_input_group_dir}/{item_name}.lab', 'w') as f_txt: + f_txt.write(ph_gb_word_nosil) + return ph_gb_word_nosil, new_wav_align_fn + + def load_spk_map(self, base_dir): + spk_map_fn = f"{base_dir}/spk_map.json" + spk_map = json.load(open(spk_map_fn, 'r')) + return spk_map + + def load_dict(self, base_dir): + ph_encoder = build_token_encoder(f'{base_dir}/phone_set.json') + return ph_encoder + + @property + def meta_csv_filename(self): + return 'metadata' + + @property + def wav_processed_dirname(self): + return 'wav_processed' \ No newline at end of file diff --git a/data_gen/tts/bin/binarize.py b/data_gen/tts/bin/binarize.py new file mode 100644 index 0000000000000000000000000000000000000000..4bd3c1f69fa59ed52fdd32eb80e746dedbae7535 --- /dev/null +++ b/data_gen/tts/bin/binarize.py @@ -0,0 +1,20 @@ +import os + +os.environ["OMP_NUM_THREADS"] = "1" + +import importlib +from utils.hparams import set_hparams, hparams + + +def binarize(): + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + print("| Binarizer: ", binarizer_cls) + binarizer_cls().process() + + +if __name__ == '__main__': + set_hparams() + binarize() diff --git a/data_gen/tts/data_gen_utils.py b/data_gen/tts/data_gen_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b6bf10862cf3f9a8b2aee560ae5d44eabbf00bc --- /dev/null +++ b/data_gen/tts/data_gen_utils.py @@ -0,0 +1,352 @@ +import warnings + +warnings.filterwarnings("ignore") + +# import parselmouth +import os +import torch +from skimage.transform import resize +from utils.text_encoder import TokenTextEncoder +from utils.pitch_utils import f0_to_coarse +import struct +import webrtcvad +from scipy.ndimage.morphology import binary_dilation +import librosa +import numpy as np +from utils import audio +import pyloudnorm as pyln +import re +import json +from collections import OrderedDict + +PUNCS = '!,.?;:' + +int16_max = (2 ** 15) - 1 + + +def trim_long_silences(path, sr=None, return_raw_wav=False, norm=True, vad_max_silence_length=12): + """ + Ensures that segments without voice in the waveform remain no longer than a + threshold determined by the VAD parameters in params.py. + :param wav: the raw waveform as a numpy array of floats + :param vad_max_silence_length: Maximum number of consecutive silent frames a segment can have. + :return: the same waveform with silences trimmed away (length <= original wav length) + """ + + ## Voice Activation Detection + # Window size of the VAD. Must be either 10, 20 or 30 milliseconds. + # This sets the granularity of the VAD. Should not need to be changed. + sampling_rate = 16000 + wav_raw, sr = librosa.core.load(path, sr=sr) + + if norm: + meter = pyln.Meter(sr) # create BS.1770 meter + loudness = meter.integrated_loudness(wav_raw) + wav_raw = pyln.normalize.loudness(wav_raw, loudness, -20.0) + if np.abs(wav_raw).max() > 1.0: + wav_raw = wav_raw / np.abs(wav_raw).max() + + wav = librosa.resample(wav_raw, sr, sampling_rate, res_type='kaiser_best') + + vad_window_length = 30 # In milliseconds + # Number of frames to average together when performing the moving average smoothing. + # The larger this value, the larger the VAD variations must be to not get smoothed out. + vad_moving_average_width = 8 + + # Compute the voice detection window size + samples_per_window = (vad_window_length * sampling_rate) // 1000 + + # Trim the end of the audio to have a multiple of the window size + wav = wav[:len(wav) - (len(wav) % samples_per_window)] + + # Convert the float waveform to 16-bit mono PCM + pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16)) + + # Perform voice activation detection + voice_flags = [] + vad = webrtcvad.Vad(mode=3) + for window_start in range(0, len(wav), samples_per_window): + window_end = window_start + samples_per_window + voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2], + sample_rate=sampling_rate)) + voice_flags = np.array(voice_flags) + + # Smooth the voice detection with a moving average + def moving_average(array, width): + array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2))) + ret = np.cumsum(array_padded, dtype=float) + ret[width:] = ret[width:] - ret[:-width] + return ret[width - 1:] / width + + audio_mask = moving_average(voice_flags, vad_moving_average_width) + audio_mask = np.round(audio_mask).astype(np.bool) + + # Dilate the voiced regions + audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1)) + audio_mask = np.repeat(audio_mask, samples_per_window) + audio_mask = resize(audio_mask, (len(wav_raw),)) > 0 + if return_raw_wav: + return wav_raw, audio_mask, sr + return wav_raw[audio_mask], audio_mask, sr + + +def process_utterance(wav_path, + fft_size=1024, + hop_size=256, + win_length=1024, + window="hann", + num_mels=80, + fmin=80, + fmax=7600, + eps=1e-6, + sample_rate=22050, + loud_norm=False, + min_level_db=-100, + return_linear=False, + trim_long_sil=False, vocoder='pwg'): + if isinstance(wav_path, str): + if trim_long_sil: + wav, _, _ = trim_long_silences(wav_path, sample_rate) + else: + wav, _ = librosa.core.load(wav_path, sr=sample_rate) + else: + wav = wav_path + + if loud_norm: + meter = pyln.Meter(sample_rate) # create BS.1770 meter + loudness = meter.integrated_loudness(wav) + wav = pyln.normalize.loudness(wav, loudness, -22.0) + if np.abs(wav).max() > 1: + wav = wav / np.abs(wav).max() + + # get amplitude spectrogram + x_stft = librosa.stft(wav, n_fft=fft_size, hop_length=hop_size, + win_length=win_length, window=window, pad_mode="constant") + spc = np.abs(x_stft) # (n_bins, T) + + # get mel basis + fmin = 0 if fmin == -1 else fmin + fmax = sample_rate / 2 if fmax == -1 else fmax + mel_basis = librosa.filters.mel(sample_rate, fft_size, num_mels, fmin, fmax) + mel = mel_basis @ spc + + if vocoder == 'pwg': + mel = np.log10(np.maximum(eps, mel)) # (n_mel_bins, T) + else: + assert False, f'"{vocoder}" is not in ["pwg"].' + + l_pad, r_pad = audio.librosa_pad_lr(wav, fft_size, hop_size, 1) + wav = np.pad(wav, (l_pad, r_pad), mode='constant', constant_values=0.0) + wav = wav[:mel.shape[1] * hop_size] + + if not return_linear: + return wav, mel + else: + spc = audio.amp_to_db(spc) + spc = audio.normalize(spc, {'min_level_db': min_level_db}) + return wav, mel, spc + + +def get_pitch(wav_data, mel, hparams): + """ + + :param wav_data: [T] + :param mel: [T, 80] + :param hparams: + :return: + """ + time_step = hparams['hop_size'] / hparams['audio_sample_rate'] * 1000 + f0_min = 80 + f0_max = 750 + + if hparams['hop_size'] == 128: + pad_size = 4 + elif hparams['hop_size'] == 256: + pad_size = 2 + else: + assert False + + f0 = parselmouth.Sound(wav_data, hparams['audio_sample_rate']).to_pitch_ac( + time_step=time_step / 1000, voicing_threshold=0.6, + pitch_floor=f0_min, pitch_ceiling=f0_max).selected_array['frequency'] + lpad = pad_size * 2 + rpad = len(mel) - len(f0) - lpad + f0 = np.pad(f0, [[lpad, rpad]], mode='constant') + # mel and f0 are extracted by 2 different libraries. we should force them to have the same length. + # Attention: we find that new version of some libraries could cause ``rpad'' to be a negetive value... + # Just to be sure, we recommend users to set up the same environments as them in requirements_auto.txt (by Anaconda) + delta_l = len(mel) - len(f0) + assert np.abs(delta_l) <= 8 + if delta_l > 0: + f0 = np.concatenate([f0, [f0[-1]] * delta_l], 0) + f0 = f0[:len(mel)] + pitch_coarse = f0_to_coarse(f0) + return f0, pitch_coarse + + +def remove_empty_lines(text): + """remove empty lines""" + assert (len(text) > 0) + assert (isinstance(text, list)) + text = [t.strip() for t in text] + if "" in text: + text.remove("") + return text + + +class TextGrid(object): + def __init__(self, text): + text = remove_empty_lines(text) + self.text = text + self.line_count = 0 + self._get_type() + self._get_time_intval() + self._get_size() + self.tier_list = [] + self._get_item_list() + + def _extract_pattern(self, pattern, inc): + """ + Parameters + ---------- + pattern : regex to extract pattern + inc : increment of line count after extraction + Returns + ------- + group : extracted info + """ + try: + group = re.match(pattern, self.text[self.line_count]).group(1) + self.line_count += inc + except AttributeError: + raise ValueError("File format error at line %d:%s" % (self.line_count, self.text[self.line_count])) + return group + + def _get_type(self): + self.file_type = self._extract_pattern(r"File type = \"(.*)\"", 2) + + def _get_time_intval(self): + self.xmin = self._extract_pattern(r"xmin = (.*)", 1) + self.xmax = self._extract_pattern(r"xmax = (.*)", 2) + + def _get_size(self): + self.size = int(self._extract_pattern(r"size = (.*)", 2)) + + def _get_item_list(self): + """Only supports IntervalTier currently""" + for itemIdx in range(1, self.size + 1): + tier = OrderedDict() + item_list = [] + tier_idx = self._extract_pattern(r"item \[(.*)\]:", 1) + tier_class = self._extract_pattern(r"class = \"(.*)\"", 1) + if tier_class != "IntervalTier": + raise NotImplementedError("Only IntervalTier class is supported currently") + tier_name = self._extract_pattern(r"name = \"(.*)\"", 1) + tier_xmin = self._extract_pattern(r"xmin = (.*)", 1) + tier_xmax = self._extract_pattern(r"xmax = (.*)", 1) + tier_size = self._extract_pattern(r"intervals: size = (.*)", 1) + for i in range(int(tier_size)): + item = OrderedDict() + item["idx"] = self._extract_pattern(r"intervals \[(.*)\]", 1) + item["xmin"] = self._extract_pattern(r"xmin = (.*)", 1) + item["xmax"] = self._extract_pattern(r"xmax = (.*)", 1) + item["text"] = self._extract_pattern(r"text = \"(.*)\"", 1) + item_list.append(item) + tier["idx"] = tier_idx + tier["class"] = tier_class + tier["name"] = tier_name + tier["xmin"] = tier_xmin + tier["xmax"] = tier_xmax + tier["size"] = tier_size + tier["items"] = item_list + self.tier_list.append(tier) + + def toJson(self): + _json = OrderedDict() + _json["file_type"] = self.file_type + _json["xmin"] = self.xmin + _json["xmax"] = self.xmax + _json["size"] = self.size + _json["tiers"] = self.tier_list + return json.dumps(_json, ensure_ascii=False, indent=2) + + +def get_mel2ph(tg_fn, ph, mel, hparams): + ph_list = ph.split(" ") + with open(tg_fn, "r") as f: + tg = f.readlines() + tg = remove_empty_lines(tg) + tg = TextGrid(tg) + tg = json.loads(tg.toJson()) + split = np.ones(len(ph_list) + 1, np.float) * -1 + tg_idx = 0 + ph_idx = 0 + tg_align = [x for x in tg['tiers'][-1]['items']] + tg_align_ = [] + for x in tg_align: + x['xmin'] = float(x['xmin']) + x['xmax'] = float(x['xmax']) + if x['text'] in ['sil', 'sp', '', 'SIL', 'PUNC']: + x['text'] = '' + if len(tg_align_) > 0 and tg_align_[-1]['text'] == '': + tg_align_[-1]['xmax'] = x['xmax'] + continue + tg_align_.append(x) + tg_align = tg_align_ + tg_len = len([x for x in tg_align if x['text'] != '']) + ph_len = len([x for x in ph_list if not is_sil_phoneme(x)]) + assert tg_len == ph_len, (tg_len, ph_len, tg_align, ph_list, tg_fn) + while tg_idx < len(tg_align) or ph_idx < len(ph_list): + if tg_idx == len(tg_align) and is_sil_phoneme(ph_list[ph_idx]): + split[ph_idx] = 1e8 + ph_idx += 1 + continue + x = tg_align[tg_idx] + if x['text'] == '' and ph_idx == len(ph_list): + tg_idx += 1 + continue + assert ph_idx < len(ph_list), (tg_len, ph_len, tg_align, ph_list, tg_fn) + ph = ph_list[ph_idx] + if x['text'] == '' and not is_sil_phoneme(ph): + assert False, (ph_list, tg_align) + if x['text'] != '' and is_sil_phoneme(ph): + ph_idx += 1 + else: + assert (x['text'] == '' and is_sil_phoneme(ph)) \ + or x['text'].lower() == ph.lower() \ + or x['text'].lower() == 'sil', (x['text'], ph) + split[ph_idx] = x['xmin'] + if ph_idx > 0 and split[ph_idx - 1] == -1 and is_sil_phoneme(ph_list[ph_idx - 1]): + split[ph_idx - 1] = split[ph_idx] + ph_idx += 1 + tg_idx += 1 + assert tg_idx == len(tg_align), (tg_idx, [x['text'] for x in tg_align]) + assert ph_idx >= len(ph_list) - 1, (ph_idx, ph_list, len(ph_list), [x['text'] for x in tg_align], tg_fn) + mel2ph = np.zeros([mel.shape[0]], np.int) + split[0] = 0 + split[-1] = 1e8 + for i in range(len(split) - 1): + assert split[i] != -1 and split[i] <= split[i + 1], (split[:-1],) + split = [int(s * hparams['audio_sample_rate'] / hparams['hop_size'] + 0.5) for s in split] + for ph_idx in range(len(ph_list)): + mel2ph[split[ph_idx]:split[ph_idx + 1]] = ph_idx + 1 + mel2ph_torch = torch.from_numpy(mel2ph) + T_t = len(ph_list) + dur = mel2ph_torch.new_zeros([T_t + 1]).scatter_add(0, mel2ph_torch, torch.ones_like(mel2ph_torch)) + dur = dur[1:].numpy() + return mel2ph, dur + + +def build_phone_encoder(data_dir): + phone_list_file = os.path.join(data_dir, 'phone_set.json') + phone_list = json.load(open(phone_list_file)) + return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + + +def is_sil_phoneme(p): + return not p[0].isalpha() + + +def build_token_encoder(token_list_file): + token_list = json.load(open(token_list_file)) + return TokenTextEncoder(None, vocab_list=token_list, replace_oov='') diff --git a/data_gen/tts/txt_processors/__init__.py b/data_gen/tts/txt_processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7bff3e9af7d634363116c6605f22a52aad614dea --- /dev/null +++ b/data_gen/tts/txt_processors/__init__.py @@ -0,0 +1 @@ +from . import en \ No newline at end of file diff --git a/data_gen/tts/txt_processors/base_text_processor.py b/data_gen/tts/txt_processors/base_text_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..69d51201dcb191c1c208ae1c87a34b5c97e6307f --- /dev/null +++ b/data_gen/tts/txt_processors/base_text_processor.py @@ -0,0 +1,47 @@ +from data_gen.tts.data_gen_utils import is_sil_phoneme + +REGISTERED_TEXT_PROCESSORS = {} + +def register_txt_processors(name): + def _f(cls): + REGISTERED_TEXT_PROCESSORS[name] = cls + return cls + + return _f + + +def get_txt_processor_cls(name): + return REGISTERED_TEXT_PROCESSORS.get(name, None) + + +class BaseTxtProcessor: + @staticmethod + def sp_phonemes(): + return ['|'] + + @classmethod + def process(cls, txt, preprocess_args): + raise NotImplementedError + + @classmethod + def postprocess(cls, txt_struct, preprocess_args): + # remove sil phoneme in head and tail + while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[0][0]): + txt_struct = txt_struct[1:] + while len(txt_struct) > 0 and is_sil_phoneme(txt_struct[-1][0]): + txt_struct = txt_struct[:-1] + if preprocess_args['with_phsep']: + txt_struct = cls.add_bdr(txt_struct) + if preprocess_args['add_eos_bos']: + txt_struct = [["", [""]]] + txt_struct + [["", [""]]] + return txt_struct + + @classmethod + def add_bdr(cls, txt_struct): + txt_struct_ = [] + for i, ts in enumerate(txt_struct): + txt_struct_.append(ts) + if i != len(txt_struct) - 1 and \ + not is_sil_phoneme(txt_struct[i][0]) and not is_sil_phoneme(txt_struct[i + 1][0]): + txt_struct_.append(['|', ['|']]) + return txt_struct_ \ No newline at end of file diff --git a/data_gen/tts/txt_processors/en.py b/data_gen/tts/txt_processors/en.py new file mode 100644 index 0000000000000000000000000000000000000000..6f755d5ab1f2cf4407daee08cc3639a05e941a97 --- /dev/null +++ b/data_gen/tts/txt_processors/en.py @@ -0,0 +1,77 @@ +import re +import unicodedata + +from g2p_en import G2p +from g2p_en.expand import normalize_numbers +from nltk import pos_tag +from nltk.tokenize import TweetTokenizer + +from data_gen.tts.txt_processors.base_text_processor import BaseTxtProcessor, register_txt_processors +from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS + +class EnG2p(G2p): + word_tokenize = TweetTokenizer().tokenize + + def __call__(self, text): + # preprocessing + words = EnG2p.word_tokenize(text) + tokens = pos_tag(words) # tuples of (word, tag) + + # steps + prons = [] + for word, pos in tokens: + if re.search("[a-z]", word) is None: + pron = [word] + + elif word in self.homograph2features: # Check homograph + pron1, pron2, pos1 = self.homograph2features[word] + if pos.startswith(pos1): + pron = pron1 + else: + pron = pron2 + elif word in self.cmu: # lookup CMU dict + pron = self.cmu[word][0] + else: # predict for oov + pron = self.predict(word) + + prons.extend(pron) + prons.extend([" "]) + + return prons[:-1] + + +@register_txt_processors('en') +class TxtProcessor(BaseTxtProcessor): + g2p = EnG2p() + + @staticmethod + def preprocess_text(text): + text = normalize_numbers(text) + text = ''.join(char for char in unicodedata.normalize('NFD', text) + if unicodedata.category(char) != 'Mn') # Strip accents + text = text.lower() + text = re.sub("[\'\"()]+", "", text) + text = re.sub("[-]+", " ", text) + text = re.sub(f"[^ a-z{PUNCS}]", "", text) + text = re.sub(f" ?([{PUNCS}]) ?", r"\1", text) # !! -> ! + text = re.sub(f"([{PUNCS}])+", r"\1", text) # !! -> ! + text = text.replace("i.e.", "that is") + text = text.replace("i.e.", "that is") + text = text.replace("etc.", "etc") + text = re.sub(f"([{PUNCS}])", r" \1 ", text) + text = re.sub(rf"\s+", r" ", text) + return text + + @classmethod + def process(cls, txt, preprocess_args): + txt = cls.preprocess_text(txt).strip() + phs = cls.g2p(txt) + txt_struct = [[w, []] for w in txt.split(" ")] + i_word = 0 + for p in phs: + if p == ' ': + i_word += 1 + else: + txt_struct[i_word][1].append(p) + txt_struct = cls.postprocess(txt_struct, preprocess_args) + return txt_struct, txt \ No newline at end of file diff --git a/data_gen/tts/wav_processors/__init__.py b/data_gen/tts/wav_processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4be97b377dcb95a0e6bceb876ac0ce93c8290249 --- /dev/null +++ b/data_gen/tts/wav_processors/__init__.py @@ -0,0 +1,2 @@ +from . import base_processor +from . import common_processors diff --git a/data_gen/tts/wav_processors/base_processor.py b/data_gen/tts/wav_processors/base_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e8200dc58a9388ac94a5ec34b8a65f75e380255b --- /dev/null +++ b/data_gen/tts/wav_processors/base_processor.py @@ -0,0 +1,25 @@ +REGISTERED_WAV_PROCESSORS = {} + + +def register_wav_processors(name): + def _f(cls): + REGISTERED_WAV_PROCESSORS[name] = cls + return cls + + return _f + + +def get_wav_processor_cls(name): + return REGISTERED_WAV_PROCESSORS.get(name, None) + + +class BaseWavProcessor: + @property + def name(self): + raise NotImplementedError + + def output_fn(self, input_fn): + return f'{input_fn[:-4]}_{self.name}.wav' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + raise NotImplementedError diff --git a/data_gen/tts/wav_processors/common_processors.py b/data_gen/tts/wav_processors/common_processors.py new file mode 100644 index 0000000000000000000000000000000000000000..de0b49f4a31cb6737f2cffc6c8d010d88d11c853 --- /dev/null +++ b/data_gen/tts/wav_processors/common_processors.py @@ -0,0 +1,86 @@ +import os +import subprocess +import librosa +import numpy as np +from data_gen.tts.wav_processors.base_processor import BaseWavProcessor, register_wav_processors +from data_gen.tts.data_gen_utils import trim_long_silences +from utils.audio import save_wav +from utils.rnnoise import rnnoise +from utils.hparams import hparams + + +@register_wav_processors(name='sox_to_wav') +class ConvertToWavProcessor(BaseWavProcessor): + @property + def name(self): + return 'ToWav' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + if input_fn[-4:] == '.wav': + return input_fn, sr + else: + output_fn = self.output_fn(input_fn) + subprocess.check_call(f'sox -v 0.95 "{input_fn}" -t wav "{output_fn}"', shell=True) + return output_fn, sr + + +@register_wav_processors(name='sox_resample') +class ResampleProcessor(BaseWavProcessor): + @property + def name(self): + return 'Resample' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + output_fn = self.output_fn(input_fn) + sr_file = librosa.core.get_samplerate(input_fn) + if sr != sr_file: + subprocess.check_call(f'sox -v 0.95 "{input_fn}" -r{sr} "{output_fn}"', shell=True) + y, _ = librosa.core.load(input_fn, sr=sr) + y, _ = librosa.effects.trim(y) + save_wav(y, output_fn, sr) + return output_fn, sr + else: + return input_fn, sr + + +@register_wav_processors(name='trim_sil') +class TrimSILProcessor(BaseWavProcessor): + @property + def name(self): + return 'TrimSIL' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + output_fn = self.output_fn(input_fn) + y, _ = librosa.core.load(input_fn, sr=sr) + y, _ = librosa.effects.trim(y) + save_wav(y, output_fn, sr) + return output_fn + + +@register_wav_processors(name='trim_all_sil') +class TrimAllSILProcessor(BaseWavProcessor): + @property + def name(self): + return 'TrimSIL' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + output_fn = self.output_fn(input_fn) + y, audio_mask, _ = trim_long_silences( + input_fn, vad_max_silence_length=preprocess_args.get('vad_max_silence_length', 12)) + save_wav(y, output_fn, sr) + if preprocess_args['save_sil_mask']: + os.makedirs(f'{processed_dir}/sil_mask', exist_ok=True) + np.save(f'{processed_dir}/sil_mask/{item_name}.npy', audio_mask) + return output_fn, sr + + +@register_wav_processors(name='denoise') +class DenoiseProcessor(BaseWavProcessor): + @property + def name(self): + return 'Denoise' + + def process(self, input_fn, sr, tmp_dir, processed_dir, item_name, preprocess_args): + output_fn = self.output_fn(input_fn) + rnnoise(input_fn, output_fn, out_sample_rate=sr) + return output_fn, sr diff --git a/egs/datasets/audio/libritts/base_text2mel.yaml b/egs/datasets/audio/libritts/base_text2mel.yaml new file mode 100755 index 0000000000000000000000000000000000000000..85c389a45a7bc2be4927867d07fa881754814a2c --- /dev/null +++ b/egs/datasets/audio/libritts/base_text2mel.yaml @@ -0,0 +1,14 @@ +raw_data_dir: 'data/raw/LibriTTS' +processed_data_dir: 'data/processed/libritts' +binary_data_dir: 'data/binary/libritts' +pre_align_cls: egs.datasets.audio.libritts.pre_align.LibrittsPreAlign +binarization_args: + shuffle: true +use_spk_id: true +test_num: 200 +num_spk: 2320 +pitch_type: frame +min_frames: 128 +num_test_samples: 30 +mel_loss: "ssim:0.5|l1:0.5" +vocoder_ckpt: '' \ No newline at end of file diff --git a/egs/datasets/audio/libritts/fs2.yaml b/egs/datasets/audio/libritts/fs2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..95ae09730837aa10606f9771f41f40429b99d9ac --- /dev/null +++ b/egs/datasets/audio/libritts/fs2.yaml @@ -0,0 +1,3 @@ +base_config: + - egs/egs_bases/tts/fs2.yaml + - ./base_text2mel.yaml diff --git a/egs/datasets/audio/libritts/pre_align.py b/egs/datasets/audio/libritts/pre_align.py new file mode 100755 index 0000000000000000000000000000000000000000..335b43d913edb02fb02e10c2479aa3dd9e07bb2f --- /dev/null +++ b/egs/datasets/audio/libritts/pre_align.py @@ -0,0 +1,18 @@ +import os + +from data_gen.tts.base_pre_align import BasePreAlign +import glob + + +class LibrittsPreAlign(BasePreAlign): + def meta_data(self): + wav_fns = sorted(glob.glob(f'{self.raw_data_dir}/*/*/*/*.wav')) + for wav_fn in wav_fns: + item_name = os.path.basename(wav_fn)[:-4] + txt_fn = f'{wav_fn[:-4]}.normalized.txt' + spk = item_name.split("_")[0] + yield item_name, wav_fn, (self.load_txt, txt_fn), spk + + +if __name__ == "__main__": + LibrittsPreAlign().process() diff --git a/egs/datasets/audio/libritts/pwg.yaml b/egs/datasets/audio/libritts/pwg.yaml new file mode 100755 index 0000000000000000000000000000000000000000..a0fd70869274c3f8a7ac02a9ca5d7e1202dc895d --- /dev/null +++ b/egs/datasets/audio/libritts/pwg.yaml @@ -0,0 +1,8 @@ +base_config: egs/egs_bases/tts/vocoder/pwg.yaml +raw_data_dir: 'data/raw/LibriTTS' +processed_data_dir: 'data/processed/libritts' +binary_data_dir: 'data/binary/libritts_wav' +generator_params: + kernel_size: 5 +num_spk: 400 +max_samples: 20480 diff --git a/egs/datasets/audio/lj/base_mel2wav.yaml b/egs/datasets/audio/lj/base_mel2wav.yaml new file mode 100755 index 0000000000000000000000000000000000000000..df4355bc38d0568c0b3acfa4f7fc040cef5995d6 --- /dev/null +++ b/egs/datasets/audio/lj/base_mel2wav.yaml @@ -0,0 +1,5 @@ +raw_data_dir: 'data/raw/LJSpeech-1.1' +processed_data_dir: 'data/processed/ljspeech' +binary_data_dir: 'data/binary/ljspeech_wav' +binarization_args: + with_spk_embed: false \ No newline at end of file diff --git a/egs/datasets/audio/lj/pre_align.py b/egs/datasets/audio/lj/pre_align.py new file mode 100755 index 0000000000000000000000000000000000000000..847b9f87b4e74cd634dd5bb2313f78afd5602ad7 --- /dev/null +++ b/egs/datasets/audio/lj/pre_align.py @@ -0,0 +1,13 @@ +from data_gen.tts.base_preprocess import BasePreprocessor + + +class LJPreAlign(BasePreprocessor): + def meta_data(self): + for l in open(f'{self.raw_data_dir}/metadata.csv').readlines(): + item_name, _, txt = l.strip().split("|") + wav_fn = f"{self.raw_data_dir}/wavs/{item_name}.wav" + yield item_name, wav_fn, txt, 'SPK1' + + +if __name__ == "__main__": + LJPreAlign().process() diff --git a/egs/datasets/audio/lj/pwg.yaml b/egs/datasets/audio/lj/pwg.yaml new file mode 100755 index 0000000000000000000000000000000000000000..e0c6dc6da4367bad9bafd1a7ba492a5cfb18a347 --- /dev/null +++ b/egs/datasets/audio/lj/pwg.yaml @@ -0,0 +1,3 @@ +base_config: + - egs/egs_bases/tts/vocoder/pwg.yaml + - ./base_mel2wav.yaml \ No newline at end of file diff --git a/egs/datasets/audio/vctk/base_mel2wav.yaml b/egs/datasets/audio/vctk/base_mel2wav.yaml new file mode 100755 index 0000000000000000000000000000000000000000..b5210a1361259d30449430f70849061d17a8e59f --- /dev/null +++ b/egs/datasets/audio/vctk/base_mel2wav.yaml @@ -0,0 +1,3 @@ +raw_data_dir: 'data/raw/VCTK-Corpus' +processed_data_dir: 'data/processed/vctk' +binary_data_dir: 'data/binary/vctk_wav' diff --git a/egs/datasets/audio/vctk/fs2.yaml b/egs/datasets/audio/vctk/fs2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..49bb983bd98caca74af0ea70414e71aa34334b38 --- /dev/null +++ b/egs/datasets/audio/vctk/fs2.yaml @@ -0,0 +1,12 @@ +base_config: + - egs/egs_bases/tts/fs2.yaml +raw_data_dir: 'data/raw/VCTK-Corpus' +processed_data_dir: 'data/processed/vctk' +binary_data_dir: 'data/binary/vctk' +pre_align_cls: egs.datasets.audio.vctk.pre_align.VCTKPreAlign +use_spk_id: true +test_num: 200 +num_spk: 400 +binarization_args: + shuffle: true + trim_eos_bos: true \ No newline at end of file diff --git a/egs/datasets/audio/vctk/pre_align.py b/egs/datasets/audio/vctk/pre_align.py new file mode 100755 index 0000000000000000000000000000000000000000..a03b3e12af245fa603403432f4487c53e8b13eab --- /dev/null +++ b/egs/datasets/audio/vctk/pre_align.py @@ -0,0 +1,22 @@ +import os + +from data_gen.tts.base_pre_align import BasePreAlign +import glob + + +class VCTKPreAlign(BasePreAlign): + def meta_data(self): + wav_fns = glob.glob(f'{self.raw_data_dir}/wav48/*/*.wav') + for wav_fn in wav_fns: + item_name = os.path.basename(wav_fn)[:-4] + spk = item_name.split("_")[0] + txt_fn = wav_fn.split("/") + txt_fn[-1] = f'{item_name}.txt' + txt_fn[-3] = f'txt' + txt_fn = "/".join(txt_fn) + if os.path.exists(txt_fn) and os.path.exists(wav_fn): + yield item_name, wav_fn, (self.load_txt, txt_fn), spk + + +if __name__ == "__main__": + VCTKPreAlign().process() diff --git a/egs/datasets/audio/vctk/pwg.yaml b/egs/datasets/audio/vctk/pwg.yaml new file mode 100755 index 0000000000000000000000000000000000000000..8c7c557fee02b0a62881f36a2abe010ede0d5f39 --- /dev/null +++ b/egs/datasets/audio/vctk/pwg.yaml @@ -0,0 +1,6 @@ +base_config: + - egs/egs_bases/tts/vocoder/pwg.yaml + - ./base_mel2wav.yaml + +num_spk: 400 +max_samples: 20480 diff --git a/egs/egs_bases/config_base.yaml b/egs/egs_bases/config_base.yaml new file mode 100755 index 0000000000000000000000000000000000000000..39240ceb8cd52e4c0a146aa579f1ea046c727c84 --- /dev/null +++ b/egs/egs_bases/config_base.yaml @@ -0,0 +1,46 @@ +# task +binary_data_dir: '' +work_dir: '' # experiment directory. +infer: false # inference +amp: false +seed: 1234 +debug: false +save_codes: [] +# - configs +# - modules +# - tasks +# - utils +# - usr + +############# +# dataset +############# +ds_workers: 1 +test_num: 100 +endless_ds: false +sort_by_len: true + +######### +# train and eval +######### +print_nan_grads: false +load_ckpt: '' +save_best: true +num_ckpt_keep: 3 +clip_grad_norm: 0 +accumulate_grad_batches: 1 +tb_log_interval: 100 +num_sanity_val_steps: 5 # steps of validation at the beginning +check_val_every_n_epoch: 10 +val_check_interval: 2000 +valid_monitor_key: 'val_loss' +valid_monitor_mode: 'min' +max_epochs: 1000 +max_updates: 1000000 +max_tokens: 31250 +max_sentences: 100000 +max_valid_tokens: -1 +max_valid_sentences: -1 +test_input_dir: '' +resume_from_checkpoint: 0 +rename_tmux: true \ No newline at end of file diff --git a/egs/egs_bases/tts/base.yaml b/egs/egs_bases/tts/base.yaml new file mode 100755 index 0000000000000000000000000000000000000000..255e4cce1bb52d58d24443bc277621d70732583f --- /dev/null +++ b/egs/egs_bases/tts/base.yaml @@ -0,0 +1,112 @@ +# task +base_config: ../config_base.yaml +task_cls: '' +############# +# dataset +############# +raw_data_dir: '' +processed_data_dir: '' +binary_data_dir: '' +dict_dir: '' +pre_align_cls: '' +binarizer_cls: data_gen.tts.base_binarizer.BaseBinarizer +pre_align_args: + txt_processor: en + use_tone: true # for ZH + sox_resample: false + sox_to_wav: false + allow_no_txt: false + trim_sil: false + denoise: false +binarization_args: + shuffle: false + with_txt: true + with_wav: false + with_align: true + with_spk_embed: false + with_spk_id: true + with_f0: true + with_f0cwt: false + with_linear: false + with_word: true + trim_sil: false + trim_eos_bos: false + reset_phone_dict: true + reset_word_dict: true +word_size: 30000 +pitch_extractor: parselmouth + +loud_norm: false +endless_ds: true + +test_num: 100 +min_frames: 0 +max_frames: 1548 +frames_multiple: 1 +max_input_tokens: 1550 +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) +win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) +fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) +fmax: 7600 # To be increased/reduced depending on data. +fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter +min_level_db: -100 +ref_level_db: 20 +griffin_lim_iters: 60 +num_spk: 1 +mel_vmin: -6 +mel_vmax: 1.5 +ds_workers: 1 + +######### +# model +######### +dropout: 0.1 +enc_layers: 4 +dec_layers: 4 +hidden_size: 256 +num_heads: 2 +enc_ffn_kernel_size: 9 +dec_ffn_kernel_size: 9 +ffn_act: gelu +ffn_padding: 'SAME' +use_spk_id: false +use_split_spk_id: false +use_spk_embed: false + + +########### +# optimization +########### +lr: 2.0 +scheduler: rsqrt # rsqrt|none +warmup_updates: 8000 +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +weight_decay: 0 +clip_grad_norm: 1 +clip_grad_value: 0 + + +########### +# train and eval +########### +max_tokens: 30000 +max_sentences: 100000 +max_valid_sentences: 1 +max_valid_tokens: 60000 +valid_infer_interval: 10000 +train_set_name: 'train' +train_sets: '' +valid_set_name: 'valid' +test_set_name: 'test' +num_test_samples: 0 +num_valid_plots: 10 +test_ids: [ ] +vocoder_denoise_c: 0.0 +profile_infer: false +out_wav_norm: false +save_gt: true +save_f0: false +gen_dir_name: '' \ No newline at end of file diff --git a/egs/egs_bases/tts/fs2.yaml b/egs/egs_bases/tts/fs2.yaml new file mode 100755 index 0000000000000000000000000000000000000000..c200a50ad2e04ff685bd30c7d5c3be69ff7cdf3f --- /dev/null +++ b/egs/egs_bases/tts/fs2.yaml @@ -0,0 +1,102 @@ +base_config: ./base.yaml +task_cls: tasks.tts.fs2.FastSpeech2Task + +# model +hidden_size: 256 +dropout: 0.1 +encoder_type: fft # rel_fft|fft|tacotron|tacotron2|conformer +decoder_type: fft # fft|rnn|conv|conformer|wn + +# rnn enc/dec +encoder_K: 8 +decoder_rnn_dim: 0 # for rnn decoder, 0 -> hidden_size * 2 + +# fft enc/dec +use_pos_embed: true +dec_num_heads: 2 +dec_layers: 4 +ffn_hidden_size: 1024 +enc_ffn_kernel_size: 9 +dec_ffn_kernel_size: 9 + +# conv enc/dec +enc_dec_norm: ln +conv_use_pos: false +layers_in_block: 2 +enc_dilations: [ 1, 1, 1, 1 ] +enc_kernel_size: 5 +dec_dilations: [ 1, 1, 1, 1 ] # for conv decoder +dec_kernel_size: 5 +dur_loss: mse # huber|mol + +# duration +predictor_hidden: -1 +predictor_kernel: 5 +predictor_layers: 2 +dur_predictor_kernel: 3 +dur_predictor_layers: 2 +predictor_dropout: 0.5 + +# pitch and energy +pitch_norm: standard # standard|log +use_pitch_embed: true +pitch_type: frame # frame|ph|cwt +use_uv: true +cwt_hidden_size: 128 +cwt_layers: 2 +cwt_loss: l1 +cwt_add_f0_loss: false +cwt_std_scale: 0.8 + +pitch_ar: false +pitch_embed_type: 0 +pitch_loss: 'l1' # l1|l2|ssim +pitch_ssim_win: 11 +use_energy_embed: false + +# reference encoder and speaker embedding +use_ref_enc: false +use_var_enc: false +lambda_commit: 0.25 +var_enc_vq_codes: 64 +ref_norm_layer: bn +dec_inp_add_noise: false +sil_add_noise: false +ref_hidden_stride_kernel: + - 0,3,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size + - 0,3,5 + - 0,2,5 + - 0,2,5 + - 0,2,5 +pitch_enc_hidden_stride_kernel: + - 0,2,5 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size + - 0,2,5 + - 0,2,5 +dur_enc_hidden_stride_kernel: + - 0,2,3 # conv_hidden_size, conv_stride, conv_kernel_size. conv_hidden_size=0: use hidden_size + - 0,2,3 + - 0,1,3 + +# mel +mel_loss: l1:0.5|ssim:0.5 # l1|l2|gdl|ssim or l1:0.5|ssim:0.5 + +# loss lambda +lambda_f0: 1.0 +lambda_uv: 1.0 +lambda_energy: 0.1 +lambda_ph_dur: 0.1 +lambda_sent_dur: 1.0 +lambda_word_dur: 1.0 +predictor_grad: 0.1 + +# train and eval +pretrain_fs_ckpt: '' +warmup_updates: 2000 +max_tokens: 32000 +max_sentences: 100000 +max_valid_sentences: 1 +max_updates: 120000 +use_gt_dur: false +use_gt_f0: false +ds_workers: 2 +lr: 1.0 diff --git a/egs/egs_bases/tts/vocoder/base.yaml b/egs/egs_bases/tts/vocoder/base.yaml new file mode 100755 index 0000000000000000000000000000000000000000..92a1a74e269f5245af09a1bc00a5998f069b6801 --- /dev/null +++ b/egs/egs_bases/tts/vocoder/base.yaml @@ -0,0 +1,34 @@ +base_config: ../base.yaml +binarization_args: + with_wav: true + with_spk_embed: false + with_align: false + with_word: false + with_txt: false + +########### +# train and eval +########### +max_samples: 25600 +max_sentences: 5 +max_valid_sentences: 1 +max_updates: 1000000 +val_check_interval: 2000 + +########################################################### +# FEATURE EXTRACTION SETTING # +########################################################### +fft_size: 1024 # FFT size. +hop_size: 256 # Hop size. +win_length: null # Window length. +# If set to null, it will be the same as fft_size. +window: "hann" # Window function. +num_mels: 80 # Number of mel basis. +fmin: 80 # Minimum freq in mel basis calculation. +fmax: 7600 # Maximum frequency in mel basis calculation. +aux_context_window: 0 # Context window size for auxiliary feature. +use_pitch_embed: false + +generator_grad_norm: 10 # Generator's gradient norm. +discriminator_grad_norm: 1 # Discriminator's gradient norm. +disc_start_steps: 40000 # Number of steps to start to train discriminator. diff --git a/egs/egs_bases/tts/vocoder/pwg.yaml b/egs/egs_bases/tts/vocoder/pwg.yaml new file mode 100755 index 0000000000000000000000000000000000000000..2d95bbd92abcdf70fb7a38b15877509390696fd2 --- /dev/null +++ b/egs/egs_bases/tts/vocoder/pwg.yaml @@ -0,0 +1,82 @@ +base_config: ./base.yaml +task_cls: tasks.vocoder.pwg.PwgTask + +aux_context_window: 2 # Context window size for auxiliary feature. +use_pitch_embed: false +########################################################### +# GENERATOR NETWORK ARCHITECTURE SETTING # +########################################################### +generator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Kernel size of dilated convolution. + layers: 30 # Number of residual block layers. + stacks: 3 # Number of stacks i.e., dilation cycles. + residual_channels: 64 # Number of channels in residual conv. + gate_channels: 128 # Number of channels in gated conv. + skip_channels: 64 # Number of channels in skip conv. + aux_channels: 80 # Number of channels for auxiliary feature conv. + # Must be the same as num_mels. + # If set to 2, previous 2 and future 2 frames will be considered. + dropout: 0.0 # Dropout rate. 0.0 means no dropout applied. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + upsample_net: "ConvInUpsampleNetwork" # Upsampling network architecture. + upsample_params: # Upsampling network parameters. + upsample_scales: [4, 4, 4, 4] # Upsampling scales. Prodcut of these must be the same as hop size. + use_pitch_embed: false + use_nsf: false +########################################################### +# DISCRIMINATOR NETWORK ARCHITECTURE SETTING # +########################################################### +discriminator_params: + in_channels: 1 # Number of input channels. + out_channels: 1 # Number of output channels. + kernel_size: 3 # Number of output channels. + layers: 10 # Number of conv layers. + conv_channels: 64 # Number of chnn layers. + bias: true # Whether to use bias parameter in conv. + use_weight_norm: true # Whether to use weight norm. + # If set to true, it will be applied to all of the conv layers. + nonlinear_activation: "LeakyReLU" # Nonlinear function after each conv. + nonlinear_activation_params: # Nonlinear function parameters + negative_slope: 0.2 # Alpha in LeakyReLU. +rerun_gen: true + +########################################################### +# STFT LOSS SETTING # +########################################################### +stft_loss_params: + fft_sizes: [1024, 2048, 512] # List of FFT size for STFT-based loss. + hop_sizes: [120, 240, 50] # List of hop size for STFT-based loss + win_lengths: [600, 1200, 240] # List of window length for STFT-based loss. + window: "hann_window" # Window function for STFT-based loss +use_mel_loss: false + +########################################################### +# ADVERSARIAL LOSS SETTING # +########################################################### +lambda_adv: 4.0 # Loss balancing coefficient. + +########################################################### +# OPTIMIZER & SCHEDULER SETTING # +########################################################### +generator_optimizer_params: + lr: 0.0001 # Generator's learning rate. + eps: 1.0e-6 # Generator's epsilon. + weight_decay: 0.0 # Generator's weight decay coefficient. +generator_scheduler_params: + step_size: 200000 # Generator's scheduler step size. + gamma: 0.5 # Generator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +generator_grad_norm: 10 # Generator's gradient norm. +discriminator_optimizer_params: + lr: 0.00005 # Discriminator's learning rate. + eps: 1.0e-6 # Discriminator's epsilon. + weight_decay: 0.0 # Discriminator's weight decay coefficient. +discriminator_scheduler_params: + step_size: 200000 # Discriminator's scheduler step size. + gamma: 0.5 # Discriminator's scheduler gamma. + # At each step size, lr will be multiplied by this parameter. +discriminator_grad_norm: 1 # Discriminator's gradient norm. +disc_start_steps: 40000 # Number of steps to start to train discriminator. diff --git a/inference/ProDiff.py b/inference/ProDiff.py new file mode 100644 index 0000000000000000000000000000000000000000..945497f353e50c7b9314476a615f4b1b94a2a940 --- /dev/null +++ b/inference/ProDiff.py @@ -0,0 +1,49 @@ +import torch +from inference.base_tts_infer import BaseTTSInfer +from utils.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.hparams import hparams +from modules.ProDiff.model.ProDiff import GaussianDiffusion +from usr.diff.net import DiffNet +import os +import numpy as np +from functools import partial + +class ProDiffInfer(BaseTTSInfer): + def build_model(self): + f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + model = GaussianDiffusion( + phone_encoder=self.ph_encoder, + out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']), + timesteps=hparams['timesteps'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + checkpoint = torch.load(hparams['teacher_ckpt'], map_location='cpu')["state_dict"]['model'] + teacher_timesteps = int(checkpoint['timesteps'].item()) + teacher_timescales = int(checkpoint['timescale'].item()) + student_timesteps = teacher_timesteps // 2 + student_timescales = teacher_timescales * 2 + to_torch = partial(torch.tensor, dtype=torch.float32) + model.register_buffer('timesteps', to_torch(student_timesteps)) # beta + model.register_buffer('timescale', to_torch(student_timescales)) # beta + model.eval() + load_ckpt(model, hparams['work_dir'], 'model') + return model + + def forward_model(self, inp): + sample = self.input_to_batch(inp) + txt_tokens = sample['txt_tokens'] # [B, T_t] + with torch.no_grad(): + output = self.model(txt_tokens, infer=True) + mel_out = output['mel_out'] + wav_out = self.run_vocoder(mel_out) + wav_out = wav_out.squeeze().cpu().numpy() + return wav_out + + +if __name__ == '__main__': + ProDiffInfer.example_run() diff --git a/inference/ProDiff_Teacher.py b/inference/ProDiff_Teacher.py new file mode 100644 index 0000000000000000000000000000000000000000..2e10278e1864d0709c8667a7dd2106dfb4b1cd38 --- /dev/null +++ b/inference/ProDiff_Teacher.py @@ -0,0 +1,41 @@ +import torch +from inference.base_tts_infer import BaseTTSInfer +from utils.ckpt_utils import load_ckpt, get_last_checkpoint +from utils.hparams import hparams +from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion +from usr.diff.net import DiffNet +import os +import numpy as np + +class ProDiffTeacherInfer(BaseTTSInfer): + def build_model(self): + f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + model = GaussianDiffusion( + phone_encoder=self.ph_encoder, + out_dims=80, denoise_fn=DiffNet(hparams['audio_num_mel_bins']), + timesteps=hparams['timesteps'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + + model.eval() + load_ckpt(model, hparams['work_dir'], 'model') + return model + + def forward_model(self, inp): + sample = self.input_to_batch(inp) + txt_tokens = sample['txt_tokens'] # [B, T_t] + with torch.no_grad(): + output = self.model(txt_tokens, infer=True) + mel_out = output['mel_out'] + wav_out = self.run_vocoder(mel_out) + wav_out = wav_out.squeeze().cpu().numpy() + return wav_out + + +if __name__ == '__main__': + ProDiffTeacherInfer.example_run() diff --git a/inference/base_tts_infer.py b/inference/base_tts_infer.py new file mode 100644 index 0000000000000000000000000000000000000000..f34f207ace872dc6f075cf645a5692c536c640b6 --- /dev/null +++ b/inference/base_tts_infer.py @@ -0,0 +1,167 @@ +import os + +import torch + +from tasks.tts.dataset_utils import FastSpeechWordDataset +from tasks.tts.tts_utils import load_data_preprocessor +import numpy as np +from modules.FastDiff.module.util import compute_hyperparams_given_schedule, sampling_given_noise_schedule + +import os + +import torch + +from modules.FastDiff.module.FastDiff_model import FastDiff +from utils.ckpt_utils import load_ckpt +from utils.hparams import set_hparams + + +class BaseTTSInfer: + def __init__(self, hparams, device=None): + if device is None: + device = 'cuda' if torch.cuda.is_available() else 'cpu' + self.hparams = hparams + self.device = device + self.data_dir = hparams['binary_data_dir'] + self.preprocessor, self.preprocess_args = load_data_preprocessor() + self.ph_encoder = self.preprocessor.load_dict(self.data_dir) + self.spk_map = self.preprocessor.load_spk_map(self.data_dir) + self.ds_cls = FastSpeechWordDataset + self.model = self.build_model() + self.model.eval() + self.model.to(self.device) + self.vocoder, self.diffusion_hyperparams, self.noise_schedule = self.build_vocoder() + self.vocoder.eval() + self.vocoder.to(self.device) + + def build_model(self): + raise NotImplementedError + + def forward_model(self, inp): + raise NotImplementedError + + def build_vocoder(self): + base_dir = self.hparams['vocoder_ckpt'] + config_path = f'{base_dir}/config.yaml' + config = set_hparams(config_path, global_hparams=False) + vocoder = FastDiff(audio_channels=config['audio_channels'], + inner_channels=config['inner_channels'], + cond_channels=config['cond_channels'], + upsample_ratios=config['upsample_ratios'], + lvc_layers_each_block=config['lvc_layers_each_block'], + lvc_kernel_size=config['lvc_kernel_size'], + kpnet_hidden_channels=config['kpnet_hidden_channels'], + kpnet_conv_size=config['kpnet_conv_size'], + dropout=config['dropout'], + diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'], + diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'], + diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'], + use_weight_norm=config['use_weight_norm']) + load_ckpt(vocoder, base_dir, 'model') + + # Init hyperparameters by linear schedule + noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"])) + diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) + + if config['noise_schedule'] != '': + noise_schedule = config['noise_schedule'] + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule) + else: + # Select Schedule + try: + reverse_step = int(self.hparams.get('N')) + except: + print( + 'Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') + reverse_step = 4 + if reverse_step == 1000: + noise_schedule = torch.linspace(0.000001, 0.01, 1000) + elif reverse_step == 200: + noise_schedule = torch.linspace(0.0001, 0.02, 200) + + # Below are schedules derived by Noise Predictor. + # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned! + elif reverse_step == 8: + noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, + 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, + 0.5] + elif reverse_step == 6: + noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, + 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] + elif reverse_step == 4: + noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] + elif reverse_step == 3: + noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] + else: + raise NotImplementedError + + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule) + + return vocoder, diffusion_hyperparams, noise_schedule + + def run_vocoder(self, c): + c = c.transpose(2, 1) + audio_length = c.shape[-1] * self.hparams["hop_size"] + y = sampling_given_noise_schedule( + self.vocoder, (1, 1, audio_length), self.diffusion_hyperparams, self.noise_schedule, condition=c, ddim=False, return_sequence=False) + return y + + def preprocess_input(self, inp): + """ + :param inp: {'text': str, 'item_name': (str, optional), 'spk_name': (str, optional)} + :return: + """ + preprocessor, preprocess_args = self.preprocessor, self.preprocess_args + text_raw = inp['text'] + item_name = inp.get('item_name', '') + spk_name = inp.get('spk_name', 'SPK1') + ph, txt = preprocessor.txt_to_ph( + preprocessor.txt_processor, text_raw, preprocess_args) + ph_token = self.ph_encoder.encode(ph) + spk_id = self.spk_map[spk_name] + item = {'item_name': item_name, 'text': txt, 'ph': ph, 'spk_id': spk_id, 'ph_token': ph_token} + item['ph_len'] = len(item['ph_token']) + return item + + def input_to_batch(self, item): + item_names = [item['item_name']] + text = [item['text']] + ph = [item['ph']] + txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device) + txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device) + spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device) + batch = { + 'item_name': item_names, + 'text': text, + 'ph': ph, + 'txt_tokens': txt_tokens, + 'txt_lengths': txt_lengths, + 'spk_ids': spk_ids, + } + return batch + + def postprocess_output(self, output): + return output + + def infer_once(self, inp): + inp = self.preprocess_input(inp) + output = self.forward_model(inp) + output = self.postprocess_output(output) + return output + + @classmethod + def example_run(cls): + from utils.hparams import set_hparams + from utils.hparams import hparams as hp + from utils.audio import save_wav + + set_hparams() + inp = { + 'text': hp['text'] + } + infer_ins = cls(hp) + out = infer_ins.infer_once(inp) + os.makedirs('infer_out', exist_ok=True) + save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate']) diff --git a/inference/gradio/gradio_settings.yaml b/inference/gradio/gradio_settings.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6cefb0cd6edbd36e5c544532adf5de7048e5f047 --- /dev/null +++ b/inference/gradio/gradio_settings.yaml @@ -0,0 +1,41 @@ +title: 'Extremely-Fast diffusion text-to-speech synthesis pipeline with ProDiff and FastDiff' +description: | + Gradio demo for **2-iter** ProDiff and **4-iter** FastDiff. To use it, simply add your audio, or click one of the examples to load them. **This space is running on CPU, inference will be slower.** + + ## Key Features + - **Extremely-Fast** diffusion text-to-speech synthesis pipeline for potential **industrial deployment**. + - **Tutorial and code base** for speech diffusion models. + - More **supported diffusion mechanism** (e.g., guided diffusion) will be available. + + +article: | + ## Reference + Link to ProDiff Github REPO + + If you find this code useful in your research, please cite our work: + ``` + @inproceedings{huang2022prodiff, + title={ProDiff: Progressive Fast Diffusion Model For High-Quality Text-to-Speech}, + author={Huang, Rongjie and Zhao, Zhou and Liu, Huadai and Liu, Jinglin and Cui, Chenye and Ren, Yi}, + booktitle={Proceedings of the 30th ACM International Conference on Multimedia}, + year={2022} + + @inproceedings{huang2022fastdiff, + title={FastDiff: A Fast Conditional Diffusion Model for High-Quality Speech Synthesis}, + author={Huang, Rongjie and Lam, Max WY and Wang, Jun and Su, Dan and Yu, Dong and Ren, Yi and Zhao, Zhou}, + booktitle = {Proceedings of the Thirty-First International Joint Conference on Artificial Intelligence, {IJCAI-22}}, + year={2022} + } + ``` + + ## Disclaimer + Any organization or individual is prohibited from using any technology mentioned in this paper to generate someone's speech without his/her consent, including but not limited to government leaders, political figures, and celebrities. If you do not comply with this item, you could be in violation of copyright laws. + +example_inputs: + - |- + the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing. + - |- + Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition. +inference_cls: inference.ProDiff.ProDiffInfer +exp_name: ProDiff +config: modules/ProDiff/config/prodiff.yaml \ No newline at end of file diff --git a/inference/gradio/infer.py b/inference/gradio/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..27acc399c78e024672013cd03048448c22e59df4 --- /dev/null +++ b/inference/gradio/infer.py @@ -0,0 +1,69 @@ +import importlib +import re + +import gradio as gr +import yaml +from gradio.inputs import Textbox + +from inference.base_tts_infer import BaseTTSInfer +from utils.hparams import set_hparams +from utils.hparams import hparams as hp +import numpy as np + +from data_gen.tts.data_gen_utils import is_sil_phoneme, PUNCS + +class GradioInfer: + def __init__(self, exp_name, config, inference_cls, title, description, article, example_inputs): + self.exp_name = exp_name + self.config = config + self.title = title + self.description = description + self.article = article + self.example_inputs = example_inputs + pkg = ".".join(inference_cls.split(".")[:-1]) + cls_name = inference_cls.split(".")[-1] + self.inference_cls = getattr(importlib.import_module(pkg), cls_name) + + def greet(self, text): + sents = re.split(rf'([{PUNCS}])', text.replace('\n', ',')) + if sents[-1] not in list(PUNCS): + sents = sents + ['.'] + audio_outs = [] + s = "" + for i in range(0, len(sents), 2): + if len(sents[i]) > 0: + s += sents[i] + sents[i + 1] + if len(s) >= 400 or (i >= len(sents) - 2 and len(s) > 0): + audio_out = self.infer_ins.infer_once({ + 'text': s + }) + audio_out = audio_out * 32767 + audio_out = audio_out.astype(np.int16) + audio_outs.append(audio_out) + audio_outs.append(np.zeros(int(hp['audio_sample_rate'] * 0.3)).astype(np.int16)) + s = "" + audio_outs = np.concatenate(audio_outs) + return hp['audio_sample_rate'], audio_outs + + def run(self): + set_hparams(exp_name=self.exp_name, config=self.config) + infer_cls = self.inference_cls + self.infer_ins: BaseTTSInfer = infer_cls(hp) + example_inputs = self.example_inputs + iface = gr.Interface(fn=self.greet, + inputs=Textbox( + lines=10, placeholder=None, default=example_inputs[0], label="input text"), + outputs="audio", + allow_flagging="never", + title=self.title, + description=self.description, + article=self.article, + examples=example_inputs, + enable_queue=True) + iface.launch(share=True,cache_examples=True) + + +if __name__ == '__main__': + gradio_config = yaml.safe_load(open('inference/gradio/gradio_settings.yaml')) + g = GradioInfer(**gradio_config) + g.run() diff --git a/modules/FastDiff/config/FastDiff.yaml b/modules/FastDiff/config/FastDiff.yaml new file mode 100644 index 0000000000000000000000000000000000000000..af6f1d1511deea17dd0d5fc08bf0270682a2623e --- /dev/null +++ b/modules/FastDiff/config/FastDiff.yaml @@ -0,0 +1,7 @@ +base_config: + - ./base.yaml + +audio_sample_rate: 22050 +raw_data_dir: 'data/raw/LJSpeech-1.1' +processed_data_dir: 'data/processed/LJSpeech' +binary_data_dir: 'data/binary/LJSpeech' diff --git a/modules/FastDiff/config/FastDiff_libritts.yaml b/modules/FastDiff/config/FastDiff_libritts.yaml new file mode 100644 index 0000000000000000000000000000000000000000..372f3e07e7ecbe546ee91a6ac32f5bd370de81d9 --- /dev/null +++ b/modules/FastDiff/config/FastDiff_libritts.yaml @@ -0,0 +1,7 @@ +base_config: + - ./base.yaml + +audio_sample_rate: 22050 +raw_data_dir: 'data/raw/LibriTTS' +processed_data_dir: 'data/processed/LibriTTS' +binary_data_dir: 'data/binary/LibriTTS' \ No newline at end of file diff --git a/modules/FastDiff/config/FastDiff_sc09.yaml b/modules/FastDiff/config/FastDiff_sc09.yaml new file mode 100644 index 0000000000000000000000000000000000000000..23322fd8b6d50595ff1e2ae4b3d080ed707ade74 --- /dev/null +++ b/modules/FastDiff/config/FastDiff_sc09.yaml @@ -0,0 +1,25 @@ +base_config: + - egs/egs_bases/tts/vocoder/base.yaml + - egs/datasets/audio/lj/base_mel2wav.yaml + - ./base.yaml + +#raw_data_dir: '/home1/huangrongjie/dataset/sc09/data/' +#processed_data_dir: 'data/processed/SC09' +#binary_data_dir: 'data/binary/SC09' + +raw_data_dir: '/home1/huangrongjie/Project/AdaGrad/data/raw/SC09/' +processed_data_dir: 'data/processed/SC09_ten_processed' +binary_data_dir: 'data/binary/SC09_ten_processed' + +pre_align_cls: egs.datasets.audio.sc09.pre_align.Sc09PreAlign +audio_sample_rate: 16000 +max_samples: 12800 + +pre_align_args: + sox_resample: false + sox_to_wav: false + allow_no_txt: true + trim_sil: true + denoise: true + +loud_norm: true \ No newline at end of file diff --git a/modules/FastDiff/config/FastDiff_tacotron.yaml b/modules/FastDiff/config/FastDiff_tacotron.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c180b02789ac6a38b22b88496f000dcc259b330 --- /dev/null +++ b/modules/FastDiff/config/FastDiff_tacotron.yaml @@ -0,0 +1,58 @@ +base_config: + - egs/egs_bases/tts/vocoder/pwg.yaml + - egs/egs_bases/tts/base_mel2wav.yaml + - egs/datasets/audio/lj/pwg.yaml + +raw_data_dir: 'data/raw/LJSpeech-1.1' +processed_data_dir: 'data/processed/LJSpeech_FastDiff' +#binary_data_dir: 'data/binary/LJSpeech_Taco' +binary_data_dir: /apdcephfs/private_nlphuang/preprocess/AdaGrad/data/binary/LJSpeech_Taco + +binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer +pre_align_cls: egs.datasets.audio.lj.pre_align.LJPreAlign +task_cls: modules.FastDiff.task.FastDiff.FastDiffTask +binarization_args: + with_wav: true + with_spk_embed: false + with_align: false + with_word: false + with_txt: false + with_f0: false + +# data +num_spk: 400 +max_samples: 25600 +aux_context_window: 0 +max_sentences: 20 +test_input_dir: '' # 'wavs' # wav->wav inference +test_mel_dir: '' # 'mels' # mel->wav inference +use_wav: True # mel->wav inference + +# training +num_sanity_val_steps: -1 +max_updates: 1000000 +lr: 2e-4 +weight_decay: 0 + +# FastDiff +audio_channels: 1 +inner_channels: 32 +cond_channels: 80 +upsample_ratios: [8, 8, 4] +lvc_layers_each_block: 4 +lvc_kernel_size: 3 +kpnet_hidden_channels: 64 +kpnet_conv_size: 3 +dropout: 0.0 +diffusion_step_embed_dim_in: 128 +diffusion_step_embed_dim_mid: 512 +diffusion_step_embed_dim_out: 512 +use_weight_norm: True + +# Diffusion +T: 1000 +beta_0: 0.000001 +beta_T: 0.01 +noise_schedule: '' +N: '' + diff --git a/modules/FastDiff/config/FastDiff_vctk.yaml b/modules/FastDiff/config/FastDiff_vctk.yaml new file mode 100644 index 0000000000000000000000000000000000000000..54bb0000db9351348d7923c921e1040bfd48890e --- /dev/null +++ b/modules/FastDiff/config/FastDiff_vctk.yaml @@ -0,0 +1,7 @@ +base_config: + - ./base.yaml + +audio_sample_rate: 22050 +raw_data_dir: 'data/raw/VCTK' +processed_data_dir: 'data/processed/VCTK' +binary_data_dir: 'data/binary/VCTK' \ No newline at end of file diff --git a/modules/FastDiff/config/base.yaml b/modules/FastDiff/config/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7d8ace1c05811117c1798d1ef49f307d0d7929e5 --- /dev/null +++ b/modules/FastDiff/config/base.yaml @@ -0,0 +1,157 @@ +############# +# Custom dataset preprocess +############# +audio_num_mel_bins: 80 +audio_sample_rate: 22050 +hop_size: 256 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) +win_size: 1024 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) +fmin: 80 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) +fmax: 7600 # To be increased/reduced depending on data. +fft_size: 1024 # Extra window size is filled with 0 paddings to match this parameter +min_level_db: -100 +ref_level_db: 20 +griffin_lim_iters: 60 +num_spk: 1 # number of speakers +mel_vmin: -6 +mel_vmax: 1.5 + +############# +# FastDiff Model +############# +audio_channels: 1 +inner_channels: 32 +cond_channels: 80 +upsample_ratios: [8, 8, 4] +lvc_layers_each_block: 4 +lvc_kernel_size: 3 +kpnet_hidden_channels: 64 +kpnet_conv_size: 3 +dropout: 0.0 +diffusion_step_embed_dim_in: 128 +diffusion_step_embed_dim_mid: 512 +diffusion_step_embed_dim_out: 512 +use_weight_norm: True + +########### +# Diffusion +########### +T: 1000 +beta_0: 0.000001 +beta_T: 0.01 +noise_schedule: '' +N: '' + + +########### +# train and eval +########### +task_cls: modules.FastDiff.task.FastDiff.FastDiffTask +max_updates: 1000000 # max training steps +max_samples: 25600 # audio length in training +max_sentences: 20 # max batch size in training +num_sanity_val_steps: -1 +max_valid_sentences: 1 +valid_infer_interval: 10000 +val_check_interval: 2000 +num_test_samples: 0 +num_valid_plots: 10 + + +############# +# Stage 1 of data processing +############# +pre_align_cls: egs.datasets.audio.pre_align.PreAlign +pre_align_args: + nsample_per_mfa_group: 1000 + txt_processor: en + use_tone: true # for ZH + sox_resample: false + sox_to_wav: false + allow_no_txt: true + trim_sil: false + denoise: false + + +############# +# Stage 2 of data processing +############# +binarizer_cls: data_gen.tts.vocoder_binarizer.VocoderBinarizer +binarization_args: + with_wav: true + with_spk_embed: false + with_align: false + with_word: false + with_txt: false + with_f0: false + shuffle: false + with_spk_id: true + with_f0cwt: false + with_linear: false + trim_eos_bos: false + reset_phone_dict: true + reset_word_dict: true + + +########### +# optimization +########### +lr: 2e-4 # learning rate +weight_decay: 0 +scheduler: rsqrt # rsqrt|none +optimizer_adam_beta1: 0.9 +optimizer_adam_beta2: 0.98 +clip_grad_norm: 1 +clip_grad_value: 0 + +############# +# Setting for this Pytorch framework +############# +max_input_tokens: 1550 +frames_multiple: 1 +use_word_input: false +vocoder: FastDiff +vocoder_ckpt: checkpoints/FastDiff +vocoder_denoise_c: 0.0 +max_tokens: 30000 +max_valid_tokens: 60000 +test_ids: [ ] +profile_infer: false +out_wav_norm: false +save_gt: true +save_f0: false +aux_context_window: 0 +test_input_dir: '' # 'wavs' # wav->wav inference +test_mel_dir: '' # 'mels' # mel->wav inference +use_wav: True # mel->wav inference +pitch_extractor: parselmouth +loud_norm: false +endless_ds: true +test_num: 100 +min_frames: 0 +max_frames: 1548 +ds_workers: 1 +gen_dir_name: '' +accumulate_grad_batches: 1 +tb_log_interval: 100 +print_nan_grads: false +work_dir: '' # experiment directory. +infer: false # inference +amp: false +debug: false +save_codes: [] +save_best: true +num_ckpt_keep: 3 +sort_by_len: true +load_ckpt: '' +check_val_every_n_epoch: 10 +max_epochs: 1000 +eval_max_batches: -1 +resume_from_checkpoint: 0 +rename_tmux: true +valid_monitor_key: 'val_loss' +valid_monitor_mode: 'min' +train_set_name: 'train' +train_sets: '' +valid_set_name: 'valid' +test_set_name: 'test' +seed: 1234 \ No newline at end of file diff --git a/modules/FastDiff/module/FastDiff_model.py b/modules/FastDiff/module/FastDiff_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e52a488bf53432eadc79f3127a64e46bda4d4532 --- /dev/null +++ b/modules/FastDiff/module/FastDiff_model.py @@ -0,0 +1,123 @@ +import torch.nn as nn +import torch +import logging +from modules.FastDiff.module.modules import DiffusionDBlock, TimeAware_LVCBlock +from modules.FastDiff.module.util import calc_diffusion_step_embedding + +def swish(x): + return x * torch.sigmoid(x) + +class FastDiff(nn.Module): + """FastDiff module.""" + + def __init__(self, + audio_channels=1, + inner_channels=32, + cond_channels=80, + upsample_ratios=[8, 8, 4], + lvc_layers_each_block=4, + lvc_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + dropout=0.0, + diffusion_step_embed_dim_in=128, + diffusion_step_embed_dim_mid=512, + diffusion_step_embed_dim_out=512, + use_weight_norm=True): + super().__init__() + + self.diffusion_step_embed_dim_in = diffusion_step_embed_dim_in + + self.audio_channels = audio_channels + self.cond_channels = cond_channels + self.lvc_block_nums = len(upsample_ratios) + self.first_audio_conv = nn.Conv1d(1, inner_channels, + kernel_size=7, padding=(7 - 1) // 2, + dilation=1, bias=True) + + # define residual blocks + self.lvc_blocks = nn.ModuleList() + self.downsample = nn.ModuleList() + + # the layer-specific fc for noise scale embedding + self.fc_t = nn.ModuleList() + self.fc_t1 = nn.Linear(diffusion_step_embed_dim_in, diffusion_step_embed_dim_mid) + self.fc_t2 = nn.Linear(diffusion_step_embed_dim_mid, diffusion_step_embed_dim_out) + + cond_hop_length = 1 + for n in range(self.lvc_block_nums): + cond_hop_length = cond_hop_length * upsample_ratios[n] + lvcb = TimeAware_LVCBlock( + in_channels=inner_channels, + cond_channels=cond_channels, + upsample_ratio=upsample_ratios[n], + conv_layers=lvc_layers_each_block, + conv_kernel_size=lvc_kernel_size, + cond_hop_length=cond_hop_length, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=dropout, + noise_scale_embed_dim_out=diffusion_step_embed_dim_out + ) + self.lvc_blocks += [lvcb] + self.downsample.append(DiffusionDBlock(inner_channels, inner_channels, upsample_ratios[self.lvc_block_nums-n-1])) + + + # define output layers + self.final_conv = nn.Sequential(nn.Conv1d(inner_channels, audio_channels, kernel_size=7, padding=(7 - 1) // 2, + dilation=1, bias=True)) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, data): + """Calculate forward propagation. + Args: + x (Tensor): Input noise signal (B, 1, T). + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + Returns: + Tensor: Output tensor (B, out_channels, T) + """ + audio, c, diffusion_steps = data + + # embed diffusion step t + diffusion_step_embed = calc_diffusion_step_embedding(diffusion_steps, self.diffusion_step_embed_dim_in) + diffusion_step_embed = swish(self.fc_t1(diffusion_step_embed)) + diffusion_step_embed = swish(self.fc_t2(diffusion_step_embed)) + + audio = self.first_audio_conv(audio) + downsample = [] + for down_layer in self.downsample: + downsample.append(audio) + audio = down_layer(audio) + + x = audio + for n, audio_down in enumerate(reversed(downsample)): + x = self.lvc_blocks[n]((x, audio_down, c, diffusion_step_embed)) + + # apply final layers + x = self.final_conv(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + diff --git a/modules/FastDiff/module/WaveNet.py b/modules/FastDiff/module/WaveNet.py new file mode 100644 index 0000000000000000000000000000000000000000..15f5fdc75ff696646c86551642deaebf2dd89ead --- /dev/null +++ b/modules/FastDiff/module/WaveNet.py @@ -0,0 +1,189 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from modules.FastDiff.module.util import calc_noise_scale_embedding +def swish(x): + return x * torch.sigmoid(x) + + +# dilated conv layer with kaiming_normal initialization +# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py +class Conv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1): + super(Conv, self).__init__() + self.padding = dilation * (kernel_size - 1) // 2 + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation, padding=self.padding) + self.conv = nn.utils.weight_norm(self.conv) + nn.init.kaiming_normal_(self.conv.weight) + + def forward(self, x): + out = self.conv(x) + return out + + +# conv1x1 layer with zero initialization +# from https://github.com/ksw0306/FloWaveNet/blob/master/modules.py but the scale parameter is removed +class ZeroConv1d(nn.Module): + def __init__(self, in_channel, out_channel): + super(ZeroConv1d, self).__init__() + self.conv = nn.Conv1d(in_channel, out_channel, kernel_size=1, padding=0) + self.conv.weight.data.zero_() + self.conv.bias.data.zero_() + + def forward(self, x): + out = self.conv(x) + return out + + +# every residual block (named residual layer in paper) +# contains one noncausal dilated conv +class Residual_block(nn.Module): + def __init__(self, res_channels, skip_channels, dilation, + noise_scale_embed_dim_out, multiband=True): + super(Residual_block, self).__init__() + self.res_channels = res_channels + + # the layer-specific fc for noise scale embedding + self.fc_t = nn.Linear(noise_scale_embed_dim_out, self.res_channels) + + # dilated conv layer + self.dilated_conv_layer = Conv(self.res_channels, 2 * self.res_channels, kernel_size=3, dilation=dilation) + + # add mel spectrogram upsampler and conditioner conv1x1 layer + self.upsample_conv2d = torch.nn.ModuleList() + if multiband is True: + params = 8 + else: + params = 16 + for s in [params, params]: ####### Very Important!!!!! ####### + conv_trans2d = torch.nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s)) + conv_trans2d = torch.nn.utils.weight_norm(conv_trans2d) + torch.nn.init.kaiming_normal_(conv_trans2d.weight) + self.upsample_conv2d.append(conv_trans2d) + self.mel_conv = Conv(80, 2 * self.res_channels, kernel_size=1) # 80 is mel bands + + # residual conv1x1 layer, connect to next residual layer + self.res_conv = nn.Conv1d(res_channels, res_channels, kernel_size=1) + self.res_conv = nn.utils.weight_norm(self.res_conv) + nn.init.kaiming_normal_(self.res_conv.weight) + + # skip conv1x1 layer, add to all skip outputs through skip connections + self.skip_conv = nn.Conv1d(res_channels, skip_channels, kernel_size=1) + self.skip_conv = nn.utils.weight_norm(self.skip_conv) + nn.init.kaiming_normal_(self.skip_conv.weight) + + def forward(self, input_data): + x, mel_spec, noise_scale_embed = input_data + h = x + B, C, L = x.shape # B, res_channels, L + assert C == self.res_channels + + # add in noise scale embedding + part_t = self.fc_t(noise_scale_embed) + part_t = part_t.view([B, self.res_channels, 1]) + h += part_t + + # dilated conv layer + h = self.dilated_conv_layer(h) + + # add mel spectrogram as (local) conditioner + assert mel_spec is not None + + # Upsample spectrogram to size of audio + mel_spec = torch.unsqueeze(mel_spec, dim=1) # (B, 1, 80, T') + mel_spec = F.leaky_relu(self.upsample_conv2d[0](mel_spec), 0.4) + mel_spec = F.leaky_relu(self.upsample_conv2d[1](mel_spec), 0.4) + mel_spec = torch.squeeze(mel_spec, dim=1) + + assert(mel_spec.size(2) >= L) + if mel_spec.size(2) > L: + mel_spec = mel_spec[:, :, :L] + + mel_spec = self.mel_conv(mel_spec) + h += mel_spec + + # gated-tanh nonlinearity + out = torch.tanh(h[:,:self.res_channels,:]) * torch.sigmoid(h[:,self.res_channels:,:]) + + # residual and skip outputs + res = self.res_conv(out) + assert x.shape == res.shape + skip = self.skip_conv(out) + + return (x + res) * math.sqrt(0.5), skip # normalize for training stability + + +class Residual_group(nn.Module): + def __init__(self, res_channels, skip_channels, num_res_layers, dilation_cycle, + noise_scale_embed_dim_in, + noise_scale_embed_dim_mid, + noise_scale_embed_dim_out, multiband): + super(Residual_group, self).__init__() + self.num_res_layers = num_res_layers + self.noise_scale_embed_dim_in = noise_scale_embed_dim_in + + # the shared two fc layers for noise scale embedding + self.fc_t1 = nn.Linear(noise_scale_embed_dim_in, noise_scale_embed_dim_mid) + self.fc_t2 = nn.Linear(noise_scale_embed_dim_mid, noise_scale_embed_dim_out) + + # stack all residual blocks with dilations 1, 2, ... , 512, ... , 1, 2, ..., 512 + self.residual_blocks = nn.ModuleList() + for n in range(self.num_res_layers): + self.residual_blocks.append(Residual_block(res_channels, skip_channels, + dilation=2 ** (n % dilation_cycle), + noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband)) + + def forward(self, input_data): + x, mel_spectrogram, noise_scales = input_data + + # embed noise scale + noise_scale_embed = calc_noise_scale_embedding(noise_scales, self.noise_scale_embed_dim_in) + noise_scale_embed = swish(self.fc_t1(noise_scale_embed)) + noise_scale_embed = swish(self.fc_t2(noise_scale_embed)) + + # pass all residual layers + h = x + skip = 0 + for n in range(self.num_res_layers): + h, skip_n = self.residual_blocks[n]((h, mel_spectrogram, noise_scale_embed)) # use the output from last residual layer + skip += skip_n # accumulate all skip outputs + + return skip * math.sqrt(1.0 / self.num_res_layers) # normalize for training stability + + +class WaveNet_vocoder(nn.Module): + def __init__(self, in_channels, res_channels, skip_channels, out_channels, + num_res_layers, dilation_cycle, + noise_scale_embed_dim_in, + noise_scale_embed_dim_mid, + noise_scale_embed_dim_out, multiband): + super(WaveNet_vocoder, self).__init__() + + # initial conv1x1 with relu + self.init_conv = nn.Sequential(Conv(in_channels, res_channels, kernel_size=1), nn.ReLU()) + + # all residual layers + self.residual_layer = Residual_group(res_channels=res_channels, + skip_channels=skip_channels, + num_res_layers=num_res_layers, + dilation_cycle=dilation_cycle, + noise_scale_embed_dim_in=noise_scale_embed_dim_in, + noise_scale_embed_dim_mid=noise_scale_embed_dim_mid, + noise_scale_embed_dim_out=noise_scale_embed_dim_out, multiband=multiband) + + # final conv1x1 -> relu -> zeroconv1x1 + self.final_conv = nn.Sequential(Conv(skip_channels, skip_channels, kernel_size=1), + nn.ReLU(), + ZeroConv1d(skip_channels, out_channels)) + + def forward(self, input_data): + audio, mel_spectrogram, noise_scales = input_data # b x band x T, b x 80 x T', b x 1 + x = audio + x = self.init_conv(x) + x = self.residual_layer((x, mel_spectrogram, noise_scales)) + x = self.final_conv(x) + + return x + diff --git a/modules/FastDiff/module/modules.py b/modules/FastDiff/module/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..29b0f42123b10a0518093c23592277b9622b5266 --- /dev/null +++ b/modules/FastDiff/module/modules.py @@ -0,0 +1,343 @@ +import math +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +from torch.nn import Conv1d + +LRELU_SLOPE = 0.1 + + + +def get_sinusoid_encoding_table(n_position, d_hid, padding_idx=None): + ''' Sinusoid position encoding table ''' + + def cal_angle(position, hid_idx): + return position / np.power(10000, 2 * (hid_idx // 2) / d_hid) + + def get_posi_angle_vec(position): + return [cal_angle(position, hid_j) for hid_j in range(d_hid)] + + sinusoid_table = np.array([get_posi_angle_vec(pos_i) + for pos_i in range(n_position)]) + + sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i + sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1 + + if padding_idx is not None: + # zero vector for padding dimension + sinusoid_table[padding_idx] = 0. + + return torch.FloatTensor(sinusoid_table) + + +def overlap_and_add(signal, frame_step): + """Reconstructs a signal from a framed representation. + + Adds potentially overlapping frames of a signal with shape + `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`. + The resulting tensor has shape `[..., output_size]` where + + output_size = (frames - 1) * frame_step + frame_length + + Args: + signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2. + frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length. + + Returns: + A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions. + output_size = (frames - 1) * frame_step + frame_length + + Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py + """ + outer_dimensions = signal.size()[:-2] + frames, frame_length = signal.size()[-2:] + + # gcd=Greatest Common Divisor + subframe_length = math.gcd(frame_length, frame_step) + subframe_step = frame_step // subframe_length + subframes_per_frame = frame_length // subframe_length + output_size = frame_step * (frames - 1) + frame_length + output_subframes = output_size // subframe_length + + subframe_signal = signal.view(*outer_dimensions, -1, subframe_length) + + frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step) + frame = signal.new_tensor(frame).long() # signal may in GPU or CPU + frame = frame.contiguous().view(-1) + + result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length) + device_of_result = result.device + result.index_add_(-2, frame.to(device_of_result), subframe_signal) + result = result.view(*outer_dimensions, -1) + return result + + +class LastLayer(nn.Module): + def __init__(self, in_channels, out_channels, + nonlinear_activation, nonlinear_activation_params, + pad, kernel_size, pad_params, bias): + super(LastLayer, self).__init__() + self.activation = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.pad = getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params) + self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, bias=bias) + + def forward(self, x): + x = self.activation(x) + x = self.pad(x) + x = self.conv(x) + return x + + +class WeightConv1d(Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super(Conv1d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__(in_channels, out_channels, + kernel_size=1, padding=0, + dilation=1, bias=bias) + +class DiffusionDBlock(nn.Module): + def __init__(self, input_size, hidden_size, factor): + super().__init__() + self.factor = factor + self.residual_dense = Conv1d(input_size, hidden_size, 1) + self.conv = nn.ModuleList([ + Conv1d(input_size, hidden_size, 3, dilation=1, padding=1), + Conv1d(hidden_size, hidden_size, 3, dilation=2, padding=2), + Conv1d(hidden_size, hidden_size, 3, dilation=4, padding=4), + ]) + + def forward(self, x): + size = x.shape[-1] // self.factor + + residual = self.residual_dense(x) + residual = F.interpolate(residual, size=size) + + x = F.interpolate(x, size=size) + for layer in self.conv: + x = F.leaky_relu(x, 0.2) + x = layer(x) + + return x + residual + + +class TimeAware_LVCBlock(torch.nn.Module): + ''' time-aware location-variable convolutions + ''' + def __init__(self, + in_channels, + cond_channels, + upsample_ratio, + conv_layers=4, + conv_kernel_size=3, + cond_hop_length=256, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + noise_scale_embed_dim_out=512 + ): + super().__init__() + + self.cond_hop_length = cond_hop_length + self.conv_layers = conv_layers + self.conv_kernel_size = conv_kernel_size + self.convs = torch.nn.ModuleList() + + self.upsample = torch.nn.ConvTranspose1d(in_channels, in_channels, + kernel_size=upsample_ratio*2, stride=upsample_ratio, + padding=upsample_ratio // 2 + upsample_ratio % 2, + output_padding=upsample_ratio % 2) + + + self.kernel_predictor = KernelPredictor( + cond_channels=cond_channels, + conv_in_channels=in_channels, + conv_out_channels=2 * in_channels, + conv_layers=conv_layers, + conv_kernel_size=conv_kernel_size, + kpnet_hidden_channels=kpnet_hidden_channels, + kpnet_conv_size=kpnet_conv_size, + kpnet_dropout=kpnet_dropout + ) + + # the layer-specific fc for noise scale embedding + self.fc_t = torch.nn.Linear(noise_scale_embed_dim_out, cond_channels) + + for i in range(conv_layers): + padding = (3 ** i) * int((conv_kernel_size - 1) / 2) + conv = torch.nn.Conv1d(in_channels, in_channels, kernel_size=conv_kernel_size, padding=padding, dilation=3 ** i) + + self.convs.append(conv) + + + def forward(self, data): + ''' forward propagation of the time-aware location-variable convolutions. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length) + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + + Returns: + Tensor: the output sequence (batch, in_channels, in_length) + ''' + x, audio_down, c, noise_embedding = data + batch, in_channels, in_length = x.shape + + noise = (self.fc_t(noise_embedding)).unsqueeze(-1) # (B, 80) + condition = c + noise # (B, 80, T) + kernels, bias = self.kernel_predictor(condition) + x = F.leaky_relu(x, 0.2) + x = self.upsample(x) + + for i in range(self.conv_layers): + x += audio_down + y = F.leaky_relu(x, 0.2) + y = self.convs[i](y) + y = F.leaky_relu(y, 0.2) + + k = kernels[:, i, :, :, :, :] + b = bias[:, i, :, :] + y = self.location_variable_convolution(y, k, b, 1, self.cond_hop_length) + x = x + torch.sigmoid(y[:, :in_channels, :]) * torch.tanh(y[:, in_channels:, :]) + return x + + def location_variable_convolution(self, x, kernel, bias, dilation, hop_size): + ''' perform location-variable convolution operation on the input sequence (x) using the local convolution kernl. + Time: 414 μs ± 309 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each), test on NVIDIA V100. + Args: + x (Tensor): the input sequence (batch, in_channels, in_length). + kernel (Tensor): the local convolution kernel (batch, in_channel, out_channels, kernel_size, kernel_length) + bias (Tensor): the bias for the local convolution (batch, out_channels, kernel_length) + dilation (int): the dilation of convolution. + hop_size (int): the hop_size of the conditioning sequence. + Returns: + (Tensor): the output sequence after performing local convolution. (batch, out_channels, in_length). + ''' + batch, in_channels, in_length = x.shape + batch, in_channels, out_channels, kernel_size, kernel_length = kernel.shape + + + assert in_length == (kernel_length * hop_size), "length of (x, kernel) is not matched" + + padding = dilation * int((kernel_size - 1) / 2) + x = F.pad(x, (padding, padding), 'constant', 0) # (batch, in_channels, in_length + 2*padding) + x = x.unfold(2, hop_size + 2 * padding, hop_size) # (batch, in_channels, kernel_length, hop_size + 2*padding) + + if hop_size < dilation: + x = F.pad(x, (0, dilation), 'constant', 0) + x = x.unfold(3, dilation, + dilation) # (batch, in_channels, kernel_length, (hop_size + 2*padding)/dilation, dilation) + x = x[:, :, :, :, :hop_size] + x = x.transpose(3, 4) # (batch, in_channels, kernel_length, dilation, (hop_size + 2*padding)/dilation) + x = x.unfold(4, kernel_size, 1) # (batch, in_channels, kernel_length, dilation, _, kernel_size) + + o = torch.einsum('bildsk,biokl->bolsd', x, kernel) + o = o + bias.unsqueeze(-1).unsqueeze(-1) + o = o.contiguous().view(batch, out_channels, -1) + return o + + + +class KernelPredictor(torch.nn.Module): + ''' Kernel predictor for the time-aware location-variable convolutions + ''' + + def __init__(self, + cond_channels, + conv_in_channels, + conv_out_channels, + conv_layers, + conv_kernel_size=3, + kpnet_hidden_channels=64, + kpnet_conv_size=3, + kpnet_dropout=0.0, + kpnet_nonlinear_activation="LeakyReLU", + kpnet_nonlinear_activation_params={"negative_slope": 0.1} + ): + ''' + Args: + cond_channels (int): number of channel for the conditioning sequence, + conv_in_channels (int): number of channel for the input sequence, + conv_out_channels (int): number of channel for the output sequence, + conv_layers (int): + kpnet_ + ''' + super().__init__() + + self.conv_in_channels = conv_in_channels + self.conv_out_channels = conv_out_channels + self.conv_kernel_size = conv_kernel_size + self.conv_layers = conv_layers + + l_w = conv_in_channels * conv_out_channels * conv_kernel_size * conv_layers + l_b = conv_out_channels * conv_layers + + padding = (kpnet_conv_size - 1) // 2 + self.input_conv = torch.nn.Sequential( + torch.nn.Conv1d(cond_channels, kpnet_hidden_channels, 5, padding=(5 - 1) // 2, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.residual_conv = torch.nn.Sequential( + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Dropout(kpnet_dropout), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + torch.nn.Conv1d(kpnet_hidden_channels, kpnet_hidden_channels, kpnet_conv_size, padding=padding, bias=True), + getattr(torch.nn, kpnet_nonlinear_activation)(**kpnet_nonlinear_activation_params), + ) + + self.kernel_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_w, kpnet_conv_size, + padding=padding, bias=True) + self.bias_conv = torch.nn.Conv1d(kpnet_hidden_channels, l_b, kpnet_conv_size, padding=padding, + bias=True) + + def forward(self, c): + ''' + Args: + c (Tensor): the conditioning sequence (batch, cond_channels, cond_length) + Returns: + ''' + batch, cond_channels, cond_length = c.shape + + c = self.input_conv(c) + c = c + self.residual_conv(c) + k = self.kernel_conv(c) + b = self.bias_conv(c) + + kernels = k.contiguous().view(batch, + self.conv_layers, + self.conv_in_channels, + self.conv_out_channels, + self.conv_kernel_size, + cond_length) + bias = b.contiguous().view(batch, + self.conv_layers, + self.conv_out_channels, + cond_length) + return kernels, bias diff --git a/modules/FastDiff/module/util.py b/modules/FastDiff/module/util.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3b5ff412c70ae6674596ed5e5903d347ad167b --- /dev/null +++ b/modules/FastDiff/module/util.py @@ -0,0 +1,429 @@ +import os +import numpy as np +import torch +import torch.nn as nn +import copy +from tqdm import tqdm +def flatten(v): + """ + Flatten a list of lists/tuples + """ + + return [x for y in v for x in y] + + +def rescale(x): + """ + Rescale a tensor to 0-1 + """ + + return (x - x.min()) / (x.max() - x.min()) + + +def find_max_epoch(path): + """ + Find maximum epoch/iteration in path, formatted ${n_iter}.pkl + E.g. 100000.pkl + + Parameters: + path (str): checkpoint path + + Returns: + maximum iteration, -1 if there is no (valid) checkpoint + """ + + files = os.listdir(path) + epoch = -1 + for f in files: + if len(f) <= 4: + continue + if f[-4:] == '.pkl': + try: + epoch = max(epoch, int(f[:-4])) + except: + continue + #print(path, epoch, flush=True) + return epoch + + +def print_size(net): + """ + Print the number of parameters of a network + """ + + if net is not None and isinstance(net, torch.nn.Module): + module_parameters = filter(lambda p: p.requires_grad, net.parameters()) + params = sum([np.prod(p.size()) for p in module_parameters]) + print("{} Parameters: {:.6f}M".format( + net.__class__.__name__, params / 1e6), flush=True) + + +# Utilities for diffusion models + +def std_normal(size): + """ + Generate the standard Gaussian variable of a certain size + """ + + return torch.normal(0, 1, size=size) + + +def calc_noise_scale_embedding(noise_scales, noise_scale_embed_dim_in): + """ + Embed a noise scale $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + + Parameters: + noise_scales (torch.long tensor, shape=(batchsize, 1)): + noise scales for batch data + noise_scale_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete noise scales + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, noise_scale_embed_dim_in)): + """ + + assert noise_scale_embed_dim_in % 2 == 0 + + half_dim = noise_scale_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed) + _embed = noise_scales * _embed + noise_scale_embed = torch.cat((torch.sin(_embed), + torch.cos(_embed)), 1) + + return noise_scale_embed + + +def calc_diffusion_hyperparams_given_beta(beta): + """ + Compute diffusion process hyperparameters + + Parameters: + beta (tensor): beta schedule + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + T = len(beta) + alpha = 1 - beta + sigma = beta + 0 + for t in range(1, T): + alpha[t] *= alpha[t-1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) + sigma[t] *= (1-alpha[t-1]) / (1-alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) + alpha = torch.sqrt(alpha) + sigma = torch.sqrt(sigma) + + _dh = {} + _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma + diffusion_hyperparams = _dh + return diffusion_hyperparams + + +def calc_diffusion_hyperparams(T, beta_0, beta_T, tau, N, beta_N, alpha_N, rho): + """ + Compute diffusion process hyperparameters + + Parameters: + T (int): number of noise scales + beta_0 and beta_T (float): beta schedule start/end value, + where any beta_t in the middle is linearly interpolated + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + beta = torch.linspace(beta_0, beta_T, T) + alpha = 1 - beta + sigma = beta + 0 + for t in range(1, T): + alpha[t] *= alpha[t-1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) + sigma[t] *= (1-alpha[t-1]) / (1-alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) + alpha = torch.sqrt(alpha) + sigma = torch.sqrt(sigma) + + _dh = {} + _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma + _dh["tau"], _dh["N"], _dh["betaN"], _dh["alphaN"], _dh["rho"] = tau, N, beta_N, alpha_N, rho + diffusion_hyperparams = _dh + return diffusion_hyperparams + + +def sampling_given_noise_schedule( + net, + size, + diffusion_hyperparams, + inference_noise_schedule, + condition=None, + ddim=False, + return_sequence=False): + """ + Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) + + Parameters: + net (torch network): the wavenet models + size (tuple): size of tensor to be generated, + usually is (number of audios to generate, channels=1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + condition (torch.tensor): ground truth mel spectrogram read from disk + None if used for unconditional generation + + Returns: + the generated audio(s) in torch.tensor, shape=size + """ + + _dh = diffusion_hyperparams + T, alpha = _dh["T"], _dh["alpha"] + assert len(alpha) == T + assert len(size) == 3 + + N = len(inference_noise_schedule) + beta_infer = inference_noise_schedule + alpha_infer = 1 - beta_infer + sigma_infer = beta_infer + 0 + for n in range(1, N): + alpha_infer[n] *= alpha_infer[n - 1] + sigma_infer[n] *= (1 - alpha_infer[n - 1]) / (1 - alpha_infer[n]) + alpha_infer = torch.sqrt(alpha_infer) + sigma_infer = torch.sqrt(sigma_infer) + + # Mapping noise scales to time steps + steps_infer = [] + for n in range(N): + step = map_noise_scale_to_time_step(alpha_infer[n], alpha) + if step >= 0: + steps_infer.append(step) + steps_infer = torch.FloatTensor(steps_infer) + + # N may change since alpha_infer can be out of the range of alpha + N = len(steps_infer) + + x = std_normal(size) + if return_sequence: + x_ = copy.deepcopy(x) + xs = [x_] + with torch.no_grad(): + for n in tqdm(range(N - 1, -1, -1), desc='FastDiff sample time step', total=N): + diffusion_steps = (steps_infer[n] * torch.ones((size[0], 1))) + epsilon_theta = net((x, condition, diffusion_steps,)) + if ddim: + alpha_next = alpha_infer[n] / (1 - beta_infer[n]).sqrt() + c1 = alpha_next / alpha_infer[n] + c2 = -(1 - alpha_infer[n] ** 2.).sqrt() * c1 + c3 = (1 - alpha_next ** 2.).sqrt() + x = c1 * x + c2 * epsilon_theta + c3 * epsilon_theta # std_normal(size) + else: + x -= beta_infer[n] / torch.sqrt(1 - alpha_infer[n] ** 2.) * epsilon_theta + x /= torch.sqrt(1 - beta_infer[n]) + if n > 0: + x = x + sigma_infer[n] * std_normal(size) + if return_sequence: + x_ = copy.deepcopy(x) + xs.append(x_) + if return_sequence: + return xs + return x + +def noise_scheduling(net, size, diffusion_hyperparams, condition=None, ddim=False): + """ + Perform the complete sampling step according to p(x_0|x_T) = \prod_{t=1}^T p_{\theta}(x_{t-1}|x_t) + + Parameters: + net (torch network): the wavenet models + size (tuple): size of tensor to be generated, + usually is (number of audios to generate, channels=1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + condition (torch.tensor): ground truth mel spectrogram read from disk + None if used for unconditional generation + + Returns: + noise schedule: a list of noise scales in torch.tensor, length <= N + """ + + _dh = diffusion_hyperparams + N, betaN, alphaN, rho, alpha = _dh["N"], _dh["betaN"], _dh["alphaN"], _dh["rho"], _dh["alpha"] + + print('begin noise scheduling, maximum number of reverse steps = %d' % (N)) + + betas = [] + x = std_normal(size) + with torch.no_grad(): + beta_cur = torch.ones(1, 1, 1).cuda() * betaN + alpha_cur = torch.ones(1, 1, 1).cuda() * alphaN + for n in range(N - 1, -1, -1): + # print(n, beta_cur.squeeze().item(), alpha_cur.squeeze().item()) + step = map_noise_scale_to_time_step(alpha_cur.squeeze().item(), alpha) + if step >= 0: + betas.append(beta_cur.squeeze().item()) + diffusion_steps = (step * torch.ones((size[0], 1))).cuda() + epsilon_theta = net((x, condition, diffusion_steps,)) + if ddim: + alpha_nxt = alpha_cur / (1 - beta_cur).sqrt() + c1 = alpha_nxt / alpha_cur + c2 = -(1 - alpha_cur ** 2.).sqrt() * c1 + c3 = (1 - alpha_nxt ** 2.).sqrt() + x = c1 * x + c2 * epsilon_theta + c3 * epsilon_theta # std_normal(size) + else: + x -= beta_cur / torch.sqrt(1 - alpha_cur ** 2.) * epsilon_theta + x /= torch.sqrt(1 - beta_cur) + alpha_nxt, beta_nxt = alpha_cur, beta_cur + alpha_cur = alpha_nxt / (1 - beta_nxt).sqrt() + if alpha_cur > 1: + break + beta_cur = net.noise_pred( + x.squeeze(1), (beta_nxt.view(-1, 1), (1 - alpha_cur ** 2.).view(-1, 1))) + if beta_cur.squeeze().item() < rho: + break + return torch.FloatTensor(betas[::-1]).cuda() + + +def theta_timestep_loss(net, X, diffusion_hyperparams, reverse=False): + """ + Compute the training loss for learning theta + + Parameters: + net (torch network): the wavenet models + X (tuple, shape=(2,)): training data in tuple form (mel_spectrograms, audios) + mel_spectrograms: torch.tensor, shape is batchsize followed by each mel_spectrogram shape + audios: torch.tensor, shape=(batchsize, 1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + + Returns: + theta loss + """ + assert type(X) == tuple and len(X) == 2 + loss_fn = nn.MSELoss() + + _dh = diffusion_hyperparams + T, alpha = _dh["T"], _dh["alpha"] + + mel_spectrogram, audio = X + B, C, L = audio.shape # B is batchsize, C=1, L is audio length + ts = torch.randint(T, size=(B, 1, 1)).cuda() # randomly sample steps from 1~T + z = std_normal(audio.shape) + delta = (1 - alpha[ts] ** 2.).sqrt() + alpha_cur = alpha[ts] + noisy_audio = alpha_cur * audio + delta * z # compute x_t from q(x_t|x_0) + epsilon_theta = net((noisy_audio, mel_spectrogram, ts.view(B, 1),)) + + if reverse: + x0 = (noisy_audio - delta * epsilon_theta) / alpha_cur + return loss_fn(epsilon_theta, z), x0 + + return loss_fn(epsilon_theta, z) + + +def phi_loss(net, X, diffusion_hyperparams): + """ + Compute the training loss for learning phi + Parameters: + net (torch network): the wavenet models + X (tuple, shape=(2,)): training data in tuple form (mel_spectrograms, audios) + mel_spectrograms: torch.tensor, shape is batchsize followed by each mel_spectrogram shape + audios: torch.tensor, shape=(batchsize, 1, length of audio) + diffusion_hyperparams (dict): dictionary of diffusion hyperparameters returned by calc_diffusion_hyperparams + note, the tensors need to be cuda tensors + + Returns: + phi loss + """ + assert type(X) == tuple and len(X) == 2 + _dh = diffusion_hyperparams + T, alpha, tau = _dh["T"], _dh["alpha"], _dh["tau"] + + mel_spectrogram, audio = X + B, C, L = audio.shape # B is batchsize, C=1, L is audio length + ts = torch.randint(tau, T - tau, size=(B,)).cuda() # randomly sample steps from 1~T + alpha_cur = alpha.index_select(0, ts).view(B, 1, 1) + alpha_nxt = alpha.index_select(0, ts + tau).view(B, 1, 1) + beta_nxt = 1 - (alpha_nxt / alpha_cur) ** 2. + delta = (1 - alpha_cur ** 2.).sqrt() + z = std_normal(audio.shape) + noisy_audio = alpha_cur * audio + delta * z # compute x_t from q(x_t|x_0) + epsilon_theta = net((noisy_audio, mel_spectrogram, ts.view(B, 1),)) + beta_est = net.noise_pred(noisy_audio.squeeze(1), (beta_nxt.view(B, 1), delta.view(B, 1) ** 2.)) + phi_loss = 1 / (2. * (delta ** 2. - beta_est)) * ( + delta * z - beta_est / delta * epsilon_theta) ** 2. + phi_loss += torch.log(1e-8 + delta ** 2. / (beta_est + 1e-8)) / 4. + phi_loss = (torch.mean(phi_loss, -1, keepdim=True) + beta_est / delta ** 2 / 2.).mean() + + return phi_loss + + +def compute_hyperparams_given_schedule(beta): + """ + Compute diffusion process hyperparameters + + Parameters: + beta (tensor): beta schedule + + Returns: + a dictionary of diffusion hyperparameters including: + T (int), beta/alpha/sigma (torch.tensor on cpu, shape=(T, )) + These cpu tensors are changed to cuda tensors on each individual gpu + """ + + T = len(beta) + alpha = 1 - beta + sigma = beta + 0 + for t in range(1, T): + alpha[t] *= alpha[t - 1] # \alpha^2_t = \prod_{s=1}^t (1-\beta_s) + sigma[t] *= (1 - alpha[t - 1]) / (1 - alpha[t]) # \sigma^2_t = \beta_t * (1-\alpha_{t-1}) / (1-\alpha_t) + alpha = torch.sqrt(alpha) + sigma = torch.sqrt(sigma) + + _dh = {} + _dh["T"], _dh["beta"], _dh["alpha"], _dh["sigma"] = T, beta, alpha, sigma + diffusion_hyperparams = _dh + return diffusion_hyperparams + + + +def map_noise_scale_to_time_step(alpha_infer, alpha): + if alpha_infer < alpha[-1]: + return len(alpha) - 1 + if alpha_infer > alpha[0]: + return 0 + for t in range(len(alpha) - 1): + if alpha[t+1] <= alpha_infer <= alpha[t]: + step_diff = alpha[t] - alpha_infer + step_diff /= alpha[t] - alpha[t+1] + return t + step_diff.item() + return -1 + + +def calc_diffusion_step_embedding(diffusion_steps, diffusion_step_embed_dim_in): + """ + Embed a diffusion step $t$ into a higher dimensional space + E.g. the embedding vector in the 128-dimensional space is + [sin(t * 10^(0*4/63)), ... , sin(t * 10^(63*4/63)), cos(t * 10^(0*4/63)), ... , cos(t * 10^(63*4/63))] + + Parameters: + diffusion_steps (torch.long tensor, shape=(batchsize, 1)): + diffusion steps for batch data + diffusion_step_embed_dim_in (int, default=128): + dimensionality of the embedding space for discrete diffusion steps + + Returns: + the embedding vectors (torch.tensor, shape=(batchsize, diffusion_step_embed_dim_in)): + """ + + assert diffusion_step_embed_dim_in % 2 == 0 + + half_dim = diffusion_step_embed_dim_in // 2 + _embed = np.log(10000) / (half_dim - 1) + _embed = torch.exp(torch.arange(half_dim) * -_embed) + _embed = diffusion_steps * _embed + diffusion_step_embed = torch.cat((torch.sin(_embed), + torch.cos(_embed)), 1) + + return diffusion_step_embed \ No newline at end of file diff --git a/modules/FastDiff/task/FastDiff.py b/modules/FastDiff/task/FastDiff.py new file mode 100644 index 0000000000000000000000000000000000000000..c8902b4309ff45b4c1b88707e45c43238f52b795 --- /dev/null +++ b/modules/FastDiff/task/FastDiff.py @@ -0,0 +1,133 @@ +import os + +import torch +import utils +from modules.FastDiff.module.FastDiff_model import FastDiff +from tasks.vocoder.vocoder_base import VocoderBaseTask +from utils import audio +from utils.hparams import hparams +from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule + + +class FastDiffTask(VocoderBaseTask): + def __init__(self): + super(FastDiffTask, self).__init__() + + def build_model(self): + self.model = FastDiff(audio_channels=hparams['audio_channels'], + inner_channels=hparams['inner_channels'], + cond_channels=hparams['cond_channels'], + upsample_ratios=hparams['upsample_ratios'], + lvc_layers_each_block=hparams['lvc_layers_each_block'], + lvc_kernel_size=hparams['lvc_kernel_size'], + kpnet_hidden_channels=hparams['kpnet_hidden_channels'], + kpnet_conv_size=hparams['kpnet_conv_size'], + dropout=hparams['dropout'], + diffusion_step_embed_dim_in=hparams['diffusion_step_embed_dim_in'], + diffusion_step_embed_dim_mid=hparams['diffusion_step_embed_dim_mid'], + diffusion_step_embed_dim_out=hparams['diffusion_step_embed_dim_out'], + use_weight_norm=hparams['use_weight_norm']) + utils.print_arch(self.model) + + # Init hyperparameters by linear schedule + noise_schedule = torch.linspace(float(hparams["beta_0"]), float(hparams["beta_T"]), int(hparams["T"])).cuda() + diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) + + # map diffusion hyperparameters to gpu + for key in diffusion_hyperparams: + if key in ["beta", "alpha", "sigma"]: + diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda() + self.diffusion_hyperparams = diffusion_hyperparams + + return self.model + + def _training_step(self, sample, batch_idx, optimizer_idx): + mels = sample['mels'] + y = sample['wavs'] + X = (mels, y) + loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) + return loss, {'loss': loss} + + + def validation_step(self, sample, batch_idx): + mels = sample['mels'] + y = sample['wavs'] + X = (mels, y) + loss = theta_timestep_loss(self.model, X, self.diffusion_hyperparams) + return loss, {'loss': loss} + + + def test_step(self, sample, batch_idx): + mels = sample['mels'] + y = sample['wavs'] + loss_output = {} + + if hparams['noise_schedule'] != '': + noise_schedule = hparams['noise_schedule'] + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule).cuda() + else: + # Select Schedule + try: + reverse_step = int(hparams.get('N')) + except: + print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') + reverse_step = 4 + if reverse_step == 1000: + noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda() + elif reverse_step == 200: + noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda() + + # Below are schedules derived by Noise Predictor. + # We will release codes of noise predictor training process & noise scheduling process soon. Please Stay Tuned! + elif reverse_step == 8: + noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, + 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] + elif reverse_step == 6: + noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, + 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] + elif reverse_step == 4: + noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] + elif reverse_step == 3: + noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] + else: + raise NotImplementedError + + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule).cuda() + + audio_length = mels.shape[-1] * hparams["hop_size"] + # generate using DDPM reverse process + + y_ = sampling_given_noise_schedule( + self.model, (1, 1, audio_length), self.diffusion_hyperparams, noise_schedule, + condition=mels, ddim=False, return_sequence=False) + gen_dir = os.path.join(hparams['work_dir'], f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(gen_dir, exist_ok=True) + + if len(y) == 0: + # Inference from mel + for idx, (wav_pred, item_name) in enumerate(zip(y_, sample["item_name"])): + wav_pred = wav_pred / wav_pred.abs().max() + audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', + hparams['audio_sample_rate']) + else: + for idx, (wav_pred, wav_gt, item_name) in enumerate(zip(y_, y, sample["item_name"])): + wav_gt = wav_gt / wav_gt.abs().max() + wav_pred = wav_pred / wav_pred.abs().max() + audio.save_wav(wav_gt.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_gt.wav', hparams['audio_sample_rate']) + audio.save_wav(wav_pred.view(-1).cpu().float().numpy(), f'{gen_dir}/{item_name}_pred.wav', hparams['audio_sample_rate']) + return loss_output + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + self.model.parameters(), + lr=float(hparams['lr']), weight_decay=float(hparams['weight_decay'])) + return optimizer + + def compute_rtf(self, sample, generation_time, sample_rate=22050): + """ + Computes RTF for a given sample. + """ + total_length = sample.shape[-1] + return float(generation_time * sample_rate / total_length) \ No newline at end of file diff --git a/modules/ProDiff/config/base.yaml b/modules/ProDiff/config/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..696546e3302a55ee97f633136ead9c44d3e45a12 --- /dev/null +++ b/modules/ProDiff/config/base.yaml @@ -0,0 +1,67 @@ +base_config: + - egs/egs_bases/tts/fs2.yaml + +# diffusion model +diff_decoder_type: 'wavenet' +dilation_cycle_length: 1 +residual_layers: 20 +residual_channels: 256 +keep_bins: 80 +spec_min: [ ] +spec_max: [ ] +diff_loss_type: l1 +timesteps: 100 +max_beta: 0.06 + +# train +max_sentences: 48 +max_updates: 200000 + + +# FastDiff vocoder +vocoder: FastDiff +N: 4 # denoising steps +vocoder_ckpt: checkpoints/FastDiff + +# eval +use_gt_dur: true +use_gt_f0: true +gen_tgt_spk_id: -1 +num_sanity_val_steps: -1 +num_valid_plots: 10 +use_cond_disc: true +save_gt: true +num_test_samples: 20 +max_valid_sentences: 1 +text: the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing. + + +# variation +pitch_type: frame +pitch_extractor: 'parselmouth' +use_pitch_embed: true +use_energy_embed: true +mel_loss: "ssim:0.5|l1:0.5" + + +# dataset +preprocess_cls: egs.datasets.audio.lj.pre_align.LJPreAlign +preprocess_args: + nsample_per_mfa_group: 1000 + # text process + txt_processor: en + use_mfa: true + with_phsep: true + reset_phone_dict: true + reset_word_dict: true + add_eos_bos: true + # mfa + mfa_group_shuffle: false + mfa_offset: 0.02 + # wav processors + wav_processors: [ ] + save_sil_mask: true + vad_max_silence_length: 12 + + + diff --git a/modules/ProDiff/config/prodiff.yaml b/modules/ProDiff/config/prodiff.yaml new file mode 100644 index 0000000000000000000000000000000000000000..64ee8eb61c97f2453726750e58cba71f08fef532 --- /dev/null +++ b/modules/ProDiff/config/prodiff.yaml @@ -0,0 +1,16 @@ +base_config: + - ./base.yaml + +raw_data_dir: 'data/raw/LJSpeech' +processed_data_dir: 'data/processed/LJSpeech' +binary_data_dir: 'data/binary/LJSpeech' + +task_cls: modules.ProDiff.task.ProDiff_task.ProDiff_Task + + +# diffusion +timesteps: 4 +teacher_ckpt: checkpoints/ProDiff_Teacher/model_ckpt_steps_188000.ckpt +diff_decoder_type: 'wavenet' +schedule_type: 'vpsde' + diff --git a/modules/ProDiff/config/prodiff_teacher.yaml b/modules/ProDiff/config/prodiff_teacher.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1e2e38cfc4a2a8895783664203098a94878688bf --- /dev/null +++ b/modules/ProDiff/config/prodiff_teacher.yaml @@ -0,0 +1,13 @@ +base_config: + - ./base.yaml + +raw_data_dir: 'data/raw/LJSpeech' +processed_data_dir: 'data/processed/LJSpeech' +binary_data_dir: 'data/binary/LJSpeech' + +task_cls: modules.ProDiff.task.ProDiff_teacher_task.ProDiff_teacher_Task + +# diffusion +timesteps: 4 +timescale: 1 +schedule_type: 'vpsde' diff --git a/modules/ProDiff/model/ProDiff.py b/modules/ProDiff/model/ProDiff.py new file mode 100644 index 0000000000000000000000000000000000000000..4caba78873c81691e92014c099bf9a36d0ba076b --- /dev/null +++ b/modules/ProDiff/model/ProDiff.py @@ -0,0 +1,210 @@ +import math +import random +from functools import partial +from usr.diff.shallow_diffusion_tts import * +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from utils.hparams import hparams + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, teacher_steps=4, + timesteps=4, time_scale=1, loss_type='l1', betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.fs2.decoder = None + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + betas = get_noise_schedule_list( + schedule_mode=hparams['schedule_type'], + timesteps=teacher_steps + 1, + min_beta=0.1, + max_beta=40, + s=0.008, + ) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + self.time_scale = time_scale + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) # beta + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # alphacum_t + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # alphacum_{t-1} + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_posterior_sample(self, x_start, x_t, t, repeat_noise=False): + b, *_, device = *x_start.shape, x_start.device + model_mean, _, model_log_variance = self.q_posterior(x_start=x_start, x_t=x_t, t=t) + noise = noise_like(x_start.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_start.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample(self, x_t, t, cond, spk_emb=None, clip_denoised=True, repeat_noise=False): + b, *_, device = *x_t.shape, x_t.device + x_0_pred = self.denoise_fn(x_t, t, cond) + + return self.q_posterior_sample(x_start=x_0_pred, x_t=x_t, t=t) + + def sample_q(self, x_0, ts, epsilon=None): + """ + Sample from q(x_t | x_0) for a batch of x_0. + """ + alpha, sigma = self.get_schedule(x_0, ts) + return alpha * x_0 + sigma * epsilon + + @torch.no_grad() + def p_sample_ddim(self, x_t, t, cond): + b, *_, device = *x_t.shape, x_t.device + x_0_pred = self.denoise_fn(x_t, t, cond) + alpha, sigma = self.get_schedule(x_t, t) + eps = (x_t - x_0_pred * alpha) / sigma + return self.sample_q(x_0_pred, t-self.time_scale, eps) + + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def get_schedule(self, x_t, t): + return extract(self.sqrt_alphas_cumprod, t, x_t.shape), extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) + + def diffuse_fn(self, x_start, t, noise=None): + x_start = self.norm_spec(x_start) + x_start = x_start.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + zero_idx = t < 0 # for items where t is -1 + t[zero_idx] = 0 + noise = default(noise, lambda: torch.randn_like(x_start)) + out = self.q_sample(x_start=x_start, t=t, noise=noise) + out[zero_idx] = x_start[zero_idx] # set x_{-1} as the gt mel + return out + + def forward(self, txt_tokens, teacher_fn=None, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=infer) + cond = ret['decoder_inp'].transpose(1, 2) + if not infer: + with torch.no_grad(): + t = self.time_scale * torch.randint(1, self.num_timesteps+1, (b,), device=device).long() # [2, 4] + nonpadding = (mel2ph != 0).float().unsqueeze(1).unsqueeze(1) + noise = default(None, lambda: torch.randn_like(ref_mels.transpose(1, 2)[:, None, :, :])) + + # Diffusion + x_t = self.diffuse_fn(ref_mels, t, noise) * nonpadding + + # 2 steps of DDIM + x0_pred = teacher_fn.denoise_fn(x_t, t, cond) * nonpadding # p(x_0|x_t,t) correct + alpha, sigma = self.get_schedule(x_t, t) + alpha_pre, sigma_pre = self.get_schedule(x_t, t - self.time_scale // 2) + alpha_pre_pre, sigma_pre_pre = self.get_schedule(x_t, t - self.time_scale) + x_t_pre = alpha_pre * x0_pred + sigma_pre / sigma * (x_t - alpha * x0_pred) # correct + x0_pred1 = teacher_fn.denoise_fn(x_t_pre, t - self.time_scale // 2, cond) * nonpadding # correct + x_t_pre_pre = alpha_pre_pre * x0_pred1 + sigma_pre_pre / sigma_pre * ( + x_t_pre - alpha_pre * x0_pred1) # correct + x_target = (x_t_pre_pre - (sigma_pre_pre / sigma) * x_t) / (alpha_pre_pre - sigma_pre_pre / sigma * alpha) * nonpadding + + x_pred = self.denoise_fn(x_t, t - self.time_scale, cond) * nonpadding # student [0, 1]: 8 steps correct + x_t_prev = self.diffuse_fn(ref_mels, t - self.time_scale - 1, noise) * nonpadding # teacher [-1, 1] + x_t_prev_pred = self.q_posterior_sample(x_pred, x_t, t - self.time_scale) * nonpadding # [-1, 1] p(x_t-1|x_t,x_0,t) + + if self.loss_type == 'l1': + if nonpadding is not None: # [B, T] + loss = ((x_pred - x_target).abs() * nonpadding).mean() # [B, B, M, T].mean() + else: + # print('are you sure w/o nonpadding?') + loss = (x_pred - x_target).abs().mean() + + elif self.loss_type == 'l2': + loss = F.mse_loss(x_pred, x_target) + else: + raise NotImplementedError() + + ret['mel_out'] = loss # [B, T, 80] + ret['x_t'] = x_t[:, 0].transpose(1, 2) + ret['x_t_prev'] = x_t_prev[:, 0].transpose(1, 2) + ret['x_t_prev_pred'] = x_t_prev_pred[:, 0].transpose(1, 2) + ret['t'] = t + else: + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) # noise + sample_steps = [self.time_scale * i for i in range(0, self.num_timesteps)] + for i in tqdm(reversed(sample_steps), desc='ProDiff sample time step', total=len(sample_steps)): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) # x(mel), t, condition(phoneme) + x = x[:, 0].transpose(1, 2) + # p_sample: 0.1805 + ret['mel_out'] = self.denorm_spec(x) # 去除norm + return ret + + + def norm_spec(self, x): + return x + + def denorm_spec(self, x): + return x + + def out2mel(self, x): + return x \ No newline at end of file diff --git a/modules/ProDiff/model/ProDiff_teacher.py b/modules/ProDiff/model/ProDiff_teacher.py new file mode 100644 index 0000000000000000000000000000000000000000..fa93b0976c672eef9771999962e8204e7668e9db --- /dev/null +++ b/modules/ProDiff/model/ProDiff_teacher.py @@ -0,0 +1,190 @@ +import math +import random +from functools import partial +from usr.diff.shallow_diffusion_tts import * +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from utils.hparams import hparams + + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, + timesteps=1000, time_scale=1, loss_type='l1', betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.fs2.decoder = None + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + betas = get_noise_schedule_list( + schedule_mode=hparams['schedule_type'], + timesteps=timesteps + 1, + min_beta=0.1, + max_beta=40, + s=0.008, + ) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + self.time_scale = time_scale + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('timesteps', to_torch(self.num_timesteps)) # beta + self.register_buffer('timescale', to_torch(self.time_scale)) # beta + self.register_buffer('betas', to_torch(betas)) # beta + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) # alphacum_t + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) # alphacum_{t-1} + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def q_posterior_sample(self, x_start, x_t, t, repeat_noise=False): + b, *_, device = *x_start.shape, x_start.device + model_mean, _, model_log_variance = self.q_posterior(x_start=x_start, x_t=x_t, t=t) + noise = noise_like(x_start.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x_start.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample(self, x_t, t, cond, spk_emb=None, clip_denoised=True, repeat_noise=False): + b, *_, device = *x_t.shape, x_t.device + x_0_pred = self.denoise_fn(x_t, t, cond) + + return self.q_posterior_sample(x_start=x_0_pred, x_t=x_t, t=t) + + @torch.no_grad() + def interpolate(self, x1, x2, t, cond, spk_emb, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + x = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond, spk_emb) + x = x[:, 0].transpose(1, 2) + return self.denorm_spec(x) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def diffuse_trace(self, x_start, mask): + b, *_, device = *x_start.shape, x_start.device + trace = [self.norm_spec(x_start).clamp_(-1., 1.) * ~mask.unsqueeze(-1)] + for t in range(self.num_timesteps): + t = torch.full((b,), t, device=device, dtype=torch.long) + trace.append( + self.diffuse_fn(x_start, t)[:, 0].transpose(1, 2) * ~mask.unsqueeze(-1) + ) + return trace + + def diffuse_fn(self, x_start, t, noise=None): + x_start = self.norm_spec(x_start) + x_start = x_start.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + zero_idx = t < 0 # for items where t is -1 + t[zero_idx] = 0 + noise = default(noise, lambda: torch.randn_like(x_start)) + out = self.q_sample(x_start=x_start, t=t, noise=noise) + out[zero_idx] = x_start[zero_idx] # set x_{-1} as the gt mel + return out + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=infer) + nonpadding = (ret['mel2ph'] != 0).float().unsqueeze(1).unsqueeze(1) # [B, T] + cond = ret['decoder_inp'].transpose(1, 2) + if not infer: + t = torch.randint(0, self.num_timesteps + 1, (b,), device=device).long() + # Diffusion + x_t = self.diffuse_fn(ref_mels, t) * nonpadding + + # Predict x_{start} + x_0_pred = self.denoise_fn(x_t, t, cond) * nonpadding + + ret['mel_out'] = x_0_pred[:, 0].transpose(1, 2) # [B, T, 80] + else: + t = self.num_timesteps # reverse总步数 + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) # noise + for i in tqdm(reversed(range(0, t)), desc='ProDiff sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) # x(mel), t, condition(phoneme) + x = x[:, 0].transpose(1, 2) + ret['mel_out'] = self.denorm_spec(x) # 去除norm + return ret + + def norm_spec(self, x): + return x + + def denorm_spec(self, x): + return x + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + + def out2mel(self, x): + return x diff --git a/modules/ProDiff/task/ProDiff_task.py b/modules/ProDiff/task/ProDiff_task.py new file mode 100644 index 0000000000000000000000000000000000000000..795752e4414a8d28dd12ecf020465f9f299b6d0f --- /dev/null +++ b/modules/ProDiff/task/ProDiff_task.py @@ -0,0 +1,137 @@ +import torch +from torch import nn +import utils +from functools import partial +from utils.hparams import hparams +from modules.ProDiff.model.ProDiff import GaussianDiffusion +from usr.diff.net import DiffNet +from tasks.tts.fs2 import FastSpeech2Task +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from utils.pitch_utils import denorm_f0 +from tasks.tts.fs2_utils import FastSpeechDataset +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class ProDiff_Task(FastSpeech2Task): + def __init__(self): + super(ProDiff_Task, self).__init__() + self.dataset_cls = FastSpeechDataset + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + + def build_model(self): + self.build_tts_model() + if hparams['load_ckpt'] != '': + self.load_ckpt(hparams['load_ckpt'], strict=False) + utils.num_params(self.model, print_out=True, model_name="Generator: student") + utils.num_params(self.teacher, print_out=True, model_name="Generator: teacher") + if not hasattr(self, 'gen_params'): + self.gen_params = list(self.model.parameters()) + return self.model + + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + checkpoint = torch.load(hparams['teacher_ckpt'], map_location='cpu')["state_dict"]['model'] + teacher_timesteps = int(checkpoint['timesteps'].item()) + teacher_timescales = int(checkpoint['timescale'].item()) + student_timesteps = teacher_timesteps // 2 + student_timescales = teacher_timescales * 2 + + self.teacher = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + loss_type=hparams['diff_loss_type'], + timesteps=teacher_timesteps, time_scale=teacher_timescales, + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=student_timesteps, time_scale=student_timescales, + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + + utils.load_ckpt(self.teacher, hparams['teacher_ckpt'], 'model', strict=False) + utils.load_ckpt(self.model, hparams['teacher_ckpt'], 'model', strict=False) + to_torch = partial(torch.tensor, dtype=torch.float32) + self.model.num_timesteps = student_timesteps + self.model.time_scale = student_timescales + self.model.register_buffer('timesteps', to_torch(student_timesteps)) # beta + self.model.register_buffer('timescale', to_torch(student_timescales)) # beta + + for k, v in self.model.fs2.named_parameters(): + if not 'denoise_fn' in k: + v.requires_grad = False + + for param in self.teacher.parameters(): + param.requires_grad = False + + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s] + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + output = model(txt_tokens, self.teacher, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + losses['l1'] = output['mel_out'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + # model_out = self.model( + # txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True) + # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}') + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm')) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs + + + ############ + # validation plots + ############ + def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): + gt_wav = gt_wav[0].cpu().numpy() + wav_out = wav_out[0].cpu().numpy() + gt_f0 = gt_f0[0].cpu().numpy() + f0 = f0[0].cpu().numpy() + if is_mel: + gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) + wav_out = self.vocoder.spec2wav(wav_out, f0=f0) + self.logger.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + self.logger.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + diff --git a/modules/ProDiff/task/ProDiff_teacher_task.py b/modules/ProDiff/task/ProDiff_teacher_task.py new file mode 100644 index 0000000000000000000000000000000000000000..8e4f89645ce09c0486c18429b0571b5bf6605b79 --- /dev/null +++ b/modules/ProDiff/task/ProDiff_teacher_task.py @@ -0,0 +1,101 @@ +import torch + +import utils +from utils.hparams import hparams +from modules.ProDiff.model.ProDiff_teacher import GaussianDiffusion +from usr.diff.net import DiffNet +from tasks.tts.fs2 import FastSpeech2Task +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from utils.pitch_utils import denorm_f0 +from tasks.tts.fs2_utils import FastSpeechDataset + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class ProDiff_teacher_Task(FastSpeech2Task): + def __init__(self): + super(ProDiff_teacher_Task, self).__init__() + self.dataset_cls = FastSpeechDataset + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + + def build_model(self): + self.build_tts_model() + utils.num_params(self.model) + return self.model + + def build_tts_model(self): + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=hparams['audio_num_mel_bins'], denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], time_scale=hparams['timescale'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + self.add_mel_loss(output['mel_out'], target, losses) + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + # model_out = self.model( + # txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True) + # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}') + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm')) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs + + ############ + # validation plots + ############ + def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): + gt_wav = gt_wav[0].cpu().numpy() + wav_out = wav_out[0].cpu().numpy() + gt_f0 = gt_f0[0].cpu().numpy() + f0 = f0[0].cpu().numpy() + if is_mel: + gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) + wav_out = self.vocoder.spec2wav(wav_out, f0=f0) + self.logger.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + self.logger.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py new file mode 100755 index 0000000000000000000000000000000000000000..fe8c664acd66ddc737ccc38b56d8bb077d636bf2 --- /dev/null +++ b/modules/commons/common_layers.py @@ -0,0 +1,971 @@ +import math +import torch +from torch import nn +from torch.nn import Parameter +import torch.onnx.operators +import torch.nn.functional as F +from utils.tts_utils import make_positions, softmax, get_incremental_state, set_incremental_state + + +class Reshape(nn.Module): + def __init__(self, *args): + super(Reshape, self).__init__() + self.shape = args + + def forward(self, x): + return x.view(self.shape) + + +class Permute(nn.Module): + def __init__(self, *args): + super(Permute, self).__init__() + self.args = args + + def forward(self, x): + return x.permute(self.args) + + +class LinearNorm(torch.nn.Module): + def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'): + super(LinearNorm, self).__init__() + self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias) + + torch.nn.init.xavier_uniform_( + self.linear_layer.weight, + gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, x): + return self.linear_layer(x) + + +class ConvNorm(torch.nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, + padding=None, dilation=1, bias=True, w_init_gain='linear'): + super(ConvNorm, self).__init__() + if padding is None: + assert (kernel_size % 2 == 1) + padding = int(dilation * (kernel_size - 1) / 2) + + self.conv = torch.nn.Conv1d(in_channels, out_channels, + kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, + bias=bias) + + torch.nn.init.xavier_uniform_( + self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain)) + + def forward(self, signal): + conv_signal = self.conv(signal) + return conv_signal + + +def Embedding(num_embeddings, embedding_dim, padding_idx=None): + m = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx) + nn.init.normal_(m.weight, mean=0, std=embedding_dim ** -0.5) + if padding_idx is not None: + nn.init.constant_(m.weight[padding_idx], 0) + return m + + +class GroupNorm1DTBC(nn.GroupNorm): + def forward(self, input): + return super(GroupNorm1DTBC, self).forward(input.permute(1, 2, 0)).permute(2, 0, 1) + + +def LayerNorm(normalized_shape, eps=1e-5, elementwise_affine=True, export=False): + if not export and torch.cuda.is_available(): + try: + from apex.normalization import FusedLayerNorm + return FusedLayerNorm(normalized_shape, eps, elementwise_affine) + except ImportError: + pass + return torch.nn.LayerNorm(normalized_shape, eps, elementwise_affine) + + +def Linear(in_features, out_features, bias=True): + m = nn.Linear(in_features, out_features, bias) + nn.init.xavier_uniform_(m.weight) + if bias: + nn.init.constant_(m.bias, 0.) + return m + + +class SinusoidalPositionalEmbedding(nn.Module): + """This module produces sinusoidal positional embeddings of any length. + + Padding symbols are ignored. + """ + + def __init__(self, embedding_dim, padding_idx, init_size=1024): + super().__init__() + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.weights = SinusoidalPositionalEmbedding.get_embedding( + init_size, + embedding_dim, + padding_idx, + ) + self.register_buffer('_float_tensor', torch.FloatTensor(1)) + + @staticmethod + def get_embedding(num_embeddings, embedding_dim, padding_idx=None): + """Build sinusoidal embeddings. + + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) + emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) + if embedding_dim % 2 == 1: + # zero pad + emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) + if padding_idx is not None: + emb[padding_idx, :] = 0 + return emb + + def forward(self, input, incremental_state=None, timestep=None, positions=None, **kwargs): + """Input is expected to be of size [bsz x seqlen].""" + bsz, seq_len = input.shape[:2] + max_pos = self.padding_idx + 1 + seq_len + if self.weights is None or max_pos > self.weights.size(0): + # recompute/expand embeddings if needed + self.weights = SinusoidalPositionalEmbedding.get_embedding( + max_pos, + self.embedding_dim, + self.padding_idx, + ) + self.weights = self.weights.to(self._float_tensor) + + if incremental_state is not None: + # positions is the same for every token when decoding a single step + pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len + return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) + + positions = make_positions(input, self.padding_idx) if positions is None else positions + return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() + + def max_positions(self): + """Maximum number of supported positions.""" + return int(1e5) # an arbitrary large number + + +class ConvTBC(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding=0): + super(ConvTBC, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.padding = padding + + self.weight = torch.nn.Parameter(torch.Tensor( + self.kernel_size, in_channels, out_channels)) + self.bias = torch.nn.Parameter(torch.Tensor(out_channels)) + + def forward(self, input): + return torch.conv_tbc(input.contiguous(), self.weight, self.bias, self.padding) + + +class MultiheadAttention(nn.Module): + def __init__(self, embed_dim, num_heads, kdim=None, vdim=None, dropout=0., bias=True, + add_bias_kv=False, add_zero_attn=False, self_attention=False, + encoder_decoder_attention=False): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self.qkv_same_dim = self.kdim == embed_dim and self.vdim == embed_dim + + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" + self.scaling = self.head_dim ** -0.5 + + self.self_attention = self_attention + self.encoder_decoder_attention = encoder_decoder_attention + + assert not self.self_attention or self.qkv_same_dim, 'Self-attention requires query, key and ' \ + 'value to be of the same size' + + if self.qkv_same_dim: + self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim)) + else: + self.k_proj_weight = Parameter(torch.Tensor(embed_dim, self.kdim)) + self.v_proj_weight = Parameter(torch.Tensor(embed_dim, self.vdim)) + self.q_proj_weight = Parameter(torch.Tensor(embed_dim, embed_dim)) + + if bias: + self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim)) + else: + self.register_parameter('in_proj_bias', None) + + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + if add_bias_kv: + self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim)) + self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim)) + else: + self.bias_k = self.bias_v = None + + self.add_zero_attn = add_zero_attn + + self.reset_parameters() + + self.enable_torch_version = False + if hasattr(F, "multi_head_attention_forward"): + self.enable_torch_version = True + else: + self.enable_torch_version = False + self.last_attn_probs = None + + def reset_parameters(self): + if self.qkv_same_dim: + nn.init.xavier_uniform_(self.in_proj_weight) + else: + nn.init.xavier_uniform_(self.k_proj_weight) + nn.init.xavier_uniform_(self.v_proj_weight) + nn.init.xavier_uniform_(self.q_proj_weight) + + nn.init.xavier_uniform_(self.out_proj.weight) + if self.in_proj_bias is not None: + nn.init.constant_(self.in_proj_bias, 0.) + nn.init.constant_(self.out_proj.bias, 0.) + if self.bias_k is not None: + nn.init.xavier_normal_(self.bias_k) + if self.bias_v is not None: + nn.init.xavier_normal_(self.bias_v) + + def forward( + self, + query, key, value, + key_padding_mask=None, + incremental_state=None, + need_weights=True, + static_kv=False, + attn_mask=None, + before_softmax=False, + need_head_weights=False, + enc_dec_attn_constraint_mask=None, + reset_attn_weight=None + ): + """Input shape: Time x Batch x Channel + + Args: + key_padding_mask (ByteTensor, optional): mask to exclude + keys that are pads, of shape `(batch, src_len)`, where + padding elements are indicated by 1s. + need_weights (bool, optional): return the attention weights, + averaged over heads (default: False). + attn_mask (ByteTensor, optional): typically used to + implement causal attention, where the mask prevents the + attention from looking forward in time (default: None). + before_softmax (bool, optional): return the raw attention + weights and values before the attention softmax. + need_head_weights (bool, optional): return the attention + weights for each head. Implies *need_weights*. Default: + return the average attention weights over all heads. + """ + if need_head_weights: + need_weights = True + + tgt_len, bsz, embed_dim = query.size() + assert embed_dim == self.embed_dim + assert list(query.size()) == [tgt_len, bsz, embed_dim] + if self.enable_torch_version and incremental_state is None and not static_kv and reset_attn_weight is None: + if self.qkv_same_dim: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + self.in_proj_weight, + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask) + else: + return F.multi_head_attention_forward(query, key, value, + self.embed_dim, self.num_heads, + torch.empty([0]), + self.in_proj_bias, self.bias_k, self.bias_v, + self.add_zero_attn, self.dropout, + self.out_proj.weight, self.out_proj.bias, + self.training, key_padding_mask, need_weights, + attn_mask, use_separate_proj_weight=True, + q_proj_weight=self.q_proj_weight, + k_proj_weight=self.k_proj_weight, + v_proj_weight=self.v_proj_weight) + + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + # previous time steps are cached - no need to recompute + # key and value if they are static + if static_kv: + assert self.encoder_decoder_attention and not self.self_attention + key = value = None + else: + saved_state = None + + if self.self_attention: + # self-attention + q, k, v = self.in_proj_qkv(query) + elif self.encoder_decoder_attention: + # encoder-decoder attention + q = self.in_proj_q(query) + if key is None: + assert value is None + k = v = None + else: + k = self.in_proj_k(key) + v = self.in_proj_v(key) + + else: + q = self.in_proj_q(query) + k = self.in_proj_k(key) + v = self.in_proj_v(value) + q *= self.scaling + + if self.bias_k is not None: + assert self.bias_v is not None + k = torch.cat([k, self.bias_k.repeat(1, bsz, 1)]) + v = torch.cat([v, self.bias_v.repeat(1, bsz, 1)]) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, key_padding_mask.new_zeros(key_padding_mask.size(0), 1)], dim=1) + + q = q.contiguous().view(tgt_len, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if k is not None: + k = k.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + if v is not None: + v = v.contiguous().view(-1, bsz * self.num_heads, self.head_dim).transpose(0, 1) + + if saved_state is not None: + # saved states are stored with shape (bsz, num_heads, seq_len, head_dim) + if 'prev_key' in saved_state: + prev_key = saved_state['prev_key'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + k = prev_key + else: + k = torch.cat((prev_key, k), dim=1) + if 'prev_value' in saved_state: + prev_value = saved_state['prev_value'].view(bsz * self.num_heads, -1, self.head_dim) + if static_kv: + v = prev_value + else: + v = torch.cat((prev_value, v), dim=1) + if 'prev_key_padding_mask' in saved_state and saved_state['prev_key_padding_mask'] is not None: + prev_key_padding_mask = saved_state['prev_key_padding_mask'] + if static_kv: + key_padding_mask = prev_key_padding_mask + else: + key_padding_mask = torch.cat((prev_key_padding_mask, key_padding_mask), dim=1) + + saved_state['prev_key'] = k.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_value'] = v.view(bsz, self.num_heads, -1, self.head_dim) + saved_state['prev_key_padding_mask'] = key_padding_mask + + self._set_input_buffer(incremental_state, saved_state) + + src_len = k.size(1) + + # This is part of a workaround to get around fork/join parallelism + # not supporting Optional types. + if key_padding_mask is not None and key_padding_mask.shape == torch.Size([]): + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.size(0) == bsz + assert key_padding_mask.size(1) == src_len + + if self.add_zero_attn: + src_len += 1 + k = torch.cat([k, k.new_zeros((k.size(0), 1) + k.size()[2:])], dim=1) + v = torch.cat([v, v.new_zeros((v.size(0), 1) + v.size()[2:])], dim=1) + if attn_mask is not None: + attn_mask = torch.cat([attn_mask, attn_mask.new_zeros(attn_mask.size(0), 1)], dim=1) + if key_padding_mask is not None: + key_padding_mask = torch.cat( + [key_padding_mask, torch.zeros(key_padding_mask.size(0), 1).type_as(key_padding_mask)], dim=1) + + attn_weights = torch.bmm(q, k.transpose(1, 2)) + attn_weights = self.apply_sparse_mask(attn_weights, tgt_len, src_len, bsz) + + assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + + if attn_mask is not None: + if len(attn_mask.shape) == 2: + attn_mask = attn_mask.unsqueeze(0) + elif len(attn_mask.shape) == 3: + attn_mask = attn_mask[:, None].repeat([1, self.num_heads, 1, 1]).reshape( + bsz * self.num_heads, tgt_len, src_len) + attn_weights = attn_weights + attn_mask + + if enc_dec_attn_constraint_mask is not None: # bs x head x L_kv + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + enc_dec_attn_constraint_mask.unsqueeze(2).bool(), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + if key_padding_mask is not None: + # don't attend to padding symbols + attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2), + -1e8, + ) + attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + + attn_logits = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + + if before_softmax: + return attn_weights, v + + attn_weights_float = softmax(attn_weights, dim=-1) + attn_weights = attn_weights_float.type_as(attn_weights) + attn_probs = F.dropout(attn_weights_float.type_as(attn_weights), p=self.dropout, training=self.training) + + if reset_attn_weight is not None: + if reset_attn_weight: + self.last_attn_probs = attn_probs.detach() + else: + assert self.last_attn_probs is not None + attn_probs = self.last_attn_probs + attn = torch.bmm(attn_probs, v) + assert list(attn.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + attn = attn.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) + attn = self.out_proj(attn) + + if need_weights: + attn_weights = attn_weights_float.view(bsz, self.num_heads, tgt_len, src_len).transpose(1, 0) + if not need_head_weights: + # average attention weights over heads + attn_weights = attn_weights.mean(dim=0) + else: + attn_weights = None + + return attn, (attn_weights, attn_logits) + + def in_proj_qkv(self, query): + return self._in_proj(query).chunk(3, dim=-1) + + def in_proj_q(self, query): + if self.qkv_same_dim: + return self._in_proj(query, end=self.embed_dim) + else: + bias = self.in_proj_bias + if bias is not None: + bias = bias[:self.embed_dim] + return F.linear(query, self.q_proj_weight, bias) + + def in_proj_k(self, key): + if self.qkv_same_dim: + return self._in_proj(key, start=self.embed_dim, end=2 * self.embed_dim) + else: + weight = self.k_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[self.embed_dim:2 * self.embed_dim] + return F.linear(key, weight, bias) + + def in_proj_v(self, value): + if self.qkv_same_dim: + return self._in_proj(value, start=2 * self.embed_dim) + else: + weight = self.v_proj_weight + bias = self.in_proj_bias + if bias is not None: + bias = bias[2 * self.embed_dim:] + return F.linear(value, weight, bias) + + def _in_proj(self, input, start=0, end=None): + weight = self.in_proj_weight + bias = self.in_proj_bias + weight = weight[start:end, :] + if bias is not None: + bias = bias[start:end] + return F.linear(input, weight, bias) + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'attn_state', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'attn_state', + buffer, + ) + + def apply_sparse_mask(self, attn_weights, tgt_len, src_len, bsz): + return attn_weights + + def clear_buffer(self, incremental_state=None): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_key' in saved_state: + del saved_state['prev_key'] + if 'prev_value' in saved_state: + del saved_state['prev_value'] + self._set_input_buffer(incremental_state, saved_state) + + +class Swish(torch.autograd.Function): + @staticmethod + def forward(ctx, i): + result = i * torch.sigmoid(i) + ctx.save_for_backward(i) + return result + + @staticmethod + def backward(ctx, grad_output): + i = ctx.saved_variables[0] + sigmoid_i = torch.sigmoid(i) + return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) + + +class CustomSwish(nn.Module): + def forward(self, input_tensor): + return Swish.apply(input_tensor) + + +class TransformerFFNLayer(nn.Module): + def __init__(self, hidden_size, filter_size, padding="SAME", kernel_size=1, dropout=0., act='gelu'): + super().__init__() + self.kernel_size = kernel_size + self.dropout = dropout + self.act = act + if padding == 'SAME': + self.ffn_1 = nn.Conv1d(hidden_size, filter_size, kernel_size, padding=kernel_size // 2) + elif padding == 'LEFT': + self.ffn_1 = nn.Sequential( + nn.ConstantPad1d((kernel_size - 1, 0), 0.0), + nn.Conv1d(hidden_size, filter_size, kernel_size) + ) + self.ffn_2 = Linear(filter_size, hidden_size) + if self.act == 'swish': + self.swish_fn = CustomSwish() + + def forward(self, x, incremental_state=None): + # x: T x B x C + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + prev_input = saved_state['prev_input'] + x = torch.cat((prev_input, x), dim=0) + x = x[-self.kernel_size:] + saved_state['prev_input'] = x + self._set_input_buffer(incremental_state, saved_state) + + x = self.ffn_1(x.permute(1, 2, 0)).permute(2, 0, 1) + x = x * self.kernel_size ** -0.5 + + if incremental_state is not None: + x = x[-1:] + if self.act == 'gelu': + x = F.gelu(x) + if self.act == 'relu': + x = F.relu(x) + if self.act == 'swish': + x = self.swish_fn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = self.ffn_2(x) + return x + + def _get_input_buffer(self, incremental_state): + return get_incremental_state( + self, + incremental_state, + 'f', + ) or {} + + def _set_input_buffer(self, incremental_state, buffer): + set_incremental_state( + self, + incremental_state, + 'f', + buffer, + ) + + def clear_buffer(self, incremental_state): + if incremental_state is not None: + saved_state = self._get_input_buffer(incremental_state) + if 'prev_input' in saved_state: + del saved_state['prev_input'] + self._set_input_buffer(incremental_state, saved_state) + + +class BatchNorm1dTBC(nn.Module): + def __init__(self, c): + super(BatchNorm1dTBC, self).__init__() + self.bn = nn.BatchNorm1d(c) + + def forward(self, x): + """ + + :param x: [T, B, C] + :return: [T, B, C] + """ + x = x.permute(1, 2, 0) # [B, C, T] + x = self.bn(x) # [B, C, T] + x = x.permute(2, 0, 1) # [T, B, C] + return x + + +class EncSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, + relu_dropout=0.1, kernel_size=9, padding='SAME', norm='ln', act='gelu'): + super().__init__() + self.c = c + self.dropout = dropout + self.num_heads = num_heads + if num_heads > 0: + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm1 = BatchNorm1dTBC(c) + elif norm == 'gn': + self.layer_norm1 = GroupNorm1DTBC(8, c) + self.self_attn = MultiheadAttention( + self.c, num_heads, self_attention=True, dropout=attention_dropout, bias=False) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'bn': + self.layer_norm2 = BatchNorm1dTBC(c) + elif norm == 'gn': + self.layer_norm2 = GroupNorm1DTBC(8, c) + self.ffn = TransformerFFNLayer( + c, 4 * c, kernel_size=kernel_size, dropout=relu_dropout, padding=padding, act=act) + + def forward(self, x, encoder_padding_mask=None, **kwargs): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + if self.num_heads > 0: + residual = x + x = self.layer_norm1(x) + x, _, = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=encoder_padding_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + + residual = x + x = self.layer_norm2(x) + x = self.ffn(x) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + x = x * (1 - encoder_padding_mask.float()).transpose(0, 1)[..., None] + return x + + +class DecSALayer(nn.Module): + def __init__(self, c, num_heads, dropout, attention_dropout=0.1, relu_dropout=0.1, + kernel_size=9, act='gelu', norm='ln'): + super().__init__() + self.c = c + self.dropout = dropout + if norm == 'ln': + self.layer_norm1 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm1 = GroupNorm1DTBC(8, c) + self.self_attn = MultiheadAttention( + c, num_heads, self_attention=True, dropout=attention_dropout, bias=False + ) + if norm == 'ln': + self.layer_norm2 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm2 = GroupNorm1DTBC(8, c) + self.encoder_attn = MultiheadAttention( + c, num_heads, encoder_decoder_attention=True, dropout=attention_dropout, bias=False, + ) + if norm == 'ln': + self.layer_norm3 = LayerNorm(c) + elif norm == 'gn': + self.layer_norm3 = GroupNorm1DTBC(8, c) + self.ffn = TransformerFFNLayer( + c, 4 * c, padding='LEFT', kernel_size=kernel_size, dropout=relu_dropout, act=act) + + def forward( + self, + x, + encoder_out=None, + encoder_padding_mask=None, + incremental_state=None, + self_attn_mask=None, + self_attn_padding_mask=None, + attn_out=None, + reset_attn_weight=None, + **kwargs, + ): + layer_norm_training = kwargs.get('layer_norm_training', None) + if layer_norm_training is not None: + self.layer_norm1.training = layer_norm_training + self.layer_norm2.training = layer_norm_training + self.layer_norm3.training = layer_norm_training + residual = x + x = self.layer_norm1(x) + x, _ = self.self_attn( + query=x, + key=x, + value=x, + key_padding_mask=self_attn_padding_mask, + incremental_state=incremental_state, + attn_mask=self_attn_mask + ) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + attn_logits = None + if encoder_out is not None or attn_out is not None: + residual = x + x = self.layer_norm2(x) + if encoder_out is not None: + x, attn = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + incremental_state=incremental_state, + static_kv=True, + enc_dec_attn_constraint_mask=get_incremental_state(self, incremental_state, + 'enc_dec_attn_constraint_mask'), + reset_attn_weight=reset_attn_weight + ) + attn_logits = attn[1] + elif attn_out is not None: + x = self.encoder_attn.in_proj_v(attn_out) + if encoder_out is not None or attn_out is not None: + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + + residual = x + x = self.layer_norm3(x) + x = self.ffn(x, incremental_state=incremental_state) + x = F.dropout(x, self.dropout, training=self.training) + x = residual + x + return x, attn_logits + + def clear_buffer(self, input, encoder_out=None, encoder_padding_mask=None, incremental_state=None): + self.encoder_attn.clear_buffer(incremental_state) + self.ffn.clear_buffer(incremental_state) + + def set_buffer(self, name, tensor, incremental_state): + return set_incremental_state(self, incremental_state, name, tensor) + + +class ConvBlock(nn.Module): + def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0): + super().__init__() + self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride) + self.norm = norm + if self.norm == 'bn': + self.norm = nn.BatchNorm1d(n_chans) + elif self.norm == 'in': + self.norm = nn.InstanceNorm1d(n_chans, affine=True) + elif self.norm == 'gn': + self.norm = nn.GroupNorm(n_chans // 16, n_chans) + elif self.norm == 'ln': + self.norm = LayerNorm(n_chans // 16, n_chans) + elif self.norm == 'wn': + self.conv = torch.nn.utils.weight_norm(self.conv.conv) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + + def forward(self, x): + """ + + :param x: [B, C, T] + :return: [B, C, T] + """ + x = self.conv(x) + if not isinstance(self.norm, str): + if self.norm == 'none': + pass + elif self.norm == 'ln': + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + else: + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + return x + + +class ConvStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', + dropout=0, strides=None, res=True): + super().__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.res = res + self.in_proj = Linear(idim, n_chans) + if strides is None: + strides = [1] * n_layers + else: + assert len(strides) == n_layers + for idx in range(n_layers): + self.conv.append(ConvBlock( + n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout)) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x, return_hiddens=False): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + hiddens = [] + for f in self.conv: + x_ = f(x) + x = x + x_ if self.res else x_ # (B, C, Tmax) + hiddens.append(x) + x = x.transpose(1, -1) + x = self.out_proj(x) # (B, Tmax, H) + if return_hiddens: + hiddens = torch.stack(hiddens, 1) # [B, L, C, T] + return x, hiddens + return x + + +class ConvGlobalStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', dropout=0, + strides=[2, 2, 2, 2, 2]): + super().__init__() + self.conv = torch.nn.ModuleList() + self.pooling = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.in_proj = Linear(idim, n_chans) + for idx in range(n_layers): + self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=strides[idx], + norm=norm, dropout=dropout)) + self.pooling.append(nn.MaxPool1d(strides[idx])) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + for f, p in zip(self.conv, self.pooling): + x = f(x) # (B, C, T) + x = x.transpose(1, -1) + x = self.out_proj(x.mean(1)) # (B, H) + return x + + +class ConvLSTMStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=3, norm='gn', dropout=0): + super().__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.in_proj = Linear(idim, n_chans) + for idx in range(n_layers): + self.conv.append(ConvBlock(n_chans, n_chans, kernel_size, stride=1, norm=norm, dropout=dropout)) + self.lstm = nn.LSTM(n_chans, n_chans, 1, batch_first=True, bidirectional=True) + self.out_proj = Linear(n_chans * 2, odim) + + def forward(self, x): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + x = x + f(x) # (B, C, Tmax) + x = x.transpose(1, -1) + x, _ = self.lstm(x) # (B, Tmax, H*2) + x = self.out_proj(x) # (B, Tmax, H) + return x + + +class ResidualLayer(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, padding): + super(ResidualLayer, self).__init__() + self.conv1d_layer = nn.Sequential(nn.Conv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding), + nn.InstanceNorm1d(num_features=out_channels, + affine=True)) + + self.conv_layer_gates = nn.Sequential(nn.Conv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding), + nn.InstanceNorm1d(num_features=out_channels, + affine=True)) + + self.conv1d_out_layer = nn.Sequential(nn.Conv1d(in_channels=out_channels, + out_channels=in_channels, + kernel_size=kernel_size, + stride=1, + padding=padding), + nn.InstanceNorm1d(num_features=in_channels, + affine=True)) + + def forward(self, input): + """ + + :param input: [B, H, T] + :return: input: [B, H, T] + """ + h1_norm = self.conv1d_layer(input) + h1_gates_norm = self.conv_layer_gates(input) + + # GLU + h1_glu = h1_norm * torch.sigmoid(h1_gates_norm) + + h2_norm = self.conv1d_out_layer(h1_glu) + return input + h2_norm + + +class ConvGLUStacks(nn.Module): + def __init__(self, idim=80, n_layers=3, n_chans=256, odim=32, kernel_size=5, dropout=0): + super().__init__() + self.convs = [] + self.kernel_size = kernel_size + self.in_proj = Linear(idim, n_chans) + for idx in range(n_layers): + self.convs.append( + nn.Sequential(ResidualLayer( + n_chans, n_chans, kernel_size, kernel_size // 2), + nn.Dropout(dropout) + )) + self.convs = nn.Sequential(*self.convs) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + x = self.convs(x) # (B, C, Tmax) + x = x.transpose(1, -1) + x = self.out_proj(x) # (B, Tmax, H) + return x diff --git a/modules/commons/espnet_positional_embedding.py b/modules/commons/espnet_positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..74decb6ab300951490ae08a4b93041a0542b5bb7 --- /dev/null +++ b/modules/commons/espnet_positional_embedding.py @@ -0,0 +1,113 @@ +import math +import torch + + +class PositionalEncoding(torch.nn.Module): + """Positional encoding. + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + reverse (bool): Whether to reverse the input position. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000, reverse=False): + """Construct an PositionalEncoding object.""" + super(PositionalEncoding, self).__init__() + self.d_model = d_model + self.reverse = reverse + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + + def extend_pe(self, x): + """Reset the positional encodings.""" + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + if self.reverse: + position = torch.arange( + x.size(1) - 1, -1, -1.0, dtype=torch.float32 + ).unsqueeze(1) + else: + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) + * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +class ScaledPositionalEncoding(PositionalEncoding): + """Scaled positional encoding module. + See Sec. 3.2 https://arxiv.org/abs/1809.08895 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model=d_model, dropout_rate=dropout_rate, max_len=max_len) + self.alpha = torch.nn.Parameter(torch.tensor(1.0)) + + def reset_parameters(self): + """Reset parameters.""" + self.alpha.data = torch.tensor(1.0) + + def forward(self, x): + """Add positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + """ + self.extend_pe(x) + x = x + self.alpha * self.pe[:, : x.size(1)] + return self.dropout(x) + + +class RelPositionalEncoding(PositionalEncoding): + """Relative positional encoding module. + See : Appendix B in https://arxiv.org/abs/1901.02860 + Args: + d_model (int): Embedding dimension. + dropout_rate (float): Dropout rate. + max_len (int): Maximum input length. + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Initialize class.""" + super().__init__(d_model, dropout_rate, max_len, reverse=True) + + def forward(self, x): + """Compute positional encoding. + Args: + x (torch.Tensor): Input tensor (batch, time, `*`). + Returns: + torch.Tensor: Encoded tensor (batch, time, `*`). + torch.Tensor: Positional embedding tensor (1, time, `*`). + """ + self.extend_pe(x) + x = x * self.xscale + pos_emb = self.pe[:, : x.size(1)] + return self.dropout(x) + self.dropout(pos_emb) \ No newline at end of file diff --git a/modules/commons/ssim.py b/modules/commons/ssim.py new file mode 100644 index 0000000000000000000000000000000000000000..0d0241f267ef58b24979e022b05f2a9adf768826 --- /dev/null +++ b/modules/commons/ssim.py @@ -0,0 +1,391 @@ +# ''' +# https://github.com/One-sixth/ms_ssim_pytorch/blob/master/ssim.py +# ''' +# +# import torch +# import torch.jit +# import torch.nn.functional as F +# +# +# @torch.jit.script +# def create_window(window_size: int, sigma: float, channel: int): +# ''' +# Create 1-D gauss kernel +# :param window_size: the size of gauss kernel +# :param sigma: sigma of normal distribution +# :param channel: input channel +# :return: 1D kernel +# ''' +# coords = torch.arange(window_size, dtype=torch.float) +# coords -= window_size // 2 +# +# g = torch.exp(-(coords ** 2) / (2 * sigma ** 2)) +# g /= g.sum() +# +# g = g.reshape(1, 1, 1, -1).repeat(channel, 1, 1, 1) +# return g +# +# +# @torch.jit.script +# def _gaussian_filter(x, window_1d, use_padding: bool): +# ''' +# Blur input with 1-D kernel +# :param x: batch of tensors to be blured +# :param window_1d: 1-D gauss kernel +# :param use_padding: padding image before conv +# :return: blured tensors +# ''' +# C = x.shape[1] +# padding = 0 +# if use_padding: +# window_size = window_1d.shape[3] +# padding = window_size // 2 +# out = F.conv2d(x, window_1d, stride=1, padding=(0, padding), groups=C) +# out = F.conv2d(out, window_1d.transpose(2, 3), stride=1, padding=(padding, 0), groups=C) +# return out +# +# +# @torch.jit.script +# def ssim(X, Y, window, data_range: float, use_padding: bool = False): +# ''' +# Calculate ssim index for X and Y +# :param X: images [B, C, H, N_bins] +# :param Y: images [B, C, H, N_bins] +# :param window: 1-D gauss kernel +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param use_padding: padding image before conv +# :return: +# ''' +# +# K1 = 0.01 +# K2 = 0.03 +# compensation = 1.0 +# +# C1 = (K1 * data_range) ** 2 +# C2 = (K2 * data_range) ** 2 +# +# mu1 = _gaussian_filter(X, window, use_padding) +# mu2 = _gaussian_filter(Y, window, use_padding) +# sigma1_sq = _gaussian_filter(X * X, window, use_padding) +# sigma2_sq = _gaussian_filter(Y * Y, window, use_padding) +# sigma12 = _gaussian_filter(X * Y, window, use_padding) +# +# mu1_sq = mu1.pow(2) +# mu2_sq = mu2.pow(2) +# mu1_mu2 = mu1 * mu2 +# +# sigma1_sq = compensation * (sigma1_sq - mu1_sq) +# sigma2_sq = compensation * (sigma2_sq - mu2_sq) +# sigma12 = compensation * (sigma12 - mu1_mu2) +# +# cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) +# # Fixed the issue that the negative value of cs_map caused ms_ssim to output Nan. +# cs_map = cs_map.clamp_min(0.) +# ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map +# +# ssim_val = ssim_map.mean(dim=(1, 2, 3)) # reduce along CHW +# cs = cs_map.mean(dim=(1, 2, 3)) +# +# return ssim_val, cs +# +# +# @torch.jit.script +# def ms_ssim(X, Y, window, data_range: float, weights, use_padding: bool = False, eps: float = 1e-8): +# ''' +# interface of ms-ssim +# :param X: a batch of images, (N,C,H,W) +# :param Y: a batch of images, (N,C,H,W) +# :param window: 1-D gauss kernel +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param weights: weights for different levels +# :param use_padding: padding image before conv +# :param eps: use for avoid grad nan. +# :return: +# ''' +# levels = weights.shape[0] +# cs_vals = [] +# ssim_vals = [] +# for _ in range(levels): +# ssim_val, cs = ssim(X, Y, window=window, data_range=data_range, use_padding=use_padding) +# # Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. +# ssim_val = ssim_val.clamp_min(eps) +# cs = cs.clamp_min(eps) +# cs_vals.append(cs) +# +# ssim_vals.append(ssim_val) +# padding = (X.shape[2] % 2, X.shape[3] % 2) +# X = F.avg_pool2d(X, kernel_size=2, stride=2, padding=padding) +# Y = F.avg_pool2d(Y, kernel_size=2, stride=2, padding=padding) +# +# cs_vals = torch.stack(cs_vals, dim=0) +# ms_ssim_val = torch.prod((cs_vals[:-1] ** weights[:-1].unsqueeze(1)) * (ssim_vals[-1] ** weights[-1]), dim=0) +# return ms_ssim_val +# +# +# class SSIM(torch.jit.ScriptModule): +# __constants__ = ['data_range', 'use_padding'] +# +# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False): +# ''' +# :param window_size: the size of gauss kernel +# :param window_sigma: sigma of normal distribution +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param channel: input channels (default: 3) +# :param use_padding: padding image before conv +# ''' +# super().__init__() +# assert window_size % 2 == 1, 'Window size must be odd.' +# window = create_window(window_size, window_sigma, channel) +# self.register_buffer('window', window) +# self.data_range = data_range +# self.use_padding = use_padding +# +# @torch.jit.script_method +# def forward(self, X, Y): +# r = ssim(X, Y, window=self.window, data_range=self.data_range, use_padding=self.use_padding) +# return r[0] +# +# +# class MS_SSIM(torch.jit.ScriptModule): +# __constants__ = ['data_range', 'use_padding', 'eps'] +# +# def __init__(self, window_size=11, window_sigma=1.5, data_range=255., channel=3, use_padding=False, weights=None, +# levels=None, eps=1e-8): +# ''' +# class for ms-ssim +# :param window_size: the size of gauss kernel +# :param window_sigma: sigma of normal distribution +# :param data_range: value range of input images. (usually 1.0 or 255) +# :param channel: input channels +# :param use_padding: padding image before conv +# :param weights: weights for different levels. (default [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]) +# :param levels: number of downsampling +# :param eps: Use for fix a issue. When c = a ** b and a is 0, c.backward() will cause the a.grad become inf. +# ''' +# super().__init__() +# assert window_size % 2 == 1, 'Window size must be odd.' +# self.data_range = data_range +# self.use_padding = use_padding +# self.eps = eps +# +# window = create_window(window_size, window_sigma, channel) +# self.register_buffer('window', window) +# +# if weights is None: +# weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] +# weights = torch.tensor(weights, dtype=torch.float) +# +# if levels is not None: +# weights = weights[:levels] +# weights = weights / weights.sum() +# +# self.register_buffer('weights', weights) +# +# @torch.jit.script_method +# def forward(self, X, Y): +# return ms_ssim(X, Y, window=self.window, data_range=self.data_range, weights=self.weights, +# use_padding=self.use_padding, eps=self.eps) +# +# +# if __name__ == '__main__': +# print('Simple Test') +# im = torch.randint(0, 255, (5, 3, 256, 256), dtype=torch.float, device='cuda') +# img1 = im / 255 +# img2 = img1 * 0.5 +# +# losser = SSIM(data_range=1.).cuda() +# loss = losser(img1, img2).mean() +# +# losser2 = MS_SSIM(data_range=1.).cuda() +# loss2 = losser2(img1, img2).mean() +# +# print(loss.item()) +# print(loss2.item()) +# +# if __name__ == '__main__': +# print('Training Test') +# import cv2 +# import torch.optim +# import numpy as np +# import imageio +# import time +# +# out_test_video = False +# # 最好不要直接输出gif图,会非常大,最好先输出mkv文件后用ffmpeg转换到GIF +# video_use_gif = False +# +# im = cv2.imread('test_img1.jpg', 1) +# t_im = torch.from_numpy(im).cuda().permute(2, 0, 1).float()[None] / 255. +# +# if out_test_video: +# if video_use_gif: +# fps = 0.5 +# out_wh = (im.shape[1] // 2, im.shape[0] // 2) +# suffix = '.gif' +# else: +# fps = 5 +# out_wh = (im.shape[1], im.shape[0]) +# suffix = '.mkv' +# video_last_time = time.perf_counter() +# video = imageio.get_writer('ssim_test' + suffix, fps=fps) +# +# # 测试ssim +# print('Training SSIM') +# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. +# rand_im.requires_grad = True +# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) +# losser = SSIM(data_range=1., channel=t_im.shape[1]).cuda() +# ssim_score = 0 +# while ssim_score < 0.999: +# optim.zero_grad() +# loss = losser(rand_im, t_im) +# (-loss).sum().backward() +# ssim_score = loss.item() +# optim.step() +# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] +# r_im = cv2.putText(r_im, 'ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) +# +# if out_test_video: +# if time.perf_counter() - video_last_time > 1. / fps: +# video_last_time = time.perf_counter() +# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) +# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) +# if isinstance(out_frame, cv2.UMat): +# out_frame = out_frame.get() +# video.append_data(out_frame) +# +# cv2.imshow('ssim', r_im) +# cv2.setWindowTitle('ssim', 'ssim %f' % ssim_score) +# cv2.waitKey(1) +# +# if out_test_video: +# video.close() +# +# # 测试ms_ssim +# if out_test_video: +# if video_use_gif: +# fps = 0.5 +# out_wh = (im.shape[1] // 2, im.shape[0] // 2) +# suffix = '.gif' +# else: +# fps = 5 +# out_wh = (im.shape[1], im.shape[0]) +# suffix = '.mkv' +# video_last_time = time.perf_counter() +# video = imageio.get_writer('ms_ssim_test' + suffix, fps=fps) +# +# print('Training MS_SSIM') +# rand_im = torch.randint_like(t_im, 0, 255, dtype=torch.float32) / 255. +# rand_im.requires_grad = True +# optim = torch.optim.Adam([rand_im], 0.003, eps=1e-8) +# losser = MS_SSIM(data_range=1., channel=t_im.shape[1]).cuda() +# ssim_score = 0 +# while ssim_score < 0.999: +# optim.zero_grad() +# loss = losser(rand_im, t_im) +# (-loss).sum().backward() +# ssim_score = loss.item() +# optim.step() +# r_im = np.transpose(rand_im.detach().cpu().numpy().clip(0, 1) * 255, [0, 2, 3, 1]).astype(np.uint8)[0] +# r_im = cv2.putText(r_im, 'ms_ssim %f' % ssim_score, (10, 30), cv2.FONT_HERSHEY_PLAIN, 2, (255, 0, 0), 2) +# +# if out_test_video: +# if time.perf_counter() - video_last_time > 1. / fps: +# video_last_time = time.perf_counter() +# out_frame = cv2.cvtColor(r_im, cv2.COLOR_BGR2RGB) +# out_frame = cv2.resize(out_frame, out_wh, interpolation=cv2.INTER_AREA) +# if isinstance(out_frame, cv2.UMat): +# out_frame = out_frame.get() +# video.append_data(out_frame) +# +# cv2.imshow('ms_ssim', r_im) +# cv2.setWindowTitle('ms_ssim', 'ms_ssim %f' % ssim_score) +# cv2.waitKey(1) +# +# if out_test_video: +# video.close() + +""" +Adapted from https://github.com/Po-Hsun-Su/pytorch-ssim +""" + +import torch +import torch.nn.functional as F +from torch.autograd import Variable +import numpy as np +from math import exp + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) + return gauss / gauss.sum() + + +def create_window(window_size, channel): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) + window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) + return window + + +def _ssim(img1, img2, window, window_size, channel, size_average=True): + mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) + mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq + sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq + sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 + + C1 = 0.01 ** 2 + C2 = 0.03 ** 2 + + ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) + + if size_average: + return ssim_map.mean() + else: + return ssim_map.mean(1) + + +class SSIM(torch.nn.Module): + def __init__(self, window_size=11, size_average=True): + super(SSIM, self).__init__() + self.window_size = window_size + self.size_average = size_average + self.channel = 1 + self.window = create_window(window_size, self.channel) + + def forward(self, img1, img2): + (_, channel, _, _) = img1.size() + + if channel == self.channel and self.window.data.type() == img1.data.type(): + window = self.window + else: + window = create_window(self.window_size, channel) + + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + + self.window = window + self.channel = channel + + return _ssim(img1, img2, window, self.window_size, channel, self.size_average) + + +window = None + + +def ssim(img1, img2, window_size=11, size_average=True): + (_, channel, _, _) = img1.size() + global window + if window is None: + window = create_window(window_size, channel) + if img1.is_cuda: + window = window.cuda(img1.get_device()) + window = window.type_as(img1) + return _ssim(img1, img2, window, window_size, channel, size_average) diff --git a/modules/fastspeech/fs2.py b/modules/fastspeech/fs2.py new file mode 100644 index 0000000000000000000000000000000000000000..52b4ac4aaa7ae49f06736a038bde83ca2cfa8483 --- /dev/null +++ b/modules/fastspeech/fs2.py @@ -0,0 +1,255 @@ +from modules.commons.common_layers import * +from modules.commons.common_layers import Embedding +from modules.fastspeech.tts_modules import FastspeechDecoder, DurationPredictor, LengthRegulator, PitchPredictor, \ + EnergyPredictor, FastspeechEncoder +from utils.cwt import cwt2f0 +from utils.hparams import hparams +from utils.pitch_utils import f0_to_coarse, denorm_f0, norm_f0 + +FS_ENCODERS = { + 'fft': lambda hp, embed_tokens, d: FastspeechEncoder( + embed_tokens, hp['hidden_size'], hp['enc_layers'], hp['enc_ffn_kernel_size'], + num_heads=hp['num_heads']), +} + +FS_DECODERS = { + 'fft': lambda hp: FastspeechDecoder( + hp['hidden_size'], hp['dec_layers'], hp['dec_ffn_kernel_size'], hp['num_heads']), +} + + +class FastSpeech2(nn.Module): + def __init__(self, dictionary, out_dims=None): + super().__init__() + self.dictionary = dictionary + self.padding_idx = dictionary.pad() + self.enc_layers = hparams['enc_layers'] + self.dec_layers = hparams['dec_layers'] + self.hidden_size = hparams['hidden_size'] + self.encoder_embed_tokens = self.build_embedding(self.dictionary, self.hidden_size) + self.encoder = FS_ENCODERS[hparams['encoder_type']](hparams, self.encoder_embed_tokens, self.dictionary) + self.decoder = FS_DECODERS[hparams['decoder_type']](hparams) + self.out_dims = out_dims + if out_dims is None: + self.out_dims = hparams['audio_num_mel_bins'] + self.mel_out = Linear(self.hidden_size, self.out_dims, bias=True) + + if hparams['use_spk_id']: + self.spk_embed_proj = Embedding(hparams['num_spk'] + 1, self.hidden_size) + if hparams['use_split_spk_id']: + self.spk_embed_f0 = Embedding(hparams['num_spk'] + 1, self.hidden_size) + self.spk_embed_dur = Embedding(hparams['num_spk'] + 1, self.hidden_size) + elif hparams['use_spk_embed']: + self.spk_embed_proj = Linear(256, self.hidden_size, bias=True) + predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size + self.dur_predictor = DurationPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['dur_predictor_layers'], + dropout_rate=hparams['predictor_dropout'], padding=hparams['ffn_padding'], + kernel_size=hparams['dur_predictor_kernel']) + self.length_regulator = LengthRegulator() + if hparams['use_pitch_embed']: + self.pitch_embed = Embedding(300, self.hidden_size, self.padding_idx) + if hparams['pitch_type'] == 'cwt': + h = hparams['cwt_hidden_size'] + cwt_out_dims = 10 + if hparams['use_uv']: + cwt_out_dims = cwt_out_dims + 1 + self.cwt_predictor = nn.Sequential( + nn.Linear(self.hidden_size, h), + PitchPredictor( + h, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], odim=cwt_out_dims, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel'])) + self.cwt_stats_layers = nn.Sequential( + nn.Linear(self.hidden_size, h), nn.ReLU(), + nn.Linear(h, h), nn.ReLU(), nn.Linear(h, 2) + ) + else: + self.pitch_predictor = PitchPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], + odim=2 if hparams['pitch_type'] == 'frame' else 1, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + if hparams['use_energy_embed']: + self.energy_embed = Embedding(256, self.hidden_size, self.padding_idx) + self.energy_predictor = EnergyPredictor( + self.hidden_size, + n_chans=predictor_hidden, + n_layers=hparams['predictor_layers'], + dropout_rate=hparams['predictor_dropout'], odim=1, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + + def build_embedding(self, dictionary, embed_dim): + num_embeddings = len(dictionary) + emb = Embedding(num_embeddings, embed_dim, self.padding_idx) + return emb + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False, + spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs): + ret = {} + encoder_out = self.encoder(txt_tokens) # [B, T, C] + src_nonpadding = (txt_tokens > 0).float()[:, :, None] + + # add ref style embed + # Not implemented + # variance encoder + var_embed = 0 + + # encoder_out_dur denotes encoder outputs for duration predictor + # in speech adaptation, duration predictor use old speaker embedding + if hparams['use_spk_embed']: + spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :] + elif hparams['use_spk_id']: + spk_embed_id = spk_embed + if spk_embed_dur_id is None: + spk_embed_dur_id = spk_embed_id + if spk_embed_f0_id is None: + spk_embed_f0_id = spk_embed_id + spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :] + spk_embed_dur = spk_embed_f0 = spk_embed + if hparams['use_split_spk_id']: + spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :] + spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :] + else: + spk_embed_dur = spk_embed_f0 = spk_embed = 0 + + # add dur + dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding + + mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret) + + decoder_inp = F.pad(encoder_out, [0, 0, 1, 0]) + + mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]]) + decoder_inp_origin = decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H] + + tgt_nonpadding = (mel2ph > 0).float()[:, :, None] + + # add pitch and energy embed + pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding + if hparams['use_pitch_embed']: + pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding + decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph) + if hparams['use_energy_embed']: + decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret) + + ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding + + if skip_decoder: + return ret + ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs) + + return ret + + def add_dur(self, dur_input, mel2ph, txt_tokens, ret): + """ + + :param dur_input: [B, T_txt, H] + :param mel2ph: [B, T_mel] + :param txt_tokens: [B, T_txt] + :param ret: + :return: + """ + src_padding = txt_tokens == 0 + dur_input = dur_input.detach() + hparams['predictor_grad'] * (dur_input - dur_input.detach()) + if mel2ph is None: + dur, xs = self.dur_predictor.inference(dur_input, src_padding) + ret['dur'] = xs + ret['dur_choice'] = dur + mel2ph = self.length_regulator(dur, src_padding).detach() + # from modules.fastspeech.fake_modules import FakeLengthRegulator + # fake_lr = FakeLengthRegulator() + # fake_mel2ph = fake_lr(dur, (1 - src_padding.long()).sum(-1))[..., 0].detach() + # print(mel2ph == fake_mel2ph) + else: + ret['dur'] = self.dur_predictor(dur_input, src_padding) + ret['mel2ph'] = mel2ph + return mel2ph + + def add_energy(self, decoder_inp, energy, ret): + decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) + ret['energy_pred'] = energy_pred = self.energy_predictor(decoder_inp)[:, :, 0] + if energy is None: + energy = energy_pred + energy = torch.clamp(energy * 256 // 4, max=255).long() + energy_embed = self.energy_embed(energy) + return energy_embed + + def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None): + if hparams['pitch_type'] == 'ph': + pitch_pred_inp = encoder_out.detach() + hparams['predictor_grad'] * (encoder_out - encoder_out.detach()) + pitch_padding = encoder_out.sum().abs() == 0 + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(pitch_pred_inp) + if f0 is None: + f0 = pitch_pred[:, :, 0] + ret['f0_denorm'] = f0_denorm = denorm_f0(f0, None, hparams, pitch_padding=pitch_padding) + pitch = f0_to_coarse(f0_denorm) # start from 0 [B, T_txt] + pitch = F.pad(pitch, [1, 0]) + pitch = torch.gather(pitch, 1, mel2ph) # [B, T_mel] + pitch_embed = self.pitch_embed(pitch) + return pitch_embed + decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach()) + + pitch_padding = mel2ph == 0 + + if hparams['pitch_type'] == 'cwt': + pitch_padding = None + ret['cwt'] = cwt_out = self.cwt_predictor(decoder_inp) + stats_out = self.cwt_stats_layers(encoder_out[:, 0, :]) # [B, 2] + mean = ret['f0_mean'] = stats_out[:, 0] + std = ret['f0_std'] = stats_out[:, 1] + cwt_spec = cwt_out[:, :, :10] + if f0 is None: + std = std * hparams['cwt_std_scale'] + f0 = self.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + if hparams['use_uv']: + assert cwt_out.shape[-1] == 11 + uv = cwt_out[:, :, -1] > 0 + elif hparams['pitch_ar']: + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None) + if f0 is None: + f0 = pitch_pred[:, :, 0] + else: + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp) + if f0 is None: + f0 = pitch_pred[:, :, 0] + if hparams['use_uv'] and uv is None: + uv = pitch_pred[:, :, 1] > 0 + ret['f0_denorm'] = f0_denorm = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding) + if pitch_padding is not None: + f0[pitch_padding] = 0 + + pitch = f0_to_coarse(f0_denorm) # start from 0 + pitch_embed = self.pitch_embed(pitch) + return pitch_embed + + def run_decoder(self, decoder_inp, tgt_nonpadding, ret, infer, **kwargs): + x = decoder_inp # [B, T, H] + x = self.decoder(x) + x = self.mel_out(x) + return x * tgt_nonpadding + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + f0 = cwt2f0(cwt_spec, mean, std, hparams['cwt_scales']) + f0 = torch.cat( + [f0] + [f0[:, -1:]] * (mel2ph.shape[1] - f0.shape[1]), 1) + f0_norm = norm_f0(f0, None, hparams) + return f0_norm + + def out2mel(self, out): + return out + + @staticmethod + def mel_norm(x): + return (x + 5.5) / (6.3 / 2) - 1 + + @staticmethod + def mel_denorm(x): + return (x + 1) * (6.3 / 2) - 5.5 diff --git a/modules/fastspeech/pe.py b/modules/fastspeech/pe.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fa5098b378bb4ed10f97f05a9ff725d1d2239c --- /dev/null +++ b/modules/fastspeech/pe.py @@ -0,0 +1,149 @@ +from modules.commons.common_layers import * +from utils.hparams import hparams +from modules.fastspeech.tts_modules import PitchPredictor +from utils.pitch_utils import denorm_f0 + + +class Prenet(nn.Module): + def __init__(self, in_dim=80, out_dim=256, kernel=5, n_layers=3, strides=None): + super(Prenet, self).__init__() + padding = kernel // 2 + self.layers = [] + self.strides = strides if strides is not None else [1] * n_layers + for l in range(n_layers): + self.layers.append(nn.Sequential( + nn.Conv1d(in_dim, out_dim, kernel_size=kernel, padding=padding, stride=self.strides[l]), + nn.ReLU(), + nn.BatchNorm1d(out_dim) + )) + in_dim = out_dim + self.layers = nn.ModuleList(self.layers) + self.out_proj = nn.Linear(out_dim, out_dim) + + def forward(self, x): + """ + + :param x: [B, T, 80] + :return: [L, B, T, H], [B, T, H] + """ + padding_mask = x.abs().sum(-1).eq(0).data # [B, T] + nonpadding_mask_TB = 1 - padding_mask.float()[:, None, :] # [B, 1, T] + x = x.transpose(1, 2) + hiddens = [] + for i, l in enumerate(self.layers): + nonpadding_mask_TB = nonpadding_mask_TB[:, :, ::self.strides[i]] + x = l(x) * nonpadding_mask_TB + hiddens.append(x) + hiddens = torch.stack(hiddens, 0) # [L, B, H, T] + hiddens = hiddens.transpose(2, 3) # [L, B, T, H] + x = self.out_proj(x.transpose(1, 2)) # [B, T, H] + x = x * nonpadding_mask_TB.transpose(1, 2) + return hiddens, x + + +class ConvBlock(nn.Module): + def __init__(self, idim=80, n_chans=256, kernel_size=3, stride=1, norm='gn', dropout=0): + super().__init__() + self.conv = ConvNorm(idim, n_chans, kernel_size, stride=stride) + self.norm = norm + if self.norm == 'bn': + self.norm = nn.BatchNorm1d(n_chans) + elif self.norm == 'in': + self.norm = nn.InstanceNorm1d(n_chans, affine=True) + elif self.norm == 'gn': + self.norm = nn.GroupNorm(n_chans // 16, n_chans) + elif self.norm == 'ln': + self.norm = LayerNorm(n_chans // 16, n_chans) + elif self.norm == 'wn': + self.conv = torch.nn.utils.weight_norm(self.conv.conv) + self.dropout = nn.Dropout(dropout) + self.relu = nn.ReLU() + + def forward(self, x): + """ + + :param x: [B, C, T] + :return: [B, C, T] + """ + x = self.conv(x) + if not isinstance(self.norm, str): + if self.norm == 'none': + pass + elif self.norm == 'ln': + x = self.norm(x.transpose(1, 2)).transpose(1, 2) + else: + x = self.norm(x) + x = self.relu(x) + x = self.dropout(x) + return x + + +class ConvStacks(nn.Module): + def __init__(self, idim=80, n_layers=5, n_chans=256, odim=32, kernel_size=5, norm='gn', + dropout=0, strides=None, res=True): + super().__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.res = res + self.in_proj = Linear(idim, n_chans) + if strides is None: + strides = [1] * n_layers + else: + assert len(strides) == n_layers + for idx in range(n_layers): + self.conv.append(ConvBlock( + n_chans, n_chans, kernel_size, stride=strides[idx], norm=norm, dropout=dropout)) + self.out_proj = Linear(n_chans, odim) + + def forward(self, x, return_hiddens=False): + """ + + :param x: [B, T, H] + :return: [B, T, H] + """ + x = self.in_proj(x) + x = x.transpose(1, -1) # (B, idim, Tmax) + hiddens = [] + for f in self.conv: + x_ = f(x) + x = x + x_ if self.res else x_ # (B, C, Tmax) + hiddens.append(x) + x = x.transpose(1, -1) + x = self.out_proj(x) # (B, Tmax, H) + if return_hiddens: + hiddens = torch.stack(hiddens, 1) # [B, L, C, T] + return x, hiddens + return x + + +class PitchExtractor(nn.Module): + def __init__(self, n_mel_bins=80, conv_layers=2): + super().__init__() + self.hidden_size = hparams['hidden_size'] + self.predictor_hidden = hparams['predictor_hidden'] if hparams['predictor_hidden'] > 0 else self.hidden_size + self.conv_layers = conv_layers + + self.mel_prenet = Prenet(n_mel_bins, self.hidden_size, strides=[1, 1, 1]) + if self.conv_layers > 0: + self.mel_encoder = ConvStacks( + idim=self.hidden_size, n_chans=self.hidden_size, odim=self.hidden_size, n_layers=self.conv_layers) + self.pitch_predictor = PitchPredictor( + self.hidden_size, n_chans=self.predictor_hidden, + n_layers=5, dropout_rate=0.1, odim=2, + padding=hparams['ffn_padding'], kernel_size=hparams['predictor_kernel']) + + def forward(self, mel_input=None): + ret = {} + mel_hidden = self.mel_prenet(mel_input)[1] + if self.conv_layers > 0: + mel_hidden = self.mel_encoder(mel_hidden) + + ret['pitch_pred'] = pitch_pred = self.pitch_predictor(mel_hidden) + + pitch_padding = mel_input.abs().sum(-1) == 0 + use_uv = hparams['pitch_type'] == 'frame' and hparams['use_uv'] + + ret['f0_denorm_pred'] = denorm_f0( + pitch_pred[:, :, 0], (pitch_pred[:, :, 1] > 0) if use_uv else None, + hparams, pitch_padding=pitch_padding) + return ret \ No newline at end of file diff --git a/modules/fastspeech/tts_modules.py b/modules/fastspeech/tts_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..195eff279de781dd2565cfb2da65533c58f6c332 --- /dev/null +++ b/modules/fastspeech/tts_modules.py @@ -0,0 +1,357 @@ +import logging +import math + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from modules.commons.espnet_positional_embedding import RelPositionalEncoding +from modules.commons.common_layers import SinusoidalPositionalEmbedding, Linear, EncSALayer, DecSALayer, BatchNorm1dTBC +from utils.hparams import hparams + +DEFAULT_MAX_SOURCE_POSITIONS = 2000 +DEFAULT_MAX_TARGET_POSITIONS = 2000 + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, hidden_size, dropout, kernel_size=None, num_heads=2, norm='ln'): + super().__init__() + self.hidden_size = hidden_size + self.dropout = dropout + self.num_heads = num_heads + self.op = EncSALayer( + hidden_size, num_heads, dropout=dropout, + attention_dropout=0.0, relu_dropout=dropout, + kernel_size=kernel_size + if kernel_size is not None else hparams['enc_ffn_kernel_size'], + padding=hparams['ffn_padding'], + norm=norm, act=hparams['ffn_act']) + + def forward(self, x, **kwargs): + return self.op(x, **kwargs) + + +###################### +# fastspeech modules +###################### +class LayerNorm(torch.nn.LayerNorm): + """Layer normalization module. + :param int nout: output dim size + :param int dim: dimension to be normalized + """ + + def __init__(self, nout, dim=-1): + """Construct an LayerNorm object.""" + super(LayerNorm, self).__init__(nout, eps=1e-12) + self.dim = dim + + def forward(self, x): + """Apply layer normalization. + :param torch.Tensor x: input tensor + :return: layer normalized tensor + :rtype torch.Tensor + """ + if self.dim == -1: + return super(LayerNorm, self).forward(x) + return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) + + +class DurationPredictor(torch.nn.Module): + """Duration predictor module. + This is a module of duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_. + The duration predictor predicts a duration of each frame in log domain from the hidden embeddings of encoder. + .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: + https://arxiv.org/pdf/1905.09263.pdf + Note: + The calculation domain of outputs is different between in `forward` and in `inference`. In `forward`, + the outputs are calculated in log domain but in `inference`, those are calculated in linear domain. + """ + + def __init__(self, idim, n_layers=2, n_chans=384, kernel_size=3, dropout_rate=0.1, offset=1.0, padding='SAME'): + """Initilize duration predictor module. + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + offset (float, optional): Offset value to avoid nan in log domain. + """ + super(DurationPredictor, self).__init__() + self.offset = offset + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.padding = padding + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential( + torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2) + if padding == 'SAME' + else (kernel_size - 1, 0), 0), + torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate) + )] + if hparams['dur_loss'] in ['mse', 'huber']: + odims = 1 + elif hparams['dur_loss'] == 'mog': + odims = 15 + elif hparams['dur_loss'] == 'crf': + odims = 32 + from torchcrf import CRF + self.crf = CRF(odims, batch_first=True) + self.linear = torch.nn.Linear(n_chans, odims) + + def _forward(self, xs, x_masks=None, is_inference=False): + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + if x_masks is not None: + xs = xs * (1 - x_masks.float())[:, None, :] + + xs = self.linear(xs.transpose(1, -1)) # [B, T, C] + xs = xs * (1 - x_masks.float())[:, :, None] # (B, T, C) + if is_inference: + return self.out2dur(xs), xs + else: + if hparams['dur_loss'] in ['mse']: + xs = xs.squeeze(-1) # (B, Tmax) + return xs + + def out2dur(self, xs): + if hparams['dur_loss'] in ['mse']: + # NOTE: calculate in log domain + xs = xs.squeeze(-1) # (B, Tmax) + dur = torch.clamp(torch.round(xs.exp() - self.offset), min=0).long() # avoid negative value + elif hparams['dur_loss'] == 'mog': + return NotImplementedError + elif hparams['dur_loss'] == 'crf': + dur = torch.LongTensor(self.crf.decode(xs)).cuda() + return dur + + def forward(self, xs, x_masks=None): + """Calculate forward propagation. + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). + Returns: + Tensor: Batch of predicted durations in log domain (B, Tmax). + """ + return self._forward(xs, x_masks, False) + + def inference(self, xs, x_masks=None): + """Inference duration. + Args: + xs (Tensor): Batch of input sequences (B, Tmax, idim). + x_masks (ByteTensor, optional): Batch of masks indicating padded part (B, Tmax). + Returns: + LongTensor: Batch of predicted durations in linear domain (B, Tmax). + """ + return self._forward(xs, x_masks, True) + + +class LengthRegulator(torch.nn.Module): + def __init__(self, pad_value=0.0): + super(LengthRegulator, self).__init__() + self.pad_value = pad_value + + def forward(self, dur, dur_padding=None, alpha=1.0): + """ + Example (no batch dim version): + 1. dur = [2,2,3] + 2. token_idx = [[1],[2],[3]], dur_cumsum = [2,4,7], dur_cumsum_prev = [0,2,4] + 3. token_mask = [[1,1,0,0,0,0,0], + [0,0,1,1,0,0,0], + [0,0,0,0,1,1,1]] + 4. token_idx * token_mask = [[1,1,0,0,0,0,0], + [0,0,2,2,0,0,0], + [0,0,0,0,3,3,3]] + 5. (token_idx * token_mask).sum(0) = [1,1,2,2,3,3,3] + + :param dur: Batch of durations of each frame (B, T_txt) + :param dur_padding: Batch of padding of each frame (B, T_txt) + :param alpha: duration rescale coefficient + :return: + mel2ph (B, T_speech) + """ + assert alpha > 0 + dur = torch.round(dur.float() * alpha).long() + if dur_padding is not None: + dur = dur * (1 - dur_padding.long()) + token_idx = torch.arange(1, dur.shape[1] + 1)[None, :, None].to(dur.device) + dur_cumsum = torch.cumsum(dur, 1) + dur_cumsum_prev = F.pad(dur_cumsum, [1, -1], mode='constant', value=0) + + pos_idx = torch.arange(dur.sum(-1).max())[None, None].to(dur.device) + token_mask = (pos_idx >= dur_cumsum_prev[:, :, None]) & (pos_idx < dur_cumsum[:, :, None]) + mel2ph = (token_idx * token_mask.long()).sum(1) + return mel2ph + + +class PitchPredictor(torch.nn.Module): + def __init__(self, idim, n_layers=5, n_chans=384, odim=2, kernel_size=5, + dropout_rate=0.1, padding='SAME'): + """Initilize pitch predictor module. + Args: + idim (int): Input dimension. + n_layers (int, optional): Number of convolutional layers. + n_chans (int, optional): Number of channels of convolutional layers. + kernel_size (int, optional): Kernel size of convolutional layers. + dropout_rate (float, optional): Dropout rate. + """ + super(PitchPredictor, self).__init__() + self.conv = torch.nn.ModuleList() + self.kernel_size = kernel_size + self.padding = padding + for idx in range(n_layers): + in_chans = idim if idx == 0 else n_chans + self.conv += [torch.nn.Sequential( + torch.nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2) + if padding == 'SAME' + else (kernel_size - 1, 0), 0), + torch.nn.Conv1d(in_chans, n_chans, kernel_size, stride=1, padding=0), + torch.nn.ReLU(), + LayerNorm(n_chans, dim=1), + torch.nn.Dropout(dropout_rate) + )] + self.linear = torch.nn.Linear(n_chans, odim) + self.embed_positions = SinusoidalPositionalEmbedding(idim, 0, init_size=4096) + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) + + def forward(self, xs): + """ + + :param xs: [B, T, H] + :return: [B, T, H] + """ + positions = self.pos_embed_alpha * self.embed_positions(xs[..., 0]) + xs = xs + positions + xs = xs.transpose(1, -1) # (B, idim, Tmax) + for f in self.conv: + xs = f(xs) # (B, C, Tmax) + # NOTE: calculate in log domain + xs = self.linear(xs.transpose(1, -1)) # (B, Tmax, H) + return xs + + +class EnergyPredictor(PitchPredictor): + pass + + +def mel2ph_to_dur(mel2ph, T_txt, max_dur=None): + B, _ = mel2ph.shape + dur = mel2ph.new_zeros(B, T_txt + 1).scatter_add(1, mel2ph, torch.ones_like(mel2ph)) + dur = dur[:, 1:] + if max_dur is not None: + dur = dur.clamp(max=max_dur) + return dur + + +class FFTBlocks(nn.Module): + def __init__(self, hidden_size, num_layers, ffn_kernel_size=9, dropout=None, num_heads=2, + use_pos_embed=True, use_last_norm=True, norm='ln', use_pos_embed_alpha=True): + super().__init__() + self.num_layers = num_layers + embed_dim = self.hidden_size = hidden_size + self.dropout = dropout if dropout is not None else hparams['dropout'] + self.use_pos_embed = use_pos_embed + self.use_last_norm = use_last_norm + if use_pos_embed: + self.max_source_positions = DEFAULT_MAX_TARGET_POSITIONS + self.padding_idx = 0 + self.pos_embed_alpha = nn.Parameter(torch.Tensor([1])) if use_pos_embed_alpha else 1 + self.embed_positions = SinusoidalPositionalEmbedding( + embed_dim, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + self.layers = nn.ModuleList([]) + self.layers.extend([ + TransformerEncoderLayer(self.hidden_size, self.dropout, + kernel_size=ffn_kernel_size, num_heads=num_heads) + for _ in range(self.num_layers) + ]) + if self.use_last_norm: + if norm == 'ln': + self.layer_norm = nn.LayerNorm(embed_dim) + elif norm == 'bn': + self.layer_norm = BatchNorm1dTBC(embed_dim) + else: + self.layer_norm = None + + def forward(self, x, padding_mask=None, attn_mask=None, return_hiddens=False): + """ + :param x: [B, T, C] + :param padding_mask: [B, T] + :return: [B, T, C] or [L, B, T, C] + """ + padding_mask = x.abs().sum(-1).eq(0).data if padding_mask is None else padding_mask + nonpadding_mask_TB = 1 - padding_mask.transpose(0, 1).float()[:, :, None] # [T, B, 1] + if self.use_pos_embed: + positions = self.pos_embed_alpha * self.embed_positions(x[..., 0]) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + # B x T x C -> T x B x C + x = x.transpose(0, 1) * nonpadding_mask_TB + hiddens = [] + for layer in self.layers: + x = layer(x, encoder_padding_mask=padding_mask, attn_mask=attn_mask) * nonpadding_mask_TB + hiddens.append(x) + if self.use_last_norm: + x = self.layer_norm(x) * nonpadding_mask_TB + if return_hiddens: + x = torch.stack(hiddens, 0) # [L, T, B, C] + x = x.transpose(1, 2) # [L, B, T, C] + else: + x = x.transpose(0, 1) # [B, T, C] + return x + + +class FastspeechEncoder(FFTBlocks): + def __init__(self, embed_tokens, hidden_size=None, num_layers=None, kernel_size=None, num_heads=2): + hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size + kernel_size = hparams['enc_ffn_kernel_size'] if kernel_size is None else kernel_size + num_layers = hparams['dec_layers'] if num_layers is None else num_layers + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads, + use_pos_embed=False) # use_pos_embed_alpha for compatibility + self.embed_tokens = embed_tokens + self.embed_scale = math.sqrt(hidden_size) + self.padding_idx = 0 + if hparams.get('rel_pos') is not None and hparams['rel_pos']: + self.embed_positions = RelPositionalEncoding(hidden_size, dropout_rate=0.0) + else: + self.embed_positions = SinusoidalPositionalEmbedding( + hidden_size, self.padding_idx, init_size=DEFAULT_MAX_TARGET_POSITIONS, + ) + + def forward(self, txt_tokens): + """ + + :param txt_tokens: [B, T] + :return: { + 'encoder_out': [T x B x C] + } + """ + encoder_padding_mask = txt_tokens.eq(self.padding_idx).data + x = self.forward_embedding(txt_tokens) # [B, T, H] + x = super(FastspeechEncoder, self).forward(x, encoder_padding_mask) + return x + + def forward_embedding(self, txt_tokens): + # embed tokens and positions + x = self.embed_scale * self.embed_tokens(txt_tokens) + if hparams['use_pos_embed']: + positions = self.embed_positions(txt_tokens) + x = x + positions + x = F.dropout(x, p=self.dropout, training=self.training) + return x + + +class FastspeechDecoder(FFTBlocks): + def __init__(self, hidden_size=None, num_layers=None, kernel_size=None, num_heads=None): + num_heads = hparams['num_heads'] if num_heads is None else num_heads + hidden_size = hparams['hidden_size'] if hidden_size is None else hidden_size + kernel_size = hparams['dec_ffn_kernel_size'] if kernel_size is None else kernel_size + num_layers = hparams['dec_layers'] if num_layers is None else num_layers + super().__init__(hidden_size, num_layers, kernel_size, num_heads=num_heads) + diff --git a/modules/hifigan/hifigan.py b/modules/hifigan/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7e61f56b00d60bcc49a18ece3edbe54746f7ea --- /dev/null +++ b/modules/hifigan/hifigan.py @@ -0,0 +1,365 @@ +import torch +import torch.nn.functional as F +import torch.nn as nn +from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d +from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm + +from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork +from modules.parallel_wavegan.models.source import SourceModuleHnNSF +import numpy as np + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock1(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]))) + ]) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1, + padding=get_padding(kernel_size, 1))) + ]) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class ResBlock2(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)): + super(ResBlock2, self).__init__() + self.h = h + self.convs = nn.ModuleList([ + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]))), + weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]))) + ]) + self.convs.apply(init_weights) + + def forward(self, x): + for c in self.convs: + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs: + remove_weight_norm(l) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__(in_channels, out_channels, + kernel_size=1, padding=0, + dilation=1, bias=bias) + + +class HifiGanGenerator(torch.nn.Module): + def __init__(self, h, c_out=1): + super(HifiGanGenerator, self).__init__() + self.h = h + self.num_kernels = len(h['resblock_kernel_sizes']) + self.num_upsamples = len(h['upsample_rates']) + + if h['use_pitch_embed']: + self.harmonic_num = 8 + self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates'])) + self.m_source = SourceModuleHnNSF( + sampling_rate=h['audio_sample_rate'], + harmonic_num=self.harmonic_num) + self.noise_convs = nn.ModuleList() + self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3)) + resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2 + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])): + c_cur = h['upsample_initial_channel'] // (2 ** (i + 1)) + self.ups.append(weight_norm( + ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2))) + if h['use_pitch_embed']: + if i + 1 < len(h['upsample_rates']): + stride_f0 = np.prod(h['upsample_rates'][i + 1:]) + self.noise_convs.append(Conv1d( + 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2)) + else: + self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h['upsample_initial_channel'] // (2 ** (i + 1)) + for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x, f0=None): + if f0 is not None: + # harmonic-source signal, noise-source signal, uv flag + f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) + har_source, noi_source, uv = self.m_source(f0) + har_source = har_source.transpose(1, 2) + + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + if f0 is not None: + x_source = self.noise_convs[i](har_source) + x = x + x_source + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) + + +class DiscriminatorP(torch.nn.Module): + def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1): + super(DiscriminatorP, self).__init__() + self.use_cond = use_cond + if use_cond: + from utils.hparams import hparams + t = hparams['hop_size'] + self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2) + c_in = 2 + + self.period = period + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x, mel): + fmap = [] + if self.use_cond: + x_mel = self.cond_net(mel) + x = torch.cat([x_mel, x], 1) + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), "reflect") + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + def __init__(self, use_cond=False, c_in=1): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2, use_cond=use_cond, c_in=c_in), + DiscriminatorP(3, use_cond=use_cond, c_in=c_in), + DiscriminatorP(5, use_cond=use_cond, c_in=c_in), + DiscriminatorP(7, use_cond=use_cond, c_in=c_in), + DiscriminatorP(11, use_cond=use_cond, c_in=c_in), + ]) + + def forward(self, y, y_hat, mel=None): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y, mel) + y_d_g, fmap_g = d(y_hat, mel) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1): + super(DiscriminatorS, self).__init__() + self.use_cond = use_cond + if use_cond: + t = np.prod(upsample_rates) + self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2) + c_in = 2 + norm_f = weight_norm if use_spectral_norm == False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(c_in, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x, mel): + if self.use_cond: + x_mel = self.cond_net(mel) + x = torch.cat([x_mel, x], 1) + fmap = [] + for l in self.convs: + x = l(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + def __init__(self, use_cond=False, c_in=1): + super(MultiScaleDiscriminator, self).__init__() + from utils.hparams import hparams + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True, use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 16], + c_in=c_in), + DiscriminatorS(use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 32], + c_in=c_in), + DiscriminatorS(use_cond=use_cond, + upsample_rates=[4, 4, hparams['hop_size'] // 64], + c_in=c_in), + ]) + self.meanpools = nn.ModuleList([ + AvgPool1d(4, 2, padding=1), + AvgPool1d(4, 2, padding=1) + ]) + + def forward(self, y, y_hat, mel=None): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + y = self.meanpools[i - 1](y) + y_hat = self.meanpools[i - 1](y_hat) + y_d_r, fmap_r = d(y, mel) + y_d_g, fmap_g = d(y_hat, mel) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + r_losses = 0 + g_losses = 0 + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr) ** 2) + g_loss = torch.mean(dg ** 2) + r_losses += r_loss + g_losses += g_loss + r_losses = r_losses / len(disc_real_outputs) + g_losses = g_losses / len(disc_real_outputs) + return r_losses, g_losses + + +def cond_discriminator_loss(outputs): + loss = 0 + for dg in outputs: + g_loss = torch.mean(dg ** 2) + loss += g_loss + loss = loss / len(outputs) + return loss + + +def generator_loss(disc_outputs): + loss = 0 + for dg in disc_outputs: + l = torch.mean((1 - dg) ** 2) + loss += l + loss = loss / len(disc_outputs) + return loss diff --git a/modules/hifigan/mel_utils.py b/modules/hifigan/mel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..06e0f7d4d16fa3e4aefc8949347455f5a6e938da --- /dev/null +++ b/modules/hifigan/mel_utils.py @@ -0,0 +1,80 @@ +import numpy as np +import torch +import torch.utils.data +from librosa.filters import mel as librosa_mel_fn +from scipy.io.wavfile import read + +MAX_WAV_VALUE = 32768.0 + + +def load_wav(full_path): + sampling_rate, data = read(full_path) + return data, sampling_rate + + +def dynamic_range_compression(x, C=1, clip_val=1e-5): + return np.log(np.clip(x, a_min=clip_val, a_max=None) * C) + + +def dynamic_range_decompression(x, C=1): + return np.exp(x) / C + + +def dynamic_range_compression_torch(x, C=1, clip_val=1e-5): + return torch.log(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression_torch(x, C=1): + return torch.exp(x) / C + + +def spectral_normalize_torch(magnitudes): + output = dynamic_range_compression_torch(magnitudes) + return output + + +def spectral_de_normalize_torch(magnitudes): + output = dynamic_range_decompression_torch(magnitudes) + return output + + +mel_basis = {} +hann_window = {} + + +def mel_spectrogram(y, hparams, center=False, complex=False): + # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) + # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate) + # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + # fmax: 10000 # To be increased/reduced depending on data. + # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter + # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, + n_fft = hparams['fft_size'] + num_mels = hparams['audio_num_mel_bins'] + sampling_rate = hparams['audio_sample_rate'] + hop_size = hparams['hop_size'] + win_size = hparams['win_size'] + fmin = hparams['fmin'] + fmax = hparams['fmax'] + y = y.clamp(min=-1., max=1.) + global mel_basis, hann_window + if fmax not in mel_basis: + mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax) + mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device) + hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device) + + y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), + mode='reflect') + y = y.squeeze(1) + + spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)], + center=center, pad_mode='reflect', normalized=False, onesided=True) + + if not complex: + spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) + spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec) + spec = spectral_normalize_torch(spec) + else: + B, C, T, _ = spec.shape + spec = spec.transpose(1, 2) # [B, T, n_fft, 2] + return spec diff --git a/modules/parallel_wavegan/__init__.py b/modules/parallel_wavegan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/modules/parallel_wavegan/layers/__init__.py b/modules/parallel_wavegan/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e477f51116a3157781b1aefefbaf32fe4d4bd1f0 --- /dev/null +++ b/modules/parallel_wavegan/layers/__init__.py @@ -0,0 +1,5 @@ +from .causal_conv import * # NOQA +from .pqmf import * # NOQA +from .residual_block import * # NOQA +from modules.parallel_wavegan.layers.residual_stack import * # NOQA +from .upsample import * # NOQA diff --git a/modules/parallel_wavegan/layers/causal_conv.py b/modules/parallel_wavegan/layers/causal_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..fca77daf65f234e6fbe355ed148fc8f0ee85038a --- /dev/null +++ b/modules/parallel_wavegan/layers/causal_conv.py @@ -0,0 +1,56 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Causal convolusion layer modules.""" + + +import torch + + +class CausalConv1d(torch.nn.Module): + """CausalConv1d module with customized initialization.""" + + def __init__(self, in_channels, out_channels, kernel_size, + dilation=1, bias=True, pad="ConstantPad1d", pad_params={"value": 0.0}): + """Initialize CausalConv1d module.""" + super(CausalConv1d, self).__init__() + self.pad = getattr(torch.nn, pad)((kernel_size - 1) * dilation, **pad_params) + self.conv = torch.nn.Conv1d(in_channels, out_channels, kernel_size, + dilation=dilation, bias=bias) + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T). + + Returns: + Tensor: Output tensor (B, out_channels, T). + + """ + return self.conv(self.pad(x))[:, :, :x.size(2)] + + +class CausalConvTranspose1d(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, in_channels, out_channels, kernel_size, stride, bias=True): + """Initialize CausalConvTranspose1d module.""" + super(CausalConvTranspose1d, self).__init__() + self.deconv = torch.nn.ConvTranspose1d( + in_channels, out_channels, kernel_size, stride, bias=bias) + self.stride = stride + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + + Returns: + Tensor: Output tensor (B, out_channels, T_out). + + """ + return self.deconv(x)[:, :, :-self.stride] diff --git a/modules/parallel_wavegan/layers/pqmf.py b/modules/parallel_wavegan/layers/pqmf.py new file mode 100644 index 0000000000000000000000000000000000000000..ac21074fd32a370a099fa2facb62cfd3253d7579 --- /dev/null +++ b/modules/parallel_wavegan/layers/pqmf.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Pseudo QMF modules.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from scipy.signal import kaiser + + +def design_prototype_filter(taps=62, cutoff_ratio=0.15, beta=9.0): + """Design prototype filter for PQMF. + + This method is based on `A Kaiser window approach for the design of prototype + filters of cosine modulated filterbanks`_. + + Args: + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + Returns: + ndarray: Impluse response of prototype filter (taps + 1,). + + .. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`: + https://ieeexplore.ieee.org/abstract/document/681427 + + """ + # check the arguments are valid + assert taps % 2 == 0, "The number of taps mush be even number." + assert 0.0 < cutoff_ratio < 1.0, "Cutoff ratio must be > 0.0 and < 1.0." + + # make initial filter + omega_c = np.pi * cutoff_ratio + with np.errstate(invalid='ignore'): + h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) \ + / (np.pi * (np.arange(taps + 1) - 0.5 * taps)) + h_i[taps // 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form + + # apply kaiser window + w = kaiser(taps + 1, beta) + h = h_i * w + + return h + + +class PQMF(torch.nn.Module): + """PQMF module. + + This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_. + + .. _`Near-perfect-reconstruction pseudo-QMF banks`: + https://ieeexplore.ieee.org/document/258122 + + """ + + def __init__(self, subbands=4, taps=62, cutoff_ratio=0.15, beta=9.0): + """Initilize PQMF module. + + Args: + subbands (int): The number of subbands. + taps (int): The number of filter taps. + cutoff_ratio (float): Cut-off frequency ratio. + beta (float): Beta coefficient for kaiser window. + + """ + super(PQMF, self).__init__() + + # define filter coefficient + h_proto = design_prototype_filter(taps, cutoff_ratio, beta) + h_analysis = np.zeros((subbands, len(h_proto))) + h_synthesis = np.zeros((subbands, len(h_proto))) + for k in range(subbands): + h_analysis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - ((taps - 1) / 2)) + + (-1) ** k * np.pi / 4) + h_synthesis[k] = 2 * h_proto * np.cos( + (2 * k + 1) * (np.pi / (2 * subbands)) * + (np.arange(taps + 1) - ((taps - 1) / 2)) - + (-1) ** k * np.pi / 4) + + # convert to tensor + analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1) + synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0) + + # register coefficients as beffer + self.register_buffer("analysis_filter", analysis_filter) + self.register_buffer("synthesis_filter", synthesis_filter) + + # filter for downsampling & upsampling + updown_filter = torch.zeros((subbands, subbands, subbands)).float() + for k in range(subbands): + updown_filter[k, k, 0] = 1.0 + self.register_buffer("updown_filter", updown_filter) + self.subbands = subbands + + # keep padding info + self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0) + + def analysis(self, x): + """Analysis with PQMF. + + Args: + x (Tensor): Input tensor (B, 1, T). + + Returns: + Tensor: Output tensor (B, subbands, T // subbands). + + """ + x = F.conv1d(self.pad_fn(x), self.analysis_filter) + return F.conv1d(x, self.updown_filter, stride=self.subbands) + + def synthesis(self, x): + """Synthesis with PQMF. + + Args: + x (Tensor): Input tensor (B, subbands, T // subbands). + + Returns: + Tensor: Output tensor (B, 1, T). + + """ + x = F.conv_transpose1d(x, self.updown_filter * self.subbands, stride=self.subbands) + return F.conv1d(self.pad_fn(x), self.synthesis_filter) diff --git a/modules/parallel_wavegan/layers/residual_block.py b/modules/parallel_wavegan/layers/residual_block.py new file mode 100644 index 0000000000000000000000000000000000000000..7a267a86c1fa521c2824addf9dda304c43f1ff1f --- /dev/null +++ b/modules/parallel_wavegan/layers/residual_block.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +"""Residual block module in WaveNet. + +This code is modified from https://github.com/r9y9/wavenet_vocoder. + +""" + +import math + +import torch +import torch.nn.functional as F + + +class Conv1d(torch.nn.Conv1d): + """Conv1d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv1d module.""" + super(Conv1d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + torch.nn.init.kaiming_normal_(self.weight, nonlinearity="relu") + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class Conv1d1x1(Conv1d): + """1x1 Conv1d with customized initialization.""" + + def __init__(self, in_channels, out_channels, bias): + """Initialize 1x1 Conv1d module.""" + super(Conv1d1x1, self).__init__(in_channels, out_channels, + kernel_size=1, padding=0, + dilation=1, bias=bias) + + +class ResidualBlock(torch.nn.Module): + """Residual block module in WaveNet.""" + + def __init__(self, + kernel_size=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + dropout=0.0, + dilation=1, + bias=True, + use_causal_conv=False + ): + """Initialize ResidualBlock module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + residual_channels (int): Number of channels for residual connection. + skip_channels (int): Number of channels for skip connection. + aux_channels (int): Local conditioning channels i.e. auxiliary input dimension. + dropout (float): Dropout probability. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + use_causal_conv (bool): Whether to use use_causal_conv or non-use_causal_conv convolution. + + """ + super(ResidualBlock, self).__init__() + self.dropout = dropout + # no future time stamps available + if use_causal_conv: + padding = (kernel_size - 1) * dilation + else: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + padding = (kernel_size - 1) // 2 * dilation + self.use_causal_conv = use_causal_conv + + # dilation conv + self.conv = Conv1d(residual_channels, gate_channels, kernel_size, + padding=padding, dilation=dilation, bias=bias) + + # local conditioning + if aux_channels > 0: + self.conv1x1_aux = Conv1d1x1(aux_channels, gate_channels, bias=False) + else: + self.conv1x1_aux = None + + # conv output is split into two groups + gate_out_channels = gate_channels // 2 + self.conv1x1_out = Conv1d1x1(gate_out_channels, residual_channels, bias=bias) + self.conv1x1_skip = Conv1d1x1(gate_out_channels, skip_channels, bias=bias) + + def forward(self, x, c): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, residual_channels, T). + c (Tensor): Local conditioning auxiliary tensor (B, aux_channels, T). + + Returns: + Tensor: Output tensor for residual connection (B, residual_channels, T). + Tensor: Output tensor for skip connection (B, skip_channels, T). + + """ + residual = x + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.conv(x) + + # remove future time steps if use_causal_conv conv + x = x[:, :, :residual.size(-1)] if self.use_causal_conv else x + + # split into two part for gated activation + splitdim = 1 + xa, xb = x.split(x.size(splitdim) // 2, dim=splitdim) + + # local conditioning + if c is not None: + assert self.conv1x1_aux is not None + c = self.conv1x1_aux(c) + ca, cb = c.split(c.size(splitdim) // 2, dim=splitdim) + xa, xb = xa + ca, xb + cb + + x = torch.tanh(xa) * torch.sigmoid(xb) + + # for skip connection + s = self.conv1x1_skip(x) + + # for residual connection + x = (self.conv1x1_out(x) + residual) * math.sqrt(0.5) + + return x, s diff --git a/modules/parallel_wavegan/layers/residual_stack.py b/modules/parallel_wavegan/layers/residual_stack.py new file mode 100644 index 0000000000000000000000000000000000000000..6e07c8803ad348dd923f6b7c0f7aff14aab9cf78 --- /dev/null +++ b/modules/parallel_wavegan/layers/residual_stack.py @@ -0,0 +1,75 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Residual stack module in MelGAN.""" + +import torch + +from . import CausalConv1d + + +class ResidualStack(torch.nn.Module): + """Residual stack module introduced in MelGAN.""" + + def __init__(self, + kernel_size=3, + channels=32, + dilation=1, + bias=True, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_causal_conv=False, + ): + """Initialize ResidualStack module. + + Args: + kernel_size (int): Kernel size of dilation convolution layer. + channels (int): Number of channels of convolution layers. + dilation (int): Dilation factor. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(ResidualStack, self).__init__() + + # defile residual stack part + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + getattr(torch.nn, pad)((kernel_size - 1) // 2 * dilation, **pad_params), + torch.nn.Conv1d(channels, channels, kernel_size, dilation=dilation, bias=bias), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + else: + self.stack = torch.nn.Sequential( + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + CausalConv1d(channels, channels, kernel_size, dilation=dilation, + bias=bias, pad=pad, pad_params=pad_params), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + torch.nn.Conv1d(channels, channels, 1, bias=bias), + ) + + # defile extra layer for skip connection + self.skip_layer = torch.nn.Conv1d(channels, channels, 1, bias=bias) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, chennels, T). + + """ + return self.stack(c) + self.skip_layer(c) diff --git a/modules/parallel_wavegan/layers/tf_layers.py b/modules/parallel_wavegan/layers/tf_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f46bd755c161cda2ac904fe37f3f3c6357a88d --- /dev/null +++ b/modules/parallel_wavegan/layers/tf_layers.py @@ -0,0 +1,129 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 MINH ANH (@dathudeptrai) +# MIT License (https://opensource.org/licenses/MIT) + +"""Tensorflow Layer modules complatible with pytorch.""" + +import tensorflow as tf + + +class TFReflectionPad1d(tf.keras.layers.Layer): + """Tensorflow ReflectionPad1d module.""" + + def __init__(self, padding_size): + """Initialize TFReflectionPad1d module. + + Args: + padding_size (int): Padding size. + + """ + super(TFReflectionPad1d, self).__init__() + self.padding_size = padding_size + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensor: Padded tensor (B, T + 2 * padding_size, 1, C). + + """ + return tf.pad(x, [[0, 0], [self.padding_size, self.padding_size], [0, 0], [0, 0]], "REFLECT") + + +class TFConvTranspose1d(tf.keras.layers.Layer): + """Tensorflow ConvTranspose1d module.""" + + def __init__(self, channels, kernel_size, stride, padding): + """Initialize TFConvTranspose1d( module. + + Args: + channels (int): Number of channels. + kernel_size (int): kernel size. + strides (int): Stride width. + padding (str): Padding type ("same" or "valid"). + + """ + super(TFConvTranspose1d, self).__init__() + self.conv1d_transpose = tf.keras.layers.Conv2DTranspose( + filters=channels, + kernel_size=(kernel_size, 1), + strides=(stride, 1), + padding=padding, + ) + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensors: Output tensor (B, T', 1, C'). + + """ + x = self.conv1d_transpose(x) + return x + + +class TFResidualStack(tf.keras.layers.Layer): + """Tensorflow ResidualStack module.""" + + def __init__(self, + kernel_size, + channels, + dilation, + bias, + nonlinear_activation, + nonlinear_activation_params, + padding, + ): + """Initialize TFResidualStack module. + + Args: + kernel_size (int): Kernel size. + channles (int): Number of channels. + dilation (int): Dilation ine. + bias (bool): Whether to add bias parameter in convolution layers. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + padding (str): Padding type ("same" or "valid"). + + """ + super(TFResidualStack, self).__init__() + self.block = [ + getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), + TFReflectionPad1d(dilation), + tf.keras.layers.Conv2D( + filters=channels, + kernel_size=(kernel_size, 1), + dilation_rate=(dilation, 1), + use_bias=bias, + padding="valid", + ), + getattr(tf.keras.layers, nonlinear_activation)(**nonlinear_activation_params), + tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) + ] + self.shortcut = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, use_bias=bias) + + @tf.function + def call(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, T, 1, C). + + Returns: + Tensor: Output tensor (B, T, 1, C). + + """ + _x = tf.identity(x) + for i, layer in enumerate(self.block): + _x = layer(_x) + shortcut = self.shortcut(x) + return shortcut + _x diff --git a/modules/parallel_wavegan/layers/upsample.py b/modules/parallel_wavegan/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..18c6397c420a81fadc5320e3a48f3249534decd8 --- /dev/null +++ b/modules/parallel_wavegan/layers/upsample.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- + +"""Upsampling module. + +This code is modified from https://github.com/r9y9/wavenet_vocoder. + +""" + +import numpy as np +import torch +import torch.nn.functional as F + +from . import Conv1d + + +class Stretch2d(torch.nn.Module): + """Stretch2d module.""" + + def __init__(self, x_scale, y_scale, mode="nearest"): + """Initialize Stretch2d module. + + Args: + x_scale (int): X scaling factor (Time axis in spectrogram). + y_scale (int): Y scaling factor (Frequency axis in spectrogram). + mode (str): Interpolation mode. + + """ + super(Stretch2d, self).__init__() + self.x_scale = x_scale + self.y_scale = y_scale + self.mode = mode + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input tensor (B, C, F, T). + + Returns: + Tensor: Interpolated tensor (B, C, F * y_scale, T * x_scale), + + """ + return F.interpolate( + x, scale_factor=(self.y_scale, self.x_scale), mode=self.mode) + + +class Conv2d(torch.nn.Conv2d): + """Conv2d module with customized initialization.""" + + def __init__(self, *args, **kwargs): + """Initialize Conv2d module.""" + super(Conv2d, self).__init__(*args, **kwargs) + + def reset_parameters(self): + """Reset parameters.""" + self.weight.data.fill_(1. / np.prod(self.kernel_size)) + if self.bias is not None: + torch.nn.init.constant_(self.bias, 0.0) + + +class UpsampleNetwork(torch.nn.Module): + """Upsampling network module.""" + + def __init__(self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + use_causal_conv=False, + ): + """Initialize upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + interpolate_mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + + """ + super(UpsampleNetwork, self).__init__() + self.use_causal_conv = use_causal_conv + self.up_layers = torch.nn.ModuleList() + for scale in upsample_scales: + # interpolation layer + stretch = Stretch2d(scale, 1, interpolate_mode) + self.up_layers += [stretch] + + # conv layer + assert (freq_axis_kernel_size - 1) % 2 == 0, "Not support even number freq axis kernel size." + freq_axis_padding = (freq_axis_kernel_size - 1) // 2 + kernel_size = (freq_axis_kernel_size, scale * 2 + 1) + if use_causal_conv: + padding = (freq_axis_padding, scale * 2) + else: + padding = (freq_axis_padding, scale) + conv = Conv2d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.up_layers += [conv] + + # nonlinear + if nonlinear_activation is not None: + nonlinear = getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params) + self.up_layers += [nonlinear] + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T). + + Returns: + Tensor: Upsampled tensor (B, C, T'), where T' = T * prod(upsample_scales). + + """ + c = c.unsqueeze(1) # (B, 1, C, T) + for f in self.up_layers: + if self.use_causal_conv and isinstance(f, Conv2d): + c = f(c)[..., :c.size(-1)] + else: + c = f(c) + return c.squeeze(1) # (B, C, T') + + +class ConvInUpsampleNetwork(torch.nn.Module): + """Convolution + upsampling network module.""" + + def __init__(self, + upsample_scales, + nonlinear_activation=None, + nonlinear_activation_params={}, + interpolate_mode="nearest", + freq_axis_kernel_size=1, + aux_channels=80, + aux_context_window=0, + use_causal_conv=False + ): + """Initialize convolution + upsampling network module. + + Args: + upsample_scales (list): List of upsampling scales. + nonlinear_activation (str): Activation function name. + nonlinear_activation_params (dict): Arguments for specified activation function. + mode (str): Interpolation mode. + freq_axis_kernel_size (int): Kernel size in the direction of frequency axis. + aux_channels (int): Number of channels of pre-convolutional layer. + aux_context_window (int): Context window size of the pre-convolutional layer. + use_causal_conv (bool): Whether to use causal structure. + + """ + super(ConvInUpsampleNetwork, self).__init__() + self.aux_context_window = aux_context_window + self.use_causal_conv = use_causal_conv and aux_context_window > 0 + # To capture wide-context information in conditional features + kernel_size = aux_context_window + 1 if use_causal_conv else 2 * aux_context_window + 1 + # NOTE(kan-bayashi): Here do not use padding because the input is already padded + self.conv_in = Conv1d(aux_channels, aux_channels, kernel_size=kernel_size, bias=False) + self.upsample = UpsampleNetwork( + upsample_scales=upsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + interpolate_mode=interpolate_mode, + freq_axis_kernel_size=freq_axis_kernel_size, + use_causal_conv=use_causal_conv, + ) + + def forward(self, c): + """Calculate forward propagation. + + Args: + c : Input tensor (B, C, T'). + + Returns: + Tensor: Upsampled tensor (B, C, T), + where T = (T' - aux_context_window * 2) * prod(upsample_scales). + + Note: + The length of inputs considers the context window size. + + """ + c_ = self.conv_in(c) + c = c_[:, :, :-self.aux_context_window] if self.use_causal_conv else c_ + return self.upsample(c) diff --git a/modules/parallel_wavegan/losses/__init__.py b/modules/parallel_wavegan/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b03080a907cb5cb4b316ceb74866ddbc406b33bf --- /dev/null +++ b/modules/parallel_wavegan/losses/__init__.py @@ -0,0 +1 @@ +from .stft_loss import * # NOQA diff --git a/modules/parallel_wavegan/losses/stft_loss.py b/modules/parallel_wavegan/losses/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..74d2aa21ad30ba094c406366e652067462f49cd2 --- /dev/null +++ b/modules/parallel_wavegan/losses/stft_loss.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" + +import torch +import torch.nn.functional as F + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + + """ + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real ** 2 + imag ** 2, min=1e-7)).transpose(2, 1) + + +class SpectralConvergengeLoss(torch.nn.Module): + """Spectral convergence loss module.""" + + def __init__(self): + """Initilize spectral convergence loss module.""" + super(SpectralConvergengeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Spectral convergence loss value. + + """ + return torch.norm(y_mag - x_mag, p="fro") / torch.norm(y_mag, p="fro") + + +class LogSTFTMagnitudeLoss(torch.nn.Module): + """Log STFT magnitude loss module.""" + + def __init__(self): + """Initilize los STFT magnitude loss module.""" + super(LogSTFTMagnitudeLoss, self).__init__() + + def forward(self, x_mag, y_mag): + """Calculate forward propagation. + + Args: + x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins). + y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins). + + Returns: + Tensor: Log STFT magnitude loss value. + + """ + return F.l1_loss(torch.log(y_mag), torch.log(x_mag)) + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window"): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window"): + """Initialize Multi resolution STFT loss module. + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window)] + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/modules/parallel_wavegan/models/__init__.py b/modules/parallel_wavegan/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4803ba6b2a0afc8022e756ae5b3f4c7403c3c1bd --- /dev/null +++ b/modules/parallel_wavegan/models/__init__.py @@ -0,0 +1,2 @@ +from .melgan import * # NOQA +from .parallel_wavegan import * # NOQA diff --git a/modules/parallel_wavegan/models/melgan.py b/modules/parallel_wavegan/models/melgan.py new file mode 100644 index 0000000000000000000000000000000000000000..e021ae4817a8c1c97338e61b00b230c881836fd8 --- /dev/null +++ b/modules/parallel_wavegan/models/melgan.py @@ -0,0 +1,427 @@ +# -*- coding: utf-8 -*- + +# Copyright 2020 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""MelGAN Modules.""" + +import logging + +import numpy as np +import torch + +from modules.parallel_wavegan.layers import CausalConv1d +from modules.parallel_wavegan.layers import CausalConvTranspose1d +from modules.parallel_wavegan.layers import ResidualStack + + +class MelGANGenerator(torch.nn.Module): + """MelGAN generator module.""" + + def __init__(self, + in_channels=80, + out_channels=1, + kernel_size=7, + channels=512, + bias=True, + upsample_scales=[8, 8, 2, 2], + stack_kernel_size=3, + stacks=3, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_final_nonlinear_activation=True, + use_weight_norm=True, + use_causal_conv=False, + ): + """Initialize MelGANGenerator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of initial and final conv layer. + channels (int): Initial number of channels for conv layer. + bias (bool): Whether to add bias parameter in convolution layers. + upsample_scales (list): List of upsampling scales. + stack_kernel_size (int): Kernel size of dilated conv layers in residual stack. + stacks (int): Number of stacks in a single residual stack. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_final_nonlinear_activation (torch.nn.Module): Activation function for the final layer. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANGenerator, self).__init__() + + # check hyper parameters is valid + assert channels >= np.prod(upsample_scales) + assert channels % (2 ** len(upsample_scales)) == 0 + if not use_causal_conv: + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + # add initial layer + layers = [] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d(in_channels, channels, kernel_size, bias=bias), + ] + else: + layers += [ + CausalConv1d(in_channels, channels, kernel_size, + bias=bias, pad=pad, pad_params=pad_params), + ] + + for i, upsample_scale in enumerate(upsample_scales): + # add upsampling layer + layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] + if not use_causal_conv: + layers += [ + torch.nn.ConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + padding=upsample_scale // 2 + upsample_scale % 2, + output_padding=upsample_scale % 2, + bias=bias, + ) + ] + else: + layers += [ + CausalConvTranspose1d( + channels // (2 ** i), + channels // (2 ** (i + 1)), + upsample_scale * 2, + stride=upsample_scale, + bias=bias, + ) + ] + + # add residual stack + for j in range(stacks): + layers += [ + ResidualStack( + kernel_size=stack_kernel_size, + channels=channels // (2 ** (i + 1)), + dilation=stack_kernel_size ** j, + bias=bias, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + use_causal_conv=use_causal_conv, + ) + ] + + # add final layer + layers += [getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params)] + if not use_causal_conv: + layers += [ + getattr(torch.nn, pad)((kernel_size - 1) // 2, **pad_params), + torch.nn.Conv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, bias=bias), + ] + else: + layers += [ + CausalConv1d(channels // (2 ** (i + 1)), out_channels, kernel_size, + bias=bias, pad=pad, pad_params=pad_params), + ] + if use_final_nonlinear_activation: + layers += [torch.nn.Tanh()] + + # define the model as a single function + self.melgan = torch.nn.Sequential(*layers) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, c): + """Calculate forward propagation. + + Args: + c (Tensor): Input tensor (B, channels, T). + + Returns: + Tensor: Output tensor (B, 1, T ** prod(upsample_scales)). + + """ + return self.melgan(c) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py + + """ + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) + + +class MelGANDiscriminator(torch.nn.Module): + """MelGAN discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + ): + """Initilize MelGAN discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_sizes (list): List of two kernel sizes. The prod will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + For example if kernel_sizes = [5, 3], the first layer kernel size will be 5 * 3 = 15, + the last two layers' kernel size will be 5 and 3, respectively. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + + """ + super(MelGANDiscriminator, self).__init__() + self.layers = torch.nn.ModuleList() + + # check kernel size is valid + assert len(kernel_sizes) == 2 + assert kernel_sizes[0] % 2 == 1 + assert kernel_sizes[1] % 2 == 1 + + # add first layer + self.layers += [ + torch.nn.Sequential( + getattr(torch.nn, pad)((np.prod(kernel_sizes) - 1) // 2, **pad_params), + torch.nn.Conv1d(in_channels, channels, np.prod(kernel_sizes), bias=bias), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + + # add downsample layers + in_chs = channels + for downsample_scale in downsample_scales: + out_chs = min(in_chs * downsample_scale, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, out_chs, + kernel_size=downsample_scale * 10 + 1, + stride=downsample_scale, + padding=downsample_scale * 5, + groups=in_chs // 4, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + in_chs = out_chs + + # add final layers + out_chs = min(in_chs * 2, max_downsample_channels) + self.layers += [ + torch.nn.Sequential( + torch.nn.Conv1d( + in_chs, out_chs, kernel_sizes[0], + padding=(kernel_sizes[0] - 1) // 2, + bias=bias, + ), + getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params), + ) + ] + self.layers += [ + torch.nn.Conv1d( + out_chs, out_channels, kernel_sizes[1], + padding=(kernel_sizes[1] - 1) // 2, + bias=bias, + ), + ] + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of output tensors of each layer. + + """ + outs = [] + for f in self.layers: + x = f(x) + outs += [x] + + return outs + + +class MelGANMultiScaleDiscriminator(torch.nn.Module): + """MelGAN multi-scale discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + scales=3, + downsample_pooling="AvgPool1d", + # follow the official implementation setting + downsample_pooling_params={ + "kernel_size": 4, + "stride": 2, + "padding": 1, + "count_include_pad": False, + }, + kernel_sizes=[5, 3], + channels=16, + max_downsample_channels=1024, + bias=True, + downsample_scales=[4, 4, 4, 4], + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + pad="ReflectionPad1d", + pad_params={}, + use_weight_norm=True, + ): + """Initilize MelGAN multi-scale discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + downsample_pooling (str): Pooling module name for downsampling of the inputs. + downsample_pooling_params (dict): Parameters for the above pooling module. + kernel_sizes (list): List of two kernel sizes. The sum will be used for the first conv layer, + and the first and the second kernel sizes will be used for the last two layers. + channels (int): Initial number of channels for conv layer. + max_downsample_channels (int): Maximum number of channels for downsampling layers. + bias (bool): Whether to add bias parameter in convolution layers. + downsample_scales (list): List of downsampling scales. + nonlinear_activation (str): Activation function module name. + nonlinear_activation_params (dict): Hyperparameters for activation function. + pad (str): Padding function module name before dilated convolution layer. + pad_params (dict): Hyperparameters for padding function. + use_causal_conv (bool): Whether to use causal convolution. + + """ + super(MelGANMultiScaleDiscriminator, self).__init__() + self.discriminators = torch.nn.ModuleList() + + # add discriminators + for _ in range(scales): + self.discriminators += [ + MelGANDiscriminator( + in_channels=in_channels, + out_channels=out_channels, + kernel_sizes=kernel_sizes, + channels=channels, + max_downsample_channels=max_downsample_channels, + bias=bias, + downsample_scales=downsample_scales, + nonlinear_activation=nonlinear_activation, + nonlinear_activation_params=nonlinear_activation_params, + pad=pad, + pad_params=pad_params, + ) + ] + self.pooling = getattr(torch.nn, downsample_pooling)(**downsample_pooling_params) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + # reset parameters + self.reset_parameters() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + List: List of list of each discriminator outputs, which consists of each layer output tensors. + + """ + outs = [] + for f in self.discriminators: + outs += [f(x)] + x = self.pooling(x) + + return outs + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def reset_parameters(self): + """Reset parameters. + + This initialization follows official implementation manner. + https://github.com/descriptinc/melgan-neurips/blob/master/spec2wav/modules.py + + """ + def _reset_parameters(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.ConvTranspose1d): + m.weight.data.normal_(0.0, 0.02) + logging.debug(f"Reset parameters in {m}.") + + self.apply(_reset_parameters) diff --git a/modules/parallel_wavegan/models/parallel_wavegan.py b/modules/parallel_wavegan/models/parallel_wavegan.py new file mode 100644 index 0000000000000000000000000000000000000000..c63b59f67aa48342179415c1d1beac68574a5498 --- /dev/null +++ b/modules/parallel_wavegan/models/parallel_wavegan.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Parallel WaveGAN Modules.""" + +import logging +import math + +import torch +from torch import nn + +from modules.parallel_wavegan.layers import Conv1d +from modules.parallel_wavegan.layers import Conv1d1x1 +from modules.parallel_wavegan.layers import ResidualBlock +from modules.parallel_wavegan.layers import upsample +from modules.parallel_wavegan import models + + +class ParallelWaveGANGenerator(torch.nn.Module): + """Parallel WaveGAN Generator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=30, + stacks=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + aux_channels=80, + aux_context_window=2, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + upsample_conditional_features=True, + upsample_net="ConvInUpsampleNetwork", + upsample_params={"upsample_scales": [4, 4, 4, 4]}, + use_pitch_embed=False, + ): + """Initialize Parallel WaveGAN Generator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + aux_channels (int): Number of channels for auxiliary feature conv. + aux_context_window (int): Context window size for auxiliary feature. + dropout (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv layer. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal structure. + upsample_conditional_features (bool): Whether to use upsampling network. + upsample_net (str): Upsampling network architecture. + upsample_params (dict): Upsampling network parameters. + + """ + super(ParallelWaveGANGenerator, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.aux_channels = aux_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True) + + # define conv + upsampling network + if upsample_conditional_features: + upsample_params.update({ + "use_causal_conv": use_causal_conv, + }) + if upsample_net == "MelGANGenerator": + assert aux_context_window == 0 + upsample_params.update({ + "use_weight_norm": False, # not to apply twice + "use_final_nonlinear_activation": False, + }) + self.upsample_net = getattr(models, upsample_net)(**upsample_params) + else: + if upsample_net == "ConvInUpsampleNetwork": + upsample_params.update({ + "aux_channels": aux_channels, + "aux_context_window": aux_context_window, + }) + self.upsample_net = getattr(upsample, upsample_net)(**upsample_params) + else: + self.upsample_net = None + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=aux_channels, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=use_causal_conv, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList([ + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, skip_channels, bias=True), + torch.nn.ReLU(inplace=True), + Conv1d1x1(skip_channels, out_channels, bias=True), + ]) + + self.use_pitch_embed = use_pitch_embed + if use_pitch_embed: + self.pitch_embed = nn.Embedding(300, aux_channels, 0) + self.c_proj = nn.Linear(2 * aux_channels, aux_channels) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x, c=None, pitch=None, **kwargs): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, C_in, T). + c (Tensor): Local conditioning auxiliary features (B, C ,T'). + pitch (Tensor): Local conditioning pitch (B, T'). + + Returns: + Tensor: Output tensor (B, C_out, T) + + """ + # perform upsampling + if c is not None and self.upsample_net is not None: + if self.use_pitch_embed: + p = self.pitch_embed(pitch) + c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2) + c = self.upsample_net(c) + assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1)) + + # encode to hidden representation + x = self.first_conv(x) + skips = 0 + for f in self.conv_layers: + x, h = f(x, c) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + + return x + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + @staticmethod + def _get_receptive_field_size(layers, stacks, kernel_size, + dilation=lambda x: 2 ** x): + assert layers % stacks == 0 + layers_per_cycle = layers // stacks + dilations = [dilation(i % layers_per_cycle) for i in range(layers)] + return (kernel_size - 1) * sum(dilations) + 1 + + @property + def receptive_field_size(self): + """Return receptive field size.""" + return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size) + + +class ParallelWaveGANDiscriminator(torch.nn.Module): + """Parallel WaveGAN Discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=10, + conv_channels=64, + dilation_factor=1, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + bias=True, + use_weight_norm=True, + ): + """Initialize Parallel WaveGAN Discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Number of output channels. + layers (int): Number of conv layers. + conv_channels (int): Number of chnn layers. + dilation_factor (int): Dilation factor. For example, if dilation_factor = 2, + the dilation will be 2, 4, 8, ..., and so on. + nonlinear_activation (str): Nonlinear function after each conv. + nonlinear_activation_params (dict): Nonlinear function parameters + bias (bool): Whether to use bias parameter in conv. + use_weight_norm (bool) Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + + """ + super(ParallelWaveGANDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + assert dilation_factor > 0, "Dilation factor must be > 0." + self.conv_layers = torch.nn.ModuleList() + conv_in_channels = in_channels + for i in range(layers - 1): + if i == 0: + dilation = 1 + else: + dilation = i if dilation_factor == 1 else dilation_factor ** i + conv_in_channels = conv_channels + padding = (kernel_size - 1) // 2 * dilation + conv_layer = [ + Conv1d(conv_in_channels, conv_channels, + kernel_size=kernel_size, padding=padding, + dilation=dilation, bias=bias), + getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params) + ] + self.conv_layers += conv_layer + padding = (kernel_size - 1) // 2 + last_conv_layer = Conv1d( + conv_in_channels, out_channels, + kernel_size=kernel_size, padding=padding, bias=bias) + self.conv_layers += [last_conv_layer] + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + Tensor: Output tensor (B, 1, T) + + """ + for f in self.conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) + + +class ResidualParallelWaveGANDiscriminator(torch.nn.Module): + """Parallel WaveGAN Discriminator module.""" + + def __init__(self, + in_channels=1, + out_channels=1, + kernel_size=3, + layers=30, + stacks=3, + residual_channels=64, + gate_channels=128, + skip_channels=64, + dropout=0.0, + bias=True, + use_weight_norm=True, + use_causal_conv=False, + nonlinear_activation="LeakyReLU", + nonlinear_activation_params={"negative_slope": 0.2}, + ): + """Initialize Parallel WaveGAN Discriminator module. + + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + kernel_size (int): Kernel size of dilated convolution. + layers (int): Number of residual block layers. + stacks (int): Number of stacks i.e., dilation cycles. + residual_channels (int): Number of channels in residual conv. + gate_channels (int): Number of channels in gated conv. + skip_channels (int): Number of channels in skip conv. + dropout (float): Dropout rate. 0.0 means no dropout applied. + bias (bool): Whether to use bias parameter in conv. + use_weight_norm (bool): Whether to use weight norm. + If set to true, it will be applied to all of the conv layers. + use_causal_conv (bool): Whether to use causal structure. + nonlinear_activation_params (dict): Nonlinear function parameters + + """ + super(ResidualParallelWaveGANDiscriminator, self).__init__() + assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size." + + self.in_channels = in_channels + self.out_channels = out_channels + self.layers = layers + self.stacks = stacks + self.kernel_size = kernel_size + + # check the number of layers and stacks + assert layers % stacks == 0 + layers_per_stack = layers // stacks + + # define first convolution + self.first_conv = torch.nn.Sequential( + Conv1d1x1(in_channels, residual_channels, bias=True), + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + ) + + # define residual blocks + self.conv_layers = torch.nn.ModuleList() + for layer in range(layers): + dilation = 2 ** (layer % layers_per_stack) + conv = ResidualBlock( + kernel_size=kernel_size, + residual_channels=residual_channels, + gate_channels=gate_channels, + skip_channels=skip_channels, + aux_channels=-1, + dilation=dilation, + dropout=dropout, + bias=bias, + use_causal_conv=use_causal_conv, + ) + self.conv_layers += [conv] + + # define output layers + self.last_conv_layers = torch.nn.ModuleList([ + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + Conv1d1x1(skip_channels, skip_channels, bias=True), + getattr(torch.nn, nonlinear_activation)( + inplace=True, **nonlinear_activation_params), + Conv1d1x1(skip_channels, out_channels, bias=True), + ]) + + # apply weight norm + if use_weight_norm: + self.apply_weight_norm() + + def forward(self, x): + """Calculate forward propagation. + + Args: + x (Tensor): Input noise signal (B, 1, T). + + Returns: + Tensor: Output tensor (B, 1, T) + + """ + x = self.first_conv(x) + + skips = 0 + for f in self.conv_layers: + x, h = f(x, None) + skips += h + skips *= math.sqrt(1.0 / len(self.conv_layers)) + + # apply final layers + x = skips + for f in self.last_conv_layers: + x = f(x) + return x + + def apply_weight_norm(self): + """Apply weight normalization module from all of the layers.""" + def _apply_weight_norm(m): + if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d): + torch.nn.utils.weight_norm(m) + logging.debug(f"Weight norm is applied to {m}.") + + self.apply(_apply_weight_norm) + + def remove_weight_norm(self): + """Remove weight normalization module from all of the layers.""" + def _remove_weight_norm(m): + try: + logging.debug(f"Weight norm is removed from {m}.") + torch.nn.utils.remove_weight_norm(m) + except ValueError: # this module didn't have weight norm + return + + self.apply(_remove_weight_norm) diff --git a/modules/parallel_wavegan/models/source.py b/modules/parallel_wavegan/models/source.py new file mode 100644 index 0000000000000000000000000000000000000000..f2a006e53c0e2194036fd08ea9d6ed4d9a10d6cf --- /dev/null +++ b/modules/parallel_wavegan/models/source.py @@ -0,0 +1,538 @@ +import torch +import numpy as np +import sys +import torch.nn.functional as torch_nn_func + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - + tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, + device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + + # generate sine waveforms + sine_waves = self._f02sine(f0_buf) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class PulseGen(torch.nn.Module): + """ Definition of Pulse train generator + + There are many ways to implement pulse generator. + Here, PulseGen is based on SinGen. For a perfect + """ + def __init__(self, samp_rate, pulse_amp = 0.1, + noise_std = 0.003, voiced_threshold = 0): + super(PulseGen, self).__init__() + self.pulse_amp = pulse_amp + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.noise_std = noise_std + self.l_sinegen = SineGen(self.sampling_rate, harmonic_num=0, \ + sine_amp=self.pulse_amp, noise_std=0, \ + voiced_threshold=self.voiced_threshold, \ + flag_for_pulse=True) + + def forward(self, f0): + """ Pulse train generator + pulse_train, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output pulse_train: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + + Note: self.l_sine doesn't make sure that the initial phase of + a voiced segment is np.pi, the first pulse in a voiced segment + may not be at the first time step within a voiced segment + """ + with torch.no_grad(): + sine_wav, uv, noise = self.l_sinegen(f0) + + # sine without additive noise + pure_sine = sine_wav - noise + + # step t corresponds to a pulse if + # sine[t] > sine[t+1] & sine[t] > sine[t-1] + # & sine[t-1], sine[t+1], and sine[t] are voiced + # or + # sine[t] is voiced, sine[t-1] is unvoiced + # we use torch.roll to simulate sine[t+1] and sine[t-1] + sine_1 = torch.roll(pure_sine, shifts=1, dims=1) + uv_1 = torch.roll(uv, shifts=1, dims=1) + uv_1[:, 0, :] = 0 + sine_2 = torch.roll(pure_sine, shifts=-1, dims=1) + uv_2 = torch.roll(uv, shifts=-1, dims=1) + uv_2[:, -1, :] = 0 + + loc = (pure_sine > sine_1) * (pure_sine > sine_2) \ + * (uv_1 > 0) * (uv_2 > 0) * (uv > 0) \ + + (uv_1 < 1) * (uv > 0) + + # pulse train without noise + pulse_train = pure_sine * loc + + # additive noise to pulse train + # note that noise from sinegen is zero in voiced regions + pulse_noise = torch.randn_like(pure_sine) * self.noise_std + + # with additive noise on pulse, and unvoiced regions + pulse_train += pulse_noise * loc + pulse_noise * (1 - uv) + return pulse_train, sine_wav, uv, pulse_noise + + +class SignalsConv1d(torch.nn.Module): + """ Filtering input signal with time invariant filter + Note: FIRFilter conducted filtering given fixed FIR weight + SignalsConv1d convolves two signals + Note: this is based on torch.nn.functional.conv1d + + """ + + def __init__(self): + super(SignalsConv1d, self).__init__() + + def forward(self, signal, system_ir): + """ output = forward(signal, system_ir) + + signal: (batchsize, length1, dim) + system_ir: (length2, dim) + + output: (batchsize, length1, dim) + """ + if signal.shape[-1] != system_ir.shape[-1]: + print("Error: SignalsConv1d expects shape:") + print("signal (batchsize, length1, dim)") + print("system_id (batchsize, length2, dim)") + print("But received signal: {:s}".format(str(signal.shape))) + print(" system_ir: {:s}".format(str(system_ir.shape))) + sys.exit(1) + padding_length = system_ir.shape[0] - 1 + groups = signal.shape[-1] + + # pad signal on the left + signal_pad = torch_nn_func.pad(signal.permute(0, 2, 1), \ + (padding_length, 0)) + # prepare system impulse response as (dim, 1, length2) + # also flip the impulse response + ir = torch.flip(system_ir.unsqueeze(1).permute(2, 1, 0), \ + dims=[2]) + # convolute + output = torch_nn_func.conv1d(signal_pad, ir, groups=groups) + return output.permute(0, 2, 1) + + +class CyclicNoiseGen_v1(torch.nn.Module): + """ CyclicnoiseGen_v1 + Cyclic noise with a single parameter of beta. + Pytorch v1 implementation assumes f_t is also fixed + """ + + def __init__(self, samp_rate, + noise_std=0.003, voiced_threshold=0): + super(CyclicNoiseGen_v1, self).__init__() + self.samp_rate = samp_rate + self.noise_std = noise_std + self.voiced_threshold = voiced_threshold + + self.l_pulse = PulseGen(samp_rate, pulse_amp=1.0, + noise_std=noise_std, + voiced_threshold=voiced_threshold) + self.l_conv = SignalsConv1d() + + def noise_decay(self, beta, f0mean): + """ decayed_noise = noise_decay(beta, f0mean) + decayed_noise = n[t]exp(-t * f_mean / beta / samp_rate) + + beta: (dim=1) or (batchsize=1, 1, dim=1) + f0mean (batchsize=1, 1, dim=1) + + decayed_noise (batchsize=1, length, dim=1) + """ + with torch.no_grad(): + # exp(-1.0 n / T) < 0.01 => n > -log(0.01)*T = 4.60*T + # truncate the noise when decayed by -40 dB + length = 4.6 * self.samp_rate / f0mean + length = length.int() + time_idx = torch.arange(0, length, device=beta.device) + time_idx = time_idx.unsqueeze(0).unsqueeze(2) + time_idx = time_idx.repeat(beta.shape[0], 1, beta.shape[2]) + + noise = torch.randn(time_idx.shape, device=beta.device) + + # due to Pytorch implementation, use f0_mean as the f0 factor + decay = torch.exp(-time_idx * f0mean / beta / self.samp_rate) + return noise * self.noise_std * decay + + def forward(self, f0s, beta): + """ Producde cyclic-noise + """ + # pulse train + pulse_train, sine_wav, uv, noise = self.l_pulse(f0s) + pure_pulse = pulse_train - noise + + # decayed_noise (length, dim=1) + if (uv < 1).all(): + # all unvoiced + cyc_noise = torch.zeros_like(sine_wav) + else: + f0mean = f0s[uv > 0].mean() + + decayed_noise = self.noise_decay(beta, f0mean)[0, :, :] + # convolute + cyc_noise = self.l_conv(pure_pulse, decayed_noise) + + # add noise in invoiced segments + cyc_noise = cyc_noise + noise * (1.0 - uv) + return cyc_noise, pulse_train, sine_wav, uv, noise + + +class SineGen(torch.nn.Module): + """ Definition of sine generator + SineGen(samp_rate, harmonic_num = 0, + sine_amp = 0.1, noise_std = 0.003, + voiced_threshold = 0, + flag_for_pulse=False) + + samp_rate: sampling rate in Hz + harmonic_num: number of harmonic overtones (default 0) + sine_amp: amplitude of sine-wavefrom (default 0.1) + noise_std: std of Gaussian noise (default 0.003) + voiced_thoreshold: F0 threshold for U/V classification (default 0) + flag_for_pulse: this SinGen is used inside PulseGen (default False) + + Note: when flag_for_pulse is True, the first time step of a voiced + segment is always sin(np.pi) or cos(0) + """ + + def __init__(self, samp_rate, harmonic_num=0, + sine_amp=0.1, noise_std=0.003, + voiced_threshold=0, + flag_for_pulse=False): + super(SineGen, self).__init__() + self.sine_amp = sine_amp + self.noise_std = noise_std + self.harmonic_num = harmonic_num + self.dim = self.harmonic_num + 1 + self.sampling_rate = samp_rate + self.voiced_threshold = voiced_threshold + self.flag_for_pulse = flag_for_pulse + + def _f02uv(self, f0): + # generate uv signal + uv = torch.ones_like(f0) + uv = uv * (f0 > self.voiced_threshold) + return uv + + def _f02sine(self, f0_values): + """ f0_values: (batchsize, length, dim) + where dim indicates fundamental tone and overtones + """ + # convert to F0 in rad. The interger part n can be ignored + # because 2 * np.pi * n doesn't affect phase + rad_values = (f0_values / self.sampling_rate) % 1 + + # initial phase noise (no noise for fundamental component) + rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \ + device=f0_values.device) + rand_ini[:, 0] = 0 + rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini + + # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad) + if not self.flag_for_pulse: + # for normal case + + # To prevent torch.cumsum numerical overflow, + # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1. + # Buffer tmp_over_one_idx indicates the time step to add -1. + # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi + tmp_over_one = torch.cumsum(rad_values, 1) % 1 + tmp_over_one_idx = (tmp_over_one[:, 1:, :] - + tmp_over_one[:, :-1, :]) < 0 + cumsum_shift = torch.zeros_like(rad_values) + cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 + + sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) + * 2 * np.pi) + else: + # If necessary, make sure that the first time step of every + # voiced segments is sin(pi) or cos(0) + # This is used for pulse-train generation + + # identify the last time step in unvoiced segments + uv = self._f02uv(f0_values) + uv_1 = torch.roll(uv, shifts=-1, dims=1) + uv_1[:, -1, :] = 1 + u_loc = (uv < 1) * (uv_1 > 0) + + # get the instantanouse phase + tmp_cumsum = torch.cumsum(rad_values, dim=1) + # different batch needs to be processed differently + for idx in range(f0_values.shape[0]): + temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :] + temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :] + # stores the accumulation of i.phase within + # each voiced segments + tmp_cumsum[idx, :, :] = 0 + tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum + + # rad_values - tmp_cumsum: remove the accumulation of i.phase + # within the previous voiced segment. + i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1) + + # get the sines + sines = torch.cos(i_phase * 2 * np.pi) + return sines + + def forward(self, f0): + """ sine_tensor, uv = forward(f0) + input F0: tensor(batchsize=1, length, dim=1) + f0 for unvoiced steps should be 0 + output sine_tensor: tensor(batchsize=1, length, dim) + output uv: tensor(batchsize=1, length, 1) + """ + with torch.no_grad(): + f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \ + device=f0.device) + # fundamental component + f0_buf[:, :, 0] = f0[:, :, 0] + for idx in np.arange(self.harmonic_num): + # idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic + f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2) + + # generate sine waveforms + sine_waves = self._f02sine(f0_buf) * self.sine_amp + + # generate uv signal + # uv = torch.ones(f0.shape) + # uv = uv * (f0 > self.voiced_threshold) + uv = self._f02uv(f0) + + # noise: for unvoiced should be similar to sine_amp + # std = self.sine_amp/3 -> max value ~ self.sine_amp + # . for voiced regions is self.noise_std + noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3 + noise = noise_amp * torch.randn_like(sine_waves) + + # first: set the unvoiced part to 0 by uv + # then: additive noise + sine_waves = sine_waves * uv + noise + return sine_waves, uv, noise + + +class SourceModuleCycNoise_v1(torch.nn.Module): + """ SourceModuleCycNoise_v1 + SourceModule(sampling_rate, noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + + noise_std: std of Gaussian noise (default: 0.003) + voiced_threshold: threshold to set U/V given F0 (default: 0) + + cyc, noise, uv = SourceModuleCycNoise_v1(F0_upsampled, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, noise_std=0.003, voiced_threshod=0): + super(SourceModuleCycNoise_v1, self).__init__() + self.sampling_rate = sampling_rate + self.noise_std = noise_std + self.l_cyc_gen = CyclicNoiseGen_v1(sampling_rate, noise_std, + voiced_threshod) + + def forward(self, f0_upsamped, beta): + """ + cyc, noise, uv = SourceModuleCycNoise_v1(F0, beta) + F0_upsampled (batchsize, length, 1) + beta (1) + cyc (batchsize, length, 1) + noise (batchsize, length, 1) + uv (batchsize, length, 1) + """ + # source for harmonic branch + cyc, pulse, sine, uv, add_noi = self.l_cyc_gen(f0_upsamped, beta) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.noise_std / 3 + return cyc, noise, uv + + +class SourceModuleHnNSF(torch.nn.Module): + """ SourceModule for hn-nsf + SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0) + sampling_rate: sampling_rate in Hz + harmonic_num: number of harmonic above F0 (default: 0) + sine_amp: amplitude of sine source signal (default: 0.1) + add_noise_std: std of additive Gaussian noise (default: 0.003) + note that amplitude of noise in unvoiced is decided + by sine_amp + voiced_threshold: threhold to set U/V given F0 (default: 0) + + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + uv (batchsize, length, 1) + """ + + def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1, + add_noise_std=0.003, voiced_threshod=0): + super(SourceModuleHnNSF, self).__init__() + + self.sine_amp = sine_amp + self.noise_std = add_noise_std + + # to produce sine waveforms + self.l_sin_gen = SineGen(sampling_rate, harmonic_num, + sine_amp, add_noise_std, voiced_threshod) + + # to merge source harmonics into a single excitation + self.l_linear = torch.nn.Linear(harmonic_num + 1, 1) + self.l_tanh = torch.nn.Tanh() + + def forward(self, x): + """ + Sine_source, noise_source = SourceModuleHnNSF(F0_sampled) + F0_sampled (batchsize, length, 1) + Sine_source (batchsize, length, 1) + noise_source (batchsize, length 1) + """ + # source for harmonic branch + sine_wavs, uv, _ = self.l_sin_gen(x) + sine_merge = self.l_tanh(self.l_linear(sine_wavs)) + + # source for noise branch, in the same shape as uv + noise = torch.randn_like(uv) * self.sine_amp / 3 + return sine_merge, noise, uv + + +if __name__ == '__main__': + source = SourceModuleCycNoise_v1(24000) + x = torch.randn(16, 25600, 1) + + diff --git a/modules/parallel_wavegan/optimizers/__init__.py b/modules/parallel_wavegan/optimizers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e0c5932838281e912079e5784d84d43444a61a --- /dev/null +++ b/modules/parallel_wavegan/optimizers/__init__.py @@ -0,0 +1,2 @@ +from torch.optim import * # NOQA +from .radam import * # NOQA diff --git a/modules/parallel_wavegan/optimizers/radam.py b/modules/parallel_wavegan/optimizers/radam.py new file mode 100644 index 0000000000000000000000000000000000000000..e805d7e34921bee436e1e7fd9e1f753c7609186b --- /dev/null +++ b/modules/parallel_wavegan/optimizers/radam.py @@ -0,0 +1,91 @@ +# -*- coding: utf-8 -*- + +"""RAdam optimizer. + +This code is drived from https://github.com/LiyuanLucasLiu/RAdam. +""" + +import math +import torch + +from torch.optim.optimizer import Optimizer + + +class RAdam(Optimizer): + """Rectified Adam optimizer.""" + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): + """Initilize RAdam optimizer.""" + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + self.buffer = [[None, None, None] for ind in range(10)] + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + """Set state.""" + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + """Run one step.""" + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = torch.zeros_like(p_data_fp32) + state['exp_avg_sq'] = torch.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad) + exp_avg.mul_(beta1).add_(1 - beta1, grad) + + state['step'] += 1 + buffered = self.buffer[int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt( + (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) # NOQA + else: + step_size = 1.0 / (1 - beta1 ** state['step']) + buffered[2] = step_size + + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + + # more conservative since it's an approximated value + if N_sma >= 5: + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom) + else: + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + + p.data.copy_(p_data_fp32) + + return loss diff --git a/modules/parallel_wavegan/stft_loss.py b/modules/parallel_wavegan/stft_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..229e6c777dc9ec7f710842d1e648dba1189ec8b4 --- /dev/null +++ b/modules/parallel_wavegan/stft_loss.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""STFT-based Loss modules.""" +import librosa +import torch + +from modules.parallel_wavegan.losses import LogSTFTMagnitudeLoss, SpectralConvergengeLoss, stft + + +class STFTLoss(torch.nn.Module): + """STFT loss module.""" + + def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", + use_mel_loss=False): + """Initialize STFT loss module.""" + super(STFTLoss, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + self.window = getattr(torch, window)(win_length) + self.spectral_convergenge_loss = SpectralConvergengeLoss() + self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss() + self.use_mel_loss = use_mel_loss + self.mel_basis = None + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Spectral convergence loss value. + Tensor: Log STFT magnitude loss value. + + """ + x_mag = stft(x, self.fft_size, self.shift_size, self.win_length, self.window) + y_mag = stft(y, self.fft_size, self.shift_size, self.win_length, self.window) + if self.use_mel_loss: + if self.mel_basis is None: + self.mel_basis = torch.from_numpy(librosa.filters.mel(22050, self.fft_size, 80)).cuda().T + x_mag = x_mag @ self.mel_basis + y_mag = y_mag @ self.mel_basis + + sc_loss = self.spectral_convergenge_loss(x_mag, y_mag) + mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag) + + return sc_loss, mag_loss + + +class MultiResolutionSTFTLoss(torch.nn.Module): + """Multi resolution STFT loss module.""" + + def __init__(self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window="hann_window", + use_mel_loss=False): + """Initialize Multi resolution STFT loss module. + + Args: + fft_sizes (list): List of FFT sizes. + hop_sizes (list): List of hop sizes. + win_lengths (list): List of window lengths. + window (str): Window function type. + + """ + super(MultiResolutionSTFTLoss, self).__init__() + assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) + self.stft_losses = torch.nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.stft_losses += [STFTLoss(fs, ss, wl, window, use_mel_loss)] + + def forward(self, x, y): + """Calculate forward propagation. + + Args: + x (Tensor): Predicted signal (B, T). + y (Tensor): Groundtruth signal (B, T). + + Returns: + Tensor: Multi resolution spectral convergence loss value. + Tensor: Multi resolution log STFT magnitude loss value. + + """ + sc_loss = 0.0 + mag_loss = 0.0 + for f in self.stft_losses: + sc_l, mag_l = f(x, y) + sc_loss += sc_l + mag_loss += mag_l + sc_loss /= len(self.stft_losses) + mag_loss /= len(self.stft_losses) + + return sc_loss, mag_loss diff --git a/modules/parallel_wavegan/utils/__init__.py b/modules/parallel_wavegan/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e8fa95a020706b5412c3959fbf6e5980019c0d5f --- /dev/null +++ b/modules/parallel_wavegan/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * # NOQA diff --git a/modules/parallel_wavegan/utils/utils.py b/modules/parallel_wavegan/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d48a5ed28e8555d4b8cfb15fdee86426bbb9e368 --- /dev/null +++ b/modules/parallel_wavegan/utils/utils.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +# Copyright 2019 Tomoki Hayashi +# MIT License (https://opensource.org/licenses/MIT) + +"""Utility functions.""" + +import fnmatch +import logging +import os +import sys + +import h5py +import numpy as np + + +def find_files(root_dir, query="*.wav", include_root_dir=True): + """Find files recursively. + + Args: + root_dir (str): Root root_dir to find. + query (str): Query to find. + include_root_dir (bool): If False, root_dir name is not included. + + Returns: + list: List of found filenames. + + """ + files = [] + for root, dirnames, filenames in os.walk(root_dir, followlinks=True): + for filename in fnmatch.filter(filenames, query): + files.append(os.path.join(root, filename)) + if not include_root_dir: + files = [file_.replace(root_dir + "/", "") for file_ in files] + + return files + + +def read_hdf5(hdf5_name, hdf5_path): + """Read hdf5 dataset. + + Args: + hdf5_name (str): Filename of hdf5 file. + hdf5_path (str): Dataset name in hdf5 file. + + Return: + any: Dataset values. + + """ + if not os.path.exists(hdf5_name): + logging.error(f"There is no such a hdf5 file ({hdf5_name}).") + sys.exit(1) + + hdf5_file = h5py.File(hdf5_name, "r") + + if hdf5_path not in hdf5_file: + logging.error(f"There is no such a data in hdf5 file. ({hdf5_path})") + sys.exit(1) + + hdf5_data = hdf5_file[hdf5_path][()] + hdf5_file.close() + + return hdf5_data + + +def write_hdf5(hdf5_name, hdf5_path, write_data, is_overwrite=True): + """Write dataset to hdf5. + + Args: + hdf5_name (str): Hdf5 dataset filename. + hdf5_path (str): Dataset path in hdf5. + write_data (ndarray): Data to write. + is_overwrite (bool): Whether to overwrite dataset. + + """ + # convert to numpy array + write_data = np.array(write_data) + + # check folder existence + folder_name, _ = os.path.split(hdf5_name) + if not os.path.exists(folder_name) and len(folder_name) != 0: + os.makedirs(folder_name) + + # check hdf5 existence + if os.path.exists(hdf5_name): + # if already exists, open with r+ mode + hdf5_file = h5py.File(hdf5_name, "r+") + # check dataset existence + if hdf5_path in hdf5_file: + if is_overwrite: + logging.warning("Dataset in hdf5 file already exists. " + "recreate dataset in hdf5.") + hdf5_file.__delitem__(hdf5_path) + else: + logging.error("Dataset in hdf5 file already exists. " + "if you want to overwrite, please set is_overwrite = True.") + hdf5_file.close() + sys.exit(1) + else: + # if not exists, open with w mode + hdf5_file = h5py.File(hdf5_name, "w") + + # write data to hdf5 + hdf5_file.create_dataset(hdf5_path, data=write_data) + hdf5_file.flush() + hdf5_file.close() + + +class HDF5ScpLoader(object): + """Loader class for a fests.scp file of hdf5 file. + + Examples: + key1 /some/path/a.h5:feats + key2 /some/path/b.h5:feats + key3 /some/path/c.h5:feats + key4 /some/path/d.h5:feats + ... + >>> loader = HDF5ScpLoader("hdf5.scp") + >>> array = loader["key1"] + + key1 /some/path/a.h5 + key2 /some/path/b.h5 + key3 /some/path/c.h5 + key4 /some/path/d.h5 + ... + >>> loader = HDF5ScpLoader("hdf5.scp", "feats") + >>> array = loader["key1"] + + """ + + def __init__(self, feats_scp, default_hdf5_path="feats"): + """Initialize HDF5 scp loader. + + Args: + feats_scp (str): Kaldi-style feats.scp file with hdf5 format. + default_hdf5_path (str): Path in hdf5 file. If the scp contain the info, not used. + + """ + self.default_hdf5_path = default_hdf5_path + with open(feats_scp) as f: + lines = [line.replace("\n", "") for line in f.readlines()] + self.data = {} + for line in lines: + key, value = line.split() + self.data[key] = value + + def get_path(self, key): + """Get hdf5 file path for a given key.""" + return self.data[key] + + def __getitem__(self, key): + """Get ndarray for a given key.""" + p = self.data[key] + if ":" in p: + return read_hdf5(*p.split(":")) + else: + return read_hdf5(p, self.default_hdf5_path) + + def __len__(self): + """Return the length of the scp file.""" + return len(self.data) + + def __iter__(self): + """Return the iterator of the scp file.""" + return iter(self.data) + + def keys(self): + """Return the keys of the scp file.""" + return self.data.keys() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..b8f177662183b701add1f77712a22612682e45e2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +matplotlib +librosa==0.8.0 +tqdm +pandas +numba==0.53.1 +PyYAML==5.3.1 +tensorboardX +pyloudnorm +setuptools>=41.0.0 +g2p_en +resemblyzer +webrtcvad +tensorboard==2.6.0 +scikit-image +textgrid +jiwer +pycwt +PyWavelets +praat-parselmouth==0.3.3 +jieba +einops +chardet +h5py diff --git a/tasks/base_task.py b/tasks/base_task.py new file mode 100644 index 0000000000000000000000000000000000000000..aa31903693c814af1e9a75cd64071e883dca4aa1 --- /dev/null +++ b/tasks/base_task.py @@ -0,0 +1,355 @@ +from itertools import chain + +from torch.utils.data import ConcatDataset +from torch.utils.tensorboard import SummaryWriter +import subprocess +import traceback +from datetime import datetime +from functools import wraps +from utils.hparams import hparams +import random +import sys +import numpy as np +from utils.trainer import Trainer +from torch import nn +import torch.utils.data +import utils +import logging +import os + +torch.multiprocessing.set_sharing_strategy(os.getenv('TORCH_SHARE_STRATEGY', 'file_system')) + +log_format = '%(asctime)s %(message)s' +logging.basicConfig(stream=sys.stdout, level=logging.INFO, + format=log_format, datefmt='%m/%d %I:%M:%S %p') + + +def data_loader(fn): + """ + Decorator to make any fx with this use the lazy property + :param fn: + :return: + """ + + wraps(fn) + attr_name = '_lazy_' + fn.__name__ + + def _get_data_loader(self): + try: + value = getattr(self, attr_name) + except AttributeError: + try: + value = fn(self) # Lazy evaluation, done only once. + except AttributeError as e: + # Guard against AttributeError suppression. (Issue #142) + traceback.print_exc() + error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) + raise RuntimeError(error) from e + setattr(self, attr_name, value) # Memoize evaluation. + return value + + return _get_data_loader + + +class BaseDataset(torch.utils.data.Dataset): + def __init__(self, shuffle): + super().__init__() + self.hparams = hparams + self.shuffle = shuffle + self.sort_by_len = hparams['sort_by_len'] + self.sizes = None + + @property + def _sizes(self): + return self.sizes + + def __getitem__(self, index): + raise NotImplementedError + + def collater(self, samples): + raise NotImplementedError + + def __len__(self): + return len(self._sizes) + + def num_tokens(self, index): + return self.size(index) + + def size(self, index): + """Return an example's size as a float or tuple. This value is used when + filtering a dataset with ``--max-positions``.""" + return min(self._sizes[index], hparams['max_frames']) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.shuffle: + indices = np.random.permutation(len(self)) + if self.sort_by_len: + indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] + else: + indices = np.arange(len(self)) + return indices + + @property + def num_workers(self): + return int(os.getenv('NUM_WORKERS', hparams['ds_workers'])) + + +class BaseConcatDataset(ConcatDataset): + def collater(self, samples): + return self.datasets[0].collater(samples) + + @property + def _sizes(self): + if not hasattr(self, 'sizes'): + self.sizes = list(chain.from_iterable([d._sizes for d in self.datasets])) + return self.sizes + + def size(self, index): + return min(self._sizes[index], hparams['max_frames']) + + def num_tokens(self, index): + return self.size(index) + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + if self.datasets[0].shuffle: + indices = np.random.permutation(len(self)) + if self.datasets[0].sort_by_len: + indices = indices[np.argsort(np.array(self._sizes)[indices], kind='mergesort')] + else: + indices = np.arange(len(self)) + return indices + + @property + def num_workers(self): + return self.datasets[0].num_workers + + +class BaseTask(nn.Module): + def __init__(self, *args, **kwargs): + # dataset configs + super(BaseTask, self).__init__() + self.current_epoch = 0 + self.global_step = 0 + self.trainer = None + self.use_ddp = False + self.gradient_clip_norm = hparams['clip_grad_norm'] + self.gradient_clip_val = hparams.get('clip_grad_value', 0) + self.model = None + self.training_losses_meter = None + self.logger: SummaryWriter = None + + ###################### + # build model, dataloaders, optimizer, scheduler and tensorboard + ###################### + def build_model(self): + raise NotImplementedError + + @data_loader + def train_dataloader(self): + raise NotImplementedError + + @data_loader + def test_dataloader(self): + raise NotImplementedError + + @data_loader + def val_dataloader(self): + raise NotImplementedError + + def build_scheduler(self, optimizer): + return None + + def build_optimizer(self, model): + raise NotImplementedError + + def configure_optimizers(self): + optm = self.build_optimizer(self.model) + self.scheduler = self.build_scheduler(optm) + if isinstance(optm, (list, tuple)): + return optm + return [optm] + + def build_tensorboard(self, save_dir, name, version, **kwargs): + root_dir = os.path.join(save_dir, name) + os.makedirs(root_dir, exist_ok=True) + log_dir = os.path.join(root_dir, "version_" + str(version)) + self.logger = SummaryWriter(log_dir=log_dir, **kwargs) + + ###################### + # training + ###################### + def on_train_start(self): + pass + + def on_epoch_start(self): + self.training_losses_meter = {'total_loss': utils.AvgrageMeter()} + + def _training_step(self, sample, batch_idx, optimizer_idx): + """ + + :param sample: + :param batch_idx: + :return: total loss: torch.Tensor, loss_log: dict + """ + raise NotImplementedError + + def training_step(self, sample, batch_idx, optimizer_idx=-1): + """ + + :param sample: + :param batch_idx: + :param optimizer_idx: + :return: {'loss': torch.Tensor, 'progress_bar': dict, 'tb_log': dict} + """ + loss_ret = self._training_step(sample, batch_idx, optimizer_idx) + if loss_ret is None: + return {'loss': None} + total_loss, log_outputs = loss_ret + log_outputs = utils.tensors_to_scalars(log_outputs) + for k, v in log_outputs.items(): + if k not in self.training_losses_meter: + self.training_losses_meter[k] = utils.AvgrageMeter() + if not np.isnan(v): + self.training_losses_meter[k].update(v) + self.training_losses_meter['total_loss'].update(total_loss.item()) + + if optimizer_idx >= 0: + log_outputs[f'lr_{optimizer_idx}'] = self.trainer.optimizers[optimizer_idx].param_groups[0]['lr'] + + progress_bar_log = log_outputs + tb_log = {f'tr/{k}': v for k, v in log_outputs.items()} + return { + 'loss': total_loss, + 'progress_bar': progress_bar_log, + 'tb_log': tb_log + } + + def on_before_optimization(self, opt_idx): + if self.gradient_clip_norm > 0: + torch.nn.utils.clip_grad_norm_(self.parameters(), self.gradient_clip_norm) + if self.gradient_clip_val > 0: + torch.nn.utils.clip_grad_value_(self.parameters(), self.gradient_clip_val) + + def on_after_optimization(self, epoch, batch_idx, optimizer, optimizer_idx): + if self.scheduler is not None: + self.scheduler.step(self.global_step // hparams['accumulate_grad_batches']) + + def on_epoch_end(self): + loss_outputs = {k: round(v.avg, 4) for k, v in self.training_losses_meter.items()} + print(f"Epoch {self.current_epoch} ended. Steps: {self.global_step}. {loss_outputs}") + + def on_train_end(self): + pass + + ###################### + # validation + ###################### + def validation_step(self, sample, batch_idx): + """ + + :param sample: + :param batch_idx: + :return: output: {"losses": {...}, "total_loss": float, ...} or (total loss: torch.Tensor, loss_log: dict) + """ + raise NotImplementedError + + def validation_end(self, outputs): + """ + + :param outputs: + :return: loss_output: dict + """ + all_losses_meter = {'total_loss': utils.AvgrageMeter()} + for output in outputs: + if len(output) == 0 or output is None: + continue + if isinstance(output, dict): + assert 'losses' in output, 'Key "losses" should exist in validation output.' + n = output.pop('nsamples', 1) + losses = utils.tensors_to_scalars(output['losses']) + total_loss = output.get('total_loss', sum(losses.values())) + else: + assert len(output) == 2, 'Validation output should only consist of two elements: (total_loss, losses)' + n = 1 + total_loss, losses = output + losses = utils.tensors_to_scalars(losses) + if isinstance(total_loss, torch.Tensor): + total_loss = total_loss.item() + for k, v in losses.items(): + if k not in all_losses_meter: + all_losses_meter[k] = utils.AvgrageMeter() + all_losses_meter[k].update(v, n) + all_losses_meter['total_loss'].update(total_loss, n) + loss_output = {k: round(v.avg, 4) for k, v in all_losses_meter.items()} + print(f"| Valid results: {loss_output}") + return { + 'tb_log': {f'val/{k}': v for k, v in loss_output.items()}, + 'val_loss': loss_output['total_loss'] + } + + ###################### + # testing + ###################### + def test_start(self): + pass + + def test_step(self, sample, batch_idx): + return self.validation_step(sample, batch_idx) + + def test_end(self, outputs): + return self.validation_end(outputs) + + ###################### + # utils + ###################### + def load_ckpt(self, ckpt_base_dir, current_model_name=None, model_name='model', force=True, strict=True): + if current_model_name is None: + current_model_name = model_name + utils.load_ckpt(self.__getattr__(current_model_name), ckpt_base_dir, current_model_name, force, strict) + + ###################### + # start training/testing + ###################### + @classmethod + def start(cls): + os.environ['MASTER_PORT'] = str(random.randint(15000, 30000)) + random.seed(hparams['seed']) + np.random.seed(hparams['seed']) + work_dir = hparams['work_dir'] + trainer = Trainer( + work_dir=work_dir, + val_check_interval=hparams['val_check_interval'], + tb_log_interval=hparams['tb_log_interval'], + max_updates=hparams['max_updates'], + num_sanity_val_steps=hparams['num_sanity_val_steps'] if not hparams['validate'] else 10000, + accumulate_grad_batches=hparams['accumulate_grad_batches'], + print_nan_grads=hparams['print_nan_grads'], + resume_from_checkpoint=hparams.get('resume_from_checkpoint', 0), + amp=hparams['amp'], + # save ckpt + monitor_key=hparams['valid_monitor_key'], + monitor_mode=hparams['valid_monitor_mode'], + num_ckpt_keep=hparams['num_ckpt_keep'], + save_best=hparams['save_best'], + seed=hparams['seed'], + debug=hparams['debug'] + ) + if not hparams['inference']: # train + if len(hparams['save_codes']) > 0: + t = datetime.now().strftime('%Y%m%d%H%M%S') + code_dir = f'{work_dir}/codes/{t}' + subprocess.check_call(f'mkdir -p "{code_dir}"', shell=True) + for c in hparams['save_codes']: + if os.path.exists(c): + subprocess.check_call(f'rsync -av --exclude=__pycache__ "{c}" "{code_dir}/"', shell=True) + print(f"| Copied codes to {code_dir}.") + trainer.fit(cls) + else: + trainer.test(cls) + + def on_keyboard_interrupt(self): + pass diff --git a/tasks/run.py b/tasks/run.py new file mode 100644 index 0000000000000000000000000000000000000000..82c7559cec873eebf7c2c0ab6554895e21de7e7c --- /dev/null +++ b/tasks/run.py @@ -0,0 +1,15 @@ +import importlib +from utils.hparams import set_hparams, hparams + + +def run_task(): + assert hparams['task_cls'] != '' + pkg = ".".join(hparams["task_cls"].split(".")[:-1]) + cls_name = hparams["task_cls"].split(".")[-1] + task_cls = getattr(importlib.import_module(pkg), cls_name) + task_cls.start() + + +if __name__ == '__main__': + set_hparams() + run_task() diff --git a/tasks/tts/dataset_utils.py b/tasks/tts/dataset_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..488e616dd63cb8fdf30c47e037a2acc21c41c7f3 --- /dev/null +++ b/tasks/tts/dataset_utils.py @@ -0,0 +1,260 @@ +from utils.cwt import get_lf0_cwt +import torch.optim +import torch.utils.data +import importlib +from utils.indexed_datasets import IndexedDataset +from utils.pitch_utils import norm_interp_f0, denorm_f0, f0_to_coarse +import numpy as np +from tasks.base_task import BaseDataset +import torch +import torch.optim +import torch.utils.data +import utils +import torch.distributions +from utils.hparams import hparams +from utils.pitch_utils import norm_interp_f0 +from resemblyzer import VoiceEncoder +import json +from data_gen.tts.data_gen_utils import build_phone_encoder + +class BaseTTSDataset(BaseDataset): + def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None): + super().__init__(shuffle) + self.data_dir = hparams['binary_data_dir'] if data_dir is None else data_dir + self.prefix = prefix + self.hparams = hparams + self.indexed_ds = None + self.ext_mel2ph = None + + def load_size(): + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + + if prefix == 'test' or hparams['inference']: + if test_items is not None: + self.indexed_ds, self.sizes = test_items, test_sizes + else: + load_size() + if hparams['num_test_samples'] > 0: + self.avail_idxs = [x for x in range(hparams['num_test_samples']) \ + if x < len(self.sizes)] + if len(hparams['test_ids']) > 0: + self.avail_idxs = hparams['test_ids'] + self.avail_idxs + else: + self.avail_idxs = list(range(len(self.sizes))) + else: + load_size() + self.avail_idxs = list(range(len(self.sizes))) + + if hparams['min_frames'] > 0: + self.avail_idxs = [ + x for x in self.avail_idxs if self.sizes[x] >= hparams['min_frames']] + self.sizes = [self.sizes[i] for i in self.avail_idxs] + + def _get_item(self, index): + if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: + index = self.avail_idxs[index] + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + return self.indexed_ds[index] + + def __getitem__(self, index): + hparams = self.hparams + item = self._get_item(index) + assert len(item['mel']) == self.sizes[index], (len(item['mel']), self.sizes[index]) + max_frames = hparams['max_frames'] + spec = torch.Tensor(item['mel'])[:max_frames] + max_frames = spec.shape[0] // hparams['frames_multiple'] * hparams['frames_multiple'] + spec = spec[:max_frames] + phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']]) + sample = { + "id": index, + "item_name": item['item_name'], + "text": item['txt'], + "txt_token": phone, + "mel": spec, + "mel_nonpadding": spec.abs().sum(-1) > 0, + } + if hparams['use_spk_embed']: + sample["spk_embed"] = torch.Tensor(item['spk_embed']) + if hparams['use_spk_id']: + sample["spk_id"] = item['spk_id'] + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + hparams = self.hparams + id = torch.LongTensor([s['id'] for s in samples]) + item_names = [s['item_name'] for s in samples] + text = [s['text'] for s in samples] + txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0) + mels = utils.collate_2d([s['mel'] for s in samples], 0.0) + txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples]) + mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) + + batch = { + 'id': id, + 'item_name': item_names, + 'nsamples': len(samples), + 'text': text, + 'txt_tokens': txt_tokens, + 'txt_lengths': txt_lengths, + 'mels': mels, + 'mel_lengths': mel_lengths, + } + + if hparams['use_spk_embed']: + spk_embed = torch.stack([s['spk_embed'] for s in samples]) + batch['spk_embed'] = spk_embed + if hparams['use_spk_id']: + spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) + batch['spk_ids'] = spk_ids + return batch + + +class FastSpeechDataset(BaseTTSDataset): + def __init__(self, prefix, shuffle=False, test_items=None, test_sizes=None, data_dir=None): + super().__init__(prefix, shuffle, test_items, test_sizes, data_dir) + self.f0_mean, self.f0_std = hparams.get('f0_mean', None), hparams.get('f0_std', None) + if prefix == 'test' and hparams['test_input_dir'] != '': + self.data_dir = hparams['test_input_dir'] + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + self.indexed_ds = sorted(self.indexed_ds, key=lambda item: item['item_name']) + items = {} + for i in range(len(self.indexed_ds)): + speaker = self.indexed_ds[i]['item_name'].split('_')[0] + if speaker not in items.keys(): + items[speaker] = [i] + else: + items[speaker].append(i) + sort_item = sorted(items.values(), key=lambda item_pre_speaker: len(item_pre_speaker), reverse=True) + self.avail_idxs = [n for a in sort_item for n in a][:hparams['num_test_samples']] + self.indexed_ds, self.sizes = self.load_test_inputs() + self.avail_idxs = [i for i in range(hparams['num_test_samples'])] + + if hparams['pitch_type'] == 'cwt': + _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10)) + + def __getitem__(self, index): + sample = super(FastSpeechDataset, self).__getitem__(index) + item = self._get_item(index) + hparams = self.hparams + max_frames = hparams['max_frames'] + spec = sample['mel'] + T = spec.shape[0] + phone = sample['txt_token'] + sample['energy'] = (spec.exp() ** 2).sum(-1).sqrt() + sample['mel2ph'] = mel2ph = torch.LongTensor(item['mel2ph'])[:T] if 'mel2ph' in item else None + if hparams['use_pitch_embed']: + assert 'f0' in item + if hparams.get('normalize_pitch', False): + f0 = item["f0"] + if len(f0 > 0) > 0 and f0[f0 > 0].std() > 0: + f0[f0 > 0] = (f0[f0 > 0] - f0[f0 > 0].mean()) / f0[f0 > 0].std() * hparams['f0_std'] + \ + hparams['f0_mean'] + f0[f0 > 0] = f0[f0 > 0].clip(min=60, max=500) + pitch = f0_to_coarse(f0) + pitch = torch.LongTensor(pitch[:max_frames]) + else: + pitch = torch.LongTensor(item.get("pitch"))[:max_frames] if "pitch" in item else None + f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) + uv = torch.FloatTensor(uv) + f0 = torch.FloatTensor(f0) + if hparams['pitch_type'] == 'cwt': + cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames] + f0_mean = item.get('f0_mean', item.get('cwt_mean')) + f0_std = item.get('f0_std', item.get('cwt_std')) + sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std}) + elif hparams['pitch_type'] == 'ph': + if "f0_ph" in item: + f0 = torch.FloatTensor(item['f0_ph']) + else: + f0 = denorm_f0(f0, None, hparams) + f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0) + f0_phlevel_num = torch.zeros_like(phone).float().scatter_add( + 0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1) + f0_ph = f0_phlevel_sum / f0_phlevel_num + f0, uv = norm_interp_f0(f0_ph, hparams) + else: + f0 = uv = torch.zeros_like(mel2ph) + pitch = None + sample["f0"], sample["uv"], sample["pitch"] = f0, uv, pitch + if hparams['use_spk_embed']: + sample["spk_embed"] = torch.Tensor(item['spk_embed']) + if hparams['use_spk_id']: + sample["spk_id"] = item['spk_id'] + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + hparams = self.hparams + batch = super(FastSpeechDataset, self).collater(samples) + f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) + pitch = utils.collate_1d([s['pitch'] for s in samples]) if samples[0]['pitch'] is not None else None + uv = utils.collate_1d([s['uv'] for s in samples]) + energy = utils.collate_1d([s['energy'] for s in samples], 0.0) + mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ + if samples[0]['mel2ph'] is not None else None + batch.update({ + 'mel2ph': mel2ph, + 'energy': energy, + 'pitch': pitch, + 'f0': f0, + 'uv': uv, + }) + if hparams['pitch_type'] == 'cwt': + cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples]) + f0_mean = torch.Tensor([s['f0_mean'] for s in samples]) + f0_std = torch.Tensor([s['f0_std'] for s in samples]) + batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std}) + return batch + + def load_test_inputs(self): + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + ph_set_fn = f"{hparams['binary_data_dir']}/phone_set.json" + ph_set = json.load(open(ph_set_fn, 'r')) + print("| phone set: ", ph_set) + phone_encoder = build_phone_encoder(hparams['binary_data_dir']) + word_encoder = None + voice_encoder = VoiceEncoder().cuda() + encoder = [phone_encoder, word_encoder] + sizes = [] + items = [] + for i in range(len(self.avail_idxs)): + item = self._get_item(i) + + item2tgfn = f"{hparams['test_input_dir'].replace('binary', 'processed')}/mfa_outputs/{item['item_name']}.TextGrid" + item = binarizer_cls.process_item(item['item_name'], item['ph'], item['txt'], item2tgfn, + item['wav_fn'], item['spk_id'], encoder, hparams['binarization_args']) + item['spk_embed'] = voice_encoder.embed_utterance(item['wav']) \ + if hparams['binarization_args']['with_spk_embed'] else None # 判断是否保存embedding文件 + items.append(item) + sizes.append(item['len']) + return items, sizes + +class FastSpeechWordDataset(FastSpeechDataset): + def __getitem__(self, index): + sample = super(FastSpeechWordDataset, self).__getitem__(index) + item = self._get_item(index) + max_frames = hparams['max_frames'] + sample["ph_words"] = item["ph_words"] + sample["word_tokens"] = torch.LongTensor(item["word_tokens"]) + sample["mel2word"] = torch.LongTensor(item.get("mel2word"))[:max_frames] + sample["ph2word"] = torch.LongTensor(item['ph2word'][:hparams['max_input_tokens']]) + return sample + + def collater(self, samples): + batch = super(FastSpeechWordDataset, self).collater(samples) + ph_words = [s['ph_words'] for s in samples] + batch['ph_words'] = ph_words + word_tokens = utils.collate_1d([s['word_tokens'] for s in samples], 0) + batch['word_tokens'] = word_tokens + mel2word = utils.collate_1d([s['mel2word'] for s in samples], 0) + batch['mel2word'] = mel2word + ph2word = utils.collate_1d([s['ph2word'] for s in samples], 0) + batch['ph2word'] = ph2word + return batch diff --git a/tasks/tts/fs2.py b/tasks/tts/fs2.py new file mode 100755 index 0000000000000000000000000000000000000000..473c514b523ecbd45acfdecdb33d7b633c59eb6c --- /dev/null +++ b/tasks/tts/fs2.py @@ -0,0 +1,292 @@ +import matplotlib +matplotlib.use('Agg') + +from tasks.tts.tts_base import TTSBaseTask +from vocoders.base_vocoder import get_vocoder_cls +from tasks.tts.dataset_utils import FastSpeechDataset +from modules.commons.ssim import ssim +import os +from modules.fastspeech.tts_modules import mel2ph_to_dur +from utils.hparams import hparams +from utils.plot import spec_to_figure, dur_to_figure, f0_to_figure +from utils.pitch_utils import denorm_f0 +from modules.fastspeech.fs2 import FastSpeech2 +import torch +import torch.optim +import torch.utils.data +import torch.nn.functional as F +import utils +import torch.distributions +import numpy as np + + +class FastSpeech2Task(TTSBaseTask): + def __init__(self): + super(FastSpeech2Task, self).__init__() + self.dataset_cls = FastSpeechDataset + self.mse_loss_fn = torch.nn.MSELoss() + mel_losses = hparams['mel_loss'].split("|") + self.loss_and_lambda = {} + for i, l in enumerate(mel_losses): + if l == '': + continue + if ':' in l: + l, lbd = l.split(":") + lbd = float(lbd) + else: + lbd = 1.0 + self.loss_and_lambda[l] = lbd + print("| Mel losses:", self.loss_and_lambda) + self.sil_ph = self.phone_encoder.sil_phonemes() + f0_stats_fn = f'{hparams["binary_data_dir"]}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + + def build_tts_model(self): + self.model = FastSpeech2(self.phone_encoder) + + def build_model(self): + self.build_tts_model() + if hparams['load_ckpt'] != '': + self.load_ckpt(hparams['load_ckpt'], strict=False) + utils.print_arch(self.model) + return self.model + + def _training_step(self, sample, batch_idx, _): + loss_output = self.run_model(self.model, sample) + total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + loss_output['batch_size'] = sample['txt_tokens'].size()[0] + return total_loss, loss_output + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + mel_out = self.model.out2mel(model_out['mel_out']) + outputs = utils.tensors_to_scalars(outputs) + if self.global_step % hparams['valid_infer_interval'] == 0 \ + and batch_idx < hparams['num_valid_plots']: + vmin = hparams['mel_vmin'] + vmax = hparams['mel_vmax'] + self.plot_mel(batch_idx, sample['mels'], mel_out) + self.plot_dur(batch_idx, sample, model_out) + if hparams['use_pitch_embed']: + self.plot_pitch(batch_idx, sample, model_out) + if self.vocoder is None: + self.vocoder = get_vocoder_cls(hparams)() + if self.global_step > 0: + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + # with gt duration + model_out = self.model(sample['txt_tokens'], mel2ph=sample['mel2ph'], + spk_embed=spk_embed, infer=True) + wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu()) + self.logger.add_audio(f'wav_gtdur_{batch_idx}', wav_pred, self.global_step, + hparams['audio_sample_rate']) + self.logger.add_figure( + f'mel_gtdur_{batch_idx}', + spec_to_figure(model_out['mel_out'][0], vmin, vmax), self.global_step) + # with pred duration + model_out = self.model(sample['txt_tokens'], spk_embed=spk_embed, infer=True) + self.logger.add_figure( + f'mel_{batch_idx}', + spec_to_figure(model_out['mel_out'][0], vmin, vmax), self.global_step) + wav_pred = self.vocoder.spec2wav(model_out['mel_out'][0].cpu()) + self.logger.add_audio(f'wav_{batch_idx}', wav_pred, self.global_step, hparams['audio_sample_rate']) + # gt wav + if self.global_step <= hparams['valid_infer_interval']: + mel_gt = sample['mels'][0].cpu() + wav_gt = self.vocoder.spec2wav(mel_gt) + self.logger.add_audio(f'wav_gt_{batch_idx}', wav_gt, self.global_step, 22050) + return outputs + + def run_model(self, model, sample, return_output=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, + tgt_mels=target, infer=False) + losses = {} + self.add_mel_loss(output['mel_out'], target, losses) + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if not return_output: + return losses + else: + return losses, output + + ############ + # losses + ############ + def add_mel_loss(self, mel_out, target, losses, postfix='', mel_mix_loss=None): + nonpadding = target.abs().sum(-1).ne(0).float() + for loss_name, lbd in self.loss_and_lambda.items(): + if 'l1' == loss_name: + l = self.l1_loss(mel_out, target) + elif 'mse' == loss_name: + l = self.mse_loss(mel_out, target) + elif 'ssim' == loss_name: + l = self.ssim_loss(mel_out, target) + elif 'gdl' == loss_name: + l = self.gdl_loss_fn(mel_out, target, nonpadding) \ + * self.loss_and_lambda['gdl'] + losses[f'{loss_name}{postfix}'] = l * lbd + + def l1_loss(self, decoder_output, target): + # decoder_output : B x T x n_mel + # target : B x T x n_mel + l1_loss = F.l1_loss(decoder_output, target, reduction='none') + weights = self.weights_nonzero_speech(target) + l1_loss = (l1_loss * weights).sum() / weights.sum() + return l1_loss + + def add_energy_loss(self, energy_pred, energy, losses): + nonpadding = (energy != 0).float() + loss = (F.mse_loss(energy_pred, energy, reduction='none') * nonpadding).sum() / nonpadding.sum() + loss = loss * hparams['lambda_energy'] + losses['e'] = loss + + def mse_loss(self, decoder_output, target): + # decoder_output : B x T x n_mel + # target : B x T x n_mel + assert decoder_output.shape == target.shape + mse_loss = F.mse_loss(decoder_output, target, reduction='none') + weights = self.weights_nonzero_speech(target) + mse_loss = (mse_loss * weights).sum() / weights.sum() + return mse_loss + + def ssim_loss(self, decoder_output, target, bias=6.0): + # decoder_output : B x T x n_mel + # target : B x T x n_mel + assert decoder_output.shape == target.shape + weights = self.weights_nonzero_speech(target) + decoder_output = decoder_output[:, None] + bias + target = target[:, None] + bias + ssim_loss = 1 - ssim(decoder_output, target, size_average=False) + ssim_loss = (ssim_loss * weights).sum() / weights.sum() + return ssim_loss + + def add_dur_loss(self, dur_pred, mel2ph, txt_tokens, losses=None): + """ + + :param dur_pred: [B, T], float, log scale + :param mel2ph: [B, T] + :param txt_tokens: [B, T] + :param losses: + :return: + """ + B, T = txt_tokens.shape + nonpadding = (txt_tokens != 0).float() + dur_gt = mel2ph_to_dur(mel2ph, T).float() * nonpadding + is_sil = torch.zeros_like(txt_tokens).bool() + for p in self.sil_ph: + is_sil = is_sil | (txt_tokens == self.phone_encoder.encode(p)[0]) + is_sil = is_sil.float() # [B, T_txt] + losses['pdur'] = F.mse_loss(dur_pred, (dur_gt + 1).log(), reduction='none') + losses['pdur'] = (losses['pdur'] * nonpadding).sum() / nonpadding.sum() + losses['pdur'] = losses['pdur'] * hparams['lambda_ph_dur'] + dur_pred = (dur_pred.exp() - 1).clamp(min=0) + # use linear scale for sent and word duration + if hparams['lambda_word_dur'] > 0: + word_id = (is_sil.cumsum(-1) * (1 - is_sil)).long() + word_dur_p = dur_pred.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_pred)[:, 1:] + word_dur_g = dur_gt.new_zeros([B, word_id.max() + 1]).scatter_add(1, word_id, dur_gt)[:, 1:] + wdur_loss = F.mse_loss((word_dur_p + 1).log(), (word_dur_g + 1).log(), reduction='none') + word_nonpadding = (word_dur_g > 0).float() + wdur_loss = (wdur_loss * word_nonpadding).sum() / word_nonpadding.sum() + losses['wdur'] = wdur_loss * hparams['lambda_word_dur'] + if hparams['lambda_sent_dur'] > 0: + sent_dur_p = dur_pred.sum(-1) + sent_dur_g = dur_gt.sum(-1) + sdur_loss = F.mse_loss((sent_dur_p + 1).log(), (sent_dur_g + 1).log(), reduction='mean') + losses['sdur'] = sdur_loss.mean() * hparams['lambda_sent_dur'] + + def add_pitch_loss(self, output, sample, losses): + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \ + else (sample['txt_tokens'] != 0).float() + self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) # output['pitch_pred']: [B, T, 2], f0: [B, T], uv: [B, T] + + def add_f0_loss(self, p_pred, f0, uv, losses, nonpadding, postfix=''): + assert p_pred[..., 0].shape == f0.shape + if hparams['use_uv'] and hparams['pitch_type'] == 'frame': + assert p_pred[..., 1].shape == uv.shape, (p_pred.shape, uv.shape) + losses[f'uv{postfix}'] = (F.binary_cross_entropy_with_logits( + p_pred[:, :, 1], uv, reduction='none') * nonpadding).sum() \ + / nonpadding.sum() * hparams['lambda_uv'] + nonpadding = nonpadding * (uv == 0).float() + f0_pred = p_pred[:, :, 0] + pitch_loss_fn = F.l1_loss if hparams['pitch_loss'] == 'l1' else F.mse_loss + losses[f'f0{postfix}'] = (pitch_loss_fn(f0_pred, f0, reduction='none') * nonpadding).sum() \ + / nonpadding.sum() * hparams['lambda_f0'] + + + ############ + # validation plots + ############ + def plot_dur(self, batch_idx, sample, model_out): + T_txt = sample['txt_tokens'].shape[1] + dur_gt = mel2ph_to_dur(sample['mel2ph'], T_txt)[0] + dur_pred = model_out['dur'] + if hasattr(self.model, 'out2dur'): + dur_pred = self.model.out2dur(model_out['dur']).float() + txt = self.phone_encoder.decode(sample['txt_tokens'][0].cpu().numpy()) + txt = txt.split(" ") + self.logger.add_figure( + f'dur_{batch_idx}', dur_to_figure(dur_gt, dur_pred, txt), self.global_step) + + def plot_pitch(self, batch_idx, sample, model_out): + self.logger.add_figure( + f'f0_{batch_idx}', + f0_to_figure(model_out['f0_denorm'][0], None, model_out['f0_denorm_pred'][0]), + self.global_step) + + ############ + # inference + ############ + def test_step(self, sample, batch_idx): + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + txt_tokens = sample['txt_tokens'] + mel2ph, uv, f0 = None, None, None + ref_mels = sample['mels'] + if hparams['use_gt_dur']: + mel2ph = sample['mel2ph'] + if hparams['use_gt_f0']: + f0 = sample['f0'] + uv = sample['uv'] + run_model = lambda: self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, ref_mels=ref_mels, infer=True) + if hparams['profile_infer']: + mel2ph, uv, f0 = sample['mel2ph'], sample['uv'], sample['f0'] + with utils.Timer('fs', enable=True): + outputs = run_model() + if 'gen_wav_time' not in self.stats: + self.stats['gen_wav_time'] = 0 + wav_time = float(outputs["mels_out"].shape[1]) * hparams['hop_size'] / hparams["audio_sample_rate"] + self.stats['gen_wav_time'] += wav_time + print(f'[Timer] wav total seconds: {self.stats["gen_wav_time"]}') + from pytorch_memlab import LineProfiler + with LineProfiler(self.model.forward) as prof: + run_model() + prof.print_stats() + else: + outputs = run_model() + sample['outputs'] = self.model.out2mel(outputs['mel_out']) + sample['mel2ph_pred'] = outputs['mel2ph'] + if hparams['use_pitch_embed']: + sample['f0'] = denorm_f0(sample['f0'], sample['uv'], hparams) + if hparams['pitch_type'] == 'ph': + sample['f0'] = torch.gather(F.pad(sample['f0'], [1, 0]), 1, sample['mel2ph']) + sample['f0_pred'] = outputs.get('f0_denorm') + return self.after_infer(sample) diff --git a/tasks/tts/fs2_utils.py b/tasks/tts/fs2_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..092550863d2fd72f008cc790bc6d950340e68182 --- /dev/null +++ b/tasks/tts/fs2_utils.py @@ -0,0 +1,173 @@ +import matplotlib + +matplotlib.use('Agg') + +import glob +import importlib +from utils.cwt import get_lf0_cwt +import os +import torch.optim +import torch.utils.data +from utils.indexed_datasets import IndexedDataset +from utils.pitch_utils import norm_interp_f0 +import numpy as np +from tasks.base_task import BaseDataset +import torch +import torch.optim +import torch.utils.data +import utils +import torch.distributions +from utils.hparams import hparams + + +class FastSpeechDataset(BaseDataset): + def __init__(self, prefix, shuffle=False): + super().__init__(shuffle) + self.data_dir = hparams['binary_data_dir'] + self.prefix = prefix + self.hparams = hparams + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + self.indexed_ds = None + # self.name2spk_id={} + + # pitch stats + f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + else: + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None + + if prefix == 'test': + if hparams['test_input_dir'] != '': + self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir']) + else: + if hparams['num_test_samples'] > 0: + self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids'] + self.sizes = [self.sizes[i] for i in self.avail_idxs] + + if hparams['pitch_type'] == 'cwt': + _, hparams['cwt_scales'] = get_lf0_cwt(np.ones(10)) + + def _get_item(self, index): + if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: + index = self.avail_idxs[index] + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + return self.indexed_ds[index] + + def __getitem__(self, index): + hparams = self.hparams + item = self._get_item(index) + max_frames = hparams['max_frames'] + spec = torch.Tensor(item['mel'])[:max_frames] + energy = (spec.exp() ** 2).sum(-1).sqrt() + mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None + f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) + phone = torch.LongTensor(item['phone'][:hparams['max_input_tokens']]) + pitch = torch.LongTensor(item.get("pitch"))[:max_frames] + # print(item.keys(), item['mel'].shape, spec.shape) + sample = { + "id": index, + "item_name": item['item_name'], + "text": item['txt'], + "txt_token": phone, + "mel": spec, + "pitch": pitch, + "energy": energy, + "f0": f0, + "uv": uv, + "mel2ph": mel2ph, + "mel_nonpadding": spec.abs().sum(-1) > 0, + } + if self.hparams['use_spk_embed']: + sample["spk_embed"] = torch.Tensor(item['spk_embed']) + if self.hparams['use_spk_id']: + sample["spk_id"] = item['spk_id'] + # sample['spk_id'] = 0 + # for key in self.name2spk_id.keys(): + # if key in item['item_name']: + # sample['spk_id'] = self.name2spk_id[key] + # break + if self.hparams['pitch_type'] == 'cwt': + cwt_spec = torch.Tensor(item['cwt_spec'])[:max_frames] + f0_mean = item.get('f0_mean', item.get('cwt_mean')) + f0_std = item.get('f0_std', item.get('cwt_std')) + sample.update({"cwt_spec": cwt_spec, "f0_mean": f0_mean, "f0_std": f0_std}) + elif self.hparams['pitch_type'] == 'ph': + f0_phlevel_sum = torch.zeros_like(phone).float().scatter_add(0, mel2ph - 1, f0) + f0_phlevel_num = torch.zeros_like(phone).float().scatter_add( + 0, mel2ph - 1, torch.ones_like(f0)).clamp_min(1) + sample["f0_ph"] = f0_phlevel_sum / f0_phlevel_num + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + id = torch.LongTensor([s['id'] for s in samples]) + item_names = [s['item_name'] for s in samples] + text = [s['text'] for s in samples] + txt_tokens = utils.collate_1d([s['txt_token'] for s in samples], 0) + f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) + pitch = utils.collate_1d([s['pitch'] for s in samples]) + uv = utils.collate_1d([s['uv'] for s in samples]) + energy = utils.collate_1d([s['energy'] for s in samples], 0.0) + mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ + if samples[0]['mel2ph'] is not None else None + mels = utils.collate_2d([s['mel'] for s in samples], 0.0) + txt_lengths = torch.LongTensor([s['txt_token'].numel() for s in samples]) + mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) + + batch = { + 'id': id, + 'item_name': item_names, + 'nsamples': len(samples), + 'text': text, + 'txt_tokens': txt_tokens, + 'txt_lengths': txt_lengths, + 'mels': mels, + 'mel_lengths': mel_lengths, + 'mel2ph': mel2ph, + 'energy': energy, + 'pitch': pitch, + 'f0': f0, + 'uv': uv, + } + + if self.hparams['use_spk_embed']: + spk_embed = torch.stack([s['spk_embed'] for s in samples]) + batch['spk_embed'] = spk_embed + if self.hparams['use_spk_id']: + spk_ids = torch.LongTensor([s['spk_id'] for s in samples]) + batch['spk_ids'] = spk_ids + if self.hparams['pitch_type'] == 'cwt': + cwt_spec = utils.collate_2d([s['cwt_spec'] for s in samples]) + f0_mean = torch.Tensor([s['f0_mean'] for s in samples]) + f0_std = torch.Tensor([s['f0_std'] for s in samples]) + batch.update({'cwt_spec': cwt_spec, 'f0_mean': f0_mean, 'f0_std': f0_std}) + elif self.hparams['pitch_type'] == 'ph': + batch['f0'] = utils.collate_1d([s['f0_ph'] for s in samples]) + + return batch + + def load_test_inputs(self, test_input_dir, spk_id=0): + inp_wav_paths = glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/*.mp3') + sizes = [] + items = [] + + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizerr.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + binarization_args = hparams['binarization_args'] + + for wav_fn in inp_wav_paths: + item_name = os.path.basename(wav_fn) + ph = txt = tg_fn = '' + wav_fn = wav_fn + encoder = None + item = binarizer_cls.process_item(item_name, ph, txt, tg_fn, wav_fn, spk_id, encoder, binarization_args) + items.append(item) + sizes.append(item['len']) + return items, sizes diff --git a/tasks/tts/pe.py b/tasks/tts/pe.py new file mode 100644 index 0000000000000000000000000000000000000000..3880c80d0820c36e044c00bd38a07fd3cce73323 --- /dev/null +++ b/tasks/tts/pe.py @@ -0,0 +1,155 @@ +import matplotlib +matplotlib.use('Agg') + +import torch +import numpy as np +import os + +from tasks.base_task import BaseDataset +from tasks.tts.fs2 import FastSpeech2Task +from modules.fastspeech.pe import PitchExtractor +import utils +from utils.indexed_datasets import IndexedDataset +from utils.hparams import hparams +from utils.plot import f0_to_figure +from utils.pitch_utils import norm_interp_f0, denorm_f0 + + +class PeDataset(BaseDataset): + def __init__(self, prefix, shuffle=False): + super().__init__(shuffle) + self.data_dir = hparams['binary_data_dir'] + self.prefix = prefix + self.hparams = hparams + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + self.indexed_ds = None + + # pitch stats + f0_stats_fn = f'{self.data_dir}/train_f0s_mean_std.npy' + if os.path.exists(f0_stats_fn): + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = np.load(f0_stats_fn) + hparams['f0_mean'] = float(hparams['f0_mean']) + hparams['f0_std'] = float(hparams['f0_std']) + else: + hparams['f0_mean'], hparams['f0_std'] = self.f0_mean, self.f0_std = None, None + + if prefix == 'test': + if hparams['num_test_samples'] > 0: + self.avail_idxs = list(range(hparams['num_test_samples'])) + hparams['test_ids'] + self.sizes = [self.sizes[i] for i in self.avail_idxs] + + def _get_item(self, index): + if hasattr(self, 'avail_idxs') and self.avail_idxs is not None: + index = self.avail_idxs[index] + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + return self.indexed_ds[index] + + def __getitem__(self, index): + hparams = self.hparams + item = self._get_item(index) + max_frames = hparams['max_frames'] + spec = torch.Tensor(item['mel'])[:max_frames] + # mel2ph = torch.LongTensor(item['mel2ph'])[:max_frames] if 'mel2ph' in item else None + f0, uv = norm_interp_f0(item["f0"][:max_frames], hparams) + pitch = torch.LongTensor(item.get("pitch"))[:max_frames] + # print(item.keys(), item['mel'].shape, spec.shape) + sample = { + "id": index, + "item_name": item['item_name'], + "text": item['txt'], + "mel": spec, + "pitch": pitch, + "f0": f0, + "uv": uv, + # "mel2ph": mel2ph, + # "mel_nonpadding": spec.abs().sum(-1) > 0, + } + return sample + + def collater(self, samples): + if len(samples) == 0: + return {} + id = torch.LongTensor([s['id'] for s in samples]) + item_names = [s['item_name'] for s in samples] + text = [s['text'] for s in samples] + f0 = utils.collate_1d([s['f0'] for s in samples], 0.0) + pitch = utils.collate_1d([s['pitch'] for s in samples]) + uv = utils.collate_1d([s['uv'] for s in samples]) + mels = utils.collate_2d([s['mel'] for s in samples], 0.0) + mel_lengths = torch.LongTensor([s['mel'].shape[0] for s in samples]) + # mel2ph = utils.collate_1d([s['mel2ph'] for s in samples], 0.0) \ + # if samples[0]['mel2ph'] is not None else None + # mel_nonpaddings = utils.collate_1d([s['mel_nonpadding'].float() for s in samples], 0.0) + + batch = { + 'id': id, + 'item_name': item_names, + 'nsamples': len(samples), + 'text': text, + 'mels': mels, + 'mel_lengths': mel_lengths, + 'pitch': pitch, + # 'mel2ph': mel2ph, + # 'mel_nonpaddings': mel_nonpaddings, + 'f0': f0, + 'uv': uv, + } + return batch + + +class PitchExtractionTask(FastSpeech2Task): + def __init__(self): + super().__init__() + self.dataset_cls = PeDataset + + def build_tts_model(self): + self.model = PitchExtractor(conv_layers=hparams['pitch_extractor_conv_layers']) + + # def build_scheduler(self, optimizer): + # return torch.optim.lr_scheduler.StepLR(optimizer, hparams['decay_steps'], gamma=0.5) + def _training_step(self, sample, batch_idx, _): + loss_output = self.run_model(self.model, sample) + total_loss = sum([v for v in loss_output.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + loss_output['batch_size'] = sample['mels'].size()[0] + return total_loss, loss_output + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=True) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + self.plot_pitch(batch_idx, model_out, sample) + return outputs + + def run_model(self, model, sample, return_output=False, infer=False): + f0 = sample['f0'] + uv = sample['uv'] + output = model(sample['mels']) + losses = {} + self.add_pitch_loss(output, sample, losses) + if not return_output: + return losses + else: + return losses, output + + def plot_pitch(self, batch_idx, model_out, sample): + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.logger.experiment.add_figure( + f'f0_{batch_idx}', + f0_to_figure(gt_f0[0], None, model_out['f0_denorm_pred'][0]), + self.global_step) + + def add_pitch_loss(self, output, sample, losses): + # mel2ph = sample['mel2ph'] # [B, T_s] + mel = sample['mels'] + f0 = sample['f0'] + uv = sample['uv'] + # nonpadding = (mel2ph != 0).float() if hparams['pitch_type'] == 'frame' \ + # else (sample['txt_tokens'] != 0).float() + nonpadding = (mel.abs().sum(-1) > 0).float() # sample['mel_nonpaddings'] + # print(nonpadding[0][-8:], nonpadding.shape) + self.add_f0_loss(output['pitch_pred'], f0, uv, losses, nonpadding=nonpadding) \ No newline at end of file diff --git a/tasks/tts/tts.py b/tasks/tts/tts.py new file mode 100644 index 0000000000000000000000000000000000000000..f803c1e738137cb1eca19a1943196abd2884c0a5 --- /dev/null +++ b/tasks/tts/tts.py @@ -0,0 +1,131 @@ +from multiprocessing.pool import Pool + +import matplotlib + +from utils.pl_utils import data_loader +from utils.training_utils import RSQRTSchedule +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from modules.fastspeech.pe import PitchExtractor + +matplotlib.use('Agg') +import os +import numpy as np +from tqdm import tqdm +import torch.distributed as dist + +from tasks.base_task import BaseTask +from utils.hparams import hparams +from utils.text_encoder import TokenTextEncoder +import json + +import torch +import torch.optim +import torch.utils.data +import utils + + + +class TtsTask(BaseTask): + def __init__(self, *args, **kwargs): + self.vocoder = None + self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) + self.padding_idx = self.phone_encoder.pad() + self.eos_idx = self.phone_encoder.eos() + self.seg_idx = self.phone_encoder.seg() + self.saving_result_pool = None + self.saving_results_futures = None + self.stats = {} + super().__init__(*args, **kwargs) + + def build_scheduler(self, optimizer): + return RSQRTSchedule(optimizer) + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + model.parameters(), + lr=hparams['lr']) + return optimizer + + def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, + required_batch_size_multiple=-1, endless=False, batch_by_size=True): + devices_cnt = torch.cuda.device_count() + if devices_cnt == 0: + devices_cnt = 1 + if required_batch_size_multiple == -1: + required_batch_size_multiple = devices_cnt + + def shuffle_batches(batches): + np.random.shuffle(batches) + return batches + + if max_tokens is not None: + max_tokens *= devices_cnt + if max_sentences is not None: + max_sentences *= devices_cnt + indices = dataset.ordered_indices() + if batch_by_size: + batch_sampler = utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + batch_sampler = [] + for i in range(0, len(indices), max_sentences): + batch_sampler.append(indices[i:i + max_sentences]) + + if shuffle: + batches = shuffle_batches(list(batch_sampler)) + if endless: + batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] + else: + batches = batch_sampler + if endless: + batches = [b for _ in range(1000) for b in batches] + num_workers = dataset.num_workers + if self.trainer.use_ddp: + num_replicas = dist.get_world_size() + rank = dist.get_rank() + batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] + return torch.utils.data.DataLoader(dataset, + collate_fn=dataset.collater, + batch_sampler=batches, + num_workers=num_workers, + pin_memory=False) + + def build_phone_encoder(self, data_dir): + phone_list_file = os.path.join(data_dir, 'phone_set.json') + + phone_list = json.load(open(phone_list_file)) + return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + model.parameters(), + lr=hparams['lr']) + return optimizer + + def test_start(self): + self.saving_result_pool = Pool(8) + self.saving_results_futures = [] + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + if hparams.get('pe_enable') is not None and hparams['pe_enable']: + self.pe = PitchExtractor().cuda() + utils.load_ckpt(self.pe, hparams['pe_ckpt'], 'model', strict=True) + self.pe.eval() + def test_end(self, outputs): + self.saving_result_pool.close() + [f.get() for f in tqdm(self.saving_results_futures)] + self.saving_result_pool.join() + return {} + + ########## + # utils + ########## + def weights_nonzero_speech(self, target): + # target : B x T x mel + # Assign weight 1.0 to all labels except for padding (id=0). + dim = target.size(-1) + return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) + +if __name__ == '__main__': + TtsTask.start() diff --git a/tasks/tts/tts_base.py b/tasks/tts/tts_base.py new file mode 100755 index 0000000000000000000000000000000000000000..509740b54dbf23db6bafebd6bc46089ee83cf499 --- /dev/null +++ b/tasks/tts/tts_base.py @@ -0,0 +1,305 @@ +import filecmp + +import matplotlib + +from utils.plot import spec_to_figure + +matplotlib.use('Agg') + +from data_gen.tts.data_gen_utils import get_pitch +from modules.fastspeech.tts_modules import mel2ph_to_dur +from tasks.tts.dataset_utils import BaseTTSDataset +from utils.tts_utils import sequence_mask +from multiprocessing.pool import Pool +from tasks.base_task import data_loader, BaseConcatDataset +from utils.common_schedulers import RSQRTSchedule, NoneSchedule +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +import os +import numpy as np +from tqdm import tqdm +import torch.distributed as dist +from tasks.base_task import BaseTask +from utils.hparams import hparams +from utils.text_encoder import TokenTextEncoder +import json +import matplotlib.pyplot as plt +import torch +import torch.optim +import torch.utils.data +import utils +from utils import audio +import pandas as pd + + +class TTSBaseTask(BaseTask): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_cls = BaseTTSDataset + self.max_tokens = hparams['max_tokens'] + self.max_sentences = hparams['max_sentences'] + self.max_valid_tokens = hparams['max_valid_tokens'] + if self.max_valid_tokens == -1: + hparams['max_valid_tokens'] = self.max_valid_tokens = self.max_tokens + self.max_valid_sentences = hparams['max_valid_sentences'] + if self.max_valid_sentences == -1: + hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences + self.vocoder = None + self.phone_encoder = self.build_phone_encoder(hparams['binary_data_dir']) + self.padding_idx = self.phone_encoder.pad() + self.eos_idx = self.phone_encoder.eos() + self.seg_idx = self.phone_encoder.seg() + self.saving_result_pool = None + self.saving_results_futures = None + self.stats = {} + + @data_loader + def train_dataloader(self): + if hparams['train_sets'] != '': + train_sets = hparams['train_sets'].split("|") + # check if all train_sets have the same spk map and dictionary + binary_data_dir = hparams['binary_data_dir'] + file_to_cmp = ['phone_set.json'] + if os.path.exists(f'{binary_data_dir}/word_set.json'): + file_to_cmp.append('word_set.json') + if hparams['use_spk_id']: + file_to_cmp.append('spk_map.json') + for f in file_to_cmp: + for ds_name in train_sets: + base_file = os.path.join(binary_data_dir, f) + ds_file = os.path.join(ds_name, f) + assert filecmp.cmp(base_file, ds_file), \ + f'{f} in {ds_name} is not same with that in {binary_data_dir}.' + train_dataset = BaseConcatDataset([ + self.dataset_cls(prefix='train', shuffle=True, data_dir=ds_name) for ds_name in train_sets]) + else: + train_dataset = self.dataset_cls(prefix=hparams['train_set_name'], shuffle=True) + return self.build_dataloader(train_dataset, True, self.max_tokens, self.max_sentences, + endless=hparams['endless_ds']) + + @data_loader + def val_dataloader(self): + valid_dataset = self.dataset_cls(prefix=hparams['valid_set_name'], shuffle=False) + return self.build_dataloader(valid_dataset, False, self.max_valid_tokens, self.max_valid_sentences) + + @data_loader + def test_dataloader(self): + test_dataset = self.dataset_cls(prefix=hparams['test_set_name'], shuffle=False) + self.test_dl = self.build_dataloader( + test_dataset, False, self.max_valid_tokens, + self.max_valid_sentences, batch_by_size=False) + return self.test_dl + + def build_dataloader(self, dataset, shuffle, max_tokens=None, max_sentences=None, + required_batch_size_multiple=-1, endless=False, batch_by_size=True): + devices_cnt = torch.cuda.device_count() + if devices_cnt == 0: + devices_cnt = 1 + if required_batch_size_multiple == -1: + required_batch_size_multiple = devices_cnt + + def shuffle_batches(batches): + np.random.shuffle(batches) + return batches + + if max_tokens is not None: + max_tokens *= devices_cnt + if max_sentences is not None: + max_sentences *= devices_cnt + indices = dataset.ordered_indices() + if batch_by_size: + batch_sampler = utils.batch_by_size( + indices, dataset.num_tokens, max_tokens=max_tokens, max_sentences=max_sentences, + required_batch_size_multiple=required_batch_size_multiple, + ) + else: + batch_sampler = [] + for i in range(0, len(indices), max_sentences): + batch_sampler.append(indices[i:i + max_sentences]) + + if shuffle: + batches = shuffle_batches(list(batch_sampler)) + if endless: + batches = [b for _ in range(1000) for b in shuffle_batches(list(batch_sampler))] + else: + batches = batch_sampler + if endless: + batches = [b for _ in range(1000) for b in batches] + num_workers = dataset.num_workers + if self.trainer.use_ddp: + num_replicas = dist.get_world_size() + rank = dist.get_rank() + batches = [x[rank::num_replicas] for x in batches if len(x) % num_replicas == 0] + return torch.utils.data.DataLoader(dataset, + collate_fn=dataset.collater, + batch_sampler=batches, + num_workers=num_workers, + pin_memory=False) + + def build_phone_encoder(self, data_dir): + phone_list_file = os.path.join(data_dir, 'phone_set.json') + phone_list = json.load(open(phone_list_file)) + return TokenTextEncoder(None, vocab_list=phone_list, replace_oov=',') + + def build_scheduler(self, optimizer): + if hparams['scheduler'] == 'rsqrt': + return RSQRTSchedule(optimizer) + else: + return NoneSchedule(optimizer) + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + model.parameters(), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + weight_decay=hparams['weight_decay']) + return optimizer + + def plot_mel(self, batch_idx, spec, spec_out, name=None): + spec_cat = torch.cat([spec, spec_out], -1) + name = f'mel_{batch_idx}' if name is None else name + vmin = hparams['mel_vmin'] + vmax = hparams['mel_vmax'] + self.logger.add_figure(name, spec_to_figure(spec_cat[0], vmin, vmax), self.global_step) + + def test_start(self): + self.saving_result_pool = Pool(min(int(os.getenv('N_PROC', os.cpu_count())), 16)) + self.saving_results_futures = [] + self.results_id = 0 + self.gen_dir = os.path.join( + hparams['work_dir'], + f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + + def after_infer(self, predictions, sil_start_frame=0): + predictions = utils.unpack_dict_to_list(predictions) + assert len(predictions) == 1, 'Only support batch_size=1 in inference.' + prediction = predictions[0] + prediction = utils.tensors_to_np(prediction) + item_name = prediction.get('item_name') + text = prediction.get('text') + ph_tokens = prediction.get('txt_tokens') + mel_gt = prediction["mels"] + mel2ph_gt = prediction.get("mel2ph") + mel2ph_gt = mel2ph_gt if mel2ph_gt is not None else None + mel_pred = prediction["outputs"] + mel2ph_pred = prediction.get("mel2ph_pred") + f0_gt = prediction.get("f0") + f0_pred = prediction.get("f0_pred") + + str_phs = None + if self.phone_encoder is not None and 'txt_tokens' in prediction: + str_phs = self.phone_encoder.decode(prediction['txt_tokens'], strip_padding=True) + + if 'encdec_attn' in prediction: + encdec_attn = prediction['encdec_attn'] + encdec_attn = encdec_attn[encdec_attn.max(-1).sum(-1).argmax(-1)] + txt_lengths = prediction.get('txt_lengths') + encdec_attn = encdec_attn.T[:txt_lengths, :len(mel_gt)] + else: + encdec_attn = None + + wav_pred = self.vocoder.spec2wav(mel_pred, f0=f0_pred) + wav_pred[:sil_start_frame * hparams['hop_size']] = 0 + gen_dir = self.gen_dir + base_fn = f'[{self.results_id:06d}][{item_name}][%s]' + # if text is not None: + # base_fn += text.replace(":", "%3A")[:80] + base_fn = base_fn.replace(' ', '_') + if not hparams['profile_infer']: + os.makedirs(gen_dir, exist_ok=True) + os.makedirs(f'{gen_dir}/wavs', exist_ok=True) + os.makedirs(f'{gen_dir}/plot', exist_ok=True) + if hparams.get('save_mel_npy', False): + os.makedirs(f'{gen_dir}/npy', exist_ok=True) + if 'encdec_attn' in prediction: + os.makedirs(f'{gen_dir}/attn_plot', exist_ok=True) + self.saving_results_futures.append( + self.saving_result_pool.apply_async(self.save_result, args=[ + wav_pred, mel_pred, base_fn % 'P', gen_dir, str_phs, mel2ph_pred, encdec_attn])) + + if mel_gt is not None and hparams['save_gt']: + wav_gt = self.vocoder.spec2wav(mel_gt, f0=f0_gt) + self.saving_results_futures.append( + self.saving_result_pool.apply_async(self.save_result, args=[ + wav_gt, mel_gt, base_fn % 'G', gen_dir, str_phs, mel2ph_gt])) + if hparams['save_f0']: + import matplotlib.pyplot as plt + f0_pred_, _ = get_pitch(wav_pred, mel_pred, hparams) + f0_gt_, _ = get_pitch(wav_gt, mel_gt, hparams) + fig = plt.figure() + plt.plot(f0_pred_, label=r'$\hat{f_0}$') + plt.plot(f0_gt_, label=r'$f_0$') + plt.legend() + plt.tight_layout() + plt.savefig(f'{gen_dir}/plot/[F0][{item_name}]{text}.png', format='png') + plt.close(fig) + print(f"Pred_shape: {mel_pred.shape}, gt_shape: {mel_gt.shape}") + self.results_id += 1 + return { + 'item_name': item_name, + 'text': text, + 'ph_tokens': self.phone_encoder.decode(ph_tokens.tolist()), + 'wav_fn_pred': base_fn % 'P', + 'wav_fn_gt': base_fn % 'G', + } + + @staticmethod + def save_result(wav_out, mel, base_fn, gen_dir, str_phs=None, mel2ph=None, alignment=None): + audio.save_wav(wav_out, f'{gen_dir}/wavs/{base_fn}.wav', hparams['audio_sample_rate'], + norm=hparams['out_wav_norm']) + fig = plt.figure(figsize=(14, 10)) + spec_vmin = hparams['mel_vmin'] + spec_vmax = hparams['mel_vmax'] + heatmap = plt.pcolor(mel.T, vmin=spec_vmin, vmax=spec_vmax) + fig.colorbar(heatmap) + f0, _ = get_pitch(wav_out, mel, hparams) + f0 = f0 / 10 * (f0 > 0) + plt.plot(f0, c='white', linewidth=1, alpha=0.6) + if mel2ph is not None and str_phs is not None: + decoded_txt = str_phs.split(" ") + dur = mel2ph_to_dur(torch.LongTensor(mel2ph)[None, :], len(decoded_txt))[0].numpy() + dur = [0] + list(np.cumsum(dur)) + for i in range(len(dur) - 1): + shift = (i % 20) + 1 + plt.text(dur[i], shift, decoded_txt[i]) + plt.hlines(shift, dur[i], dur[i + 1], colors='b' if decoded_txt[i] != '|' else 'black') + plt.vlines(dur[i], 0, 5, colors='b' if decoded_txt[i] != '|' else 'black', + alpha=1, linewidth=1) + plt.tight_layout() + plt.savefig(f'{gen_dir}/plot/{base_fn}.png', format='png') + plt.close(fig) + if hparams.get('save_mel_npy', False): + np.save(f'{gen_dir}/npy/{base_fn}', mel) + if alignment is not None: + fig, ax = plt.subplots(figsize=(12, 16)) + im = ax.imshow(alignment, aspect='auto', origin='lower', + interpolation='none') + decoded_txt = str_phs.split(" ") + ax.set_yticks(np.arange(len(decoded_txt))) + ax.set_yticklabels(list(decoded_txt), fontsize=6) + fig.colorbar(im, ax=ax) + fig.savefig(f'{gen_dir}/attn_plot/{base_fn}_attn.png', format='png') + plt.close(fig) + + def test_end(self, outputs): + pd.DataFrame(outputs).to_csv(f'{self.gen_dir}/meta.csv') + self.saving_result_pool.close() + [f.get() for f in tqdm(self.saving_results_futures)] + self.saving_result_pool.join() + return {} + + ########## + # utils + ########## + def weights_nonzero_speech(self, target): + # target : B x T x mel + # Assign weight 1.0 to all labels except for padding (id=0). + dim = target.size(-1) + return target.abs().sum(-1, keepdim=True).ne(0).float().repeat(1, 1, dim) + + def make_stop_target(self, target): + # target : B x T x mel + seq_mask = target.abs().sum(-1).ne(0).float() + seq_length = seq_mask.sum(1) + mask_r = 1 - sequence_mask(seq_length - 1, target.size(1)).float() + return seq_mask, mask_r diff --git a/tasks/tts/tts_utils.py b/tasks/tts/tts_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e13439ee72e4fda220605c5868b3159110d9129b --- /dev/null +++ b/tasks/tts/tts_utils.py @@ -0,0 +1,54 @@ +import importlib + +from data_gen.tts.base_binarizer import BaseBinarizer +from data_gen.tts.base_preprocess import BasePreprocessor +from data_gen.tts.txt_processors.base_text_processor import get_txt_processor_cls +from utils.hparams import hparams + + +def parse_dataset_configs(): + max_tokens = hparams['max_tokens'] + max_sentences = hparams['max_sentences'] + max_valid_tokens = hparams['max_valid_tokens'] + if max_valid_tokens == -1: + hparams['max_valid_tokens'] = max_valid_tokens = max_tokens + max_valid_sentences = hparams['max_valid_sentences'] + if max_valid_sentences == -1: + hparams['max_valid_sentences'] = max_valid_sentences = max_sentences + return max_tokens, max_sentences, max_valid_tokens, max_valid_sentences + + +def parse_mel_losses(): + mel_losses = hparams['mel_losses'].split("|") + loss_and_lambda = {} + for i, l in enumerate(mel_losses): + if l == '': + continue + if ':' in l: + l, lbd = l.split(":") + lbd = float(lbd) + else: + lbd = 1.0 + loss_and_lambda[l] = lbd + print("| Mel losses:", loss_and_lambda) + return loss_and_lambda + + +def load_data_preprocessor(): + preprocess_cls = hparams["preprocess_cls"] + pkg = ".".join(preprocess_cls.split(".")[:-1]) + cls_name = preprocess_cls.split(".")[-1] + preprocessor: BasePreprocessor = getattr(importlib.import_module(pkg), cls_name)() + preprocess_args = {} + preprocess_args.update(hparams['preprocess_args']) + return preprocessor, preprocess_args + + +def load_data_binarizer(): + binarizer_cls = hparams['binarizer_cls'] + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer: BaseBinarizer = getattr(importlib.import_module(pkg), cls_name)() + binarization_args = {} + binarization_args.update(hparams['binarization_args']) + return binarizer, binarization_args \ No newline at end of file diff --git a/tasks/vocoder/dataset_utils.py b/tasks/vocoder/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05dcdaa524efde31575dd30b57b627d22744b53c --- /dev/null +++ b/tasks/vocoder/dataset_utils.py @@ -0,0 +1,204 @@ +import glob +import importlib +import os +from resemblyzer import VoiceEncoder +import numpy as np +import torch +import torch.distributed as dist +from torch.utils.data import DistributedSampler +import utils +from tasks.base_task import BaseDataset +from utils.hparams import hparams +from utils.indexed_datasets import IndexedDataset +from tqdm import tqdm + +class EndlessDistributedSampler(DistributedSampler): + def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): + if num_replicas is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = dist.get_world_size() + if rank is None: + if not dist.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = dist.get_rank() + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.shuffle = shuffle + + g = torch.Generator() + g.manual_seed(self.epoch) + if self.shuffle: + indices = [i for _ in range(1000) for i in torch.randperm( + len(self.dataset), generator=g).tolist()] + else: + indices = [i for _ in range(1000) for i in list(range(len(self.dataset)))] + indices = indices[:len(indices) // self.num_replicas * self.num_replicas] + indices = indices[self.rank::self.num_replicas] + self.indices = indices + + def __iter__(self): + return iter(self.indices) + + def __len__(self): + return len(self.indices) + + +class VocoderDataset(BaseDataset): + def __init__(self, prefix, shuffle=False): + super().__init__(shuffle) + self.hparams = hparams + self.prefix = prefix + self.data_dir = hparams['binary_data_dir'] + self.is_infer = prefix == 'test' + self.batch_max_frames = 0 if self.is_infer else hparams['max_samples'] // hparams['hop_size'] + self.aux_context_window = hparams['aux_context_window'] + self.hop_size = hparams['hop_size'] + if self.is_infer and hparams['test_input_dir'] != '': + self.indexed_ds, self.sizes = self.load_test_inputs(hparams['test_input_dir']) + self.avail_idxs = [i for i, _ in enumerate(self.sizes)] + elif self.is_infer and hparams['test_mel_dir'] != '': + self.indexed_ds, self.sizes = self.load_mel_inputs(hparams['test_mel_dir']) + self.avail_idxs = [i for i, _ in enumerate(self.sizes)] + else: + self.indexed_ds = None + self.sizes = np.load(f'{self.data_dir}/{self.prefix}_lengths.npy') + self.avail_idxs = [idx for idx, s in enumerate(self.sizes) if + s - 2 * self.aux_context_window > self.batch_max_frames] + print(f"| {len(self.sizes) - len(self.avail_idxs)} short items are skipped in {prefix} set.") + self.sizes = [s for idx, s in enumerate(self.sizes) if + s - 2 * self.aux_context_window > self.batch_max_frames] + + def _get_item(self, index): + if self.indexed_ds is None: + self.indexed_ds = IndexedDataset(f'{self.data_dir}/{self.prefix}') + item = self.indexed_ds[index] + return item + + def __getitem__(self, index): + index = self.avail_idxs[index] + item = self._get_item(index) + sample = { + "id": index, + "item_name": item['item_name'], + "mel": torch.FloatTensor(item['mel']), + "wav": torch.FloatTensor(item['wav'].astype(np.float32)), + } + if 'pitch' in item: + sample['pitch'] = torch.LongTensor(item['pitch']) + sample['f0'] = torch.FloatTensor(item['f0']) + + if hparams.get('use_spk_embed', False): + sample["spk_embed"] = torch.Tensor(item['spk_embed']) + if hparams.get('use_emo_embed', False): + sample["emo_embed"] = torch.Tensor(item['emo_embed']) + + return sample + + def collater(self, batch): + if len(batch) == 0: + return {} + + y_batch, c_batch, p_batch, f0_batch = [], [], [], [] + item_name = [] + have_pitch = 'pitch' in batch[0] + for idx in range(len(batch)): + item_name.append(batch[idx]['item_name']) + x, c = batch[idx]['wav'] if self.hparams['use_wav'] else None, batch[idx]['mel'].squeeze(0) + if have_pitch: + p = batch[idx]['pitch'] + f0 = batch[idx]['f0'] + if self.hparams['use_wav']:self._assert_ready_for_upsampling(x, c, self.hop_size, 0) + if len(c) - 2 * self.aux_context_window > self.batch_max_frames: + # randomly pickup with the batch_max_steps length of the part + batch_max_frames = self.batch_max_frames if self.batch_max_frames != 0 else len( + c) - 2 * self.aux_context_window - 1 + batch_max_steps = batch_max_frames * self.hop_size + interval_start = self.aux_context_window + interval_end = len(c) - batch_max_frames - self.aux_context_window + start_frame = np.random.randint(interval_start, interval_end) + start_step = start_frame * self.hop_size + if self.hparams['use_wav']:y = x[start_step: start_step + batch_max_steps] + c = c[start_frame - self.aux_context_window: + start_frame + self.aux_context_window + batch_max_frames] + if have_pitch: + p = p[start_frame - self.aux_context_window: + start_frame + self.aux_context_window + batch_max_frames] + f0 = f0[start_frame - self.aux_context_window: + start_frame + self.aux_context_window + batch_max_frames] + if self.hparams['use_wav']:self._assert_ready_for_upsampling(y, c, self.hop_size, self.aux_context_window) + else: + print(f"Removed short sample from batch (length={len(x)}).") + continue + if self.hparams['use_wav']:y_batch += [y.reshape(-1, 1)] # [(T, 1), (T, 1), ...] + c_batch += [c] # [(T' C), (T' C), ...] + if have_pitch: + p_batch += [p] # [(T' C), (T' C), ...] + f0_batch += [f0] # [(T' C), (T' C), ...] + + # convert each batch to tensor, asuume that each item in batch has the same length + if self.hparams['use_wav']:y_batch = utils.collate_2d(y_batch, 0).transpose(2, 1) # (B, 1, T) + c_batch = utils.collate_2d(c_batch, 0).transpose(2, 1) # (B, C, T') + if have_pitch: + p_batch = utils.collate_1d(p_batch, 0) # (B, T') + f0_batch = utils.collate_1d(f0_batch, 0) # (B, T') + else: + p_batch, f0_batch = None, None + + # make input noise signal batch tensor + if self.hparams['use_wav']: z_batch = torch.randn(y_batch.size()) # (B, 1, T) + else: z_batch=[] + return { + 'z': z_batch, + 'mels': c_batch, + 'wavs': y_batch, + 'pitches': p_batch, + 'f0': f0_batch, + 'item_name': item_name + } + + @staticmethod + def _assert_ready_for_upsampling(x, c, hop_size, context_window): + """Assert the audio and feature lengths are correctly adjusted for upsamping.""" + assert len(x) == (len(c) - 2 * context_window) * hop_size + + def load_test_inputs(self, test_input_dir, spk_id=0): + inp_wav_paths = sorted(glob.glob(f'{test_input_dir}/*.wav') + glob.glob(f'{test_input_dir}/**/*.mp3')) + sizes = [] + items = [] + + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + binarization_args = hparams['binarization_args'] + + for wav_fn in inp_wav_paths: + item_name = wav_fn[len(test_input_dir) + 1:].replace("/", "_") + item = binarizer_cls.process_item( + item_name, wav_fn, binarization_args) + items.append(item) + sizes.append(item['len']) + return items, sizes + + def load_mel_inputs(self, test_input_dir, spk_id=0): + inp_mel_paths = sorted(glob.glob(f'{test_input_dir}/*.npy')) + sizes = [] + items = [] + + binarizer_cls = hparams.get("binarizer_cls", 'data_gen.tts.base_binarizer.BaseBinarizer') + pkg = ".".join(binarizer_cls.split(".")[:-1]) + cls_name = binarizer_cls.split(".")[-1] + binarizer_cls = getattr(importlib.import_module(pkg), cls_name) + binarization_args = hparams['binarization_args'] + + for mel in inp_mel_paths: + mel_input = np.load(mel) + mel_input = torch.FloatTensor(mel_input) + item_name = mel[len(test_input_dir) + 1:].replace("/", "_") + item = binarizer_cls.process_mel_item(item_name, mel_input, None, binarization_args) + items.append(item) + sizes.append(item['len']) + return items, sizes diff --git a/tasks/vocoder/vocoder_base.py b/tasks/vocoder/vocoder_base.py new file mode 100644 index 0000000000000000000000000000000000000000..04f45af60c8ac1c1f8303d091f8c6031ec8451bf --- /dev/null +++ b/tasks/vocoder/vocoder_base.py @@ -0,0 +1,66 @@ +import os + +import torch +import torch.distributed as dist +from torch.utils.data import DistributedSampler + +from tasks.base_task import BaseTask +from tasks.base_task import data_loader +from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler +from utils.hparams import hparams + + +class VocoderBaseTask(BaseTask): + def __init__(self): + super(VocoderBaseTask, self).__init__() + self.max_sentences = hparams['max_sentences'] + self.max_valid_sentences = hparams['max_valid_sentences'] + if self.max_valid_sentences == -1: + hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences + self.dataset_cls = VocoderDataset + + @data_loader + def train_dataloader(self): + train_dataset = self.dataset_cls('train', shuffle=True) + return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds']) + + @data_loader + def val_dataloader(self): + valid_dataset = self.dataset_cls('valid', shuffle=False) + return self.build_dataloader(valid_dataset, False, self.max_valid_sentences) + + @data_loader + def test_dataloader(self): + test_dataset = self.dataset_cls('test', shuffle=False) + return self.build_dataloader(test_dataset, False, self.max_valid_sentences) + + def build_dataloader(self, dataset, shuffle, max_sentences, endless=False): + world_size = 1 + rank = 0 + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler + train_sampler = sampler_cls( + dataset=dataset, + num_replicas=world_size, + rank=rank, + shuffle=shuffle, + ) + return torch.utils.data.DataLoader( + dataset=dataset, + shuffle=False, + collate_fn=dataset.collater, + batch_size=max_sentences, + num_workers=dataset.num_workers, + sampler=train_sampler, + pin_memory=True, + ) + + def test_start(self): + self.gen_dir = os.path.join(hparams['work_dir'], + f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}') + os.makedirs(self.gen_dir, exist_ok=True) + + def test_end(self, outputs): + return {} diff --git a/usr/.gitkeep b/usr/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/usr/__init__.py b/usr/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/usr/diff/diffusion.py b/usr/diff/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..e874d64d4636c0b842392b91e92c7586770cbe58 --- /dev/null +++ b/usr/diff/diffusion.py @@ -0,0 +1,333 @@ +import math +import random +from functools import partial +from inspect import isfunction +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from utils.hparams import hparams + + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class Mish(nn.Module): + def forward(self, x): + return x * torch.tanh(F.softplus(x)) + + +class Upsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Downsample(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv2d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Rezero(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + self.g = nn.Parameter(torch.zeros(1)) + + def forward(self, x): + return self.fn(x) * self.g + + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv2d(dim, dim_out, 3, padding=1), + nn.GroupNorm(groups, dim_out), + Mish() + ) + + def forward(self, x): + return self.block(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim, groups=8): + super().__init__() + self.mlp = nn.Sequential( + Mish(), + nn.Linear(time_emb_dim, dim_out) + ) + + self.block1 = Block(dim, dim_out) + self.block2 = Block(dim_out, dim_out) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb): + h = self.block1(x) + h += self.mlp(time_emb)[:, :, None, None] + h = self.block2(h) + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, + timesteps=1000, loss_type='l1', betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + if hparams.get('use_midi') is not None and hparams['use_midi']: + self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims) + else: + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.fs2.decoder = None + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond, clip_denoised: bool): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, nonpadding=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if self.loss_type == 'l1': + if nonpadding is not None: + loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() + else: + # print('are you sure w/o nonpadding?') + loss = (noise - x_recon).abs().mean() + + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=infer) + cond = ret['decoder_inp'].transpose(1, 2) + if not infer: + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + nonpadding = (mel2ph != 0).float() + ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding) + else: + t = self.num_timesteps + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + ret['mel_out'] = self.denorm_spec(x) + + return ret + + def norm_spec(self, x): + return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + + def denorm_spec(self, x): + return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + + def out2mel(self, x): + return x diff --git a/usr/diff/net.py b/usr/diff/net.py new file mode 100644 index 0000000000000000000000000000000000000000..b8811115eafb4f27165cf4d89c67c0d9455aac9d --- /dev/null +++ b/usr/diff/net.py @@ -0,0 +1,130 @@ +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from math import sqrt + +from .diffusion import Mish +from utils.hparams import hparams + +Linear = nn.Linear +ConvTranspose2d = nn.ConvTranspose2d + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + def override(self, attrs): + if isinstance(attrs, dict): + self.__dict__.update(**attrs) + elif isinstance(attrs, (list, tuple, set)): + for attr in attrs: + self.override(attr) + elif attrs is not None: + raise NotImplementedError + return self + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Conv1d(*args, **kwargs): + layer = nn.Conv1d(*args, **kwargs) + nn.init.kaiming_normal_(layer.weight) + return layer + + +@torch.jit.script +def silu(x): + return x * torch.sigmoid(x) + + +class ResidualBlock(nn.Module): + def __init__(self, encoder_hidden, residual_channels, dilation): + super().__init__() + self.dilated_conv = Conv1d(residual_channels, 2 * residual_channels, 3, padding=dilation, dilation=dilation) + self.diffusion_projection = Linear(residual_channels, residual_channels) + self.conditioner_projection = Conv1d(encoder_hidden, 2 * residual_channels, 1) + self.output_projection = Conv1d(residual_channels, 2 * residual_channels, 1) + + def forward(self, x, conditioner, diffusion_step): + diffusion_step = self.diffusion_projection(diffusion_step).unsqueeze(-1) + conditioner = self.conditioner_projection(conditioner) + y = x + diffusion_step + + y = self.dilated_conv(y) + conditioner + + gate, filter = torch.chunk(y, 2, dim=1) + y = torch.sigmoid(gate) * torch.tanh(filter) + + y = self.output_projection(y) + residual, skip = torch.chunk(y, 2, dim=1) + return (x + residual) / sqrt(2.0), skip + + +class DiffNet(nn.Module): + def __init__(self, in_dims=80): + super().__init__() + self.params = params = AttrDict( + # Model params + encoder_hidden=hparams['hidden_size'], + residual_layers=hparams['residual_layers'], + residual_channels=hparams['residual_channels'], + dilation_cycle_length=hparams['dilation_cycle_length'], + ) + self.input_projection = Conv1d(in_dims, params.residual_channels, 1) + self.diffusion_embedding = SinusoidalPosEmb(params.residual_channels) + dim = params.residual_channels + self.mlp = nn.Sequential( + nn.Linear(dim, dim * 4), + Mish(), + nn.Linear(dim * 4, dim) + ) + self.residual_layers = nn.ModuleList([ + ResidualBlock(params.encoder_hidden, params.residual_channels, 2 ** (i % params.dilation_cycle_length)) + for i in range(params.residual_layers) + ]) + self.skip_projection = Conv1d(params.residual_channels, params.residual_channels, 1) + self.output_projection = Conv1d(params.residual_channels, in_dims, 1) + nn.init.zeros_(self.output_projection.weight) + + def forward(self, spec, diffusion_step, cond): + """ + + :param spec: [B, 1, M, T] + :param diffusion_step: [B, 1] + :param cond: [B, M, T] + :return: + """ + x = spec[:, 0] + x = self.input_projection(x) # x [B, residual_channel, T] + + x = F.relu(x) + diffusion_step = self.diffusion_embedding(diffusion_step) + diffusion_step = self.mlp(diffusion_step) + skip = [] + for layer_id, layer in enumerate(self.residual_layers): + x, skip_connection = layer(x, cond, diffusion_step) + skip.append(skip_connection) + + x = torch.sum(torch.stack(skip), dim=0) / sqrt(len(self.residual_layers)) + x = self.skip_projection(x) + x = F.relu(x) + x = self.output_projection(x) # [B, 80, T] + return x[:, None, :, :] diff --git a/usr/diff/shallow_diffusion_tts.py b/usr/diff/shallow_diffusion_tts.py new file mode 100644 index 0000000000000000000000000000000000000000..d296fbc1c297a9703e004bd1d216ed34f0008446 --- /dev/null +++ b/usr/diff/shallow_diffusion_tts.py @@ -0,0 +1,307 @@ +import math +import random +from functools import partial +from inspect import isfunction +from pathlib import Path +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from tqdm import tqdm +from einops import rearrange + +from modules.fastspeech.fs2 import FastSpeech2 +from utils.hparams import hparams + +def vpsde_beta_t(t, T, min_beta, max_beta): + t_coef = (2 * t - 1) / (T ** 2) + return 1. - np.exp(-min_beta / T - 0.5 * (max_beta - min_beta) * t_coef) + +def _logsnr_schedule_cosine(t, *, logsnr_min, logsnr_max): + b = np.arctan(np.exp(-0.5 * logsnr_max)) + a = np.arctan(np.exp(-0.5 * logsnr_min)) - b + return -2. * np.log(np.tan(a * t + b)) + + +def get_noise_schedule_list(schedule_mode, timesteps, min_beta=0.0, max_beta=0.01, s=0.008): + if schedule_mode == "linear": + schedule_list = np.linspace(0.000001, 0.01, timesteps) + elif schedule_mode == "cosine": + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + schedule_list = np.clip(betas, a_min=0, a_max=0.999) + elif schedule_mode == "vpsde": + schedule_list = np.array([ + vpsde_beta_t(t, timesteps, min_beta, max_beta) for t in range(1, timesteps + 1)]) + elif schedule_mode == "logsnr": + u = np.array([t for t in range(0, timesteps + 1)]) + schedule_list = np.array([ + _logsnr_schedule_cosine(t / timesteps, logsnr_min=-20.0, logsnr_max=20.0) for t in range(1, timesteps + 1)]) + else: + raise NotImplementedError + return schedule_list + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +# gaussian diffusion trainer class + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)): + """ + linear schedule + """ + betas = np.linspace(1e-4, max_beta, timesteps) + return betas + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = np.linspace(0, steps, steps) + alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, a_min=0, a_max=0.999) + + +beta_schedule = { + "cosine": cosine_beta_schedule, + "linear": linear_beta_schedule, +} + + +class GaussianDiffusion(nn.Module): + def __init__(self, phone_encoder, out_dims, denoise_fn, + timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None): + super().__init__() + self.denoise_fn = denoise_fn + if hparams.get('use_midi') is not None and hparams['use_midi']: + self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims) + else: + self.fs2 = FastSpeech2(phone_encoder, out_dims) + self.mel_bins = out_dims + + if exists(betas): + betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas + else: + if 'schedule_type' in hparams.keys(): + betas = beta_schedule[hparams['schedule_type']](timesteps) + else: + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.K_step = K_step + self.loss_type = loss_type + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer('posterior_variance', to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20)))) + self.register_buffer('posterior_mean_coef1', to_torch( + betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) + self.register_buffer('posterior_mean_coef2', to_torch( + (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) + + self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']]) + self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']]) + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, cond, clip_denoised: bool): + noise_pred = self.denoise_fn(x, t, cond=cond) + x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred) + + if clip_denoised: + x_recon.clamp_(-1., 1.) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond, noise=None, nonpadding=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) + x_recon = self.denoise_fn(x_noisy, t, cond) + + if self.loss_type == 'l1': + if nonpadding is not None: + loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean() + else: + # print('are you sure w/o nonpadding?') + loss = (noise - x_recon).abs().mean() + + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs): + b, *_, device = *txt_tokens.shape, txt_tokens.device + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=(not infer), infer=infer, **kwargs) + cond = ret['decoder_inp'].transpose(1, 2) + + if not infer: + t = torch.randint(0, self.K_step, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + ret['diff_loss'] = self.p_losses(x, t, cond) + # nonpadding = (mel2ph != 0).float() + # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding) + else: + ret['fs2_mel'] = ret['mel_out'] + fs2_mels = ret['mel_out'] + t = self.K_step + fs2_mels = self.norm_spec(fs2_mels) + fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :] + + x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long()) + if hparams.get('gaussian_start') is not None and hparams['gaussian_start']: + print('===> gaussion start.') + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + if mel2ph is not None: # for singing + ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None]) + else: + ret['mel_out'] = self.denorm_spec(x) + return ret + + # def norm_spec(self, x): + # return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1 + # + # def denorm_spec(self, x): + # return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min + + def norm_spec(self, x): + return x + + def denorm_spec(self, x): + return x + + def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph): + return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph) + + def out2mel(self, x): + return x + + +class OfflineGaussianDiffusion(GaussianDiffusion): + def forward(self, txt_tokens, mel2ph=None, spk_embed=None, + ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs): + b, *_, device = *txt_tokens.shape, txt_tokens.device + + ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy, + skip_decoder=True, infer=True, **kwargs) + cond = ret['decoder_inp'].transpose(1, 2) + fs2_mels = ref_mels[1] + ref_mels = ref_mels[0] + + if not infer: + t = torch.randint(0, self.K_step, (b,), device=device).long() + x = ref_mels + x = self.norm_spec(x) + x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T] + ret['diff_loss'] = self.p_losses(x, t, cond) + else: + t = self.K_step + fs2_mels = self.norm_spec(fs2_mels) + fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :] + + x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long()) + + if hparams.get('gaussian_start') is not None and hparams['gaussian_start']: + print('===> gaussion start.') + shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2]) + x = torch.randn(shape, device=device) + for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t): + x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond) + x = x[:, 0].transpose(1, 2) + ret['mel_out'] = self.denorm_spec(x) + return ret \ No newline at end of file diff --git a/usr/diffspeech_task.py b/usr/diffspeech_task.py new file mode 100644 index 0000000000000000000000000000000000000000..e4fca7e9e46fc378468188d58fc42bc989df824c --- /dev/null +++ b/usr/diffspeech_task.py @@ -0,0 +1,122 @@ +import torch + +import utils +from utils.hparams import hparams +from .diff.net import DiffNet +from .diff.shallow_diffusion_tts import GaussianDiffusion +from .task import DiffFsTask +from vocoders.base_vocoder import get_vocoder_cls, BaseVocoder +from utils.pitch_utils import denorm_f0 +from tasks.tts.fs2_utils import FastSpeechDataset + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class DiffSpeechTask(DiffFsTask): + def __init__(self): + super(DiffSpeechTask, self).__init__() + self.dataset_cls = FastSpeechDataset + self.vocoder: BaseVocoder = get_vocoder_cls(hparams)() + + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + K_step=hparams['K_step'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + if hparams['fs2_ckpt'] != '': + utils.load_ckpt(self.model.fs2, hparams['fs2_ckpt'], 'model', strict=True) + # self.model.fs2.decoder = None + for k, v in self.model.fs2.named_parameters(): + if not 'predictor' in k: + v.requires_grad = False + + def build_optimizer(self, model): + self.optimizer = optimizer = torch.optim.AdamW( + filter(lambda p: p.requires_grad, model.parameters()), + lr=hparams['lr'], + betas=(hparams['optimizer_adam_beta1'], hparams['optimizer_adam_beta2']), + weight_decay=hparams['weight_decay']) + return optimizer + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + # mel2ph = sample['mel2ph'] if hparams['use_gt_dur'] else None # [B, T_s] + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + # fs2_mel = sample['fs2_mels'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def validation_step(self, sample, batch_idx): + outputs = {} + txt_tokens = sample['txt_tokens'] # [B, T_t] + + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + mel2ph = sample['mel2ph'] + f0 = sample['f0'] + uv = sample['uv'] + + outputs['losses'] = {} + + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + + + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + # model_out = self.model( + # txt_tokens, spk_embed=spk_embed, mel2ph=None, f0=None, uv=None, energy=None, ref_mels=None, inference=True) + # self.plot_mel(batch_idx, model_out['mel_out'], model_out['fs2_mel'], name=f'diffspeech_vs_fs2_{batch_idx}') + model_out = self.model( + txt_tokens, spk_embed=spk_embed, mel2ph=mel2ph, f0=f0, uv=uv, energy=energy, ref_mels=None, infer=True) + gt_f0 = denorm_f0(sample['f0'], sample['uv'], hparams) + self.plot_wav(batch_idx, sample['mels'], model_out['mel_out'], is_mel=True, gt_f0=gt_f0, f0=model_out.get('f0_denorm')) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs + + ############ + # validation plots + ############ + def plot_wav(self, batch_idx, gt_wav, wav_out, is_mel=False, gt_f0=None, f0=None, name=None): + gt_wav = gt_wav[0].cpu().numpy() + wav_out = wav_out[0].cpu().numpy() + gt_f0 = gt_f0[0].cpu().numpy() + f0 = f0[0].cpu().numpy() + if is_mel: + gt_wav = self.vocoder.spec2wav(gt_wav, f0=gt_f0) + wav_out = self.vocoder.spec2wav(wav_out, f0=f0) + self.logger.experiment.add_audio(f'gt_{batch_idx}', gt_wav, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + self.logger.experiment.add_audio(f'wav_{batch_idx}', wav_out, sample_rate=hparams['audio_sample_rate'], global_step=self.global_step) + diff --git a/usr/task.py b/usr/task.py new file mode 100644 index 0000000000000000000000000000000000000000..f05d66f0a8f7aa5995c95c202af7fa81efb8a28f --- /dev/null +++ b/usr/task.py @@ -0,0 +1,73 @@ +import torch + +import utils +from .diff.diffusion import GaussianDiffusion +from .diff.net import DiffNet +from tasks.tts.fs2 import FastSpeech2Task +from utils.hparams import hparams + + +DIFF_DECODERS = { + 'wavenet': lambda hp: DiffNet(hp['audio_num_mel_bins']), +} + + +class DiffFsTask(FastSpeech2Task): + def build_tts_model(self): + mel_bins = hparams['audio_num_mel_bins'] + self.model = GaussianDiffusion( + phone_encoder=self.phone_encoder, + out_dims=mel_bins, denoise_fn=DIFF_DECODERS[hparams['diff_decoder_type']](hparams), + timesteps=hparams['timesteps'], + loss_type=hparams['diff_loss_type'], + spec_min=hparams['spec_min'], spec_max=hparams['spec_max'], + ) + + def run_model(self, model, sample, return_output=False, infer=False): + txt_tokens = sample['txt_tokens'] # [B, T_t] + target = sample['mels'] # [B, T_s, 80] + mel2ph = sample['mel2ph'] # [B, T_s] + f0 = sample['f0'] + uv = sample['uv'] + energy = sample['energy'] + spk_embed = sample.get('spk_embed') if not hparams['use_spk_id'] else sample.get('spk_ids') + if hparams['pitch_type'] == 'cwt': + cwt_spec = sample[f'cwt_spec'] + f0_mean = sample['f0_mean'] + f0_std = sample['f0_std'] + sample['f0_cwt'] = f0 = model.cwt2f0_norm(cwt_spec, f0_mean, f0_std, mel2ph) + + output = model(txt_tokens, mel2ph=mel2ph, spk_embed=spk_embed, + ref_mels=target, f0=f0, uv=uv, energy=energy, infer=infer) + + losses = {} + if 'diff_loss' in output: + losses['mel'] = output['diff_loss'] + self.add_dur_loss(output['dur'], mel2ph, txt_tokens, losses=losses) + if hparams['use_pitch_embed']: + self.add_pitch_loss(output, sample, losses) + if hparams['use_energy_embed']: + self.add_energy_loss(output['energy_pred'], energy, losses) + if not return_output: + return losses + else: + return losses, output + + def _training_step(self, sample, batch_idx, _): + log_outputs = self.run_model(self.model, sample) + total_loss = sum([v for v in log_outputs.values() if isinstance(v, torch.Tensor) and v.requires_grad]) + log_outputs['batch_size'] = sample['txt_tokens'].size()[0] + log_outputs['lr'] = self.scheduler.get_lr()[0] + return total_loss, log_outputs + + def validation_step(self, sample, batch_idx): + outputs = {} + outputs['losses'] = {} + outputs['losses'], model_out = self.run_model(self.model, sample, return_output=True, infer=False) + outputs['total_loss'] = sum(outputs['losses'].values()) + outputs['nsamples'] = sample['nsamples'] + outputs = utils.tensors_to_scalars(outputs) + if batch_idx < hparams['num_valid_plots']: + _, model_out = self.run_model(self.model, sample, return_output=True, infer=True) + self.plot_mel(batch_idx, sample['mels'], model_out['mel_out']) + return outputs diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100755 index 0000000000000000000000000000000000000000..63921bfea9a95629b15d90498677c6a22de9fec8 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,285 @@ +import time +import sys +import types + +import chardet +import numpy as np +import torch +import torch.distributed as dist +from utils.ckpt_utils import load_ckpt + + +def reduce_tensors(metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + dist.all_reduce(v) + v = v / dist.get_world_size() + if type(v) is dict: + v = reduce_tensors(v) + new_metrics[k] = v + return new_metrics + + +def tensors_to_scalars(tensors): + if isinstance(tensors, torch.Tensor): + tensors = tensors.item() + return tensors + elif isinstance(tensors, dict): + new_tensors = {} + for k, v in tensors.items(): + v = tensors_to_scalars(v) + new_tensors[k] = v + return new_tensors + elif isinstance(tensors, list): + return [tensors_to_scalars(v) for v in tensors] + else: + return tensors + + +def tensors_to_np(tensors): + if isinstance(tensors, dict): + new_np = {} + for k, v in tensors.items(): + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = tensors_to_np(v) + new_np[k] = v + elif isinstance(tensors, list): + new_np = [] + for v in tensors: + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = tensors_to_np(v) + new_np.append(v) + elif isinstance(tensors, torch.Tensor): + v = tensors + if isinstance(v, torch.Tensor): + v = v.cpu().numpy() + if type(v) is dict: + v = tensors_to_np(v) + new_np = v + else: + raise Exception(f'tensors_to_np does not support type {type(tensors)}.') + return new_np + + +def move_to_cpu(tensors): + ret = {} + for k, v in tensors.items(): + if isinstance(v, torch.Tensor): + v = v.cpu() + if type(v) is dict: + v = move_to_cpu(v) + ret[k] = v + return ret + + +def move_to_cuda(batch, gpu_id=0): + # base case: object can be directly moved using `cuda` or `to` + if callable(getattr(batch, 'cuda', None)): + return batch.cuda(gpu_id, non_blocking=True) + elif callable(getattr(batch, 'to', None)): + return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + elif isinstance(batch, list): + for i, x in enumerate(batch): + batch[i] = move_to_cuda(x, gpu_id) + return batch + elif isinstance(batch, tuple): + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = move_to_cuda(x, gpu_id) + return tuple(batch) + elif isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = move_to_cuda(v, gpu_id) + return batch + return batch + + +class AvgrageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.avg = 0 + self.sum = 0 + self.cnt = 0 + + def update(self, val, n=1): + self.sum += val * n + self.cnt += n + self.avg = self.sum / self.cnt + + +def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + dst[0] = shift_id + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None): + """Convert a list of 2d tensors into a padded 3d tensor.""" + size = max(v.size(0) for v in values) if max_len is None else max_len + res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + if shift_right: + dst[1:] = src[:-1] + else: + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)]) + return res + + +def _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + if len(batch) == 0: + return 0 + if len(batch) == max_sentences: + return 1 + if num_tokens > max_tokens: + return 1 + return 0 + + +def batch_by_size( + indices, num_tokens_fn, max_tokens=None, max_sentences=None, + required_batch_size_multiple=1, distributed=False +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + num_tokens_fn (callable): function that returns the number of tokens at + a given index + max_tokens (int, optional): max number of tokens in each batch + (default: None). + max_sentences (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be a multiple of N (default: 1). + """ + max_tokens = max_tokens if max_tokens is not None else sys.maxsize + max_sentences = max_sentences if max_sentences is not None else sys.maxsize + bsz_mult = required_batch_size_multiple + + if isinstance(indices, types.GeneratorType): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + sample_len = 0 + sample_lens = [] + batch = [] + batches = [] + for i in range(len(indices)): + idx = indices[i] + num_tokens = num_tokens_fn(idx) + sample_lens.append(num_tokens) + sample_len = max(sample_len, num_tokens) + + assert sample_len <= max_tokens, ( + "sentence at index {} of size {} exceeds max_tokens " + "limit of {}!".format(idx, sample_len, max_tokens) + ) + num_tokens = (len(batch) + 1) * sample_len + + if _is_batch_full(batch, num_tokens, max_tokens, max_sentences): + mod_len = max( + bsz_mult * (len(batch) // bsz_mult), + len(batch) % bsz_mult, + ) + batches.append(batch[:mod_len]) + batch = batch[mod_len:] + sample_lens = sample_lens[mod_len:] + sample_len = max(sample_lens) if len(sample_lens) > 0 else 0 + batch.append(idx) + if len(batch) > 0: + batches.append(batch) + return batches + +def unpack_dict_to_list(samples): + samples_ = [] + bsz = samples.get('outputs').size(0) + for i in range(bsz): + res = {} + for k, v in samples.items(): + try: + res[k] = v[i] + except: + pass + samples_.append(res) + return samples_ + + +def remove_padding(x, padding_idx=0): + if x is None: + return None + assert len(x.shape) in [1, 2] + if len(x.shape) == 2: # [T, H] + return x[np.abs(x).sum(-1) != padding_idx] + elif len(x.shape) == 1: # [T] + return x[x != padding_idx] + + +class Timer: + timer_map = {} + + def __init__(self, name, enable=False): + if name not in Timer.timer_map: + Timer.timer_map[name] = 0 + self.name = name + self.enable = enable + + def __enter__(self): + if self.enable: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.t = time.time() + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.enable: + if torch.cuda.is_available(): + torch.cuda.synchronize() + Timer.timer_map[self.name] += time.time() - self.t + if self.enable: + print(f'[Timer] {self.name}: {Timer.timer_map[self.name]}') + + +def print_arch(model, model_name='model'): + print(f"| {model_name} Arch: ", model) + num_params(model, model_name=model_name) + + +def num_params(model, print_out=True, model_name="model"): + parameters = filter(lambda p: p.requires_grad, model.parameters()) + parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 + if print_out: + print(f'| {model_name} Trainable Parameters: %.3fM' % parameters) + return parameters + + +def get_encoding(file): + with open(file, 'rb') as f: + encoding = chardet.detect(f.read())['encoding'] + if encoding == 'GB2312': + encoding = 'GB18030' + return encoding diff --git a/utils/audio.py b/utils/audio.py new file mode 100644 index 0000000000000000000000000000000000000000..aba7ab926cf793d085bbdc70c97f376001183fe1 --- /dev/null +++ b/utils/audio.py @@ -0,0 +1,56 @@ +import subprocess +import matplotlib + +matplotlib.use('Agg') +import librosa +import librosa.filters +import numpy as np +from scipy import signal +from scipy.io import wavfile + + +def save_wav(wav, path, sr, norm=False): + if norm: + wav = wav / np.abs(wav).max() + wav *= 32767 + # proposed by @dsmiller + wavfile.write(path, sr, wav.astype(np.int16)) + + +def get_hop_size(hparams): + hop_size = hparams['hop_size'] + if hop_size is None: + assert hparams['frame_shift_ms'] is not None + hop_size = int(hparams['frame_shift_ms'] / 1000 * hparams['audio_sample_rate']) + return hop_size + + +########################################################################################### +def _stft(y, hparams): + return librosa.stft(y=y, n_fft=hparams['fft_size'], hop_length=get_hop_size(hparams), + win_length=hparams['win_size'], pad_mode='constant') + + +def _istft(y, hparams): + return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams['win_size']) + + +def librosa_pad_lr(x, fsize, fshift, pad_sides=1): + '''compute right padding (final frame) or both sides padding (first and final frames) + ''' + assert pad_sides in (1, 2) + # return int(fsize // 2) + pad = (x.shape[0] // fshift + 1) * fshift - x.shape[0] + if pad_sides == 1: + return 0, pad + else: + return pad // 2, pad // 2 + pad % 2 + + +# Conversions +def amp_to_db(x): + return 20 * np.log10(np.maximum(1e-5, x)) + + +def normalize(S, hparams): + return (S - hparams['min_level_db']) / -hparams['min_level_db'] diff --git a/utils/ckpt_utils.py b/utils/ckpt_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..fc321f9ba891ffffc374df65871c3085bf898afb --- /dev/null +++ b/utils/ckpt_utils.py @@ -0,0 +1,68 @@ +import glob +import logging +import os +import re +import torch + + +def get_last_checkpoint(work_dir, steps=None): + checkpoint = None + last_ckpt_path = None + ckpt_paths = get_all_ckpts(work_dir, steps) + if len(ckpt_paths) > 0: + last_ckpt_path = ckpt_paths[0] + checkpoint = torch.load(last_ckpt_path, map_location='cpu') + logging.info(f'load module from checkpoint: {last_ckpt_path}') + return checkpoint, last_ckpt_path + + +def get_all_ckpts(work_dir, steps=None): + if steps is None: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_*.ckpt' + else: + ckpt_path_pattern = f'{work_dir}/model_ckpt_steps_{steps}.ckpt' + return sorted(glob.glob(ckpt_path_pattern), + key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) + + +def load_ckpt(cur_model, ckpt_base_dir, model_name='model', force=True, strict=True): + if os.path.isfile(ckpt_base_dir): + base_dir = os.path.dirname(ckpt_base_dir) + ckpt_path = ckpt_base_dir + checkpoint = torch.load(ckpt_base_dir, map_location='cpu') + else: + base_dir = ckpt_base_dir + checkpoint, ckpt_path = get_last_checkpoint(ckpt_base_dir) + if checkpoint is not None: + state_dict = checkpoint["state_dict"] + if len([k for k in state_dict.keys() if '.' in k]) > 0: + state_dict = {k[len(model_name) + 1:]: v for k, v in state_dict.items() + if k.startswith(f'{model_name}.')} + else: + if '.' not in model_name: + state_dict = state_dict[model_name] + else: + base_model_name = model_name.split('.')[0] + rest_model_name = model_name[len(base_model_name) + 1:] + state_dict = { + k[len(rest_model_name) + 1:]: v for k, v in state_dict[base_model_name].items() + if k.startswith(f'{rest_model_name}.')} + if not strict: + cur_model_state_dict = cur_model.state_dict() + unmatched_keys = [] + for key, param in state_dict.items(): + if key in cur_model_state_dict: + new_param = cur_model_state_dict[key] + if new_param.shape != param.shape: + unmatched_keys.append(key) + print("| Unmatched keys: ", key, new_param.shape, param.shape) + for key in unmatched_keys: + del state_dict[key] + cur_model.load_state_dict(state_dict, strict=strict) + print(f"| load '{model_name}' from '{ckpt_path}'.") + else: + e_msg = f"| ckpt not found in {base_dir}." + if force: + assert False, e_msg + else: + print(e_msg) diff --git a/utils/common_schedulers.py b/utils/common_schedulers.py new file mode 100755 index 0000000000000000000000000000000000000000..41c6f4a9250b2d5954ce93cb7c04e7b55025cb51 --- /dev/null +++ b/utils/common_schedulers.py @@ -0,0 +1,50 @@ +from utils.hparams import hparams + + +class NoneSchedule(object): + def __init__(self, optimizer): + super().__init__() + self.optimizer = optimizer + self.constant_lr = hparams['lr'] + self.step(0) + + def step(self, num_updates): + self.lr = self.constant_lr + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + def get_lr(self): + return self.optimizer.param_groups[0]['lr'] + + def get_last_lr(self): + return self.get_lr() + + +class RSQRTSchedule(object): + def __init__(self, optimizer): + super().__init__() + self.optimizer = optimizer + self.constant_lr = hparams['lr'] + self.warmup_updates = hparams['warmup_updates'] + self.hidden_size = hparams['hidden_size'] + self.lr = hparams['lr'] + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + warmup = min(num_updates / self.warmup_updates, 1.0) + rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 + rsqrt_hidden = self.hidden_size ** -0.5 + self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + def get_lr(self): + return self.optimizer.param_groups[0]['lr'] + + def get_last_lr(self): + return self.get_lr() diff --git a/utils/cwt.py b/utils/cwt.py new file mode 100644 index 0000000000000000000000000000000000000000..1a08461b9e422aac614438e6240b7355b8e4bb2c --- /dev/null +++ b/utils/cwt.py @@ -0,0 +1,146 @@ +import librosa +import numpy as np +from pycwt import wavelet +from scipy.interpolate import interp1d + + +def load_wav(wav_file, sr): + wav, _ = librosa.load(wav_file, sr=sr, mono=True) + return wav + + +def convert_continuos_f0(f0): + '''CONVERT F0 TO CONTINUOUS F0 + Args: + f0 (ndarray): original f0 sequence with the shape (T) + Return: + (ndarray): continuous f0 with the shape (T) + ''' + # get uv information as binary + f0 = np.copy(f0) + uv = np.float32(f0 != 0) + + # get start and end of f0 + if (f0 == 0).all(): + print("| all of the f0 values are 0.") + return uv, f0 + start_f0 = f0[f0 != 0][0] + end_f0 = f0[f0 != 0][-1] + + # padding start and end of f0 sequence + start_idx = np.where(f0 == start_f0)[0][0] + end_idx = np.where(f0 == end_f0)[0][-1] + f0[:start_idx] = start_f0 + f0[end_idx:] = end_f0 + + # get non-zero frame index + nz_frames = np.where(f0 != 0)[0] + + # perform linear interpolation + f = interp1d(nz_frames, f0[nz_frames]) + cont_f0 = f(np.arange(0, f0.shape[0])) + + return uv, cont_f0 + + +def get_cont_lf0(f0, frame_period=5.0): + uv, cont_f0_lpf = convert_continuos_f0(f0) + # cont_f0_lpf = low_pass_filter(cont_f0_lpf, int(1.0 / (frame_period * 0.001)), cutoff=20) + cont_lf0_lpf = np.log(cont_f0_lpf) + return uv, cont_lf0_lpf + + +def get_lf0_cwt(lf0): + ''' + input: + signal of shape (N) + output: + Wavelet_lf0 of shape(10, N), scales of shape(10) + ''' + mother = wavelet.MexicanHat() + dt = 0.005 + dj = 1 + s0 = dt * 2 + J = 9 + + Wavelet_lf0, scales, _, _, _, _ = wavelet.cwt(np.squeeze(lf0), dt, dj, s0, J, mother) + # Wavelet.shape => (J + 1, len(lf0)) + Wavelet_lf0 = np.real(Wavelet_lf0).T + return Wavelet_lf0, scales + + +def norm_scale(Wavelet_lf0): + Wavelet_lf0_norm = np.zeros((Wavelet_lf0.shape[0], Wavelet_lf0.shape[1])) + mean = Wavelet_lf0.mean(0)[None, :] + std = Wavelet_lf0.std(0)[None, :] + Wavelet_lf0_norm = (Wavelet_lf0 - mean) / std + return Wavelet_lf0_norm, mean, std + + +def normalize_cwt_lf0(f0, mean, std): + uv, cont_lf0_lpf = get_cont_lf0(f0) + cont_lf0_norm = (cont_lf0_lpf - mean) / std + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_norm) + Wavelet_lf0_norm, _, _ = norm_scale(Wavelet_lf0) + + return Wavelet_lf0_norm + + +def get_lf0_cwt_norm(f0s, mean, std): + uvs = list() + cont_lf0_lpfs = list() + cont_lf0_lpf_norms = list() + Wavelet_lf0s = list() + Wavelet_lf0s_norm = list() + scaless = list() + + means = list() + stds = list() + for f0 in f0s: + uv, cont_lf0_lpf = get_cont_lf0(f0) + cont_lf0_lpf_norm = (cont_lf0_lpf - mean) / std + + Wavelet_lf0, scales = get_lf0_cwt(cont_lf0_lpf_norm) # [560,10] + Wavelet_lf0_norm, mean_scale, std_scale = norm_scale(Wavelet_lf0) # [560,10],[1,10],[1,10] + + Wavelet_lf0s_norm.append(Wavelet_lf0_norm) + uvs.append(uv) + cont_lf0_lpfs.append(cont_lf0_lpf) + cont_lf0_lpf_norms.append(cont_lf0_lpf_norm) + Wavelet_lf0s.append(Wavelet_lf0) + scaless.append(scales) + means.append(mean_scale) + stds.append(std_scale) + + return Wavelet_lf0s_norm, scaless, means, stds + + +def inverse_cwt_torch(Wavelet_lf0, scales): + import torch + b = ((torch.arange(0, len(scales)).float().to(Wavelet_lf0.device)[None, None, :] + 1 + 2.5) ** (-2.5)) + lf0_rec = Wavelet_lf0 * b + lf0_rec_sum = lf0_rec.sum(-1) + lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdim=True)) / lf0_rec_sum.std(-1, keepdim=True) + return lf0_rec_sum + + +def inverse_cwt(Wavelet_lf0, scales): + b = ((np.arange(0, len(scales))[None, None, :] + 1 + 2.5) ** (-2.5)) + lf0_rec = Wavelet_lf0 * b + lf0_rec_sum = lf0_rec.sum(-1) + lf0_rec_sum = (lf0_rec_sum - lf0_rec_sum.mean(-1, keepdims=True)) / lf0_rec_sum.std(-1, keepdims=True) + return lf0_rec_sum + + +def cwt2f0(cwt_spec, mean, std, cwt_scales): + assert len(mean.shape) == 1 and len(std.shape) == 1 and len(cwt_spec.shape) == 3 + import torch + if isinstance(cwt_spec, torch.Tensor): + f0 = inverse_cwt_torch(cwt_spec, cwt_scales) + f0 = f0 * std[:, None] + mean[:, None] + f0 = f0.exp() # [B, T] + else: + f0 = inverse_cwt(cwt_spec, cwt_scales) + f0 = f0 * std[:, None] + mean[:, None] + f0 = np.exp(f0) # [B, T] + return f0 diff --git a/utils/ddp_utils.py b/utils/ddp_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..4b529198c13a1ffc622baea6e5178407b24aee8f --- /dev/null +++ b/utils/ddp_utils.py @@ -0,0 +1,137 @@ +from torch.nn.parallel import DistributedDataParallel +from torch.nn.parallel.distributed import _find_tensors +import torch.optim +import torch.utils.data +import torch +from packaging import version + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): # pragma: no cover + if version.parse(torch.__version__[:6]) < version.parse("1.11"): + self._sync_params() + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + assert len(self.device_ids) == 1 + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + from torch.nn.parallel.distributed import \ + logging, Join, _DDPSink, _tree_flatten_with_rref, _tree_unflatten_with_rref + with torch.autograd.profiler.record_function("DistributedDataParallel.forward"): + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.logger.set_runtime_stats_and_log() + self.num_iterations += 1 + self.reducer.prepare_for_forward() + + # Notify the join context that this process has not joined, if + # needed + work = Join.notify_join_context(self) + if work: + self.reducer._set_forward_pass_work_handle( + work, self._divide_by_initial_world_size + ) + + # Calling _rebuild_buckets before forward compuation, + # It may allocate new buckets before deallocating old buckets + # inside _rebuild_buckets. To save peak memory usage, + # call _rebuild_buckets before the peak memory usage increases + # during forward computation. + # This should be called only once during whole training period. + if torch.is_grad_enabled() and self.reducer._rebuild_buckets(): + logging.info("Reducer buckets have been rebuilt in this iteration.") + self._has_rebuilt_buckets = True + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + buffer_hook_registered = hasattr(self, 'buffer_hook') + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + + if self._join_config.enable: + # Notify joined ranks whether they should sync in backwards pass or not. + self._check_global_requires_backward_grad_sync(is_joined_rank=False) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + + # sync params according to location (before/after forward) user + # specified as part of hook, if hook was specified. + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if torch.is_grad_enabled() and self.require_backward_grad_sync: + self.require_forward_param_sync = True + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters and not self.static_graph: + # Do not need to populate this for static graph. + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + self.require_forward_param_sync = False + + # TODO: DDPSink is currently enabled for unused parameter detection and + # static graph training for first iteration. + if (self.find_unused_parameters and not self.static_graph) or ( + self.static_graph and self.num_iterations == 1 + ): + state_dict = { + 'static_graph': self.static_graph, + 'num_iterations': self.num_iterations, + } + + output_tensor_list, treespec, output_is_rref = _tree_flatten_with_rref( + output + ) + output_placeholders = [None for _ in range(len(output_tensor_list))] + # Do not touch tensors that have no grad_fn, which can cause issues + # such as https://github.com/pytorch/pytorch/issues/60733 + for i, output in enumerate(output_tensor_list): + if torch.is_tensor(output) and output.grad_fn is None: + output_placeholders[i] = output + + # When find_unused_parameters=True, makes tensors which require grad + # run through the DDPSink backward pass. When not all outputs are + # used in loss, this makes those corresponding tensors receive + # undefined gradient which the reducer then handles to ensure + # param.grad field is not touched and we don't error out. + passthrough_tensor_list = _DDPSink.apply( + self.reducer, + state_dict, + *output_tensor_list, + ) + for i in range(len(output_placeholders)): + if output_placeholders[i] is None: + output_placeholders[i] = passthrough_tensor_list[i] + + # Reconstruct output data structure. + output = _tree_unflatten_with_rref( + output_placeholders, treespec, output_is_rref + ) + return output diff --git a/utils/hparams.py b/utils/hparams.py new file mode 100644 index 0000000000000000000000000000000000000000..7efa3025ec3b52949d7b20d432b3457fc60713c4 --- /dev/null +++ b/utils/hparams.py @@ -0,0 +1,121 @@ +import argparse +import os +import yaml + +global_print_hparams = True +hparams = {} + + +class Args: + def __init__(self, **kwargs): + for k, v in kwargs.items(): + self.__setattr__(k, v) + + +def override_config(old_config: dict, new_config: dict): + for k, v in new_config.items(): + if isinstance(v, dict) and k in old_config: + override_config(old_config[k], new_config[k]) + else: + old_config[k] = v + + +def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True): + if config == '': + parser = argparse.ArgumentParser(description='neural music') + parser.add_argument('--config', type=str, default='', + help='location of the data corpus') + parser.add_argument('--exp_name', type=str, default='', help='exp_name') + parser.add_argument('--hparams', type=str, default='', + help='location of the data corpus') + parser.add_argument('--inference', action='store_true', help='inference') + parser.add_argument('--validate', action='store_true', help='validate') + parser.add_argument('--reset', action='store_true', help='reset hparams') + parser.add_argument('--debug', action='store_true', help='debug') + args, unknown = parser.parse_known_args() + else: + args = Args(config=config, exp_name=exp_name, hparams=hparams_str, + infer=False, validate=False, reset=False, debug=False) + args_work_dir = '' + if args.exp_name != '': + args.work_dir = args.exp_name + args_work_dir = f'checkpoints/{args.work_dir}' + + config_chains = [] + loaded_config = set() + + def load_config(config_fn): # deep first + with open(config_fn) as f: + hparams_ = yaml.safe_load(f) + loaded_config.add(config_fn) + if 'base_config' in hparams_: + ret_hparams = {} + if not isinstance(hparams_['base_config'], list): + hparams_['base_config'] = [hparams_['base_config']] + for c in hparams_['base_config']: + if c not in loaded_config: + if c.startswith('.'): + c = f'{os.path.dirname(config_fn)}/{c}' + c = os.path.normpath(c) + override_config(ret_hparams, load_config(c)) + override_config(ret_hparams, hparams_) + else: + ret_hparams = hparams_ + config_chains.append(config_fn) + return ret_hparams + + global hparams + assert args.config != '' or args_work_dir != '' + saved_hparams = {} + if args_work_dir != 'checkpoints/': + ckpt_config_path = f'{args_work_dir}/config.yaml' + if os.path.exists(ckpt_config_path): + try: + with open(ckpt_config_path) as f: + saved_hparams.update(yaml.safe_load(f)) + except: + pass + if args.config == '': + args.config = ckpt_config_path + + hparams_ = {} + hparams_.update(load_config(args.config)) + + if not args.reset: + hparams_.update(saved_hparams) + hparams_['work_dir'] = args_work_dir + + if args.hparams != "": + for new_hparam in args.hparams.split(","): + k, v = new_hparam.split("=") + if v in ['True', 'False'] or type(hparams_[k]) == bool: + hparams_[k] = eval(v) + else: + hparams_[k] = type(hparams_[k])(v) + + if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer: + os.makedirs(hparams_['work_dir'], exist_ok=True) + with open(ckpt_config_path, 'w') as f: + yaml.safe_dump(hparams_, f) + + hparams_['inference'] = args.infer + hparams_['debug'] = args.debug + hparams_['validate'] = args.validate + global global_print_hparams + if global_hparams: + hparams.clear() + hparams.update(hparams_) + + if print_hparams and global_print_hparams and global_hparams: + print('| Hparams chains: ', config_chains) + print('| Hparams: ') + for i, (k, v) in enumerate(sorted(hparams_.items())): + print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "") + print("") + global_print_hparams = False + # print(hparams_.keys()) + if hparams.get('exp_name') is None: + hparams['exp_name'] = args.exp_name + if hparams_.get('exp_name') is None: + hparams_['exp_name'] = args.exp_name + return hparams_ diff --git a/utils/indexed_datasets.py b/utils/indexed_datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..e15632be30d6296a3c9aa80a1f351058003698b3 --- /dev/null +++ b/utils/indexed_datasets.py @@ -0,0 +1,71 @@ +import pickle +from copy import deepcopy + +import numpy as np + + +class IndexedDataset: + def __init__(self, path, num_cache=1): + super().__init__() + self.path = path + self.data_file = None + self.data_offsets = np.load(f"{path}.idx", allow_pickle=True).item()['offsets'] + self.data_file = open(f"{path}.data", 'rb', buffering=-1) + self.cache = [] + self.num_cache = num_cache + + def check_index(self, i): + if i < 0 or i >= len(self.data_offsets) - 1: + raise IndexError('index out of range') + + def __del__(self): + if self.data_file: + self.data_file.close() + + def __getitem__(self, i): + self.check_index(i) + if self.num_cache > 0: + for c in self.cache: + if c[0] == i: + return c[1] + self.data_file.seek(self.data_offsets[i]) + b = self.data_file.read(self.data_offsets[i + 1] - self.data_offsets[i]) + item = pickle.loads(b) + if self.num_cache > 0: + self.cache = [(i, deepcopy(item))] + self.cache[:-1] + return item + + def __len__(self): + return len(self.data_offsets) - 1 + +class IndexedDatasetBuilder: + def __init__(self, path): + self.path = path + self.out_file = open(f"{path}.data", 'wb') + self.byte_offsets = [0] + + def add_item(self, item): + s = pickle.dumps(item) + bytes = self.out_file.write(s) + self.byte_offsets.append(self.byte_offsets[-1] + bytes) + + def finalize(self): + self.out_file.close() + np.save(open(f"{self.path}.idx", 'wb'), {'offsets': self.byte_offsets}) + + +if __name__ == "__main__": + import random + from tqdm import tqdm + ds_path = '/tmp/indexed_ds_example' + size = 100 + items = [{"a": np.random.normal(size=[10000, 10]), + "b": np.random.normal(size=[10000, 10])} for i in range(size)] + builder = IndexedDatasetBuilder(ds_path) + for i in tqdm(range(size)): + builder.add_item(items[i]) + builder.finalize() + ds = IndexedDataset(ds_path) + for i in tqdm(range(10000)): + idx = random.randint(0, size - 1) + assert (ds[idx]['a'] == items[idx]['a']).all() diff --git a/utils/multiprocess_utils.py b/utils/multiprocess_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d3641a332eedfbaf27cda11dbd4a79b8a65072b --- /dev/null +++ b/utils/multiprocess_utils.py @@ -0,0 +1,143 @@ +import os +import traceback +from functools import partial +from tqdm import tqdm + + +def chunked_worker(worker_id, args_queue=None, results_queue=None, init_ctx_func=None): + ctx = init_ctx_func(worker_id) if init_ctx_func is not None else None + while True: + args = args_queue.get() + if args == '': + return + job_idx, map_func, arg = args + try: + map_func_ = partial(map_func, ctx=ctx) if ctx is not None else map_func + if isinstance(arg, dict): + res = map_func_(**arg) + elif isinstance(arg, (list, tuple)): + res = map_func_(*arg) + else: + res = map_func_(arg) + results_queue.put((job_idx, res)) + except: + traceback.print_exc() + results_queue.put((job_idx, None)) + + +class MultiprocessManager: + def __init__(self, num_workers=None, init_ctx_func=None, multithread=False): + if multithread: + from multiprocessing.dummy import Queue, Process + else: + from multiprocessing import Queue, Process + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + self.num_workers = num_workers + self.results_queue = Queue(maxsize=-1) + self.args_queue = Queue(maxsize=-1) + self.workers = [] + self.total_jobs = 0 + for i in range(num_workers): + p = Process(target=chunked_worker, + args=(i, self.args_queue, self.results_queue, init_ctx_func), + daemon=True) + self.workers.append(p) + p.start() + + def add_job(self, func, args): + self.args_queue.put((self.total_jobs, func, args)) + self.total_jobs += 1 + + def get_results(self): + for w in range(self.num_workers): + self.args_queue.put("") + self.n_finished = 0 + while self.n_finished < self.total_jobs: + job_id, res = self.results_queue.get() + yield job_id, res + self.n_finished += 1 + for w in self.workers: + w.join() + + def __len__(self): + return self.total_jobs + + +def multiprocess_run_tqdm(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, + multithread=False, desc=None): + for i, res in tqdm(enumerate( + multiprocess_run(map_func, args, num_workers, ordered, init_ctx_func, multithread)), + total=len(args), desc=desc): + yield i, res + + +def multiprocess_run(map_func, args, num_workers=None, ordered=True, init_ctx_func=None, multithread=False): + """ + Multiprocessing running chunked jobs. + Examples: + >>> for res in tqdm(multiprocess_run(job_func, args): + >>> print(res) + :param map_func: + :param args: + :param num_workers: + :param ordered: + :param init_ctx_func: + :param q_max_size: + :param multithread: + :return: + """ + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + manager = MultiprocessManager(num_workers, init_ctx_func, multithread) + for arg in args: + manager.add_job(map_func, arg) + if ordered: + n_jobs = len(args) + results = ['' for _ in range(n_jobs)] + i_now = 0 + for job_i, res in manager.get_results(): + results[job_i] = res + while i_now < n_jobs and (not isinstance(results[i_now], str) or results[i_now] != ''): + yield results[i_now] + i_now += 1 + else: + for res in manager.get_results(): + yield res + + +def chunked_multiprocess_run( + map_func, args, num_workers=None, ordered=True, + init_ctx_func=None, q_max_size=1000, multithread=False): + if multithread: + from multiprocessing.dummy import Queue, Process + else: + from multiprocessing import Queue, Process + args = zip(range(len(args)), args) + args = list(args) + n_jobs = len(args) + if num_workers is None: + num_workers = int(os.getenv('N_PROC', os.cpu_count())) + results_queues = [] + if ordered: + for i in range(num_workers): + results_queues.append(Queue(maxsize=q_max_size // num_workers)) + else: + results_queue = Queue(maxsize=q_max_size) + for i in range(num_workers): + results_queues.append(results_queue) + workers = [] + for i in range(num_workers): + args_worker = args[i::num_workers] + p = Process(target=chunked_worker, args=( + i, map_func, args_worker, results_queues[i], init_ctx_func), daemon=True) + workers.append(p) + p.start() + for n_finished in range(n_jobs): + results_queue = results_queues[n_finished % num_workers] + job_idx, res = results_queue.get() + assert job_idx == n_finished or not ordered, (job_idx, n_finished) + yield res + for w in workers: + w.join() + diff --git a/utils/os_utils.py b/utils/os_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c78a44c04eadc3feb3c35f88c8a074f59ab23778 --- /dev/null +++ b/utils/os_utils.py @@ -0,0 +1,20 @@ +import os +import subprocess + + +def link_file(from_file, to_file): + subprocess.check_call( + f'ln -s "`realpath --relative-to="{os.path.dirname(to_file)}" "{from_file}"`" "{to_file}"', shell=True) + + +def move_file(from_file, to_file): + subprocess.check_call(f'mv "{from_file}" "{to_file}"', shell=True) + + +def copy_file(from_file, to_file): + subprocess.check_call(f'cp -r "{from_file}" "{to_file}"', shell=True) + + +def remove_file(*fns): + for f in fns: + subprocess.check_call(f'rm -rf "{f}"', shell=True) \ No newline at end of file diff --git a/utils/pitch_utils.py b/utils/pitch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f7fd166abd3a03bac5909e498669b482447435cf --- /dev/null +++ b/utils/pitch_utils.py @@ -0,0 +1,76 @@ +######### +# world +########## +import librosa +import numpy as np +import torch + +gamma = 0 +mcepInput = 3 # 0 for dB, 3 for magnitude +alpha = 0.45 +en_floor = 10 ** (-80 / 20) +FFT_SIZE = 2048 + + +f0_bin = 256 +f0_max = 1100.0 +f0_min = 50.0 +f0_mel_min = 1127 * np.log(1 + f0_min / 700) +f0_mel_max = 1127 * np.log(1 + f0_max / 700) + + +def f0_to_coarse(f0): + is_torch = isinstance(f0, torch.Tensor) + f0_mel = 1127 * (1 + f0 / 700).log() if is_torch else 1127 * np.log(1 + f0 / 700) + f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - f0_mel_min) * (f0_bin - 2) / (f0_mel_max - f0_mel_min) + 1 + + f0_mel[f0_mel <= 1] = 1 + f0_mel[f0_mel > f0_bin - 1] = f0_bin - 1 + f0_coarse = (f0_mel + 0.5).long() if is_torch else np.rint(f0_mel).astype(np.int) + assert f0_coarse.max() <= 255 and f0_coarse.min() >= 1, (f0_coarse.max(), f0_coarse.min()) + return f0_coarse + + +def norm_f0(f0, uv, hparams): + is_torch = isinstance(f0, torch.Tensor) + if hparams['pitch_norm'] == 'standard': + f0 = (f0 - hparams['f0_mean']) / hparams['f0_std'] + if hparams['pitch_norm'] == 'log': + f0 = torch.log2(f0) if is_torch else np.log2(f0) + if uv is not None and hparams['use_uv']: + f0[uv > 0] = 0 + return f0 + + +def norm_interp_f0(f0, hparams): + is_torch = isinstance(f0, torch.Tensor) + if is_torch: + device = f0.device + f0 = f0.data.cpu().numpy() + uv = f0 == 0 + f0 = norm_f0(f0, uv, hparams) + if sum(uv) == len(f0): + f0[uv] = 0 + elif sum(uv) > 0: + f0[uv] = np.interp(np.where(uv)[0], np.where(~uv)[0], f0[~uv]) + uv = torch.FloatTensor(uv) + f0 = torch.FloatTensor(f0) + if is_torch: + f0 = f0.to(device) + return f0, uv + + +def denorm_f0(f0, uv, hparams, pitch_padding=None, min=None, max=None): + if hparams['pitch_norm'] == 'standard': + f0 = f0 * hparams['f0_std'] + hparams['f0_mean'] + if hparams['pitch_norm'] == 'log': + f0 = 2 ** f0 + if min is not None: + f0 = f0.clamp(min=min) + if max is not None: + f0 = f0.clamp(max=max) + if uv is not None and hparams['use_uv']: + f0[uv > 0] = 0 + if pitch_padding is not None: + f0[pitch_padding] = 0 + return f0 diff --git a/utils/pl_utils.py b/utils/pl_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76a94ed6abe22e349c51c49afdbf052d52b8d98b --- /dev/null +++ b/utils/pl_utils.py @@ -0,0 +1,1618 @@ +import matplotlib +from torch.nn import DataParallel +from torch.nn.parallel import DistributedDataParallel + +matplotlib.use('Agg') +import glob +import itertools +import subprocess +import threading +import traceback + +from pytorch_lightning.callbacks import GradientAccumulationScheduler +from pytorch_lightning.callbacks import ModelCheckpoint + +from functools import wraps +from torch.cuda._utils import _get_device_index +import numpy as np +import torch.optim +import torch.utils.data +import copy +import logging +import os +import re +import sys +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm +from torch.optim.optimizer import Optimizer + + +def get_a_var(obj): # pragma: no cover + if isinstance(obj, torch.Tensor): + return obj + + if isinstance(obj, list) or isinstance(obj, tuple): + for result in map(get_a_var, obj): + if isinstance(result, torch.Tensor): + return result + if isinstance(obj, dict): + for result in map(get_a_var, obj.items()): + if isinstance(result, torch.Tensor): + return result + return None + + +def data_loader(fn): + """ + Decorator to make any fx with this use the lazy property + :param fn: + :return: + """ + + wraps(fn) + attr_name = '_lazy_' + fn.__name__ + + def _get_data_loader(self): + try: + value = getattr(self, attr_name) + except AttributeError: + try: + value = fn(self) # Lazy evaluation, done only once. + if ( + value is not None and + not isinstance(value, list) and + fn.__name__ in ['test_dataloader', 'val_dataloader'] + ): + value = [value] + except AttributeError as e: + # Guard against AttributeError suppression. (Issue #142) + traceback.print_exc() + error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) + raise RuntimeError(error) from e + setattr(self, attr_name, value) # Memoize evaluation. + return value + + return _get_data_loader + + +def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): # pragma: no cover + r"""Applies each `module` in :attr:`modules` in parallel on arguments + contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) + on each of :attr:`devices`. + + Args: + modules (Module): modules to be parallelized + inputs (tensor): inputs to the modules + devices (list of int or torch.device): CUDA devices + + :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and + :attr:`devices` (if given) should all have same length. Moreover, each + element of :attr:`inputs` can either be a single object as the only argument + to a module, or a collection of positional arguments. + """ + assert len(modules) == len(inputs) + if kwargs_tup is not None: + assert len(modules) == len(kwargs_tup) + else: + kwargs_tup = ({},) * len(modules) + if devices is not None: + assert len(modules) == len(devices) + else: + devices = [None] * len(modules) + devices = list(map(lambda x: _get_device_index(x, True), devices)) + lock = threading.Lock() + results = {} + grad_enabled = torch.is_grad_enabled() + + def _worker(i, module, input, kwargs, device=None): + torch.set_grad_enabled(grad_enabled) + if device is None: + device = get_a_var(input).get_device() + try: + with torch.cuda.device(device): + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + + # --------------- + # CHANGE + if module.training: + output = module.training_step(*input, **kwargs) + + elif module.testing: + output = module.test_step(*input, **kwargs) + + else: + output = module.validation_step(*input, **kwargs) + # --------------- + + with lock: + results[i] = output + except Exception as e: + with lock: + results[i] = e + + # make sure each module knows what training state it's in... + # fixes weird bug where copies are out of sync + root_m = modules[0] + for m in modules[1:]: + m.training = root_m.training + m.testing = root_m.testing + + if len(modules) > 1: + threads = [threading.Thread(target=_worker, + args=(i, module, input, kwargs, device)) + for i, (module, input, kwargs, device) in + enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + else: + _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) + + outputs = [] + for i in range(len(inputs)): + output = results[i] + if isinstance(output, Exception): + raise output + outputs.append(output) + return outputs + + +def _find_tensors(obj): # pragma: no cover + r""" + Recursively find all tensors contained in the specified object. + """ + if isinstance(obj, torch.Tensor): + return [obj] + if isinstance(obj, (list, tuple)): + return itertools.chain(*map(_find_tensors, obj)) + if isinstance(obj, dict): + return itertools.chain(*map(_find_tensors, obj.values())) + return [] + + +class DDP(DistributedDataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + def forward(self, *inputs, **kwargs): # pragma: no cover + self._sync_params() + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + # -------------- + # LIGHTNING MOD + # -------------- + # normal + # output = self.module(*inputs[0], **kwargs[0]) + # lightning + if self.module.training: + output = self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + output = self.module.test_step(*inputs[0], **kwargs[0]) + else: + output = self.module.validation_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + # normal + output = self.module(*inputs, **kwargs) + + if torch.is_grad_enabled(): + # We'll return the output object verbatim since it is a freeform + # object. We need to find any tensors in this object, though, + # because we need to figure out which parameters were used during + # this forward pass, to ensure we short circuit reduction for any + # unused parameters. Only if `find_unused_parameters` is set. + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + return output + + +class DP(DataParallel): + """ + Override the forward call in lightning so it goes to training and validation step respectively + """ + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + + for t in itertools.chain(self.module.parameters(), self.module.buffers()): + if t.device != self.src_device_obj: + raise RuntimeError("module must have its parameters and buffers " + "on device {} (device_ids[0]) but found one of " + "them on device: {}".format(self.src_device_obj, t.device)) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + # lightning + if self.module.training: + return self.module.training_step(*inputs[0], **kwargs[0]) + elif self.module.testing: + return self.module.test_step(*inputs[0], **kwargs[0]) + else: + return self.module.validation_step(*inputs[0], **kwargs[0]) + + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + return self.gather(outputs, self.output_device) + + def parallel_apply(self, replicas, inputs, kwargs): + return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) + + +class GradientAccumulationScheduler: + def __init__(self, scheduling: dict): + if scheduling == {}: # empty dict error + raise TypeError("Empty dict cannot be interpreted correct") + + for key in scheduling.keys(): + if not isinstance(key, int) or not isinstance(scheduling[key], int): + raise TypeError("All epoches and accumulation factor must be integers") + + minimal_epoch = min(scheduling.keys()) + if minimal_epoch < 1: + msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" + raise IndexError(msg) + elif minimal_epoch != 1: # if user didnt define first epoch accumulation factor + scheduling.update({1: 1}) + + self.scheduling = scheduling + self.epochs = sorted(scheduling.keys()) + + def on_epoch_begin(self, epoch, trainer): + epoch += 1 # indexing epochs from 1 + for i in reversed(range(len(self.epochs))): + if epoch >= self.epochs[i]: + trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) + break + + +class LatestModelCheckpoint(ModelCheckpoint): + def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5, + save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True): + super(ModelCheckpoint, self).__init__() + self.monitor = monitor + self.verbose = verbose + self.filepath = filepath + os.makedirs(filepath, exist_ok=True) + self.num_ckpt_keep = num_ckpt_keep + self.save_best = save_best + self.save_weights_only = save_weights_only + self.period = period + self.epochs_since_last_check = 0 + self.prefix = prefix + self.best_k_models = {} + # {filename: monitor} + self.kth_best_model = '' + self.save_top_k = 1 + self.task = None + if mode == 'min': + self.monitor_op = np.less + self.best = np.Inf + self.mode = 'min' + elif mode == 'max': + self.monitor_op = np.greater + self.best = -np.Inf + self.mode = 'max' + else: + if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): + self.monitor_op = np.greater + self.best = -np.Inf + self.mode = 'max' + else: + self.monitor_op = np.less + self.best = np.Inf + self.mode = 'min' + if os.path.exists(f'{self.filepath}/best_valid.npy'): + self.best = np.load(f'{self.filepath}/best_valid.npy')[0] + + def get_all_ckpts(self): + return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'), + key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) + + def on_epoch_end(self, epoch, logs=None): + logs = logs or {} + self.epochs_since_last_check += 1 + best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt' + if self.epochs_since_last_check >= self.period: + self.epochs_since_last_check = 0 + filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt' + if self.verbose > 0: + logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}') + self._save_model(filepath) + for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]: + subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True) + if self.verbose > 0: + logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') + current = logs.get(self.monitor) + if current is not None and self.save_best: + if self.monitor_op(current, self.best): + self.best = current + if self.verbose > 0: + logging.info( + f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached' + f' {current:0.5f} (best {self.best:0.5f}), saving model to' + f' {best_filepath} as top 1') + self._save_model(best_filepath) + np.save(f'{self.filepath}/best_valid.npy', [self.best]) + + +class BaseTrainer: + def __init__( + self, + logger=True, + checkpoint_callback=True, + default_save_path=None, + gradient_clip_val=0, + process_position=0, + gpus=-1, + log_gpu_memory=None, + show_progress_bar=True, + track_grad_norm=-1, + check_val_every_n_epoch=1, + accumulate_grad_batches=1, + max_updates=1000, + min_epochs=1, + val_check_interval=1.0, + log_save_interval=100, + row_log_interval=10, + print_nan_grads=False, + weights_summary='full', + num_sanity_val_steps=5, + resume_from_checkpoint=None, + ): + self.log_gpu_memory = log_gpu_memory + self.gradient_clip_val = gradient_clip_val + self.check_val_every_n_epoch = check_val_every_n_epoch + self.track_grad_norm = track_grad_norm + self.on_gpu = True if (gpus and torch.cuda.is_available()) else False + self.process_position = process_position + self.weights_summary = weights_summary + self.max_updates = max_updates + self.min_epochs = min_epochs + self.num_sanity_val_steps = num_sanity_val_steps + self.print_nan_grads = print_nan_grads + self.resume_from_checkpoint = resume_from_checkpoint + self.default_save_path = default_save_path + + # training bookeeping + self.total_batch_idx = 0 + self.running_loss = [] + self.avg_loss = 0 + self.batch_idx = 0 + self.tqdm_metrics = {} + self.callback_metrics = {} + self.num_val_batches = 0 + self.num_training_batches = 0 + self.num_test_batches = 0 + self.get_train_dataloader = None + self.get_test_dataloaders = None + self.get_val_dataloaders = None + self.is_iterable_train_dataloader = False + + # training state + self.model = None + self.testing = False + self.disable_validation = False + self.lr_schedulers = [] + self.optimizers = None + self.global_step = 0 + self.current_epoch = 0 + self.total_batches = 0 + + # configure checkpoint callback + self.checkpoint_callback = checkpoint_callback + self.checkpoint_callback.save_function = self.save_checkpoint + self.weights_save_path = self.checkpoint_callback.filepath + + # accumulated grads + self.configure_accumulated_gradients(accumulate_grad_batches) + + # allow int, string and gpu list + self.data_parallel_device_ids = [ + int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] + if len(self.data_parallel_device_ids) == 0: + self.root_gpu = None + self.on_gpu = False + else: + self.root_gpu = self.data_parallel_device_ids[0] + self.on_gpu = True + + # distributed backend choice + self.use_ddp = False + self.use_dp = False + self.single_gpu = False + self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp' + self.set_distributed_mode(self.distributed_backend) + + self.proc_rank = 0 + self.world_size = 1 + self.node_rank = 0 + + # can't init progress bar here because starting a new process + # means the progress_bar won't survive pickling + self.show_progress_bar = show_progress_bar + + # logging + self.log_save_interval = log_save_interval + self.val_check_interval = val_check_interval + self.logger = logger + self.logger.rank = 0 + self.row_log_interval = row_log_interval + + @property + def num_gpus(self): + gpus = self.data_parallel_device_ids + if gpus is None: + return 0 + else: + return len(gpus) + + @property + def data_parallel(self): + return self.use_dp or self.use_ddp + + def get_model(self): + is_dp_module = isinstance(self.model, (DDP, DP)) + model = self.model.module if is_dp_module else self.model + return model + + # ----------------------------- + # MODEL TRAINING + # ----------------------------- + def fit(self, model): + if self.use_ddp: + mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) + else: + model.model = model.build_model() + if not self.testing: + self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + if self.use_dp: + model.cuda(self.root_gpu) + model = DP(model, device_ids=self.data_parallel_device_ids) + elif self.single_gpu: + model.cuda(self.root_gpu) + self.run_pretrain_routine(model) + return 1 + + def init_optimizers(self, optimizers): + + # single optimizer + if isinstance(optimizers, Optimizer): + return [optimizers], [] + + # two lists + elif len(optimizers) == 2 and isinstance(optimizers[0], list): + optimizers, lr_schedulers = optimizers + return optimizers, lr_schedulers + + # single list or tuple + elif isinstance(optimizers, list) or isinstance(optimizers, tuple): + return optimizers, [] + + def run_pretrain_routine(self, model): + """Sanity check a few things before starting actual training. + + :param model: + """ + ref_model = model + if self.data_parallel: + ref_model = model.module + + # give model convenience properties + ref_model.trainer = self + + # set local properties on the model + self.copy_trainer_model_properties(ref_model) + + # link up experiment object + if self.logger is not None: + ref_model.logger = self.logger + self.logger.save() + + if self.use_ddp: + dist.barrier() + + # set up checkpoint callback + # self.configure_checkpoint_callback() + + # transfer data loaders from model + self.get_dataloaders(ref_model) + + # track model now. + # if cluster resets state, the model will update with the saved weights + self.model = model + + # restore training and model before hpc call + self.restore_weights(model) + + # when testing requested only run test and return + if self.testing: + self.run_evaluation(test=True) + return + + # check if we should run validation during training + self.disable_validation = self.num_val_batches == 0 + + # run tiny validation (if validation defined) + # to make sure program won't crash during val + ref_model.on_sanity_check_start() + ref_model.on_train_start() + if not self.disable_validation and self.num_sanity_val_steps > 0: + # init progress bars for validation sanity check + pbar = tqdm.tqdm(desc='Validation sanity check', + total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), + leave=False, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') + self.main_progress_bar = pbar + # dummy validation progress bar + self.val_progress_bar = tqdm.tqdm(disable=True) + + self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) + + # close progress bars + self.main_progress_bar.close() + self.val_progress_bar.close() + + # init progress bar + pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, + disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', + file=sys.stdout) + self.main_progress_bar = pbar + + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + + # CORE TRAINING LOOP + self.train() + + def test(self, model): + self.testing = True + self.fit(model) + + @property + def training_tqdm_dict(self): + tqdm_dict = { + 'step': '{}'.format(self.global_step), + } + tqdm_dict.update(self.tqdm_metrics) + return tqdm_dict + + # -------------------- + # restore ckpt + # -------------------- + def restore_weights(self, model): + """ + To restore weights we have two cases. + First, attempt to restore hpc weights. If successful, don't restore + other weights. + + Otherwise, try to restore actual weights + :param model: + :return: + """ + # clear cache before restore + if self.on_gpu: + torch.cuda.empty_cache() + + if self.resume_from_checkpoint is not None: + self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) + else: + # restore weights if same exp version + self.restore_state_if_checkpoint_exists(model) + + # wait for all models to restore weights + if self.use_ddp: + # wait for all processes to catch up + dist.barrier() + + # clear cache after restore + if self.on_gpu: + torch.cuda.empty_cache() + + def restore_state_if_checkpoint_exists(self, model): + did_restore = False + + # do nothing if there's not dir or callback + no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback) + if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): + return did_restore + + # restore trainer state and model if there is a weight for this experiment + last_steps = -1 + last_ckpt_name = None + + # find last epoch + checkpoints = os.listdir(self.checkpoint_callback.filepath) + for name in checkpoints: + if '.ckpt' in name and not name.endswith('part'): + if 'steps_' in name: + steps = name.split('steps_')[1] + steps = int(re.sub('[^0-9]', '', steps)) + + if steps > last_steps: + last_steps = steps + last_ckpt_name = name + + # restore last checkpoint + if last_ckpt_name is not None: + last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) + self.restore(last_ckpt_path, self.on_gpu) + logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}') + did_restore = True + + return did_restore + + def restore(self, checkpoint_path, on_gpu): + checkpoint = torch.load(checkpoint_path, map_location='cpu') + + # load model state + model = self.get_model() + + # load the state_dict on the model automatically + model.load_state_dict(checkpoint['state_dict'], strict=False) + if on_gpu: + model.cuda(self.root_gpu) + # load training state (affects trainer only) + self.restore_training_state(checkpoint) + model.global_step = self.global_step + del checkpoint + + try: + if dist.is_initialized() and dist.get_rank() > 0: + return + except Exception as e: + print(e) + return + + def restore_training_state(self, checkpoint): + """ + Restore trainer state. + Model will get its change to update + :param checkpoint: + :return: + """ + if self.checkpoint_callback is not None and self.checkpoint_callback is not False: + self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] + + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['epoch'] + + if self.testing: + return + + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + if optimizer is None: + return + optimizer.load_state_dict(opt_state) + + # move optimizer to GPU 1 weight at a time + # avoids OOM + if self.root_gpu is not None: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.root_gpu) + + # restore the lr schedulers + lr_schedulers = checkpoint['lr_schedulers'] + for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): + scheduler.load_state_dict(lrs_state) + + # -------------------- + # MODEL SAVE CHECKPOINT + # -------------------- + def _atomic_save(self, checkpoint, filepath): + """Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. + + This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once + saving is finished. + + Args: + checkpoint (object): The object to save. + Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` + accepts. + filepath (str|pathlib.Path): The path to which the checkpoint will be saved. + This points to the file that the checkpoint will be stored in. + """ + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path) + os.replace(tmp_path, filepath) + + def save_checkpoint(self, filepath): + checkpoint = self.dump_checkpoint() + self._atomic_save(checkpoint, filepath) + + def dump_checkpoint(self): + + checkpoint = { + 'epoch': self.current_epoch, + 'global_step': self.global_step + } + + if self.checkpoint_callback is not None and self.checkpoint_callback is not False: + checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best + + # save optimizers + optimizer_states = [] + for i, optimizer in enumerate(self.optimizers): + if optimizer is not None: + optimizer_states.append(optimizer.state_dict()) + + checkpoint['optimizer_states'] = optimizer_states + + # save lr schedulers + lr_schedulers = [] + for i, scheduler in enumerate(self.lr_schedulers): + lr_schedulers.append(scheduler.state_dict()) + + checkpoint['lr_schedulers'] = lr_schedulers + + # add the hparams and state_dict from the model + model = self.get_model() + checkpoint['state_dict'] = model.state_dict() + # give the model a chance to add a few things + model.on_save_checkpoint(checkpoint) + + return checkpoint + + def copy_trainer_model_properties(self, model): + if isinstance(model, DP): + ref_model = model.module + elif isinstance(model, DDP): + ref_model = model.module + else: + ref_model = model + + for m in [model, ref_model]: + m.trainer = self + m.on_gpu = self.on_gpu + m.use_dp = self.use_dp + m.use_ddp = self.use_ddp + m.testing = self.testing + m.single_gpu = self.single_gpu + + def transfer_batch_to_gpu(self, batch, gpu_id): + # base case: object can be directly moved using `cuda` or `to` + if callable(getattr(batch, 'cuda', None)): + return batch.cuda(gpu_id, non_blocking=True) + + elif callable(getattr(batch, 'to', None)): + return batch.to(torch.device('cuda', gpu_id), non_blocking=True) + + # when list + elif isinstance(batch, list): + for i, x in enumerate(batch): + batch[i] = self.transfer_batch_to_gpu(x, gpu_id) + return batch + + # when tuple + elif isinstance(batch, tuple): + batch = list(batch) + for i, x in enumerate(batch): + batch[i] = self.transfer_batch_to_gpu(x, gpu_id) + return tuple(batch) + + # when dict + elif isinstance(batch, dict): + for k, v in batch.items(): + batch[k] = self.transfer_batch_to_gpu(v, gpu_id) + + return batch + + # nothing matches, return the value as is without transform + return batch + + def set_distributed_mode(self, distributed_backend): + # skip for CPU + if self.num_gpus == 0: + return + + # single GPU case + # in single gpu case we allow ddp so we can train on multiple + # nodes, 1 gpu per node + elif self.num_gpus == 1: + self.single_gpu = True + self.use_dp = False + self.use_ddp = False + self.root_gpu = 0 + self.data_parallel_device_ids = [0] + else: + if distributed_backend is not None: + self.use_dp = distributed_backend == 'dp' + self.use_ddp = distributed_backend == 'ddp' + elif distributed_backend is None: + self.use_dp = True + self.use_ddp = False + + logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}') + + def ddp_train(self, gpu_idx, model): + """ + Entry point into a DP thread + :param gpu_idx: + :param model: + :param cluster_obj: + :return: + """ + # otherwise default to node rank 0 + self.node_rank = 0 + + # show progressbar only on progress_rank 0 + self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0 + + # determine which process we are and world size + if self.use_ddp: + self.proc_rank = self.node_rank * self.num_gpus + gpu_idx + self.world_size = self.num_gpus + + # let the exp know the rank to avoid overwriting logs + if self.logger is not None: + self.logger.rank = self.proc_rank + + # set up server using proc 0's ip address + # try to init for 20 times at max in case ports are taken + # where to store ip_table + model.trainer = self + model.init_ddp_connection(self.proc_rank, self.world_size) + + # CHOOSE OPTIMIZER + # allow for lr schedulers as well + model.model = model.build_model() + if not self.testing: + self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) + + # MODEL + # copy model to each gpu + if self.distributed_backend == 'ddp': + torch.cuda.set_device(gpu_idx) + model.cuda(gpu_idx) + + # set model properties before going into wrapper + self.copy_trainer_model_properties(model) + + # override root GPU + self.root_gpu = gpu_idx + + if self.distributed_backend == 'ddp': + device_ids = [gpu_idx] + else: + device_ids = None + + # allow user to configure ddp + model = model.configure_ddp(model, device_ids) + + # continue training routine + self.run_pretrain_routine(model) + + def resolve_root_node_address(self, root_node): + if '[' in root_node: + name = root_node.split('[')[0] + number = root_node.split(',')[0] + if '-' in number: + number = number.split('-')[0] + + number = re.sub('[^0-9]', '', number) + root_node = name + number + + return root_node + + def log_metrics(self, metrics, grad_norm_dic, step=None): + """Logs the metric dict passed in. + + :param metrics: + :param grad_norm_dic: + """ + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + + # add norms + metrics.update(grad_norm_dic) + + # turn all tensors to scalars + scalar_metrics = self.metrics_to_scalars(metrics) + + step = step if step is not None else self.global_step + # log actual metrics + if self.proc_rank == 0 and self.logger is not None: + self.logger.log_metrics(scalar_metrics, step=step) + self.logger.save() + + def add_tqdm_metrics(self, metrics): + for k, v in metrics.items(): + if type(v) is torch.Tensor: + v = v.item() + + self.tqdm_metrics[k] = v + + def metrics_to_scalars(self, metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + if type(v) is dict: + v = self.metrics_to_scalars(v) + + new_metrics[k] = v + + return new_metrics + + def process_output(self, output, train=False): + """Reduces output according to the training mode. + + Separates loss from logging and tqdm metrics + :param output: + :return: + """ + # --------------- + # EXTRACT CALLBACK KEYS + # --------------- + # all keys not progress_bar or log are candidates for callbacks + callback_metrics = {} + for k, v in output.items(): + if k not in ['progress_bar', 'log', 'hiddens']: + callback_metrics[k] = v + + if train and self.use_dp: + num_gpus = self.num_gpus + callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) + + for k, v in callback_metrics.items(): + if isinstance(v, torch.Tensor): + callback_metrics[k] = v.item() + + # --------------- + # EXTRACT PROGRESS BAR KEYS + # --------------- + try: + progress_output = output['progress_bar'] + + # reduce progress metrics for tqdm when using dp + if train and self.use_dp: + num_gpus = self.num_gpus + progress_output = self.reduce_distributed_output(progress_output, num_gpus) + + progress_bar_metrics = progress_output + except Exception: + progress_bar_metrics = {} + + # --------------- + # EXTRACT LOGGING KEYS + # --------------- + # extract metrics to log to experiment + try: + log_output = output['log'] + + # reduce progress metrics for tqdm when using dp + if train and self.use_dp: + num_gpus = self.num_gpus + log_output = self.reduce_distributed_output(log_output, num_gpus) + + log_metrics = log_output + except Exception: + log_metrics = {} + + # --------------- + # EXTRACT LOSS + # --------------- + # if output dict doesn't have the keyword loss + # then assume the output=loss if scalar + loss = None + if train: + try: + loss = output['loss'] + except Exception: + if type(output) is torch.Tensor: + loss = output + else: + raise RuntimeError( + 'No `loss` value in the dictionary returned from `model.training_step()`.' + ) + + # when using dp need to reduce the loss + if self.use_dp: + loss = self.reduce_distributed_output(loss, self.num_gpus) + + # --------------- + # EXTRACT HIDDEN + # --------------- + hiddens = output.get('hiddens') + + # use every metric passed in as a candidate for callback + callback_metrics.update(progress_bar_metrics) + callback_metrics.update(log_metrics) + + # convert tensors to numpy + for k, v in callback_metrics.items(): + if isinstance(v, torch.Tensor): + callback_metrics[k] = v.item() + + return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens + + def reduce_distributed_output(self, output, num_gpus): + if num_gpus <= 1: + return output + + # when using DP, we get one output per gpu + # average outputs and return + if type(output) is torch.Tensor: + return output.mean() + + for k, v in output.items(): + # recurse on nested dics + if isinstance(output[k], dict): + output[k] = self.reduce_distributed_output(output[k], num_gpus) + + # do nothing when there's a scalar + elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: + pass + + # reduce only metrics that have the same number of gpus + elif output[k].size(0) == num_gpus: + reduced = torch.mean(output[k]) + output[k] = reduced + return output + + def clip_gradients(self): + if self.gradient_clip_val > 0: + model = self.get_model() + torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) + + def print_nan_gradients(self): + model = self.get_model() + for param in model.parameters(): + if (param.grad is not None) and torch.isnan(param.grad.float()).any(): + logging.info(param, param.grad) + + def configure_accumulated_gradients(self, accumulate_grad_batches): + self.accumulate_grad_batches = None + + if isinstance(accumulate_grad_batches, dict): + self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) + elif isinstance(accumulate_grad_batches, int): + schedule = {1: accumulate_grad_batches} + self.accumulation_scheduler = GradientAccumulationScheduler(schedule) + else: + raise TypeError("Gradient accumulation supports only int and dict types") + + def get_dataloaders(self, model): + if not self.testing: + self.init_train_dataloader(model) + self.init_val_dataloader(model) + else: + self.init_test_dataloader(model) + + if self.use_ddp: + dist.barrier() + if not self.testing: + self.get_train_dataloader() + self.get_val_dataloaders() + else: + self.get_test_dataloaders() + + def init_train_dataloader(self, model): + self.fisrt_epoch = True + self.get_train_dataloader = model.train_dataloader + if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader): + self.num_training_batches = len(self.get_train_dataloader()) + self.num_training_batches = int(self.num_training_batches) + else: + self.num_training_batches = float('inf') + self.is_iterable_train_dataloader = True + if isinstance(self.val_check_interval, int): + self.val_check_batch = self.val_check_interval + else: + self._percent_range_check('val_check_interval') + self.val_check_batch = int(self.num_training_batches * self.val_check_interval) + self.val_check_batch = max(1, self.val_check_batch) + + def init_val_dataloader(self, model): + self.get_val_dataloaders = model.val_dataloader + self.num_val_batches = 0 + if self.get_val_dataloaders() is not None: + if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader): + self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) + self.num_val_batches = int(self.num_val_batches) + else: + self.num_val_batches = float('inf') + + def init_test_dataloader(self, model): + self.get_test_dataloaders = model.test_dataloader + if self.get_test_dataloaders() is not None: + if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader): + self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) + self.num_test_batches = int(self.num_test_batches) + else: + self.num_test_batches = float('inf') + + def evaluate(self, model, dataloaders, max_batches, test=False): + """Run evaluation code. + + :param model: PT model + :param dataloaders: list of PT dataloaders + :param max_batches: Scalar + :param test: boolean + :return: + """ + # enable eval mode + model.zero_grad() + model.eval() + + # copy properties for forward overrides + self.copy_trainer_model_properties(model) + + # disable gradients to save memory + torch.set_grad_enabled(False) + + if test: + self.get_model().test_start() + # bookkeeping + outputs = [] + + # run training + for dataloader_idx, dataloader in enumerate(dataloaders): + dl_outputs = [] + for batch_idx, batch in enumerate(dataloader): + + if batch is None: # pragma: no cover + continue + + # stop short when on fast_dev_run (sets max_batch=1) + if batch_idx >= max_batches: + break + + # ----------------- + # RUN EVALUATION STEP + # ----------------- + output = self.evaluation_forward(model, + batch, + batch_idx, + dataloader_idx, + test) + + # track outputs for collation + dl_outputs.append(output) + + # batch done + if test: + self.test_progress_bar.update(1) + else: + self.val_progress_bar.update(1) + outputs.append(dl_outputs) + + # with a single dataloader don't pass an array + if len(dataloaders) == 1: + outputs = outputs[0] + + # give model a chance to do something with the outputs (and method defined) + model = self.get_model() + if test: + eval_results_ = model.test_end(outputs) + else: + eval_results_ = model.validation_end(outputs) + eval_results = eval_results_ + + # enable train mode again + model.train() + + # enable gradients to save memory + torch.set_grad_enabled(True) + + return eval_results + + def run_evaluation(self, test=False): + # when testing make sure user defined a test step + model = self.get_model() + model.on_pre_performance_check() + + # select dataloaders + if test: + dataloaders = self.get_test_dataloaders() + max_batches = self.num_test_batches + else: + # val + dataloaders = self.get_val_dataloaders() + max_batches = self.num_val_batches + + # init validation or test progress bar + # main progress bar will already be closed when testing so initial position is free + position = 2 * self.process_position + (not test) + desc = 'Testing' if test else 'Validating' + pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, + disable=not self.show_progress_bar, dynamic_ncols=True, + unit='batch', file=sys.stdout) + setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) + + # run evaluation + eval_results = self.evaluate(self.model, + dataloaders, + max_batches, + test) + if eval_results is not None: + _, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( + eval_results) + + # add metrics to prog bar + self.add_tqdm_metrics(prog_bar_metrics) + + # log metrics + self.log_metrics(log_metrics, {}) + + # track metrics for callbacks + self.callback_metrics.update(callback_metrics) + + # hook + model.on_post_performance_check() + + # add model specific metrics + tqdm_metrics = self.training_tqdm_dict + if not test: + self.main_progress_bar.set_postfix(**tqdm_metrics) + + # close progress bar + if test: + self.test_progress_bar.close() + else: + self.val_progress_bar.close() + + # model checkpointing + if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: + self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, + logs=self.callback_metrics) + + def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): + # make dataloader_idx arg in validation_step optional + args = [batch, batch_idx] + + if test and len(self.get_test_dataloaders()) > 1: + args.append(dataloader_idx) + + elif not test and len(self.get_val_dataloaders()) > 1: + args.append(dataloader_idx) + + # handle DP, DDP forward + if self.use_ddp or self.use_dp: + output = model(*args) + return output + + # single GPU + if self.single_gpu: + # for single GPU put inputs on gpu manually + root_gpu = 0 + if isinstance(self.data_parallel_device_ids, list): + root_gpu = self.data_parallel_device_ids[0] + batch = self.transfer_batch_to_gpu(batch, root_gpu) + args[0] = batch + + # CPU + if test: + output = model.test_step(*args) + else: + output = model.validation_step(*args) + + return output + + def train(self): + model = self.get_model() + # run all epochs + for epoch in range(self.current_epoch, 1000000): + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): + self.get_train_dataloader().sampler.set_epoch(epoch) + + # get model + model = self.get_model() + + # update training progress in trainer and model + model.current_epoch = epoch + self.current_epoch = epoch + + total_val_batches = 0 + if not self.disable_validation: + # val can be checked multiple times in epoch + is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 + val_checks_per_epoch = self.num_training_batches // self.val_check_batch + val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 + total_val_batches = self.num_val_batches * val_checks_per_epoch + + # total batches includes multiple val checks + self.total_batches = self.num_training_batches + total_val_batches + self.batch_loss_value = 0 # accumulated grads + + if self.is_iterable_train_dataloader: + # for iterable train loader, the progress bar never ends + num_iterations = None + else: + num_iterations = self.total_batches + + # reset progress bar + # .reset() doesn't work on disabled progress bar so we should check + desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' + self.main_progress_bar.set_description(desc) + + # changing gradient according accumulation_scheduler + self.accumulation_scheduler.on_epoch_begin(epoch, self) + + # ----------------- + # RUN TNG EPOCH + # ----------------- + self.run_training_epoch() + + # update LR schedulers + if self.lr_schedulers is not None: + for lr_scheduler in self.lr_schedulers: + lr_scheduler.step(epoch=self.current_epoch) + + self.main_progress_bar.close() + + model.on_train_end() + + if self.logger is not None: + self.logger.finalize("success") + + def run_training_epoch(self): + # before epoch hook + if self.is_function_implemented('on_epoch_start'): + model = self.get_model() + model.on_epoch_start() + + # run epoch + for batch_idx, batch in enumerate(self.get_train_dataloader()): + # stop epoch if we limited the number of training batches + if batch_idx >= self.num_training_batches: + break + + self.batch_idx = batch_idx + + model = self.get_model() + model.global_step = self.global_step + + # --------------- + # RUN TRAIN STEP + # --------------- + output = self.run_training_batch(batch, batch_idx) + batch_result, grad_norm_dic, batch_step_metrics = output + + # when returning -1 from train_step, we end epoch early + early_stop_epoch = batch_result == -1 + + # --------------- + # RUN VAL STEP + # --------------- + should_check_val = ( + not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch) + self.fisrt_epoch = False + + if should_check_val: + self.run_evaluation(test=self.testing) + + # when logs should be saved + should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch + if should_save_log: + if self.proc_rank == 0 and self.logger is not None: + self.logger.save() + + # when metrics should be logged + should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch + if should_log_metrics: + # logs user requested information to logger + self.log_metrics(batch_step_metrics, grad_norm_dic) + + self.global_step += 1 + self.total_batch_idx += 1 + + # end epoch early + # stop when the flag is changed or we've gone past the amount + # requested in the batches + if early_stop_epoch: + break + if self.global_step > self.max_updates: + print("| Training end..") + exit() + + # epoch end hook + if self.is_function_implemented('on_epoch_end'): + model = self.get_model() + model.on_epoch_end() + + def run_training_batch(self, batch, batch_idx): + # track grad norms + grad_norm_dic = {} + + # track all metrics for callbacks + all_callback_metrics = [] + + # track metrics to log + all_log_metrics = [] + + if batch is None: + return 0, grad_norm_dic, {} + + # hook + if self.is_function_implemented('on_batch_start'): + model_ref = self.get_model() + response = model_ref.on_batch_start(batch) + + if response == -1: + return -1, grad_norm_dic, {} + + splits = [batch] + self.hiddens = None + for split_idx, split_batch in enumerate(splits): + self.split_idx = split_idx + + # call training_step once per optimizer + for opt_idx, optimizer in enumerate(self.optimizers): + if optimizer is None: + continue + # make sure only the gradients of the current optimizer's paramaters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if len(self.optimizers) > 1: + for param in self.get_model().parameters(): + param.requires_grad = False + for group in optimizer.param_groups: + for param in group['params']: + param.requires_grad = True + + # wrap the forward step in a closure so second order methods work + def optimizer_closure(): + # forward pass + output = self.training_forward( + split_batch, batch_idx, opt_idx, self.hiddens) + + closure_loss = output[0] + progress_bar_metrics = output[1] + log_metrics = output[2] + callback_metrics = output[3] + self.hiddens = output[4] + if closure_loss is None: + return None + + # accumulate loss + # (if accumulate_grad_batches = 1 no effect) + closure_loss = closure_loss / self.accumulate_grad_batches + + # backward pass + model_ref = self.get_model() + if closure_loss.requires_grad: + model_ref.backward(closure_loss, optimizer) + + # track metrics for callbacks + all_callback_metrics.append(callback_metrics) + + # track progress bar metrics + self.add_tqdm_metrics(progress_bar_metrics) + all_log_metrics.append(log_metrics) + + # insert after step hook + if self.is_function_implemented('on_after_backward'): + model_ref = self.get_model() + model_ref.on_after_backward() + + return closure_loss + + # calculate loss + loss = optimizer_closure() + if loss is None: + continue + + # nan grads + if self.print_nan_grads: + self.print_nan_gradients() + + # track total loss for logging (avoid mem leaks) + self.batch_loss_value += loss.item() + + # gradient update with accumulated gradients + if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: + + # track gradient norms when requested + if batch_idx % self.row_log_interval == 0: + if self.track_grad_norm > 0: + model = self.get_model() + grad_norm_dic = model.grad_norm( + self.track_grad_norm) + + # clip gradients + self.clip_gradients() + + # calls .step(), .zero_grad() + # override function to modify this behavior + model = self.get_model() + model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx) + + # calculate running loss for display + self.running_loss.append(self.batch_loss_value) + self.batch_loss_value = 0 + self.avg_loss = np.mean(self.running_loss[-100:]) + + # activate batch end hook + if self.is_function_implemented('on_batch_end'): + model = self.get_model() + model.on_batch_end() + + # update progress bar + self.main_progress_bar.update(1) + self.main_progress_bar.set_postfix(**self.training_tqdm_dict) + + # collapse all metrics into one dict + all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + + # track all metrics for callbacks + self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) + + return 0, grad_norm_dic, all_log_metrics + + def training_forward(self, batch, batch_idx, opt_idx, hiddens): + """ + Handle forward for each training case (distributed, single gpu, etc...) + :param batch: + :param batch_idx: + :return: + """ + # --------------- + # FORWARD + # --------------- + # enable not needing to add opt_idx to training_step + args = [batch, batch_idx, opt_idx] + + # distributed forward + if self.use_ddp or self.use_dp: + output = self.model(*args) + # single GPU forward + elif self.single_gpu: + gpu_id = 0 + if isinstance(self.data_parallel_device_ids, list): + gpu_id = self.data_parallel_device_ids[0] + batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) + args[0] = batch + output = self.model.training_step(*args) + # CPU forward + else: + output = self.model.training_step(*args) + + # allow any mode to define training_end + model_ref = self.get_model() + output_ = model_ref.training_end(output) + if output_ is not None: + output = output_ + + # format and reduce outputs accordingly + output = self.process_output(output, train=True) + + return output + + # --------------- + # Utils + # --------------- + def is_function_implemented(self, f_name): + model = self.get_model() + f_op = getattr(model, f_name, None) + return callable(f_op) + + def _percent_range_check(self, name): + value = getattr(self, name) + msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." + if name == "val_check_interval": + msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." + + if not 0. <= value <= 1.: + raise ValueError(msg) diff --git a/utils/plot.py b/utils/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..bdca62a8cd80869c707890cd9febd39966cd3658 --- /dev/null +++ b/utils/plot.py @@ -0,0 +1,56 @@ +import matplotlib.pyplot as plt +import numpy as np +import torch + +LINE_COLORS = ['w', 'r', 'y', 'cyan', 'm', 'b', 'lime'] + + +def spec_to_figure(spec, vmin=None, vmax=None): + if isinstance(spec, torch.Tensor): + spec = spec.cpu().numpy() + fig = plt.figure(figsize=(12, 6)) + plt.pcolor(spec.T, vmin=vmin, vmax=vmax) + return fig + + +def spec_f0_to_figure(spec, f0s, figsize=None): + max_y = spec.shape[1] + if isinstance(spec, torch.Tensor): + spec = spec.detach().cpu().numpy() + f0s = {k: f0.detach().cpu().numpy() for k, f0 in f0s.items()} + f0s = {k: f0 / 10 for k, f0 in f0s.items()} + fig = plt.figure(figsize=(12, 6) if figsize is None else figsize) + plt.pcolor(spec.T) + for i, (k, f0) in enumerate(f0s.items()): + plt.plot(f0.clip(0, max_y), label=k, c=LINE_COLORS[i], linewidth=1, alpha=0.8) + plt.legend() + return fig + + +def dur_to_figure(dur_gt, dur_pred, txt): + dur_gt = dur_gt.long().cpu().numpy() + dur_pred = dur_pred.long().cpu().numpy() + dur_gt = np.cumsum(dur_gt) + dur_pred = np.cumsum(dur_pred) + fig = plt.figure(figsize=(12, 6)) + for i in range(len(dur_gt)): + shift = (i % 8) + 1 + plt.text(dur_gt[i], shift, txt[i]) + plt.text(dur_pred[i], 10 + shift, txt[i]) + plt.vlines(dur_gt[i], 0, 10, colors='b') # blue is gt + plt.vlines(dur_pred[i], 10, 20, colors='r') # red is pred + return fig + + +def f0_to_figure(f0_gt, f0_cwt=None, f0_pred=None): + fig = plt.figure() + f0_gt = f0_gt.cpu().numpy() + plt.plot(f0_gt, color='r', label='gt') + if f0_cwt is not None: + f0_cwt = f0_cwt.cpu().numpy() + plt.plot(f0_cwt, color='b', label='cwt') + if f0_pred is not None: + f0_pred = f0_pred.cpu().numpy() + plt.plot(f0_pred, color='green', label='pred') + plt.legend() + return fig diff --git a/utils/rnnoise.py b/utils/rnnoise.py new file mode 100755 index 0000000000000000000000000000000000000000..47f4eb6471918ca8144f217580a71d1720cd8c36 --- /dev/null +++ b/utils/rnnoise.py @@ -0,0 +1,48 @@ +# rnnoise.py, requirements: ffmpeg, sox, rnnoise, python +import os +import subprocess + +INSTALL_STR = """ +RNNoise library not found. Please install RNNoise (https://github.com/xiph/rnnoise) to $REPO/rnnoise: +sudo apt-get install -y autoconf automake libtool ffmpeg sox +git clone https://github.com/xiph/rnnoise.git +rm -rf rnnoise/.git +cd rnnoise +./autogen.sh && ./configure && make +cd .. +""" + + +def rnnoise(filename, out_fn=None, verbose=False, out_sample_rate=22050): + assert os.path.exists('./rnnoise/examples/rnnoise_demo'), INSTALL_STR + if out_fn is None: + out_fn = f"{filename[:-4]}.denoised.wav" + out_48k_fn = f"{out_fn}.48000.wav" + tmp0_fn = f"{out_fn}.0.wav" + tmp1_fn = f"{out_fn}.1.wav" + tmp2_fn = f"{out_fn}.2.raw" + tmp3_fn = f"{out_fn}.3.raw" + if verbose: + print("Pre-processing audio...") # wav to pcm raw + subprocess.check_call( + f'sox "{filename}" -G -r48000 "{tmp0_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw + subprocess.check_call( + f'sox -v 0.95 "{tmp0_fn}" "{tmp1_fn}"', shell=True, stdin=subprocess.PIPE) # convert to raw + subprocess.check_call( + f'ffmpeg -y -i "{tmp1_fn}" -loglevel quiet -f s16le -ac 1 -ar 48000 "{tmp2_fn}"', + shell=True, stdin=subprocess.PIPE) # convert to raw + if verbose: + print("Applying rnnoise algorithm to audio...") # rnnoise + subprocess.check_call( + f'./rnnoise/examples/rnnoise_demo "{tmp2_fn}" "{tmp3_fn}"', shell=True) + + if verbose: + print("Post-processing audio...") # pcm raw to wav + if filename == out_fn: + subprocess.check_call(f'rm -f "{out_fn}"', shell=True) + subprocess.check_call( + f'sox -t raw -r 48000 -b 16 -e signed-integer -c 1 "{tmp3_fn}" "{out_48k_fn}"', shell=True) + subprocess.check_call(f'sox "{out_48k_fn}" -G -r{out_sample_rate} "{out_fn}"', shell=True) + subprocess.check_call(f'rm -f "{tmp0_fn}" "{tmp1_fn}" "{tmp2_fn}" "{tmp3_fn}" "{out_48k_fn}"', shell=True) + if verbose: + print("Audio-filtering completed!") diff --git a/utils/text_encoder.py b/utils/text_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e0758abc7b4e1f452481cba9715df08ceab543 --- /dev/null +++ b/utils/text_encoder.py @@ -0,0 +1,304 @@ +import re +import six +from six.moves import range # pylint: disable=redefined-builtin + +PAD = "" +EOS = "" +UNK = "" +SEG = "|" +RESERVED_TOKENS = [PAD, EOS, UNK] +NUM_RESERVED_TOKENS = len(RESERVED_TOKENS) +PAD_ID = RESERVED_TOKENS.index(PAD) # Normally 0 +EOS_ID = RESERVED_TOKENS.index(EOS) # Normally 1 +UNK_ID = RESERVED_TOKENS.index(UNK) # Normally 2 + +if six.PY2: + RESERVED_TOKENS_BYTES = RESERVED_TOKENS +else: + RESERVED_TOKENS_BYTES = [bytes(PAD, "ascii"), bytes(EOS, "ascii")] + +# Regular expression for unescaping token strings. +# '\u' is converted to '_' +# '\\' is converted to '\' +# '\213;' is converted to unichr(213) +_UNESCAPE_REGEX = re.compile(r"\\u|\\\\|\\([0-9]+);") +_ESCAPE_CHARS = set(u"\\_u;0123456789") + + +def strip_ids(ids, ids_to_strip): + """Strip ids_to_strip from the end ids.""" + ids = list(ids) + while ids and ids[-1] in ids_to_strip: + ids.pop() + return ids + + +class TextEncoder(object): + """Base class for converting from ints to/from human readable strings.""" + + def __init__(self, num_reserved_ids=NUM_RESERVED_TOKENS): + self._num_reserved_ids = num_reserved_ids + + @property + def num_reserved_ids(self): + return self._num_reserved_ids + + def encode(self, s): + """Transform a human-readable string into a sequence of int ids. + + The ids should be in the range [num_reserved_ids, vocab_size). Ids [0, + num_reserved_ids) are reserved. + + EOS is not appended. + + Args: + s: human-readable string to be converted. + + Returns: + ids: list of integers + """ + return [int(w) + self._num_reserved_ids for w in s.split()] + + def decode(self, ids, strip_extraneous=False): + """Transform a sequence of int ids into a human-readable string. + + EOS is not expected in ids. + + Args: + ids: list of integers to be converted. + strip_extraneous: bool, whether to strip off extraneous tokens + (EOS and PAD). + + Returns: + s: human-readable string. + """ + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + """Transform a sequence of int ids into a their string versions. + + This method supports transforming individual input/output ids to their + string versions so that sequence to/from text conversions can be visualized + in a human readable format. + + Args: + ids: list of integers to be converted. + + Returns: + strs: list of human-readable string. + """ + decoded_ids = [] + for id_ in ids: + if 0 <= id_ < self._num_reserved_ids: + decoded_ids.append(RESERVED_TOKENS[int(id_)]) + else: + decoded_ids.append(id_ - self._num_reserved_ids) + return [str(d) for d in decoded_ids] + + @property + def vocab_size(self): + raise NotImplementedError() + + +class ByteTextEncoder(TextEncoder): + """Encodes each byte to an id. For 8-bit strings only.""" + + def encode(self, s): + numres = self._num_reserved_ids + if six.PY2: + if isinstance(s, unicode): + s = s.encode("utf-8") + return [ord(c) + numres for c in s] + # Python3: explicitly convert to UTF-8 + return [c + numres for c in s.encode("utf-8")] + + def decode(self, ids, strip_extraneous=False): + if strip_extraneous: + ids = strip_ids(ids, list(range(self._num_reserved_ids or 0))) + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + if six.PY2: + return "".join(decoded_ids) + # Python3: join byte arrays and then decode string + return b"".join(decoded_ids).decode("utf-8", "replace") + + def decode_list(self, ids): + numres = self._num_reserved_ids + decoded_ids = [] + int2byte = six.int2byte + for id_ in ids: + if 0 <= id_ < numres: + decoded_ids.append(RESERVED_TOKENS_BYTES[int(id_)]) + else: + decoded_ids.append(int2byte(id_ - numres)) + # Python3: join byte arrays and then decode string + return decoded_ids + + @property + def vocab_size(self): + return 2**8 + self._num_reserved_ids + + +class ByteTextEncoderWithEos(ByteTextEncoder): + """Encodes each byte to an id and appends the EOS token.""" + + def encode(self, s): + return super(ByteTextEncoderWithEos, self).encode(s) + [EOS_ID] + + +class TokenTextEncoder(TextEncoder): + """Encoder based on a user-supplied vocabulary (file or list).""" + + def __init__(self, + vocab_filename, + reverse=False, + vocab_list=None, + replace_oov=None, + num_reserved_ids=NUM_RESERVED_TOKENS): + """Initialize from a file or list, one token per line. + + Handling of reserved tokens works as follows: + - When initializing from a list, we add reserved tokens to the vocab. + - When initializing from a file, we do not add reserved tokens to the vocab. + - When saving vocab files, we save reserved tokens to the file. + + Args: + vocab_filename: If not None, the full filename to read vocab from. If this + is not None, then vocab_list should be None. + reverse: Boolean indicating if tokens should be reversed during encoding + and decoding. + vocab_list: If not None, a list of elements of the vocabulary. If this is + not None, then vocab_filename should be None. + replace_oov: If not None, every out-of-vocabulary token seen when + encoding will be replaced by this string (which must be in vocab). + num_reserved_ids: Number of IDs to save for reserved tokens like . + """ + super(TokenTextEncoder, self).__init__(num_reserved_ids=num_reserved_ids) + self._reverse = reverse + self._replace_oov = replace_oov + if vocab_filename: + self._init_vocab_from_file(vocab_filename) + else: + assert vocab_list is not None + self._init_vocab_from_list(vocab_list) + self.pad_index = self._token_to_id[PAD] + self.eos_index = self._token_to_id[EOS] + self.unk_index = self._token_to_id[UNK] + self.seg_index = self._token_to_id[SEG] if SEG in self._token_to_id else self.eos_index + + def encode(self, s): + """Converts a space-separated string of tokens to a list of ids.""" + sentence = s + tokens = sentence.strip().split() + if self._replace_oov is not None: + tokens = [t if t in self._token_to_id else self._replace_oov + for t in tokens] + ret = [self._token_to_id[tok] for tok in tokens] + return ret[::-1] if self._reverse else ret + + def decode(self, ids, strip_eos=False, strip_padding=False): + if strip_padding and self.pad() in list(ids): + pad_pos = list(ids).index(self.pad()) + ids = ids[:pad_pos] + if strip_eos and self.eos() in list(ids): + eos_pos = list(ids).index(self.eos()) + ids = ids[:eos_pos] + return " ".join(self.decode_list(ids)) + + def decode_list(self, ids): + seq = reversed(ids) if self._reverse else ids + return [self._safe_id_to_token(i) for i in seq] + + @property + def vocab_size(self): + return len(self._id_to_token) + + def __len__(self): + return self.vocab_size + + def _safe_id_to_token(self, idx): + return self._id_to_token.get(idx, "ID_%d" % idx) + + def _init_vocab_from_file(self, filename): + """Load vocab from a file. + + Args: + filename: The file to load vocabulary from. + """ + with open(filename) as f: + tokens = [token.strip() for token in f.readlines()] + + def token_gen(): + for token in tokens: + yield token + + self._init_vocab(token_gen(), add_reserved_tokens=False) + + def _init_vocab_from_list(self, vocab_list): + """Initialize tokens from a list of tokens. + + It is ok if reserved tokens appear in the vocab list. They will be + removed. The set of tokens in vocab_list should be unique. + + Args: + vocab_list: A list of tokens. + """ + def token_gen(): + for token in vocab_list: + if token not in RESERVED_TOKENS: + yield token + + self._init_vocab(token_gen()) + + def _init_vocab(self, token_generator, add_reserved_tokens=True): + """Initialize vocabulary with tokens from token_generator.""" + + self._id_to_token = {} + non_reserved_start_index = 0 + + if add_reserved_tokens: + self._id_to_token.update(enumerate(RESERVED_TOKENS)) + non_reserved_start_index = len(RESERVED_TOKENS) + + self._id_to_token.update( + enumerate(token_generator, start=non_reserved_start_index)) + + # _token_to_id is the reverse of _id_to_token + self._token_to_id = dict((v, k) + for k, v in six.iteritems(self._id_to_token)) + + def pad(self): + return self.pad_index + + def eos(self): + return self.eos_index + + def unk(self): + return self.unk_index + + def seg(self): + return self.seg_index + + def store_to_file(self, filename): + """Write vocab file to disk. + + Vocab files have one token per line. The file ends in a newline. Reserved + tokens are written to the vocab file as well. + + Args: + filename: Full path of the file to store the vocab to. + """ + with open(filename, "w") as f: + for i in range(len(self._id_to_token)): + f.write(self._id_to_token[i] + "\n") + + def sil_phonemes(self): + return [p for p in self._id_to_token.values() if not p[0].isalpha()] diff --git a/utils/text_norm.py b/utils/text_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..d0973cebc91e0525aeb6657e70012a1d37b5e6ff --- /dev/null +++ b/utils/text_norm.py @@ -0,0 +1,790 @@ +# coding=utf-8 +# Authors: +# 2019.5 Zhiyang Zhou (https://github.com/Joee1995/chn_text_norm.git) +# 2019.9 Jiayu DU +# +# requirements: +# - python 3.X +# notes: python 2.X WILL fail or produce misleading results + +import sys, os, argparse, codecs, string, re + +# ================================================================================ # +# basic constant +# ================================================================================ # +CHINESE_DIGIS = u'零一二三四五六七八九' +BIG_CHINESE_DIGIS_SIMPLIFIED = u'零壹贰叁肆伍陆柒捌玖' +BIG_CHINESE_DIGIS_TRADITIONAL = u'零壹貳參肆伍陸柒捌玖' +SMALLER_BIG_CHINESE_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_BIG_CHINESE_UNITS_TRADITIONAL = u'拾佰仟萬' +LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'亿兆京垓秭穰沟涧正载' +LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'億兆京垓秭穰溝澗正載' +SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED = u'十百千万' +SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL = u'拾佰仟萬' + +ZERO_ALT = u'〇' +ONE_ALT = u'幺' +TWO_ALTS = [u'两', u'兩'] + +POSITIVE = [u'正', u'正'] +NEGATIVE = [u'负', u'負'] +POINT = [u'点', u'點'] +# PLUS = [u'加', u'加'] +# SIL = [u'杠', u'槓'] + +# 中文数字系统类型 +NUMBERING_TYPES = ['low', 'mid', 'high'] + +CURRENCY_NAMES = '(人民币|美元|日元|英镑|欧元|马克|法郎|加拿大元|澳元|港币|先令|芬兰马克|爱尔兰镑|' \ + '里拉|荷兰盾|埃斯库多|比塞塔|印尼盾|林吉特|新西兰元|比索|卢布|新加坡元|韩元|泰铢)' +CURRENCY_UNITS = '((亿|千万|百万|万|千|百)|(亿|千万|百万|万|千|百|)元|(亿|千万|百万|万|千|百|)块|角|毛|分)' +COM_QUANTIFIERS = '(匹|张|座|回|场|尾|条|个|首|阙|阵|网|炮|顶|丘|棵|只|支|袭|辆|挑|担|颗|壳|窠|曲|墙|群|腔|' \ + '砣|座|客|贯|扎|捆|刀|令|打|手|罗|坡|山|岭|江|溪|钟|队|单|双|对|出|口|头|脚|板|跳|枝|件|贴|' \ + '针|线|管|名|位|身|堂|课|本|页|家|户|层|丝|毫|厘|分|钱|两|斤|担|铢|石|钧|锱|忽|(千|毫|微)克|' \ + '毫|厘|分|寸|尺|丈|里|寻|常|铺|程|(千|分|厘|毫|微)米|撮|勺|合|升|斗|石|盘|碗|碟|叠|桶|笼|盆|' \ + '盒|杯|钟|斛|锅|簋|篮|盘|桶|罐|瓶|壶|卮|盏|箩|箱|煲|啖|袋|钵|年|月|日|季|刻|时|周|天|秒|分|旬|' \ + '纪|岁|世|更|夜|春|夏|秋|冬|代|伏|辈|丸|泡|粒|颗|幢|堆|条|根|支|道|面|片|张|颗|块)' + +# punctuation information are based on Zhon project (https://github.com/tsroten/zhon.git) +CHINESE_PUNC_STOP = '!?。。' +CHINESE_PUNC_NON_STOP = '"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏' +CHINESE_PUNC_LIST = CHINESE_PUNC_STOP + CHINESE_PUNC_NON_STOP + + +# ================================================================================ # +# basic class +# ================================================================================ # +class ChineseChar(object): + """ + 中文字符 + 每个字符对应简体和繁体, + e.g. 简体 = '负', 繁体 = '負' + 转换时可转换为简体或繁体 + """ + + def __init__(self, simplified, traditional): + self.simplified = simplified + self.traditional = traditional + # self.__repr__ = self.__str__ + + def __str__(self): + return self.simplified or self.traditional or None + + def __repr__(self): + return self.__str__() + + +class ChineseNumberUnit(ChineseChar): + """ + 中文数字/数位字符 + 每个字符除繁简体外还有一个额外的大写字符 + e.g. '陆' 和 '陸' + """ + + def __init__(self, power, simplified, traditional, big_s, big_t): + super(ChineseNumberUnit, self).__init__(simplified, traditional) + self.power = power + self.big_s = big_s + self.big_t = big_t + + def __str__(self): + return '10^{}'.format(self.power) + + @classmethod + def create(cls, index, value, numbering_type=NUMBERING_TYPES[1], small_unit=False): + + if small_unit: + return ChineseNumberUnit(power=index + 1, + simplified=value[0], traditional=value[1], big_s=value[1], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[0]: + return ChineseNumberUnit(power=index + 8, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[1]: + return ChineseNumberUnit(power=(index + 2) * 4, + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + elif numbering_type == NUMBERING_TYPES[2]: + return ChineseNumberUnit(power=pow(2, index + 3), + simplified=value[0], traditional=value[1], big_s=value[0], big_t=value[1]) + else: + raise ValueError( + 'Counting type should be in {0} ({1} provided).'.format(NUMBERING_TYPES, numbering_type)) + + +class ChineseNumberDigit(ChineseChar): + """ + 中文数字字符 + """ + + def __init__(self, value, simplified, traditional, big_s, big_t, alt_s=None, alt_t=None): + super(ChineseNumberDigit, self).__init__(simplified, traditional) + self.value = value + self.big_s = big_s + self.big_t = big_t + self.alt_s = alt_s + self.alt_t = alt_t + + def __str__(self): + return str(self.value) + + @classmethod + def create(cls, i, v): + return ChineseNumberDigit(i, v[0], v[1], v[2], v[3]) + + +class ChineseMath(ChineseChar): + """ + 中文数位字符 + """ + + def __init__(self, simplified, traditional, symbol, expression=None): + super(ChineseMath, self).__init__(simplified, traditional) + self.symbol = symbol + self.expression = expression + self.big_s = simplified + self.big_t = traditional + + +CC, CNU, CND, CM = ChineseChar, ChineseNumberUnit, ChineseNumberDigit, ChineseMath + + +class NumberSystem(object): + """ + 中文数字系统 + """ + pass + + +class MathSymbol(object): + """ + 用于中文数字系统的数学符号 (繁/简体), e.g. + positive = ['正', '正'] + negative = ['负', '負'] + point = ['点', '點'] + """ + + def __init__(self, positive, negative, point): + self.positive = positive + self.negative = negative + self.point = point + + def __iter__(self): + for v in self.__dict__.values(): + yield v + + +# class OtherSymbol(object): +# """ +# 其他符号 +# """ +# +# def __init__(self, sil): +# self.sil = sil +# +# def __iter__(self): +# for v in self.__dict__.values(): +# yield v + + +# ================================================================================ # +# basic utils +# ================================================================================ # +def create_system(numbering_type=NUMBERING_TYPES[1]): + """ + 根据数字系统类型返回创建相应的数字系统,默认为 mid + NUMBERING_TYPES = ['low', 'mid', 'high']: 中文数字系统类型 + low: '兆' = '亿' * '十' = $10^{9}$, '京' = '兆' * '十', etc. + mid: '兆' = '亿' * '万' = $10^{12}$, '京' = '兆' * '万', etc. + high: '兆' = '亿' * '亿' = $10^{16}$, '京' = '兆' * '兆', etc. + 返回对应的数字系统 + """ + + # chinese number units of '亿' and larger + all_larger_units = zip( + LARGER_CHINESE_NUMERING_UNITS_SIMPLIFIED, LARGER_CHINESE_NUMERING_UNITS_TRADITIONAL) + larger_units = [CNU.create(i, v, numbering_type, False) + for i, v in enumerate(all_larger_units)] + # chinese number units of '十, 百, 千, 万' + all_smaller_units = zip( + SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED, SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL) + smaller_units = [CNU.create(i, v, small_unit=True) + for i, v in enumerate(all_smaller_units)] + # digis + chinese_digis = zip(CHINESE_DIGIS, CHINESE_DIGIS, + BIG_CHINESE_DIGIS_SIMPLIFIED, BIG_CHINESE_DIGIS_TRADITIONAL) + digits = [CND.create(i, v) for i, v in enumerate(chinese_digis)] + digits[0].alt_s, digits[0].alt_t = ZERO_ALT, ZERO_ALT + digits[1].alt_s, digits[1].alt_t = ONE_ALT, ONE_ALT + digits[2].alt_s, digits[2].alt_t = TWO_ALTS[0], TWO_ALTS[1] + + # symbols + positive_cn = CM(POSITIVE[0], POSITIVE[1], '+', lambda x: x) + negative_cn = CM(NEGATIVE[0], NEGATIVE[1], '-', lambda x: -x) + point_cn = CM(POINT[0], POINT[1], '.', lambda x, + y: float(str(x) + '.' + str(y))) + # sil_cn = CM(SIL[0], SIL[1], '-', lambda x, y: float(str(x) + '-' + str(y))) + system = NumberSystem() + system.units = smaller_units + larger_units + system.digits = digits + system.math = MathSymbol(positive_cn, negative_cn, point_cn) + # system.symbols = OtherSymbol(sil_cn) + return system + + +def chn2num(chinese_string, numbering_type=NUMBERING_TYPES[1]): + def get_symbol(char, system): + for u in system.units: + if char in [u.traditional, u.simplified, u.big_s, u.big_t]: + return u + for d in system.digits: + if char in [d.traditional, d.simplified, d.big_s, d.big_t, d.alt_s, d.alt_t]: + return d + for m in system.math: + if char in [m.traditional, m.simplified]: + return m + + def string2symbols(chinese_string, system): + int_string, dec_string = chinese_string, '' + for p in [system.math.point.simplified, system.math.point.traditional]: + if p in chinese_string: + int_string, dec_string = chinese_string.split(p) + break + return [get_symbol(c, system) for c in int_string], \ + [get_symbol(c, system) for c in dec_string] + + def correct_symbols(integer_symbols, system): + """ + 一百八 to 一百八十 + 一亿一千三百万 to 一亿 一千万 三百万 + """ + + if integer_symbols and isinstance(integer_symbols[0], CNU): + if integer_symbols[0].power == 1: + integer_symbols = [system.digits[1]] + integer_symbols + + if len(integer_symbols) > 1: + if isinstance(integer_symbols[-1], CND) and isinstance(integer_symbols[-2], CNU): + integer_symbols.append( + CNU(integer_symbols[-2].power - 1, None, None, None, None)) + + result = [] + unit_count = 0 + for s in integer_symbols: + if isinstance(s, CND): + result.append(s) + unit_count = 0 + elif isinstance(s, CNU): + current_unit = CNU(s.power, None, None, None, None) + unit_count += 1 + + if unit_count == 1: + result.append(current_unit) + elif unit_count > 1: + for i in range(len(result)): + if isinstance(result[-i - 1], CNU) and result[-i - 1].power < current_unit.power: + result[-i - 1] = CNU(result[-i - 1].power + + current_unit.power, None, None, None, None) + return result + + def compute_value(integer_symbols): + """ + Compute the value. + When current unit is larger than previous unit, current unit * all previous units will be used as all previous units. + e.g. '两千万' = 2000 * 10000 not 2000 + 10000 + """ + value = [0] + last_power = 0 + for s in integer_symbols: + if isinstance(s, CND): + value[-1] = s.value + elif isinstance(s, CNU): + value[-1] *= pow(10, s.power) + if s.power > last_power: + value[:-1] = list(map(lambda v: v * + pow(10, s.power), value[:-1])) + last_power = s.power + value.append(0) + return sum(value) + + system = create_system(numbering_type) + int_part, dec_part = string2symbols(chinese_string, system) + int_part = correct_symbols(int_part, system) + int_str = str(compute_value(int_part)) + dec_str = ''.join([str(d.value) for d in dec_part]) + if dec_part: + return '{0}.{1}'.format(int_str, dec_str) + else: + return int_str + + +def num2chn(number_string, numbering_type=NUMBERING_TYPES[1], big=False, + traditional=False, alt_zero=False, alt_one=False, alt_two=True, + use_zeros=True, use_units=True): + def get_value(value_string, use_zeros=True): + + striped_string = value_string.lstrip('0') + + # record nothing if all zeros + if not striped_string: + return [] + + # record one digits + elif len(striped_string) == 1: + if use_zeros and len(value_string) != len(striped_string): + return [system.digits[0], system.digits[int(striped_string)]] + else: + return [system.digits[int(striped_string)]] + + # recursively record multiple digits + else: + result_unit = next(u for u in reversed( + system.units) if u.power < len(striped_string)) + result_string = value_string[:-result_unit.power] + return get_value(result_string) + [result_unit] + get_value(striped_string[-result_unit.power:]) + + system = create_system(numbering_type) + + int_dec = number_string.split('.') + if len(int_dec) == 1: + int_string = int_dec[0] + dec_string = "" + elif len(int_dec) == 2: + int_string = int_dec[0] + dec_string = int_dec[1] + else: + raise ValueError( + "invalid input num string with more than one dot: {}".format(number_string)) + + if use_units and len(int_string) > 1: + result_symbols = get_value(int_string) + else: + result_symbols = [system.digits[int(c)] for c in int_string] + dec_symbols = [system.digits[int(c)] for c in dec_string] + if dec_string: + result_symbols += [system.math.point] + dec_symbols + + if alt_two: + liang = CND(2, system.digits[2].alt_s, system.digits[2].alt_t, + system.digits[2].big_s, system.digits[2].big_t) + for i, v in enumerate(result_symbols): + if isinstance(v, CND) and v.value == 2: + next_symbol = result_symbols[i + + 1] if i < len(result_symbols) - 1 else None + previous_symbol = result_symbols[i - 1] if i > 0 else None + if isinstance(next_symbol, CNU) and isinstance(previous_symbol, (CNU, type(None))): + if next_symbol.power != 1 and ((previous_symbol is None) or (previous_symbol.power != 1)): + result_symbols[i] = liang + + # if big is True, '两' will not be used and `alt_two` has no impact on output + if big: + attr_name = 'big_' + if traditional: + attr_name += 't' + else: + attr_name += 's' + else: + if traditional: + attr_name = 'traditional' + else: + attr_name = 'simplified' + + result = ''.join([getattr(s, attr_name) for s in result_symbols]) + + # if not use_zeros: + # result = result.strip(getattr(system.digits[0], attr_name)) + + if alt_zero: + result = result.replace( + getattr(system.digits[0], attr_name), system.digits[0].alt_s) + + if alt_one: + result = result.replace( + getattr(system.digits[1], attr_name), system.digits[1].alt_s) + + for i, p in enumerate(POINT): + if result.startswith(p): + return CHINESE_DIGIS[0] + result + + # ^10, 11, .., 19 + if len(result) >= 2 and result[1] in [SMALLER_CHINESE_NUMERING_UNITS_SIMPLIFIED[0], + SMALLER_CHINESE_NUMERING_UNITS_TRADITIONAL[0]] and \ + result[0] in [CHINESE_DIGIS[1], BIG_CHINESE_DIGIS_SIMPLIFIED[1], BIG_CHINESE_DIGIS_TRADITIONAL[1]]: + result = result[1:] + + return result + + +# ================================================================================ # +# different types of rewriters +# ================================================================================ # +class Cardinal: + """ + CARDINAL类 + """ + + def __init__(self, cardinal=None, chntext=None): + self.cardinal = cardinal + self.chntext = chntext + + def chntext2cardinal(self): + return chn2num(self.chntext) + + def cardinal2chntext(self): + return num2chn(self.cardinal) + + +class Digit: + """ + DIGIT类 + """ + + def __init__(self, digit=None, chntext=None): + self.digit = digit + self.chntext = chntext + + # def chntext2digit(self): + # return chn2num(self.chntext) + + def digit2chntext(self): + return num2chn(self.digit, alt_two=False, use_units=False) + + +class TelePhone: + """ + TELEPHONE类 + """ + + def __init__(self, telephone=None, raw_chntext=None, chntext=None): + self.telephone = telephone + self.raw_chntext = raw_chntext + self.chntext = chntext + + # def chntext2telephone(self): + # sil_parts = self.raw_chntext.split('') + # self.telephone = '-'.join([ + # str(chn2num(p)) for p in sil_parts + # ]) + # return self.telephone + + def telephone2chntext(self, fixed=False): + + if fixed: + sil_parts = self.telephone.split('-') + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sil_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + else: + sp_parts = self.telephone.strip('+').split() + self.raw_chntext = ''.join([ + num2chn(part, alt_two=False, use_units=False) for part in sp_parts + ]) + self.chntext = self.raw_chntext.replace('', '') + return self.chntext + + +class Fraction: + """ + FRACTION类 + """ + + def __init__(self, fraction=None, chntext=None): + self.fraction = fraction + self.chntext = chntext + + def chntext2fraction(self): + denominator, numerator = self.chntext.split('分之') + return chn2num(numerator) + '/' + chn2num(denominator) + + def fraction2chntext(self): + numerator, denominator = self.fraction.split('/') + return num2chn(denominator) + '分之' + num2chn(numerator) + + +class Date: + """ + DATE类 + """ + + def __init__(self, date=None, chntext=None): + self.date = date + self.chntext = chntext + + # def chntext2date(self): + # chntext = self.chntext + # try: + # year, other = chntext.strip().split('年', maxsplit=1) + # year = Digit(chntext=year).digit2chntext() + '年' + # except ValueError: + # other = chntext + # year = '' + # if other: + # try: + # month, day = other.strip().split('月', maxsplit=1) + # month = Cardinal(chntext=month).chntext2cardinal() + '月' + # except ValueError: + # day = chntext + # month = '' + # if day: + # day = Cardinal(chntext=day[:-1]).chntext2cardinal() + day[-1] + # else: + # month = '' + # day = '' + # date = year + month + day + # self.date = date + # return self.date + + def date2chntext(self): + date = self.date + try: + year, other = date.strip().split('年', 1) + year = Digit(digit=year).digit2chntext() + '年' + except ValueError: + other = date + year = '' + if other: + try: + month, day = other.strip().split('月', 1) + month = Cardinal(cardinal=month).cardinal2chntext() + '月' + except ValueError: + day = date + month = '' + if day: + day = Cardinal(cardinal=day[:-1]).cardinal2chntext() + day[-1] + else: + month = '' + day = '' + chntext = year + month + day + self.chntext = chntext + return self.chntext + + +class Money: + """ + MONEY类 + """ + + def __init__(self, money=None, chntext=None): + self.money = money + self.chntext = chntext + + # def chntext2money(self): + # return self.money + + def money2chntext(self): + money = self.money + pattern = re.compile(r'(\d+(\.\d+)?)') + matchers = pattern.findall(money) + if matchers: + for matcher in matchers: + money = money.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext()) + self.chntext = money + return self.chntext + + +class Percentage: + """ + PERCENTAGE类 + """ + + def __init__(self, percentage=None, chntext=None): + self.percentage = percentage + self.chntext = chntext + + def chntext2percentage(self): + return chn2num(self.chntext.strip().strip('百分之')) + '%' + + def percentage2chntext(self): + return '百分之' + num2chn(self.percentage.strip().strip('%')) + + +# ================================================================================ # +# NSW Normalizer +# ================================================================================ # +class NSWNormalizer: + def __init__(self, raw_text): + self.raw_text = '^' + raw_text + '$' + self.norm_text = '' + + def _particular(self): + text = self.norm_text + pattern = re.compile(r"(([a-zA-Z]+)二([a-zA-Z]+))") + matchers = pattern.findall(text) + if matchers: + # print('particular') + for matcher in matchers: + text = text.replace(matcher[0], matcher[1] + '2' + matcher[2], 1) + self.norm_text = text + return self.norm_text + + def normalize(self, remove_punc=True): + text = self.raw_text + + # 规范化日期 + pattern = re.compile(r"\D+((([089]\d|(19|20)\d{2})年)?(\d{1,2}月(\d{1,2}[日号])?)?)") + matchers = pattern.findall(text) + if matchers: + # print('date') + for matcher in matchers: + text = text.replace(matcher[0], Date(date=matcher[0]).date2chntext(), 1) + + # 规范化金钱 + pattern = re.compile(r"\D+((\d+(\.\d+)?)[多余几]?" + CURRENCY_UNITS + r"(\d" + CURRENCY_UNITS + r"?)?)") + matchers = pattern.findall(text) + if matchers: + # print('money') + for matcher in matchers: + text = text.replace(matcher[0], Money(money=matcher[0]).money2chntext(), 1) + + # 规范化固话/手机号码 + # 手机 + # http://www.jihaoba.com/news/show/13680 + # 移动:139、138、137、136、135、134、159、158、157、150、151、152、188、187、182、183、184、178、198 + # 联通:130、131、132、156、155、186、185、176 + # 电信:133、153、189、180、181、177 + pattern = re.compile(r"\D((\+?86 ?)?1([38]\d|5[0-35-9]|7[678]|9[89])\d{8})\D") + matchers = pattern.findall(text) + if matchers: + # print('telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(), 1) + # 固话 + pattern = re.compile(r"\D((0(10|2[1-3]|[3-9]\d{2})-?)?[1-9]\d{6,7})\D") + matchers = pattern.findall(text) + if matchers: + # print('fixed telephone') + for matcher in matchers: + text = text.replace(matcher[0], TelePhone(telephone=matcher[0]).telephone2chntext(fixed=True), 1) + + # 规范化分数 + pattern = re.compile(r"(\d+/\d+)") + matchers = pattern.findall(text) + if matchers: + # print('fraction') + for matcher in matchers: + text = text.replace(matcher, Fraction(fraction=matcher).fraction2chntext(), 1) + + # 规范化百分数 + text = text.replace('%', '%') + pattern = re.compile(r"(\d+(\.\d+)?%)") + matchers = pattern.findall(text) + if matchers: + # print('percentage') + for matcher in matchers: + text = text.replace(matcher[0], Percentage(percentage=matcher[0]).percentage2chntext(), 1) + + # 规范化纯数+量词 + pattern = re.compile(r"(\d+(\.\d+)?)[多余几]?" + COM_QUANTIFIERS) + matchers = pattern.findall(text) + if matchers: + # print('cardinal+quantifier') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + # 规范化数字编号 + pattern = re.compile(r"(\d{4,32})") + matchers = pattern.findall(text) + if matchers: + # print('digit') + for matcher in matchers: + text = text.replace(matcher, Digit(digit=matcher).digit2chntext(), 1) + + # 规范化纯数 + pattern = re.compile(r"(\d+(\.\d+)?)") + matchers = pattern.findall(text) + if matchers: + # print('cardinal') + for matcher in matchers: + text = text.replace(matcher[0], Cardinal(cardinal=matcher[0]).cardinal2chntext(), 1) + + self.norm_text = text + self._particular() + + text = self.norm_text.lstrip('^').rstrip('$') + if remove_punc: + # Punctuations removal + old_chars = CHINESE_PUNC_LIST + string.punctuation # includes all CN and EN punctuations + new_chars = ' ' * len(old_chars) + del_chars = '' + text = text.translate(str.maketrans(old_chars, new_chars, del_chars)) + return text + + +def nsw_test_case(raw_text): + print('I:' + raw_text) + print('O:' + NSWNormalizer(raw_text).normalize()) + print('') + + +def nsw_test(): + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('固话:0595-23865596或23880880。') + nsw_test_case('手机:+86 19859213959或15659451527。') + nsw_test_case('分数:32477/76391。') + nsw_test_case('百分数:80.03%。') + nsw_test_case('编号:31520181154418。') + nsw_test_case('纯数:2983.07克或12345.60米。') + nsw_test_case('日期:1999年2月20日或09年3月15号。') + nsw_test_case('金钱:12块5,34.5元,20.1万') + nsw_test_case('特殊:O2O或B2C。') + nsw_test_case('3456万吨') + nsw_test_case('2938个') + nsw_test_case('938') + nsw_test_case('今天吃了115个小笼包231个馒头') + nsw_test_case('有62%的概率') + + +if __name__ == '__main__': + # nsw_test() + + p = argparse.ArgumentParser() + p.add_argument('ifile', help='input filename, assume utf-8 encoding') + p.add_argument('ofile', help='output filename') + p.add_argument('--to_upper', action='store_true', help='convert to upper case') + p.add_argument('--to_lower', action='store_true', help='convert to lower case') + p.add_argument('--has_key', action='store_true', help="input text has Kaldi's key as first field.") + p.add_argument('--log_interval', type=int, default=10000, help='log interval in number of processed lines') + args = p.parse_args() + + ifile = codecs.open(args.ifile, 'r', 'utf8') + ofile = codecs.open(args.ofile, 'w+', 'utf8') + + n = 0 + for l in ifile: + key = '' + text = '' + if args.has_key: + cols = l.split(maxsplit=1) + key = cols[0] + if len(cols) == 2: + text = cols[1] + else: + text = '' + else: + text = l + + # cases + if args.to_upper and args.to_lower: + sys.stderr.write('text norm: to_upper OR to_lower?') + exit(1) + if args.to_upper: + text = text.upper() + if args.to_lower: + text = text.lower() + + # NSW(Non-Standard-Word) normalization + text = NSWNormalizer(text).normalize() + + # + if args.has_key: + ofile.write(key + '\t' + text) + else: + ofile.write(text) + + n += 1 + if n % args.log_interval == 0: + sys.stderr.write("text norm: {} lines done.\n".format(n)) + + sys.stderr.write("text norm: {} lines done in total.\n".format(n)) + + ifile.close() + ofile.close() diff --git a/utils/trainer.py b/utils/trainer.py new file mode 100755 index 0000000000000000000000000000000000000000..6821fee1a4a08174bd3f3916dbc368fe89f1ba5b --- /dev/null +++ b/utils/trainer.py @@ -0,0 +1,518 @@ +import random +from torch.cuda.amp import GradScaler, autocast +from utils import move_to_cuda +import subprocess +import numpy as np +import torch.optim +import torch.utils.data +import copy +import logging +import os +import re +import sys +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import tqdm + +from utils.ckpt_utils import get_last_checkpoint, get_all_ckpts +from utils.ddp_utils import DDP +from utils.hparams import hparams + + +class Trainer: + def __init__( + self, + work_dir, + default_save_path=None, + accumulate_grad_batches=1, + max_updates=160000, + print_nan_grads=False, + val_check_interval=2000, + num_sanity_val_steps=5, + amp=False, + # tb logger + log_save_interval=100, + tb_log_interval=10, + # checkpoint + monitor_key='val_loss', + monitor_mode='min', + num_ckpt_keep=5, + save_best=True, + resume_from_checkpoint=0, + seed=1234, + debug=False, + ): + os.makedirs(work_dir, exist_ok=True) + self.work_dir = work_dir + self.accumulate_grad_batches = accumulate_grad_batches + self.max_updates = max_updates + self.num_sanity_val_steps = num_sanity_val_steps + self.print_nan_grads = print_nan_grads + self.default_save_path = default_save_path + self.resume_from_checkpoint = resume_from_checkpoint if resume_from_checkpoint > 0 else None + self.seed = seed + self.debug = debug + # model and optm + self.task = None + self.optimizers = [] + + # trainer state + self.testing = False + self.global_step = 0 + self.current_epoch = 0 + self.total_batches = 0 + + # configure checkpoint + self.monitor_key = monitor_key + self.num_ckpt_keep = num_ckpt_keep + self.save_best = save_best + self.monitor_op = np.less if monitor_mode == 'min' else np.greater + self.best_val_results = np.Inf if monitor_mode == 'min' else -np.Inf + self.mode = 'min' + + # allow int, string and gpu list + self.all_gpu_ids = [ + int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] + self.num_gpus = len(self.all_gpu_ids) + self.on_gpu = self.num_gpus > 0 + self.root_gpu = 0 + logging.info(f'GPU available: {torch.cuda.is_available()}, GPU used: {self.all_gpu_ids}') + self.use_ddp = self.num_gpus > 1 + self.proc_rank = 0 + # Tensorboard logging + self.log_save_interval = log_save_interval + self.val_check_interval = val_check_interval + self.tb_log_interval = tb_log_interval + self.amp = amp + self.amp_scalar = GradScaler() + + def test(self, task_cls): + self.testing = True + self.fit(task_cls) + + def fit(self, task_cls): + if len(self.all_gpu_ids) > 1: + mp.spawn(self.ddp_run, nprocs=self.num_gpus, args=(task_cls, copy.deepcopy(hparams))) + else: + self.task = task_cls() + self.task.trainer = self + self.run_single_process(self.task) + return 1 + + def ddp_run(self, gpu_idx, task_cls, hparams_): + hparams.update(hparams_) + task = task_cls() + self.ddp_init(gpu_idx, task) + self.run_single_process(task) + + def run_single_process(self, task): + """Sanity check a few things before starting actual training. + + :param task: + """ + # build model, optm and load checkpoint + model = task.build_model() + if model is not None: + task.model = model + checkpoint, _ = get_last_checkpoint(self.work_dir, self.resume_from_checkpoint) + if checkpoint is not None: + self.restore_weights(checkpoint) + elif self.on_gpu: + task.cuda(self.root_gpu) + if not self.testing: + self.optimizers = task.configure_optimizers() + self.fisrt_epoch = True + if checkpoint is not None: + self.restore_opt_state(checkpoint) + del checkpoint + # clear cache after restore + if self.on_gpu: + torch.cuda.empty_cache() + + if self.use_ddp: + self.task = self.configure_ddp(self.task) + dist.barrier() + + task_ref = self.get_task_ref() + task_ref.trainer = self + task_ref.testing = self.testing + # link up experiment object + if self.proc_rank == 0: + task_ref.build_tensorboard(save_dir=self.work_dir, name='lightning_logs', version='lastest') + else: + os.makedirs('tmp', exist_ok=True) + task_ref.build_tensorboard(save_dir='tmp', name='tb_tmp', version='lastest') + self.logger = task_ref.logger + try: + if self.testing: + self.run_evaluation(test=True) + else: + self.train() + except KeyboardInterrupt as e: + task_ref.on_keyboard_interrupt() + + #################### + # valid and test + #################### + def run_evaluation(self, test=False): + eval_results = self.evaluate(self.task, test, tqdm_desc='Valid' if not test else 'test') + if eval_results is not None and 'tb_log' in eval_results: + tb_log_output = eval_results['tb_log'] + self.log_metrics_to_tb(tb_log_output) + if self.proc_rank == 0 and not test: + self.save_checkpoint(epoch=self.current_epoch, logs=eval_results) + + def evaluate(self, task, test=False, tqdm_desc='Valid', max_batches=None): + # enable eval mode + task.zero_grad() + task.eval() + torch.set_grad_enabled(False) + + task_ref = self.get_task_ref() + if test: + ret = task_ref.test_start() + if ret == 'EXIT': + return + + outputs = [] + dataloader = task_ref.test_dataloader() if test else task_ref.val_dataloader() + pbar = tqdm.tqdm(dataloader, desc=tqdm_desc, total=max_batches, dynamic_ncols=True, unit='step', + disable=self.root_gpu > 0) + for batch_idx, batch in enumerate(pbar): + if batch is None: # pragma: no cover + continue + # stop short when on fast_dev_run (sets max_batch=1) + if max_batches is not None and batch_idx >= max_batches: + break + + # make dataloader_idx arg in validation_step optional + if self.on_gpu: + batch = move_to_cuda(batch, self.root_gpu) + args = [batch, batch_idx] + if self.use_ddp: + output = task(*args) + else: + if test: + output = task_ref.test_step(*args) + else: + output = task_ref.validation_step(*args) + # track outputs for collation + outputs.append(output) + # give model a chance to do something with the outputs (and method defined) + if test: + eval_results = task_ref.test_end(outputs) + else: + eval_results = task_ref.validation_end(outputs) + # enable train mode again + task.train() + torch.set_grad_enabled(True) + return eval_results + + #################### + # train + #################### + def train(self): + task_ref = self.get_task_ref() + task_ref.on_train_start() + if self.num_sanity_val_steps > 0: + # run tiny validation (if validation defined) to make sure program won't crash during val + self.evaluate(self.task, False, 'Sanity Val', max_batches=self.num_sanity_val_steps) + # clear cache before training + if self.on_gpu: + torch.cuda.empty_cache() + dataloader = task_ref.train_dataloader() + epoch = self.current_epoch + # run all epochs + while True: + # set seed for distributed sampler (enables shuffling for each epoch) + if self.use_ddp and hasattr(dataloader.sampler, 'set_epoch'): + dataloader.sampler.set_epoch(epoch) + # update training progress in trainer and model + task_ref.current_epoch = epoch + self.current_epoch = epoch + # total batches includes multiple val checks + self.batch_loss_value = 0 # accumulated grads + # before epoch hook + task_ref.on_epoch_start() + + # run epoch + train_pbar = tqdm.tqdm(dataloader, initial=self.global_step, total=float('inf'), + dynamic_ncols=True, unit='step', disable=self.root_gpu > 0) + for batch_idx, batch in enumerate(train_pbar): + pbar_metrics, tb_metrics = self.run_training_batch(batch_idx, batch) + train_pbar.set_postfix(**pbar_metrics) + should_check_val = (self.global_step % self.val_check_interval == 0 + and not self.fisrt_epoch) + if should_check_val: + self.run_evaluation() + self.fisrt_epoch = False + # when metrics should be logged + if (self.global_step + 1) % self.tb_log_interval == 0: + # logs user requested information to logger + self.log_metrics_to_tb(tb_metrics) + + self.global_step += 1 + task_ref.global_step = self.global_step + if self.global_step > self.max_updates: + print("| Training end..") + break + # epoch end hook + task_ref.on_epoch_end() + epoch += 1 + if self.global_step > self.max_updates: + break + task_ref.on_train_end() + + def run_training_batch(self, batch_idx, batch): + if batch is None: + return {} + all_progress_bar_metrics = [] + all_log_metrics = [] + task_ref = self.get_task_ref() + for opt_idx, optimizer in enumerate(self.optimizers): + if optimizer is None: + continue + # make sure only the gradients of the current optimizer's paramaters are calculated + # in the training step to prevent dangling gradients in multiple-optimizer setup. + if len(self.optimizers) > 1: + for param in task_ref.parameters(): + param.requires_grad = False + for group in optimizer.param_groups: + for param in group['params']: + param.requires_grad = True + + # forward pass + with autocast(enabled=self.amp): + if self.on_gpu: + batch = move_to_cuda(copy.copy(batch), self.root_gpu) + args = [batch, batch_idx, opt_idx] + if self.use_ddp: + output = self.task(*args) + else: + output = task_ref.training_step(*args) + loss = output['loss'] + if loss is None: + continue + progress_bar_metrics = output['progress_bar'] + log_metrics = output['tb_log'] + # accumulate loss + loss = loss / self.accumulate_grad_batches + + # backward pass + if loss.requires_grad: + if self.amp: + self.amp_scalar.scale(loss).backward() + else: + loss.backward() + + # track progress bar metrics + all_log_metrics.append(log_metrics) + all_progress_bar_metrics.append(progress_bar_metrics) + + if loss is None: + continue + + # nan grads + if self.print_nan_grads: + has_nan_grad = False + for name, param in task_ref.named_parameters(): + if (param.grad is not None) and torch.isnan(param.grad.float()).any(): + print("| NaN params: ", name, param, param.grad) + has_nan_grad = True + if has_nan_grad: + exit(0) + + # gradient update with accumulated gradients + if (self.global_step + 1) % self.accumulate_grad_batches == 0: + task_ref.on_before_optimization(opt_idx) + if self.amp: + self.amp_scalar.step(optimizer) + self.amp_scalar.update() + else: + optimizer.step() + optimizer.zero_grad() + task_ref.on_after_optimization(self.current_epoch, batch_idx, optimizer, opt_idx) + + # collapse all metrics into one dict + all_progress_bar_metrics = {k: v for d in all_progress_bar_metrics for k, v in d.items()} + all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} + return all_progress_bar_metrics, all_log_metrics + + #################### + # load and save checkpoint + #################### + def restore_weights(self, checkpoint): + # load model state + task_ref = self.get_task_ref() + + if len([k for k in checkpoint['state_dict'].keys() if '.' in k]) > 0: + task_ref.load_state_dict(checkpoint['state_dict']) + else: + for k, v in checkpoint['state_dict'].items(): + getattr(task_ref, k).load_state_dict(v) + + if self.on_gpu: + task_ref.cuda(self.root_gpu) + # load training state (affects trainer only) + self.best_val_results = checkpoint['checkpoint_callback_best'] + self.global_step = checkpoint['global_step'] + self.current_epoch = checkpoint['epoch'] + task_ref.global_step = self.global_step + + # wait for all models to restore weights + if self.use_ddp: + # wait for all processes to catch up + dist.barrier() + + def restore_opt_state(self, checkpoint): + if self.testing: + return + # restore the optimizers + optimizer_states = checkpoint['optimizer_states'] + for optimizer, opt_state in zip(self.optimizers, optimizer_states): + if optimizer is None: + return + try: + optimizer.load_state_dict(opt_state) + # move optimizer to GPU 1 weight at a time + if self.on_gpu: + for state in optimizer.state.values(): + for k, v in state.items(): + if isinstance(v, torch.Tensor): + state[k] = v.cuda(self.root_gpu) + except ValueError: + print("| WARMING: optimizer parameters not match !!!") + try: + if dist.is_initialized() and dist.get_rank() > 0: + return + except Exception as e: + print(e) + return + did_restore = True + return did_restore + + def save_checkpoint(self, epoch, logs=None): + monitor_op = np.less + ckpt_path = f'{self.work_dir}/model_ckpt_steps_{self.global_step}.ckpt' + logging.info(f'Epoch {epoch:05d}@{self.global_step}: saving model to {ckpt_path}') + self._atomic_save(ckpt_path) + for old_ckpt in get_all_ckpts(self.work_dir)[self.num_ckpt_keep:]: + subprocess.check_call(f'rm -rf "{old_ckpt}"', shell=True) + logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') + current = None + if logs is not None and self.monitor_key in logs: + current = logs[self.monitor_key] + if current is not None and self.save_best: + if monitor_op(current, self.best_val_results): + best_filepath = f'{self.work_dir}/model_ckpt_best.pt' + self.best_val_results = current + logging.info( + f'Epoch {epoch:05d}@{self.global_step}: {self.monitor_key} reached {current:0.5f}. ' + f'Saving model to {best_filepath}') + self._atomic_save(best_filepath) + + def _atomic_save(self, filepath): + checkpoint = self.dump_checkpoint() + tmp_path = str(filepath) + ".part" + torch.save(checkpoint, tmp_path, _use_new_zipfile_serialization=False) + os.replace(tmp_path, filepath) + + def dump_checkpoint(self): + checkpoint = {'epoch': self.current_epoch, 'global_step': self.global_step, + 'checkpoint_callback_best': self.best_val_results} + # save optimizers + optimizer_states = [] + for i, optimizer in enumerate(self.optimizers): + if optimizer is not None: + optimizer_states.append(optimizer.state_dict()) + + checkpoint['optimizer_states'] = optimizer_states + task_ref = self.get_task_ref() + checkpoint['state_dict'] = { + k: v.state_dict() for k, v in task_ref.named_children() if len(list(v.parameters())) > 0} + return checkpoint + + #################### + # DDP + #################### + def ddp_init(self, gpu_idx, task): + # determine which process we are and world size + self.proc_rank = gpu_idx + task.trainer = self + self.init_ddp_connection(self.proc_rank, self.num_gpus) + + # copy model to each gpu + torch.cuda.set_device(gpu_idx) + # override root GPU + self.root_gpu = gpu_idx + self.task = task + + def configure_ddp(self, task): + task = DDP(task, device_ids=[self.root_gpu], find_unused_parameters=True) + if dist.get_rank() != 0 and not self.debug: + sys.stdout = open(os.devnull, "w") + sys.stderr = open(os.devnull, "w") + random.seed(self.seed) + np.random.seed(self.seed) + return task + + def init_ddp_connection(self, proc_rank, world_size): + root_node = '127.0.0.1' + root_node = self.resolve_root_node_address(root_node) + os.environ['MASTER_ADDR'] = root_node + dist.init_process_group('nccl', rank=proc_rank, world_size=world_size) + + def resolve_root_node_address(self, root_node): + if '[' in root_node: + name = root_node.split('[')[0] + number = root_node.split(',')[0] + if '-' in number: + number = number.split('-')[0] + number = re.sub('[^0-9]', '', number) + root_node = name + number + return root_node + + #################### + # utils + #################### + def get_task_ref(self): + from tasks.base_task import BaseTask + task: BaseTask = self.task.module if isinstance(self.task, DDP) else self.task + return task + + def log_metrics_to_tb(self, metrics, step=None): + """Logs the metric dict passed in. + + :param metrics: + """ + # added metrics by Lightning for convenience + metrics['epoch'] = self.current_epoch + + # turn all tensors to scalars + scalar_metrics = self.metrics_to_scalars(metrics) + + step = step if step is not None else self.global_step + # log actual metrics + if self.proc_rank == 0: + self.log_metrics(self.logger, scalar_metrics, step=step) + + @staticmethod + def log_metrics(logger, metrics, step=None): + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + logger.add_scalar(k, v, step) + + def metrics_to_scalars(self, metrics): + new_metrics = {} + for k, v in metrics.items(): + if isinstance(v, torch.Tensor): + v = v.item() + + if type(v) is dict: + v = self.metrics_to_scalars(v) + + new_metrics[k] = v + + return new_metrics diff --git a/utils/training_utils.py b/utils/training_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..409b15388790b1aadb24632313bdd1f41b4b06ac --- /dev/null +++ b/utils/training_utils.py @@ -0,0 +1,27 @@ +from utils.hparams import hparams + + +class RSQRTSchedule(object): + def __init__(self, optimizer): + super().__init__() + self.optimizer = optimizer + self.constant_lr = hparams['lr'] + self.warmup_updates = hparams['warmup_updates'] + self.hidden_size = hparams['hidden_size'] + self.lr = hparams['lr'] + for param_group in optimizer.param_groups: + param_group['lr'] = self.lr + self.step(0) + + def step(self, num_updates): + constant_lr = self.constant_lr + warmup = min(num_updates / self.warmup_updates, 1.0) + rsqrt_decay = max(self.warmup_updates, num_updates) ** -0.5 + rsqrt_hidden = self.hidden_size ** -0.5 + self.lr = max(constant_lr * warmup * rsqrt_decay * rsqrt_hidden, 1e-7) + for param_group in self.optimizer.param_groups: + param_group['lr'] = self.lr + return self.lr + + def get_lr(self): + return self.optimizer.param_groups[0]['lr'] diff --git a/utils/tts_utils.py b/utils/tts_utils.py new file mode 100755 index 0000000000000000000000000000000000000000..9da2385ba52ce735a2d3c46ad8743d4a5bb7cd5c --- /dev/null +++ b/utils/tts_utils.py @@ -0,0 +1,371 @@ +from collections import defaultdict +import torch +import torch.nn.functional as F + + +def make_positions(tensor, padding_idx): + """Replace non-padding symbols with their position numbers. + + Position numbers begin at padding_idx+1. Padding symbols are ignored. + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. In particular XLA + # prefers ints, cumsum defaults to output longs, and ONNX doesn't know + # how to handle the dtype kwarg in cumsum. + mask = tensor.ne(padding_idx).int() + return ( + torch.cumsum(mask, dim=1).type_as(mask) * mask + ).long() + padding_idx + + +def softmax(x, dim): + return F.softmax(x, dim=dim, dtype=torch.float32) + + +def sequence_mask(lengths, maxlen, dtype=torch.bool): + if maxlen is None: + maxlen = lengths.max() + mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() + mask.type(dtype) + return mask + + +INCREMENTAL_STATE_INSTANCE_ID = defaultdict(lambda: 0) + + +def _get_full_incremental_state_key(module_instance, key): + module_name = module_instance.__class__.__name__ + + # assign a unique ID to each module instance, so that incremental state is + # not shared across module instances + if not hasattr(module_instance, '_instance_id'): + INCREMENTAL_STATE_INSTANCE_ID[module_name] += 1 + module_instance._instance_id = INCREMENTAL_STATE_INSTANCE_ID[module_name] + + return '{}.{}.{}'.format(module_name, module_instance._instance_id, key) + + +def get_incremental_state(module, incremental_state, key): + """Helper for getting incremental state for an nn.Module.""" + full_key = _get_full_incremental_state_key(module, key) + if incremental_state is None or full_key not in incremental_state: + return None + return incremental_state[full_key] + + +def set_incremental_state(module, incremental_state, key, value): + """Helper for setting incremental state for an nn.Module.""" + if incremental_state is not None: + full_key = _get_full_incremental_state_key(module, key) + incremental_state[full_key] = value + + +def fill_with_neg_inf(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(float('-inf')).type_as(t) + + +def fill_with_neg_inf2(t): + """FP16-compatible function that fills a tensor with -inf.""" + return t.float().fill_(-1e8).type_as(t) + + +def get_focus_rate(attn, src_padding_mask=None, tgt_padding_mask=None): + ''' + attn: bs x L_t x L_s + ''' + if src_padding_mask is not None: + attn = attn * (1 - src_padding_mask.float())[:, None, :] + + if tgt_padding_mask is not None: + attn = attn * (1 - tgt_padding_mask.float())[:, :, None] + + focus_rate = attn.max(-1).values.sum(-1) + focus_rate = focus_rate / attn.sum(-1).sum(-1) + return focus_rate + + +def get_phone_coverage_rate(attn, src_padding_mask=None, src_seg_mask=None, tgt_padding_mask=None): + ''' + attn: bs x L_t x L_s + ''' + src_mask = attn.new(attn.size(0), attn.size(-1)).bool().fill_(False) + if src_padding_mask is not None: + src_mask |= src_padding_mask + if src_seg_mask is not None: + src_mask |= src_seg_mask + + attn = attn * (1 - src_mask.float())[:, None, :] + if tgt_padding_mask is not None: + attn = attn * (1 - tgt_padding_mask.float())[:, :, None] + + phone_coverage_rate = attn.max(1).values.sum(-1) + # phone_coverage_rate = phone_coverage_rate / attn.sum(-1).sum(-1) + phone_coverage_rate = phone_coverage_rate / (1 - src_mask.float()).sum(-1) + return phone_coverage_rate + + +def get_diagonal_focus_rate(attn, attn_ks, target_len, src_padding_mask=None, tgt_padding_mask=None, + band_mask_factor=5, band_width=50): + ''' + attn: bx x L_t x L_s + attn_ks: shape: tensor with shape [batch_size], input_lens/output_lens + + diagonal: y=k*x (k=attn_ks, x:output, y:input) + 1 0 0 + 0 1 0 + 0 0 1 + y>=k*(x-width) and y<=k*(x+width):1 + else:0 + ''' + # width = min(target_len/band_mask_factor, 50) + width1 = target_len / band_mask_factor + width2 = target_len.new(target_len.size()).fill_(band_width) + width = torch.where(width1 < width2, width1, width2).float() + base = torch.ones(attn.size()).to(attn.device) + zero = torch.zeros(attn.size()).to(attn.device) + x = torch.arange(0, attn.size(1)).to(attn.device)[None, :, None].float() * base + y = torch.arange(0, attn.size(2)).to(attn.device)[None, None, :].float() * base + cond = (y - attn_ks[:, None, None] * x) + cond1 = cond + attn_ks[:, None, None] * width[:, None, None] + cond2 = cond - attn_ks[:, None, None] * width[:, None, None] + mask1 = torch.where(cond1 < 0, zero, base) + mask2 = torch.where(cond2 > 0, zero, base) + mask = mask1 * mask2 + + if src_padding_mask is not None: + attn = attn * (1 - src_padding_mask.float())[:, None, :] + if tgt_padding_mask is not None: + attn = attn * (1 - tgt_padding_mask.float())[:, :, None] + + diagonal_attn = attn * mask + diagonal_focus_rate = diagonal_attn.sum(-1).sum(-1) / attn.sum(-1).sum(-1) + return diagonal_focus_rate, mask + + +def select_attn(attn_logits, type='best'): + """ + + :param attn_logits: [n_layers, B, n_head, T_sp, T_txt] + :return: + """ + encdec_attn = torch.stack(attn_logits, 0).transpose(1, 2) + # [n_layers * n_head, B, T_sp, T_txt] + encdec_attn = (encdec_attn.reshape([-1, *encdec_attn.shape[2:]])).softmax(-1) + if type == 'best': + indices = encdec_attn.max(-1).values.sum(-1).argmax(0) + encdec_attn = encdec_attn.gather( + 0, indices[None, :, None, None].repeat(1, 1, encdec_attn.size(-2), encdec_attn.size(-1)))[0] + return encdec_attn + elif type == 'mean': + return encdec_attn.mean(0) + + +def make_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + Tensor: Mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[0, 0, 0, 0 ,0], + [0, 0, 0, 1, 1], + [0, 0, 1, 1, 1]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0], + [0, 0, 0, 0]], + [[0, 0, 0, 1], + [0, 0, 0, 1]], + [[0, 0, 1, 1], + [0, 0, 1, 1]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_pad_mask(lengths, xs) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_pad_mask(lengths, xs, 1) + tensor([[[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]], + [[0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1]]], dtype=torch.uint8) + >>> make_pad_mask(lengths, xs, 2) + tensor([[[0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1]], + [[0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1], + [0, 0, 0, 1, 1, 1]], + [[0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8) + """ + if length_dim == 0: + raise ValueError("length_dim cannot be 0: {}".format(length_dim)) + + if not isinstance(lengths, list): + lengths = lengths.tolist() + bs = int(len(lengths)) + if xs is None: + maxlen = int(max(lengths)) + else: + maxlen = xs.size(length_dim) + + seq_range = torch.arange(0, maxlen, dtype=torch.int64) + seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) + seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) + mask = seq_range_expand >= seq_length_expand + + if xs is not None: + assert xs.size(0) == bs, (xs.size(0), bs) + + if length_dim < 0: + length_dim = xs.dim() + length_dim + # ind = (:, None, ..., None, :, , None, ..., None) + ind = tuple( + slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) + ) + mask = mask[ind].expand_as(xs).to(xs.device) + return mask + + +def make_non_pad_mask(lengths, xs=None, length_dim=-1): + """Make mask tensor containing indices of non-padded part. + Args: + lengths (LongTensor or List): Batch of lengths (B,). + xs (Tensor, optional): The reference tensor. + If set, masks will be the same shape as this tensor. + length_dim (int, optional): Dimension indicator of the above tensor. + See the example. + Returns: + ByteTensor: mask tensor containing indices of padded part. + dtype=torch.uint8 in PyTorch 1.2- + dtype=torch.bool in PyTorch 1.2+ (including 1.2) + Examples: + With only lengths. + >>> lengths = [5, 3, 2] + >>> make_non_pad_mask(lengths) + masks = [[1, 1, 1, 1 ,1], + [1, 1, 1, 0, 0], + [1, 1, 0, 0, 0]] + With the reference tensor. + >>> xs = torch.zeros((3, 2, 4)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1], + [1, 1, 1, 1]], + [[1, 1, 1, 0], + [1, 1, 1, 0]], + [[1, 1, 0, 0], + [1, 1, 0, 0]]], dtype=torch.uint8) + >>> xs = torch.zeros((3, 2, 6)) + >>> make_non_pad_mask(lengths, xs) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + With the reference tensor and dimension indicator. + >>> xs = torch.zeros((3, 6, 6)) + >>> make_non_pad_mask(lengths, xs, 1) + tensor([[[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]], + [[1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8) + >>> make_non_pad_mask(lengths, xs, 2) + tensor([[[1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 0]], + [[1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 0, 0, 0]], + [[1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0]]], dtype=torch.uint8) + """ + return ~make_pad_mask(lengths, xs, length_dim) + + +def get_mask_from_lengths(lengths): + max_len = torch.max(lengths).item() + ids = torch.arange(0, max_len).to(lengths.device) + mask = (ids < lengths.unsqueeze(1)).bool() + return mask + + +def group_hidden_by_segs(h, seg_ids, max_len): + """ + + :param h: [B, T, H] + :param seg_ids: [B, T] + :return: h_ph: [B, T_ph, H] + """ + B, T, H = h.shape + h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) + all_ones = h.new_ones(h.shape[:2]) + cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() + h_gby_segs = h_gby_segs[:, 1:] + cnt_gby_segs = cnt_gby_segs[:, 1:] + h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) + return h_gby_segs, cnt_gby_segs diff --git a/vocoders/__init__.py b/vocoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..50e4abf21d1cd113f65d353f0101e3550de3bac3 --- /dev/null +++ b/vocoders/__init__.py @@ -0,0 +1,2 @@ +from vocoders import hifigan +from vocoders import fastdiff diff --git a/vocoders/base_vocoder.py b/vocoders/base_vocoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fe49a9e4f790ecdc5e76d60a23f96602b59fc48d --- /dev/null +++ b/vocoders/base_vocoder.py @@ -0,0 +1,39 @@ +import importlib +VOCODERS = {} + + +def register_vocoder(cls): + VOCODERS[cls.__name__.lower()] = cls + VOCODERS[cls.__name__] = cls + return cls + + +def get_vocoder_cls(hparams): + if hparams['vocoder'] in VOCODERS: + return VOCODERS[hparams['vocoder']] + else: + vocoder_cls = hparams['vocoder'] + pkg = ".".join(vocoder_cls.split(".")[:-1]) + cls_name = vocoder_cls.split(".")[-1] + vocoder_cls = getattr(importlib.import_module(pkg), cls_name) + return vocoder_cls + + +class BaseVocoder: + def spec2wav(self, mel): + """ + + :param mel: [T, 80] + :return: wav: [T'] + """ + + raise NotImplementedError + + @staticmethod + def wav2spec(wav_fn): + """ + + :param wav_fn: str + :return: wav, mel: [T, 80] + """ + raise NotImplementedError diff --git a/vocoders/fastdiff.py b/vocoders/fastdiff.py new file mode 100644 index 0000000000000000000000000000000000000000..1769085832bfc902eeff0155b788141ae194e85e --- /dev/null +++ b/vocoders/fastdiff.py @@ -0,0 +1,162 @@ +import glob +import re +import librosa +import torch +import yaml +from sklearn.preprocessing import StandardScaler +from torch import nn +from modules.FastDiff.module.FastDiff_model import FastDiff as FastDiff_model +from utils.hparams import hparams +from modules.parallel_wavegan.utils import read_hdf5 +from vocoders.base_vocoder import BaseVocoder, register_vocoder +import numpy as np +from modules.FastDiff.module.util import theta_timestep_loss, compute_hyperparams_given_schedule, sampling_given_noise_schedule + +def load_fastdiff_model(config_path, checkpoint_path): + # load config + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # setup + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = FastDiff_model(audio_channels=config['audio_channels'], + inner_channels=config['inner_channels'], + cond_channels=config['cond_channels'], + upsample_ratios=config['upsample_ratios'], + lvc_layers_each_block=config['lvc_layers_each_block'], + lvc_kernel_size=config['lvc_kernel_size'], + kpnet_hidden_channels=config['kpnet_hidden_channels'], + kpnet_conv_size=config['kpnet_conv_size'], + dropout=config['dropout'], + diffusion_step_embed_dim_in=config['diffusion_step_embed_dim_in'], + diffusion_step_embed_dim_mid=config['diffusion_step_embed_dim_mid'], + diffusion_step_embed_dim_out=config['diffusion_step_embed_dim_out'], + use_weight_norm=config['use_weight_norm']) + + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"]["model"], strict=True) + + # Init hyperparameters by linear schedule + noise_schedule = torch.linspace(float(config["beta_0"]), float(config["beta_T"]), int(config["T"])).cuda() + diffusion_hyperparams = compute_hyperparams_given_schedule(noise_schedule) + + # map diffusion hyperparameters to gpu + for key in diffusion_hyperparams: + if key in ["beta", "alpha", "sigma"]: + diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda() + diffusion_hyperparams = diffusion_hyperparams + + + if config['noise_schedule'] != '': + noise_schedule = config['noise_schedule'] + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule).cuda() + else: + # Select Schedule + try: + reverse_step = int(hparams.get('N')) + except: + print('Please specify $N (the number of revere iterations) in config file. Now denoise with 4 iterations.') + reverse_step = 4 + if reverse_step == 1000: + noise_schedule = torch.linspace(0.000001, 0.01, 1000).cuda() + elif reverse_step == 200: + noise_schedule = torch.linspace(0.0001, 0.02, 200).cuda() + + # Below are schedules derived by Noise Predictor + elif reverse_step == 8: + noise_schedule = [6.689325005027058e-07, 1.0033881153503899e-05, 0.00015496854030061513, + 0.002387222135439515, 0.035597629845142365, 0.3681158423423767, 0.4735414385795593, 0.5] + elif reverse_step == 6: + noise_schedule = [1.7838445955931093e-06, 2.7984189728158526e-05, 0.00043231004383414984, + 0.006634317338466644, 0.09357017278671265, 0.6000000238418579] + elif reverse_step == 4: + noise_schedule = [3.2176e-04, 2.5743e-03, 2.5376e-02, 7.0414e-01] + elif reverse_step == 3: + noise_schedule = [9.0000e-05, 9.0000e-03, 6.0000e-01] + else: + raise NotImplementedError + + if isinstance(noise_schedule, list): + noise_schedule = torch.FloatTensor(noise_schedule).cuda() + + model.remove_weight_norm() + model = model.eval().to(device) + print(f"| Loaded model parameters from {checkpoint_path}.") + print(f"| FastDiff device: {device}.") + return model, diffusion_hyperparams, noise_schedule, config, device + + +@register_vocoder +class FastDiff(BaseVocoder): + def __init__(self): + if hparams['vocoder_ckpt'] == '': # load LJSpeech FastDiff pretrained model + base_dir = 'checkpoint/FastDiff' + config_path = f'{base_dir}/config.yaml' + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load FastDiff: ', ckpt) + self.scaler = None + self.model, self.dh, self.noise_schedule, self.config, self.device = load_fastdiff_model( + config_path=config_path, + checkpoint_path=ckpt, + ) + else: + base_dir = hparams['vocoder_ckpt'] + print(base_dir) + config_path = f'{base_dir}/config.yaml' + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load FastDiff: ', ckpt) + self.scaler = None + self.model, self.dh, self.noise_schedule, self.config, self.device = load_fastdiff_model( + config_path=config_path, + checkpoint_path=ckpt, + ) + + def spec2wav(self, mel, **kwargs): + # start generation + device = self.device + with torch.no_grad(): + c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device) + audio_length = c.shape[-1] * hparams["hop_size"] + y = sampling_given_noise_schedule( + self.model, (1, 1, audio_length), self.dh, self.noise_schedule, condition=c, ddim=False, return_sequence=False) + wav_out = y.cpu().numpy() + return wav_out + + @staticmethod + def wav2spec(wav_fn, return_linear=False): + from data_gen.tts.data_gen_utils import process_utterance + res = process_utterance( + wav_fn, fft_size=hparams['fft_size'], + hop_size=hparams['hop_size'], + win_length=hparams['win_size'], + num_mels=hparams['audio_num_mel_bins'], + fmin=hparams['fmin'], + fmax=hparams['fmax'], + sample_rate=hparams['audio_sample_rate'], + loud_norm=hparams['loud_norm'], + min_level_db=hparams['min_level_db'], + return_linear=return_linear, vocoder='fastdiff', eps=float(hparams.get('wav2spec_eps', 1e-10))) + if return_linear: + return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft] + else: + return res[0], res[1].T + + @staticmethod + def wav2mfcc(wav_fn): + fft_size = hparams['fft_size'] + hop_size = hparams['hop_size'] + win_length = hparams['win_size'] + sample_rate = hparams['audio_sample_rate'] + wav, _ = librosa.core.load(wav_fn, sr=sample_rate) + mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, + n_fft=fft_size, hop_length=hop_size, + win_length=win_length, pad_mode="constant", power=1.0) + mfcc_delta = librosa.feature.delta(mfcc, order=1) + mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) + mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T + return mfcc diff --git a/vocoders/hifigan.py b/vocoders/hifigan.py new file mode 100644 index 0000000000000000000000000000000000000000..810d3c931b556387f8a2e85537f4964add1e76b0 --- /dev/null +++ b/vocoders/hifigan.py @@ -0,0 +1,76 @@ +import glob +import json +import os +import re + +import librosa +import torch + +import utils +from modules.hifigan.hifigan import HifiGanGenerator +from utils.hparams import hparams, set_hparams +from vocoders.base_vocoder import register_vocoder +from vocoders.pwg import PWG +from vocoders.vocoder_utils import denoise + + +def load_model(config_path, checkpoint_path): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + ckpt_dict = torch.load(checkpoint_path, map_location="cpu") + if '.yaml' in config_path: + config = set_hparams(config_path, global_hparams=False) + state = ckpt_dict["state_dict"]["model_gen"] + elif '.json' in config_path: + config = json.load(open(config_path, 'r')) + state = ckpt_dict["generator"] + + model = HifiGanGenerator(config) + model.load_state_dict(state, strict=True) + model.remove_weight_norm() + model = model.eval().to(device) + print(f"| Loaded model parameters from {checkpoint_path}.") + print(f"| HifiGAN device: {device}.") + return model, config, device + + +total_time = 0 + + +@register_vocoder +class HifiGAN(PWG): + def __init__(self): + base_dir = hparams['vocoder_ckpt'] + config_path = f'{base_dir}/config.yaml' + if os.path.exists(config_path): + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load HifiGAN: ', ckpt) + self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) + else: + config_path = f'{base_dir}/config.json' + ckpt = f'{base_dir}/generator_v1' + if os.path.exists(config_path): + self.model, self.config, self.device = load_model(config_path=config_path, checkpoint_path=ckpt) + + def spec2wav(self, mel, **kwargs): + device = self.device + with torch.no_grad(): + c = torch.FloatTensor(mel).unsqueeze(0).transpose(2, 1).to(device) + with utils.Timer('hifigan', print_time=hparams['profile_infer']): + f0 = kwargs.get('f0') + if f0 is not None and hparams.get('use_nsf'): + f0 = torch.FloatTensor(f0[None, :]).to(device) + y = self.model(c, f0).view(-1) + else: + y = self.model(c).view(-1) + wav_out = y.cpu().numpy() + if hparams.get('vocoder_denoise_c', 0.0) > 0: + wav_out = denoise(wav_out, v=hparams['vocoder_denoise_c']) + return wav_out + + # @staticmethod + # def wav2spec(wav_fn, **kwargs): + # wav, _ = librosa.core.load(wav_fn, sr=hparams['audio_sample_rate']) + # wav_torch = torch.FloatTensor(wav)[None, :] + # mel = mel_spectrogram(wav_torch, hparams).numpy()[0] + # return wav, mel.T diff --git a/vocoders/pwg.py b/vocoders/pwg.py new file mode 100644 index 0000000000000000000000000000000000000000..ca9b6891ab2ba5cb413eeca97a41534e5db129d5 --- /dev/null +++ b/vocoders/pwg.py @@ -0,0 +1,137 @@ +import glob +import re +import librosa +import torch +import yaml +from sklearn.preprocessing import StandardScaler +from torch import nn +from modules.parallel_wavegan.models import ParallelWaveGANGenerator +from modules.parallel_wavegan.utils import read_hdf5 +from utils.hparams import hparams +from utils.pitch_utils import f0_to_coarse +from vocoders.base_vocoder import BaseVocoder, register_vocoder +import numpy as np + + +def load_pwg_model(config_path, checkpoint_path, stats_path): + # load config + with open(config_path) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # setup + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + model = ParallelWaveGANGenerator(**config["generator_params"]) + + ckpt_dict = torch.load(checkpoint_path, map_location="cpu") + if 'state_dict' not in ckpt_dict: # official vocoder + model.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["model"]["generator"]) + scaler = StandardScaler() + if config["format"] == "hdf5": + scaler.mean_ = read_hdf5(stats_path, "mean") + scaler.scale_ = read_hdf5(stats_path, "scale") + elif config["format"] == "npy": + scaler.mean_ = np.load(stats_path)[0] + scaler.scale_ = np.load(stats_path)[1] + else: + raise ValueError("support only hdf5 or npy format.") + else: # custom PWG vocoder + fake_task = nn.Module() + fake_task.model_gen = model + fake_task.load_state_dict(torch.load(checkpoint_path, map_location="cpu")["state_dict"], strict=False) + scaler = None + + model.remove_weight_norm() + model = model.eval().to(device) + print(f"| Loaded model parameters from {checkpoint_path}.") + print(f"| PWG device: {device}.") + return model, scaler, config, device + + +@register_vocoder +class PWG(BaseVocoder): + def __init__(self): + if hparams['vocoder_ckpt'] == '': # load LJSpeech PWG pretrained model + base_dir = 'wavegan_pretrained' + ckpts = glob.glob(f'{base_dir}/checkpoint-*steps.pkl') + ckpt = sorted(ckpts, key= + lambda x: int(re.findall(f'{base_dir}/checkpoint-(\d+)steps.pkl', x)[0]))[-1] + config_path = f'{base_dir}/config.yaml' + print('| load PWG: ', ckpt) + self.model, self.scaler, self.config, self.device = load_pwg_model( + config_path=config_path, + checkpoint_path=ckpt, + stats_path=f'{base_dir}/stats.h5', + ) + else: + base_dir = hparams['vocoder_ckpt'] + print(base_dir) + config_path = f'{base_dir}/config.yaml' + ckpt = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key= + lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))[-1] + print('| load PWG: ', ckpt) + self.scaler = None + self.model, _, self.config, self.device = load_pwg_model( + config_path=config_path, + checkpoint_path=ckpt, + stats_path=f'{base_dir}/stats.h5', + ) + + def spec2wav(self, mel, **kwargs): + # start generation + config = self.config + device = self.device + pad_size = (config["generator_params"]["aux_context_window"], + config["generator_params"]["aux_context_window"]) + c = mel + if self.scaler is not None: + c = self.scaler.transform(c) + + with torch.no_grad(): + z = torch.randn(1, 1, c.shape[0] * config["hop_size"]).to(device) + c = np.pad(c, (pad_size, (0, 0)), "edge") + c = torch.FloatTensor(c).unsqueeze(0).transpose(2, 1).to(device) + p = kwargs.get('f0') + if p is not None: + p = f0_to_coarse(p) + p = np.pad(p, (pad_size,), "edge") + p = torch.LongTensor(p[None, :]).to(device) + y = self.model(z, c, p).view(-1) + wav_out = y.cpu().numpy() + return wav_out + + @staticmethod + def wav2spec(wav_fn, return_linear=False): + from data_gen.tts.data_gen_utils import process_utterance + res = process_utterance( + wav_fn, fft_size=hparams['fft_size'], + hop_size=hparams['hop_size'], + win_length=hparams['win_size'], + num_mels=hparams['audio_num_mel_bins'], + fmin=hparams['fmin'], + fmax=hparams['fmax'], + sample_rate=hparams['audio_sample_rate'], + loud_norm=hparams['loud_norm'], + min_level_db=hparams['min_level_db'], + return_linear=return_linear, vocoder='pwg', eps=float(hparams.get('wav2spec_eps', 1e-10))) + if return_linear: + return res[0], res[1].T, res[2].T # [T, 80], [T, n_fft] + else: + return res[0], res[1].T + + @staticmethod + def wav2mfcc(wav_fn): + fft_size = hparams['fft_size'] + hop_size = hparams['hop_size'] + win_length = hparams['win_size'] + sample_rate = hparams['audio_sample_rate'] + wav, _ = librosa.core.load(wav_fn, sr=sample_rate) + mfcc = librosa.feature.mfcc(y=wav, sr=sample_rate, n_mfcc=13, + n_fft=fft_size, hop_length=hop_size, + win_length=win_length, pad_mode="constant", power=1.0) + mfcc_delta = librosa.feature.delta(mfcc, order=1) + mfcc_delta_delta = librosa.feature.delta(mfcc, order=2) + mfcc = np.concatenate([mfcc, mfcc_delta, mfcc_delta_delta]).T + return mfcc diff --git a/vocoders/vocoder_utils.py b/vocoders/vocoder_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..db5d5ca1765928e4b047db04435a8a39b52592ca --- /dev/null +++ b/vocoders/vocoder_utils.py @@ -0,0 +1,15 @@ +import librosa + +from utils.hparams import hparams +import numpy as np + + +def denoise(wav, v=0.1): + spec = librosa.stft(y=wav, n_fft=hparams['fft_size'], hop_length=hparams['hop_size'], + win_length=hparams['win_size'], pad_mode='constant') + spec_m = np.abs(spec) + spec_m = np.clip(spec_m - v, a_min=0, a_max=None) + spec_a = np.angle(spec) + + return librosa.istft(spec_m * np.exp(1j * spec_a), hop_length=hparams['hop_size'], + win_length=hparams['win_size'])