mrfakename commited on
Commit
635f007
0 Parent(s):

Initial Commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.t7 filter=lfs diff=lfs merge=lfs -text
25
+ OOD_texts.txt filter=lfs diff=lfs merge=lfs -text
26
+ *.rar filter=lfs diff=lfs merge=lfs -text
27
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
28
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
30
+ *.tar filter=lfs diff=lfs merge=lfs -text
31
+ *.tflite filter=lfs diff=lfs merge=lfs -text
32
+ *.tgz filter=lfs diff=lfs merge=lfs -text
33
+ *.wasm filter=lfs diff=lfs merge=lfs -text
34
+ *.xz filter=lfs diff=lfs merge=lfs -text
35
+ *.zip filter=lfs diff=lfs merge=lfs -text
36
+ *.zst filter=lfs diff=lfs merge=lfs -text
37
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+
7
+ # C extensions
8
+ *.so
9
+
10
+ # Distribution / packaging
11
+ .Python
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ wheels/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+ cover/
54
+
55
+ # Translations
56
+ *.mo
57
+ *.pot
58
+
59
+ # Django stuff:
60
+ *.log
61
+ local_settings.py
62
+ db.sqlite3
63
+ db.sqlite3-journal
64
+
65
+ # Flask stuff:
66
+ instance/
67
+ .webassets-cache
68
+
69
+ # Scrapy stuff:
70
+ .scrapy
71
+
72
+ # Sphinx documentation
73
+ docs/_build/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # pdm
106
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
107
+ #pdm.lock
108
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
109
+ # in version control.
110
+ # https://pdm.fming.dev/#use-with-ide
111
+ .pdm.toml
112
+
113
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
114
+ __pypackages__/
115
+
116
+ # Celery stuff
117
+ celerybeat-schedule
118
+ celerybeat.pid
119
+
120
+ # SageMath parsed files
121
+ *.sage.py
122
+
123
+ # Environments
124
+ .env
125
+ .venv
126
+ env/
127
+ venv/
128
+ ENV/
129
+ env.bak/
130
+ venv.bak/
131
+
132
+ # Spyder project settings
133
+ .spyderproject
134
+ .spyproject
135
+
136
+ # Rope project settings
137
+ .ropeproject
138
+
139
+ # mkdocs documentation
140
+ /site
141
+
142
+ # mypy
143
+ .mypy_cache/
144
+ .dmypy.json
145
+ dmypy.json
146
+
147
+ # Pyre type checker
148
+ .pyre/
149
+
150
+ # pytype static type analyzer
151
+ .pytype/
152
+
153
+ # Cython debug symbols
154
+ cython_debug/
155
+
156
+ # PyCharm
157
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
158
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
159
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
160
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
161
+ #.idea/
162
+
163
+ voice
Configs/config.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 2
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 200 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 100 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 400 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: "/local/LJSpeech-1.1/wavs"
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: false
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'istftnet' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10, 6]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20, 12]
56
+ gen_istft_n_fft: 20
57
+ gen_istft_hop_size: 5
58
+
59
+ # speech language model config
60
+ slm:
61
+ model: 'microsoft/wavlm-base-plus'
62
+ sr: 16000 # sampling rate of SLM
63
+ hidden: 768 # hidden size of SLM
64
+ nlayers: 13 # number of layers of SLM
65
+ initial_channel: 64 # initial channels of SLM discriminator head
66
+
67
+ # style diffusion model config
68
+ diffusion:
69
+ embedding_mask_proba: 0.1
70
+ # transformer config
71
+ transformer:
72
+ num_layers: 3
73
+ num_heads: 8
74
+ head_features: 64
75
+ multiplier: 2
76
+
77
+ # diffusion distribution config
78
+ dist:
79
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
80
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
81
+ mean: -3.0
82
+ std: 1.0
83
+
84
+ loss_params:
85
+ lambda_mel: 5. # mel reconstruction loss
86
+ lambda_gen: 1. # generator loss
87
+ lambda_slm: 1. # slm feature matching loss
88
+
89
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
90
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
91
+ TMA_epoch: 50 # TMA starting epoch (1st stage)
92
+
93
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
94
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
95
+ lambda_dur: 1. # duration loss (2nd stage)
96
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
97
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
98
+ lambda_diff: 1. # score matching loss (2nd stage)
99
+
100
+ diff_epoch: 20 # style diffusion starting epoch (2nd stage)
101
+ joint_epoch: 50 # joint training starting epoch (2nd stage)
102
+
103
+ optimizer_params:
104
+ lr: 0.0001 # general learning rate
105
+ bert_lr: 0.00001 # learning rate for PLBERT
106
+ ft_lr: 0.00001 # learning rate for acoustic modules
107
+
108
+ slmadv_params:
109
+ min_len: 400 # minimum length of samples
110
+ max_len: 500 # maximum length of samples
111
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
112
+ iter: 10 # update the discriminator every this iterations of generator update
113
+ thresh: 5 # gradient norm above which the gradient is scaled
114
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
115
+ sig: 1.5 # sigma for differentiable duration modeling
116
+
Configs/config_ft.yml ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LJSpeech"
2
+ save_freq: 5
3
+ log_interval: 10
4
+ device: "cuda"
5
+ epochs: 50 # number of finetuning epoch (1 hour of data)
6
+ batch_size: 8
7
+ max_len: 400 # maximum number of frames
8
+ pretrained_model: "Models/LibriTTS/epochs_2nd_00020.pth"
9
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
10
+ load_only_params: true # set to true if do not want to load epoch numbers and optimizer parameters
11
+
12
+ F0_path: "Utils/JDC/bst.t7"
13
+ ASR_config: "Utils/ASR/config.yml"
14
+ ASR_path: "Utils/ASR/epoch_00080.pth"
15
+ PLBERT_dir: 'Utils/PLBERT/'
16
+
17
+ data_params:
18
+ train_data: "Data/train_list.txt"
19
+ val_data: "Data/val_list.txt"
20
+ root_path: "/local/LJSpeech-1.1/wavs"
21
+ OOD_data: "Data/OOD_texts.txt"
22
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
23
+
24
+ preprocess_params:
25
+ sr: 24000
26
+ spect_params:
27
+ n_fft: 2048
28
+ win_length: 1200
29
+ hop_length: 300
30
+
31
+ model_params:
32
+ multispeaker: true
33
+
34
+ dim_in: 64
35
+ hidden_dim: 512
36
+ max_conv_dim: 512
37
+ n_layer: 3
38
+ n_mels: 80
39
+
40
+ n_token: 178 # number of phoneme tokens
41
+ max_dur: 50 # maximum duration of a single phoneme
42
+ style_dim: 128 # style vector size
43
+
44
+ dropout: 0.2
45
+
46
+ # config for decoder
47
+ decoder:
48
+ type: 'hifigan' # either hifigan or istftnet
49
+ resblock_kernel_sizes: [3,7,11]
50
+ upsample_rates : [10,5,3,2]
51
+ upsample_initial_channel: 512
52
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
53
+ upsample_kernel_sizes: [20,10,6,4]
54
+
55
+ # speech language model config
56
+ slm:
57
+ model: 'microsoft/wavlm-base-plus'
58
+ sr: 16000 # sampling rate of SLM
59
+ hidden: 768 # hidden size of SLM
60
+ nlayers: 13 # number of layers of SLM
61
+ initial_channel: 64 # initial channels of SLM discriminator head
62
+
63
+ # style diffusion model config
64
+ diffusion:
65
+ embedding_mask_proba: 0.1
66
+ # transformer config
67
+ transformer:
68
+ num_layers: 3
69
+ num_heads: 8
70
+ head_features: 64
71
+ multiplier: 2
72
+
73
+ # diffusion distribution config
74
+ dist:
75
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
76
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
77
+ mean: -3.0
78
+ std: 1.0
79
+
80
+ loss_params:
81
+ lambda_mel: 5. # mel reconstruction loss
82
+ lambda_gen: 1. # generator loss
83
+ lambda_slm: 1. # slm feature matching loss
84
+
85
+ lambda_mono: 1. # monotonic alignment loss (TMA)
86
+ lambda_s2s: 1. # sequence-to-sequence loss (TMA)
87
+
88
+ lambda_F0: 1. # F0 reconstruction loss
89
+ lambda_norm: 1. # norm reconstruction loss
90
+ lambda_dur: 1. # duration loss
91
+ lambda_ce: 20. # duration predictor probability output CE loss
92
+ lambda_sty: 1. # style reconstruction loss
93
+ lambda_diff: 1. # score matching loss
94
+
95
+ diff_epoch: 10 # style diffusion starting epoch
96
+ joint_epoch: 30 # joint training starting epoch
97
+
98
+ optimizer_params:
99
+ lr: 0.0001 # general learning rate
100
+ bert_lr: 0.00001 # learning rate for PLBERT
101
+ ft_lr: 0.0001 # learning rate for acoustic modules
102
+
103
+ slmadv_params:
104
+ min_len: 400 # minimum length of samples
105
+ max_len: 500 # maximum length of samples
106
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
107
+ iter: 10 # update the discriminator every this iterations of generator update
108
+ thresh: 5 # gradient norm above which the gradient is scaled
109
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
110
+ sig: 1.5 # sigma for differentiable duration modeling
111
+
Configs/config_libritts.yml ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Models/LibriTTS"
2
+ first_stage_path: "first_stage.pth"
3
+ save_freq: 1
4
+ log_interval: 10
5
+ device: "cuda"
6
+ epochs_1st: 50 # number of epochs for first stage training (pre-training)
7
+ epochs_2nd: 30 # number of peochs for second stage training (joint training)
8
+ batch_size: 16
9
+ max_len: 300 # maximum number of frames
10
+ pretrained_model: ""
11
+ second_stage_load_pretrained: true # set to true if the pre-trained model is for 2nd stage
12
+ load_only_params: false # set to true if do not want to load epoch numbers and optimizer parameters
13
+
14
+ F0_path: "Utils/JDC/bst.t7"
15
+ ASR_config: "Utils/ASR/config.yml"
16
+ ASR_path: "Utils/ASR/epoch_00080.pth"
17
+ PLBERT_dir: 'Utils/PLBERT/'
18
+
19
+ data_params:
20
+ train_data: "Data/train_list.txt"
21
+ val_data: "Data/val_list.txt"
22
+ root_path: ""
23
+ OOD_data: "Data/OOD_texts.txt"
24
+ min_length: 50 # sample until texts with this size are obtained for OOD texts
25
+
26
+ preprocess_params:
27
+ sr: 24000
28
+ spect_params:
29
+ n_fft: 2048
30
+ win_length: 1200
31
+ hop_length: 300
32
+
33
+ model_params:
34
+ multispeaker: true
35
+
36
+ dim_in: 64
37
+ hidden_dim: 512
38
+ max_conv_dim: 512
39
+ n_layer: 3
40
+ n_mels: 80
41
+
42
+ n_token: 178 # number of phoneme tokens
43
+ max_dur: 50 # maximum duration of a single phoneme
44
+ style_dim: 128 # style vector size
45
+
46
+ dropout: 0.2
47
+
48
+ # config for decoder
49
+ decoder:
50
+ type: 'hifigan' # either hifigan or istftnet
51
+ resblock_kernel_sizes: [3,7,11]
52
+ upsample_rates : [10,5,3,2]
53
+ upsample_initial_channel: 512
54
+ resblock_dilation_sizes: [[1,3,5], [1,3,5], [1,3,5]]
55
+ upsample_kernel_sizes: [20,10,6,4]
56
+
57
+ # speech language model config
58
+ slm:
59
+ model: 'microsoft/wavlm-base-plus'
60
+ sr: 16000 # sampling rate of SLM
61
+ hidden: 768 # hidden size of SLM
62
+ nlayers: 13 # number of layers of SLM
63
+ initial_channel: 64 # initial channels of SLM discriminator head
64
+
65
+ # style diffusion model config
66
+ diffusion:
67
+ embedding_mask_proba: 0.1
68
+ # transformer config
69
+ transformer:
70
+ num_layers: 3
71
+ num_heads: 8
72
+ head_features: 64
73
+ multiplier: 2
74
+
75
+ # diffusion distribution config
76
+ dist:
77
+ sigma_data: 0.2 # placeholder for estimate_sigma_data set to false
78
+ estimate_sigma_data: true # estimate sigma_data from the current batch if set to true
79
+ mean: -3.0
80
+ std: 1.0
81
+
82
+ loss_params:
83
+ lambda_mel: 5. # mel reconstruction loss
84
+ lambda_gen: 1. # generator loss
85
+ lambda_slm: 1. # slm feature matching loss
86
+
87
+ lambda_mono: 1. # monotonic alignment loss (1st stage, TMA)
88
+ lambda_s2s: 1. # sequence-to-sequence loss (1st stage, TMA)
89
+ TMA_epoch: 5 # TMA starting epoch (1st stage)
90
+
91
+ lambda_F0: 1. # F0 reconstruction loss (2nd stage)
92
+ lambda_norm: 1. # norm reconstruction loss (2nd stage)
93
+ lambda_dur: 1. # duration loss (2nd stage)
94
+ lambda_ce: 20. # duration predictor probability output CE loss (2nd stage)
95
+ lambda_sty: 1. # style reconstruction loss (2nd stage)
96
+ lambda_diff: 1. # score matching loss (2nd stage)
97
+
98
+ diff_epoch: 10 # style diffusion starting epoch (2nd stage)
99
+ joint_epoch: 15 # joint training starting epoch (2nd stage)
100
+
101
+ optimizer_params:
102
+ lr: 0.0001 # general learning rate
103
+ bert_lr: 0.00001 # learning rate for PLBERT
104
+ ft_lr: 0.00001 # learning rate for acoustic modules
105
+
106
+ slmadv_params:
107
+ min_len: 400 # minimum length of samples
108
+ max_len: 500 # maximum length of samples
109
+ batch_percentage: 0.5 # to prevent out of memory, only use half of the original batch size
110
+ iter: 20 # update the discriminator every this iterations of generator update
111
+ thresh: 5 # gradient norm above which the gradient is scaled
112
+ scale: 0.01 # gradient scaling factor for predictors from SLM discriminators
113
+ sig: 1.5 # sigma for differentiable duration modeling
Data/OOD_texts.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0989ef6a9873b711befefcbe60660ced7a65532359277f766f4db504c558a72
3
+ size 31758898
Data/train_list.txt ADDED
The diff for this file is too large to render. See raw diff
 
Data/val_list.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LJ022-0023.wav|ðɪ ˌoʊvɚwˈɛlmɪŋ mədʒˈɔːɹᵻɾi ʌv pˈiːpəl ɪn ðɪs kˈʌntɹi nˈoʊ hˌaʊ tə sˈɪft ðə wˈiːt fɹʌmðə tʃˈæf ɪn wʌt ðeɪ hˈɪɹ ænd wʌt ðeɪ ɹˈiːd .|0
2
+ LJ043-0030.wav|ɪf sˈʌmbɑːdi dˈɪd ðˈæt tə mˌiː , ɐ lˈaʊsi tɹˈɪk lˈaɪk ðˈæt , tə tˈeɪk maɪ wˈaɪf ɐwˈeɪ , ænd ˈɔːl ðə fˈɜːnɪtʃɚ , aɪ wʊd biː mˈæd æz hˈɛl , tˈuː .|0
3
+ LJ005-0201.wav|ˌæzˌɪz ʃˈoʊn baɪ ðə ɹᵻpˈoːɹt ʌvðə kəmˈɪʃənɚz tʊ ɪŋkwˈaɪɚɹ ˌɪntʊ ðə stˈeɪt ʌvðə mjuːnˈɪsɪpəl kˌɔːɹpɚɹˈeɪʃənz ɪn ˈeɪtiːn θˈɜːɾi fˈaɪv .|0
4
+ LJ001-0110.wav|ˈiːvən ðə kˈæslɑːn tˈaɪp wɛn ɛnlˈɑːɹdʒd ʃˈoʊz ɡɹˈeɪt ʃˈɔːɹtkʌmɪŋz ɪn ðɪs ɹᵻspˈɛkt :|0
5
+ LJ003-0345.wav|ˈɔːl ðə kəmˈɪɾi kʊd dˈuː ɪn ðɪs ɹᵻspˈɛkt wʌz tə θɹˈoʊ ðə ɹᵻspˌɑːnsəbˈɪlɪɾi ˌɔn ˈʌðɚz .|0
6
+ LJ007-0154.wav|ðiːz pˈʌndʒənt ænd wˈɛl ɡɹˈaʊndᵻd stɹˈɪktʃɚz ɐplˈaɪd wɪð stˈɪl ɡɹˈeɪɾɚ fˈoːɹs tə ðɪ ʌŋkənvˈɪktᵻd pɹˈɪzənɚ , ðə mˈæn hˌuː kˈeɪm tə ðə pɹˈɪzən ˈɪnəsənt , ænd stˈɪl ʌŋkəntˈæmᵻnˌeɪɾᵻd ,|0
7
+ LJ018-0098.wav|ænd ɹˈɛkəɡnˌaɪzd æz wˈʌn ʌvðə fɹˈiːkwɛntɚz ʌvðə bˈoʊɡəs lˈɔː stˈeɪʃənɚz . hɪz ɚɹˈɛst lˈɛd tə ðæt ʌv ˈʌðɚz .|0
8
+ LJ047-0044.wav|ˈɑːswəld wʌz , haʊˈɛvɚ , wˈɪlɪŋ tə dɪskˈʌs hɪz kˈɑːntækts wɪð sˈoʊviət ɐθˈɔːɹɪɾiz . hiː dᵻnˈaɪd hˌævɪŋ ˌɛni ɪnvˈɑːlvmənt wɪð sˈoʊviət ɪntˈɛlɪdʒəns ˈeɪdʒənsiz|0
9
+ LJ031-0038.wav|ðə fˈɜːst fɪzˈɪʃən tə sˈiː ðə pɹˈɛzɪdənt æt pˈɑːɹklənd hˈɑːspɪɾəl wʌz dˈɑːktɚ . tʃˈɑːɹlz dʒˈeɪ . kˈæɹɪkˌoʊ , ɐ ɹˈɛzᵻdənt ɪn dʒˈɛnɚɹəl sˈɜːdʒɚɹi .|0
10
+ LJ048-0194.wav|dˈʊɹɹɪŋ ðə mˈɔːɹnɪŋ ʌv noʊvˈɛmbɚ twˈɛnti tˈuː pɹˈaɪɚ tə ðə mˈoʊɾɚkˌeɪd .|0
11
+ LJ049-0026.wav|ˌɔn əkˈeɪʒən ðə sˈiːkɹᵻt sˈɜːvɪs hɐzbɪn pɚmˈɪɾᵻd tə hæv ɐn ˈeɪdʒənt ɹˈaɪdɪŋ ɪnðə pˈæsɪndʒɚ kəmpˈɑːɹtmənt wɪððə pɹˈɛzɪdənt .|0
12
+ LJ004-0152.wav|ɔːlðˈoʊ æt mˈɪstɚ . bˈʌkstənz vˈɪzɪt ɐ nˈuː dʒˈeɪl wʌz ɪn pɹˈɑːsɛs ʌv ɪɹˈɛkʃən , ðə fˈɜːst stˈɛp təwˈɔːɹdz ɹᵻfˈɔːɹm sˈɪns hˈaʊɚdz vˌɪzɪtˈeɪʃən ɪn sˈɛvəntˌiːn sˈɛvənti fˈoːɹ .|0
13
+ LJ008-0278.wav|ɔːɹ ðˈɛɹz mˌaɪt biː wˈʌn ʌv mˈɛni , ænd ɪt mˌaɪt biː kənsˈɪdɚd nˈɛsᵻsɚɹi tə dˈɑːlɚ mˌeɪk ɐn ɛɡzˈæmpəl.dˈɑːlɚ|0
14
+ LJ043-0002.wav|ðə wˈɔːɹəŋ kəmˈɪʃən ɹᵻpˈoːɹt . baɪ ðə pɹˈɛzɪdənts kəmˈɪʃən ɔnðɪ ɐsˌæsᵻnˈeɪʃən ʌv pɹˈɛzɪdənt kˈɛnədi . tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld :|0
15
+ LJ009-0114.wav|mˈɪstɚ . wˈeɪkfiːld wˈaɪndz ˈʌp hɪz ɡɹˈæfɪk bˌʌt sˈʌmwʌt sɛnsˈeɪʃənəl ɐkˈaʊnt baɪ dᵻskɹˈaɪbɪŋ ɐnˈʌðɚ ɹᵻlˈɪdʒəs sˈɜːvɪs , wˌɪtʃ mˈeɪ ɐpɹˈoʊpɹɪˌeɪtli biː ɪnsˈɜːɾᵻd hˈɪɹ .|0
16
+ LJ028-0506.wav|ɐ mˈɑːdɚn ˈɑːɹɾɪst wʊdhɐv dˈɪfɪkˌʌlti ɪn dˌuːɪŋ sˈʌtʃ ˈækjʊɹət wˈɜːk .|0
17
+ LJ050-0168.wav|wɪððə pɚtˈɪkjʊlɚ pˈɜːpəsᵻz ʌvðɪ ˈeɪdʒənsi ɪnvˈɑːlvd . ðə kəmˈɪʃən ɹˈɛkəɡnˌaɪzᵻz ðæt ðɪs ɪz ɐ kˌɑːntɹəvˈɜːʃəl ˈɛɹiə|0
18
+ LJ039-0223.wav|ˈɑːswəldz mɚɹˈiːn tɹˈeɪnɪŋ ɪn mˈɑːɹksmənʃˌɪp , hɪz ˈʌðɚ ɹˈaɪfəl ɛkspˈiəɹɪəns ænd hɪz ɪstˈæblɪʃt fəmˌɪliˈæɹɪɾi wɪð ðɪs pɚtˈɪkjʊlɚ wˈɛpən|0
19
+ LJ029-0032.wav|ɐkˈoːɹdɪŋ tʊ oʊdˈɑːnəl , kwˈoʊt , wiː hæd ɐ mˈoʊɾɚkˌeɪd wɛɹˈɛvɚ kplˈʌsplʌs wˌɪtʃ hɐdbɪn bˌɪn hˈeɪstili sˈʌmənd fɚðə ðə pˈɜːpəs wiː wˈɛnt , ˈɛnd kwˈoʊt .|0
20
+ LJ031-0070.wav|dˈɑːktɚ . klˈɑːɹk , hˌuː mˈoʊst klˈoʊsli əbzˈɜːvd ðə hˈɛd wˈuːnd ,|0
21
+ LJ034-0198.wav|jˈuːɪnz , hˌuː wʌz ɔnðə saʊθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən stɹˈiːts tˈɛstᵻfˌaɪd ðæt hiː kʊd nˌɑːt dᵻskɹˈaɪb ðə mˈæn hiː sˈɔː ɪnðə wˈɪndoʊ .|0
22
+ LJ026-0068.wav|ˈɛnɚdʒi ˈɛntɚz ðə plˈænt , tʊ ɐ smˈɔːl ɛkstˈɛnt ,|0
23
+ LJ039-0075.wav|wˈʌns juː nˈoʊ ðæt juː mˈʌst pˌʊt ðə kɹˈɔshɛɹz ɔnðə tˈɑːɹɡɪt ænd ðæt ɪz ˈɔːl ðæt ɪz nˈɛsᵻsɚɹi .|0
24
+ LJ004-0096.wav|ðə fˈeɪɾəl kˈɑːnsɪkwənsᵻz wˈɛɹɑːf mˌaɪt biː pɹɪvˈɛntᵻd ɪf ðə dʒˈʌstɪsᵻz ʌvðə pˈiːs wɜː djˈuːli ˈɔːθɚɹˌaɪzd|0
25
+ LJ005-0014.wav|spˈiːkɪŋ ˌɔn ɐ dᵻbˈeɪt ˌɔn pɹˈɪzən mˈæɾɚz , hiː dᵻklˈɛɹd ðˈæt|0
26
+ LJ012-0161.wav|hiː wʌz ɹᵻpˈoːɹɾᵻd tə hæv fˈɔːlən ɐwˈeɪ tʊ ɐ ʃˈædoʊ .|0
27
+ LJ018-0239.wav|hɪz dˌɪsɐpˈɪɹəns ɡˈeɪv kˈʌlɚ ænd sˈʌbstəns tʊ ˈiːvəl ɹᵻpˈoːɹts ɔːlɹˌɛdi ɪn sˌɜːkjʊlˈeɪʃən ðætðə wɪl ænd kənvˈeɪəns əbˌʌv ɹᵻfˈɜːd tuː|0
28
+ LJ019-0257.wav|hˈɪɹ ðə tɹˈɛd wˈiːl wʌz ɪn jˈuːs , ðɛɹ sˈɛljʊlɚ kɹˈæŋks , ɔːɹ hˈɑːɹd lˈeɪbɚ məʃˈiːnz .|0
29
+ LJ028-0008.wav|juː tˈæp dʒˈɛntli wɪð jʊɹ hˈiːl əpˌɑːn ðə ʃˈoʊldɚɹ ʌvðə dɹˈoʊmdɚɹi tʊ ˈɜːdʒ hɜːɹ ˈɔn .|0
30
+ LJ024-0083.wav|ðɪs plˈæn ʌv mˈaɪn ɪz nˈoʊ ɐtˈæk ɔnðə kˈoːɹt ;|0
31
+ LJ042-0129.wav|nˈoʊ nˈaɪt klˈʌbz ɔːɹ bˈoʊlɪŋ ˈælɪz , nˈoʊ plˈeɪsᵻz ʌv ɹˌɛkɹiːˈeɪʃən ɛksˈɛpt ðə tɹˈeɪd jˈuːniən dˈænsᵻz . aɪ hæv hæd ɪnˈʌf .|0
32
+ LJ036-0103.wav|ðə pəlˈiːs ˈæskt hˌɪm wˈɛðɚ hiː kʊd pˈɪk ˈaʊt hɪz pˈæsɪndʒɚ fɹʌmðə lˈaɪnʌp .|0
33
+ LJ046-0058.wav|dˈʊɹɹɪŋ hɪz pɹˈɛzɪdənsi , fɹˈæŋklɪn dˈiː . ɹˈoʊzəvˌɛlt mˌeɪd ˈɔːlmoʊst fˈoːɹ hˈʌndɹɪd dʒˈɜːniz ænd tɹˈævəld mˈoːɹ ðɐn θɹˈiː hˈʌndɹɪd fˈɪfti θˈaʊzənd mˈaɪlz .|0
34
+ LJ014-0076.wav|hiː wʌz sˈiːn ˈæftɚwɚdz smˈoʊkɪŋ ænd tˈɔːkɪŋ wɪð hɪz hˈoʊsts ɪn ðɛɹ bˈæk pˈɑːɹlɚ , ænd nˈɛvɚ sˈiːn ɐɡˈɛn ɐlˈaɪv .|0
35
+ LJ002-0043.wav|lˈɔŋ nˈæɹoʊ ɹˈuːmz wˈʌn θˈɜːɾi sˈɪks fˈiːt , sˈɪks twˈɛnti θɹˈiː fˈiːt , ænd ðɪ ˈeɪtθ ˈeɪtiːn ,|0
36
+ LJ009-0076.wav|wiː kˈʌm tə ðə sˈɜːmən .|0
37
+ LJ017-0131.wav|ˈiːvən wɛn ðə hˈaɪ ʃˈɛɹɪf hæd tˈoʊld hˌɪm ðɛɹwˌʌz nˈoʊ pˌɑːsəbˈɪlɪɾi əvɚ ɹᵻpɹˈiːv , ænd wɪðˌɪn ɐ fjˈuː ˈaʊɚz ʌv ˌɛksɪkjˈuːʃən .|0
38
+ LJ046-0184.wav|bˌʌt ðɛɹ ɪz ɐ sˈɪstəm fɚðɪ ɪmˈiːdɪət nˌoʊɾɪfɪkˈeɪʃən ʌvðə sˈiːkɹᵻt sˈɜːvɪs baɪ ðə kənfˈaɪnɪŋ ˌɪnstɪtˈuːʃən wɛn ɐ sˈʌbdʒɛkt ɪz ɹᵻlˈiːst ɔːɹ ɛskˈeɪps .|0
39
+ LJ014-0263.wav|wˌɛn ˈʌðɚ plˈɛʒɚz pˈɔːld hiː tˈʊk ɐ θˈiəɾɚ , ænd pˈoʊzd æz ɐ mjuːnˈɪfɪsənt pˈeɪtɹən ʌvðə dɹəmˈæɾɪk ˈɑːɹt .|0
40
+ LJ042-0096.wav|ˈoʊld ɛkstʃˈeɪndʒ ɹˈeɪt ɪn ɐdˈɪʃən tə hɪz fˈæktɚɹi sˈælɚɹi ʌv ɐpɹˈɑːksɪmətli ˈiːkwəl ɐmˈaʊnt|0
41
+ LJ049-0050.wav|hˈɪl hæd bˈoʊθ fˈiːt ɔnðə kˈɑːɹ ænd wʌz klˈaɪmɪŋ ɐbˈoːɹd tʊ ɐsˈɪst pɹˈɛzɪdənt ænd mˈɪsɪz . kˈɛnədi .|0
42
+ LJ019-0186.wav|sˈiːɪŋ ðæt sˈɪns ðɪ ɪstˈæblɪʃmənt ʌvðə sˈɛntɹəl kɹˈɪmɪnəl kˈoːɹt , nˈuːɡeɪt ɹᵻsˈiːvd pɹˈɪzənɚz fɔːɹ tɹˈaɪəl fɹʌm sˈɛvɹəl kˈaʊntiz ,|0
43
+ LJ028-0307.wav|ðˈɛn lˈɛt twˈɛnti dˈeɪz pˈæs , ænd æt ðɪ ˈɛnd ʌv ðæt tˈaɪm stˈeɪʃən nˌɪɹ ðə tʃˈældæsəŋ ɡˈeɪts ɐ bˈɑːdi ʌv fˈoːɹ θˈaʊzənd .|0
44
+ LJ012-0235.wav|wˌaɪl ðeɪ wɜːɹ ɪn ɐ stˈeɪt ʌv ɪnsˌɛnsəbˈɪlɪɾi ðə mˈɜːdɚ wʌz kəmˈɪɾᵻd .|0
45
+ LJ034-0053.wav|ɹˈiːtʃt ðə sˈeɪm kəŋklˈuːʒən æz lætˈoʊnə ðætðə pɹˈɪnts fˈaʊnd ɔnðə kˈɑːɹtənz wɜː ðoʊz ʌv lˈiː hˈɑːɹvi ˈɑːswəld .|0
46
+ LJ014-0030.wav|ðiːz wɜː dˈæmnətˌoːɹi fˈækts wˌɪtʃ wˈɛl səpˈoːɹɾᵻd ðə pɹˌɑːsɪkjˈuːʃən .|0
47
+ LJ015-0203.wav|bˌʌt wɜː ðə pɹɪkˈɔːʃənz tˈuː mˈɪnɪt , ðə vˈɪdʒɪləns tˈuː klˈoʊs təbi ᵻlˈuːdᵻd ɔːɹ ˌoʊvɚkˈʌm ?|0
48
+ LJ028-0093.wav|bˌʌt hɪz skɹˈaɪb ɹˈoʊt ɪɾ ɪnðə mˈænɚ kˈʌstəmˌɛɹi fɚðə skɹˈaɪbz ʌv ðoʊz dˈeɪz tə ɹˈaɪt ʌv ðɛɹ ɹˈɔɪəl mˈæstɚz .|0
49
+ LJ002-0018.wav|ðɪ ɪnˈædɪkwəsi ʌvðə dʒˈeɪl wʌz nˈoʊɾɪst ænd ɹᵻpˈoːɹɾᵻd əpˌɑːn ɐɡˈɛn ænd ɐɡˈɛn baɪ ðə ɡɹˈænd dʒˈʊɹɹiz ʌvðə sˈɪɾi ʌv lˈʌndən ,|0
50
+ LJ028-0275.wav|æt lˈæst , ɪnðə twˈɛntiəθ mˈʌnθ ,|0
51
+ LJ012-0042.wav|wˌɪtʃ hiː kˈɛpt kənsˈiːld ɪn ɐ hˈaɪdɪŋ plˈeɪs wɪð ɐ tɹˈæp dˈoːɹ dʒˈʌst ˌʌndɚ hɪz bˈɛd .|0
52
+ LJ011-0096.wav|hiː mˈæɹid ɐ lˈeɪdi ˈɔːlsoʊ bᵻlˈɔŋɪŋ tə ðə səsˈaɪəɾi ʌv fɹˈɛndz , hˌuː bɹˈɔːt hˌɪm ɐ lˈɑːɹdʒ fˈɔːɹtʃʊn , wˈɪtʃ , ænd hɪz ˈoʊn mˈʌni , hiː pˌʊt ˌɪntʊ ɐ sˈɪɾi fˈɜːm ,|0
53
+ LJ036-0077.wav|ɹˈɑːdʒɚ dˈiː . kɹˈeɪɡ , ɐ dˈɛpjuːɾi ʃˈɛɹɪf ʌv dˈæləs kˈaʊnti ,|0
54
+ LJ016-0318.wav|ˈʌðɚɹ əfˈɪʃəlz , ɡɹˈeɪt lˈɔɪɚz , ɡˈʌvɚnɚz ʌv pɹˈɪzənz , ænd tʃˈæplɪnz səpˈoːɹɾᵻd ðɪs vjˈuː .|0
55
+ LJ013-0164.wav|hˌuː kˈeɪm fɹʌm hɪz ɹˈuːm ɹˈɛdi dɹˈɛst , ɐ səspˈɪʃəs sˈɜːkəmstˌæns , æz hiː wʌz ˈɔːlweɪz lˈeɪt ɪnðə mˈɔːɹnɪŋ .|0
56
+ LJ027-0141.wav|ɪz klˈoʊsli ɹᵻpɹədˈuːst ɪnðə lˈaɪf hˈɪstɚɹi ʌv ɛɡzˈɪstɪŋ dˈɪɹ . ɔːɹ , ɪn ˈʌðɚ wˈɜːdz ,|0
57
+ LJ028-0335.wav|ɐkˈoːɹdɪŋli ðeɪ kəmˈɪɾᵻd tə hˌɪm ðə kəmˈænd ʌv ðɛɹ hˈoʊl ˈɑːɹmi , ænd pˌʊt ðə kˈiːz ʌv ðɛɹ sˈɪɾi ˌɪntʊ hɪz hˈændz .|0
58
+ LJ031-0202.wav|mˈɪsɪz . kˈɛnədi tʃˈoʊz ðə hˈɑːspɪɾəl ɪn bəθˈɛzdə fɚðɪ ˈɔːtɑːpsi bɪkˈʌz ðə pɹˈɛzɪdənt hæd sˈɜːvd ɪnðə nˈeɪvi .|0
59
+ LJ021-0145.wav|fɹʌm ðoʊz wˈɪlɪŋ tə dʒˈɔɪn ɪn ɪstˈæblɪʃɪŋ ðɪs hˈo��pt fɔːɹ pˈiəɹɪəd ʌv pˈiːs ,|0
60
+ LJ016-0288.wav|dˈɑːlɚ mˈuːlɚ , mˈuːlɚ , hiːz ðə mˈæn , dˈɑːlɚ tˈɪl ɐ daɪvˈɜːʒən wʌz kɹiːˈeɪɾᵻd baɪ ðɪ ɐpˈɪɹəns ʌvðə ɡˈæloʊz , wˌɪtʃ wʌz ɹᵻsˈiːvd wɪð kəntˈɪnjuːəs jˈɛlz .|0
61
+ LJ028-0081.wav|jˈɪɹz lˈeɪɾɚ , wˌɛn ðɪ ˌɑːɹkiːˈɑːlədʒˌɪsts kʊd ɹˈɛdili dɪstˈɪŋɡwɪʃ ðə fˈɔls fɹʌmðə tɹˈuː ,|0
62
+ LJ018-0081.wav|hɪz dᵻfˈɛns bˌiːɪŋ ðæt hiː hæd ɪntˈɛndᵻd tə kəmˈɪt sˈuːɪsˌaɪd , bˌʌt ðˈæt , ɔnðɪ ɐpˈɪɹəns ʌv ðɪs ˈɑːfɪsɚ hˌuː hæd ɹˈɔŋd hˌɪm ,|0
63
+ LJ021-0066.wav|təɡˌɛðɚ wɪð ɐ ɡɹˈeɪt ˈɪŋkɹiːs ɪnðə pˈeɪɹoʊlz , ðɛɹ hɐz kˈʌm ɐ səbstˈænʃəl ɹˈaɪz ɪnðə tˈoʊɾəl ʌv ɪndˈʌstɹɪəl pɹˈɑːfɪts|0
64
+ LJ009-0238.wav|ˈæftɚ ðɪs ðə ʃˈɛɹɪfs sˈɛnt fɔːɹ ɐnˈʌðɚ ɹˈoʊp , bˌʌt ðə spɛktˈeɪɾɚz ˌɪntəfˈɪɹd , ænd ðə mˈæn wʌz kˈæɹid bˈæk tə dʒˈeɪl .|0
65
+ LJ005-0079.wav|ænd ɪmpɹˈuːv ðə mˈɔːɹəlz ʌvðə pɹˈɪzənɚz , ænd ʃˌæl ɪnʃˈʊɹ ðə pɹˈɑːpɚ mˈɛʒɚɹ ʌv pˈʌnɪʃmənt tə kənvˈɪktᵻd əfˈɛndɚz .|0
66
+ LJ035-0019.wav|dɹˈoʊv tə ðə nɔːɹθwˈɛst kˈɔːɹnɚɹ ʌv ˈɛlm ænd hjˈuːstən , ænd pˈɑːɹkt ɐpɹˈɑːksɪmətli tˈɛn fˈiːt fɹʌmðə tɹˈæfɪk sˈɪɡnəl .|0
67
+ LJ036-0174.wav|ðɪs ɪz ðɪ ɐpɹˈɑːksɪmət tˈaɪm hiː ˈɛntɚd ðə ɹˈuːmɪŋhˌaʊs , ɐkˈoːɹdɪŋ tʊ ˈɜːliːn ɹˈɑːbɚts , ðə hˈaʊskiːpɚ ðˈɛɹ .|0
68
+ LJ046-0146.wav|ðə kɹaɪtˈiəɹɪə ɪn ɪfˈɛkt pɹˈaɪɚ tə noʊvˈɛmbɚ twˈɛnti tˈuː , nˈaɪntiːn sˈɪksti θɹˈiː , fɔːɹ dɪtˈɜːmɪnɪŋ wˈɛðɚ tʊ ɐksˈɛpt mətˈɪɹiəl fɚðə pˌiːˌɑːɹɹˈɛs dʒˈɛnɚɹəl fˈaɪlz|0
69
+ LJ017-0044.wav|ænd ðə dˈiːpɪst æŋzˈaɪəɾi wʌz fˈɛlt ðætðə kɹˈaɪm , ɪf kɹˈaɪm ðˈɛɹ hɐdbɪn , ʃˌʊd biː bɹˈɔːt hˈoʊm tʊ ɪts pˈɜːpɪtɹˌeɪɾɚ .|0
70
+ LJ017-0070.wav|bˌʌt hɪz spˈoːɹɾɪŋ ˌɑːpɚɹˈeɪʃənz dɪdnˌɑːt pɹˈɑːspɚ , ænd hiː bɪkˌeɪm ɐ nˈiːdi mˈæn , ˈɔːlweɪz dɹˈɪvən tə dˈɛspɚɹət stɹˈeɪts fɔːɹ kˈæʃ .|0
71
+ LJ014-0020.wav|hiː wʌz sˈuːn ˈæftɚwɚdz ɚɹˈɛstᵻd ˌɔn səspˈɪʃən , ænd ɐ sˈɜːtʃ ʌv hɪz lˈɑːdʒɪŋz bɹˈɔːt tə lˈaɪt sˈɛvɹəl ɡˈɑːɹmənts sˈætʃɚɹˌeɪɾᵻd wɪð blˈʌd ;|0
72
+ LJ016-0020.wav|hiː nˈɛvɚ ɹˈiːtʃt ðə sˈɪstɚn , bˌʌt fˈɛl bˈæk ˌɪntʊ ðə jˈɑːɹd , ˈɪndʒɚɹɪŋ hɪz lˈɛɡz sᵻvˈɪɹli .|0
73
+ LJ045-0230.wav|wˌɛn hiː wʌz fˈaɪnəli ˌæpɹihˈɛndᵻd ɪnðə tˈɛksəs θˈiəɾɚ . ɔːlðˈoʊ ɪɾ ɪz nˌɑːt fˈʊli kɚɹˈɑːbɚɹˌeɪɾᵻd baɪ ˈʌðɚz hˌuː wɜː pɹˈɛzənt ,|0
74
+ LJ035-0129.wav|ænd ʃiː mˈʌstɐv ɹˈʌn dˌaʊn ðə stˈɛɹz ɐhˈɛd ʌv ˈɑːswəld ænd wʊd pɹˈɑːbəbli hæv sˈiːn ɔːɹ hˈɜːd hˌɪm .|0
75
+ LJ008-0307.wav|ˈæftɚwɚdz ɛkspɹˈɛs ɐ wˈɪʃ tə mˈɜːdɚ ðə ɹᵻkˈoːɹdɚ fɔːɹ hˌævɪŋ kˈɛpt ðˌɛm sˌoʊ lˈɔŋ ɪn səspˈɛns .|0
76
+ LJ008-0294.wav|nˌɪɹli ɪndˈɛfɪnətli dᵻfˈɜːd .|0
77
+ LJ047-0148.wav|ˌɔn ɑːktˈoʊbɚ twˈɛnti fˈaɪv ,|0
78
+ LJ008-0111.wav|ðeɪ ˈɛntɚd ɐ dˈɑːlɚ stˈoʊŋ kˈoʊld ɹˈuːm , dˈɑːlɚɹ ænd wɜː pɹˈɛzəntli dʒˈɔɪnd baɪ ðə pɹˈɪzənɚ .|0
79
+ LJ034-0042.wav|ðæt hiː kʊd ˈoʊnli tˈɛstᵻfˌaɪ wɪð sˈɜːtənti ðætðə pɹˈɪnt wʌz lˈɛs ðɐn θɹˈiː dˈeɪz ˈoʊld .|0
80
+ LJ037-0234.wav|mˈɪsɪz . mˈɛɹi bɹˈɑːk , ðə wˈaɪf əvə mɪkˈænɪk hˌuː wˈɜːkt æt ðə stˈeɪʃən , wʌz ðɛɹ æt ðə tˈaɪm ænd ʃiː sˈɔː ɐ wˈaɪt mˈeɪl ,|0
81
+ LJ040-0002.wav|tʃˈæptɚ sˈɛvən . lˈiː hˈɑːɹvi ˈɑːswəld : bˈækɡɹaʊnd ænd pˈɑːsᵻbəl mˈoʊɾɪvz , pˈɑːɹt wˌʌn .|0
82
+ LJ045-0140.wav|ðɪ ˈɑːɹɡjuːmənts hiː jˈuːzd tə dʒˈʌstᵻfˌaɪ hɪz jˈuːs ʌvðɪ ˈeɪliəs sədʒˈɛst ðæt ˈɑːswəld mˌeɪhɐv kˈʌm tə θˈɪŋk ðætðə hˈoʊl wˈɜːld wʌz bᵻkˈʌmɪŋ ɪnvˈɑːlvd|0
83
+ LJ012-0035.wav|ðə nˈʌmbɚ ænd nˈeɪmz ˌɔn wˈɑːtʃᵻz , wɜː kˈɛɹfəli ɹᵻmˈuːvd ɔːɹ əblˈɪɾɚɹˌeɪɾᵻd ˈæftɚ ðə ɡˈʊdz pˈæst ˌaʊɾəv hɪz hˈændz .|0
84
+ LJ012-0250.wav|ɔnðə sˈɛvənθ dʒuːlˈaɪ , ˈeɪtiːn θˈɜːɾi sˈɛvən ,|0
85
+ LJ016-0179.wav|kəntɹˈæktᵻd wɪð ʃˈɛɹɪfs ænd kənvˈiːnɚz tə wˈɜːk baɪ ðə dʒˈɑːb .|0
86
+ LJ016-0138.wav|æɾə dˈɪstəns fɹʌmðə pɹˈɪzən .|0
87
+ LJ027-0052.wav|ðiːz pɹˈɪnsɪpəlz ʌv həmˈɑːlədʒi ɑːɹ ᵻsˈɛnʃəl tʊ ɐ kɚɹˈɛkt ɪntˌɜːpɹɪtˈeɪʃən ʌvðə fˈækts ʌv mɔːɹfˈɑːlədʒi .|0
88
+ LJ031-0134.wav|ˌɔn wˈʌn əkˈeɪʒən mˈɪsɪz . dʒˈɑːnsən , ɐkˈʌmpənid baɪ tˈuː sˈiːkɹᵻt sˈɜːvɪs ˈeɪdʒənts , lˈɛft ðə ɹˈuːm tə sˈiː mˈɪsɪz . kˈɛnədi ænd mˈɪsɪz . kˈɑːnæli .|0
89
+ LJ019-0273.wav|wˌɪtʃ sˌɜː dʒˈɑːʃjuːə dʒˈɛb tˈoʊld ðə kəmˈɪɾi hiː kənsˈɪdɚd ðə pɹˈɑːpɚɹ ˈɛlɪmənts ʌv pˈiːnəl dˈɪsɪplˌɪn .|0
90
+ LJ014-0110.wav|æt ðə fˈɜːst ðə bˈɑːksᵻz wɜːɹ ɪmpˈaʊndᵻd , ˈoʊpənd , ænd fˈaʊnd tə kəntˈeɪn mˈɛnɪəv oʊkˈɑːnɚz ɪfˈɛkts .|0
91
+ LJ034-0160.wav|ˌɔn bɹˈɛnənz sˈʌbsᵻkwənt sˈɜːʔn̩ aɪdˈɛntɪfɪkˈeɪʃən ʌv lˈiː hˈɑːɹvi ˈɑːswəld æz ðə mˈæn hiː sˈɔː fˈaɪɚ ðə ɹˈaɪfəl .|0
92
+ LJ038-0199.wav|ᵻlˈɛvən . ɪf aɪɐm ɐlˈaɪv ænd tˈeɪkən pɹˈɪzənɚ ,|0
93
+ LJ014-0010.wav|jˈɛt hiː kʊd nˌɑːt ˌoʊvɚkˈʌm ðə stɹˈeɪndʒ fˌæsᵻnˈeɪʃən ɪt hˈæd fɔːɹ hˌɪm , ænd ɹᵻmˈeɪnd baɪ ðə sˈaɪd ʌvðə kˈɔːɹps tˈɪl ðə stɹˈɛtʃɚ kˈeɪm .|0
94
+ LJ033-0047.wav|aɪ nˈoʊɾɪst wɛn aɪ wɛnt ˈaʊt ðætðə lˈaɪt wʌz ˈɔn , ˈɛnd kwˈoʊt ,|0
95
+ LJ040-0027.wav|hiː wʌz nˈɛvɚ sˈæɾɪsfˌaɪd wɪð ˈɛnɪθˌɪŋ .|0
96
+ LJ048-0228.wav|ænd ˈʌðɚz hˌuː wɜː pɹˈɛzənt sˈeɪ ðæt nˈoʊ ˈeɪdʒənt wʌz ɪnˈiːbɹɪˌeɪɾᵻd ɔːɹ ˈæktᵻd ɪmpɹˈɑːpɚli .|0
97
+ LJ003-0111.wav|hiː wʌz ɪŋ kˈɑːnsɪkwəns pˌʊt ˌaʊɾəv ðə pɹətˈɛkʃən ʌv ðɛɹ ɪntˈɜːnəl lˈɔː , ˈɛnd kwˈoʊt . ðɛɹ kˈoʊd wʌzɐ sˈʌbdʒɛkt ʌv sˌʌm kjˌʊɹɹɪˈɔsɪɾi .|0
98
+ LJ008-0258.wav|lˈɛt mˌiː ɹᵻtɹˈeɪs maɪ stˈɛps , ænd spˈiːk mˈoːɹ ɪn diːtˈeɪl ʌvðə tɹˈiːtmənt ʌvðə kəndˈɛmd ɪn ðoʊz blˈʌdθɜːsti ænd bɹˈuːɾəli ɪndˈɪfɹənt dˈeɪz ,|0
99
+ LJ029-0022.wav|ðɪ ɚɹˈɪdʒɪnəl plˈæŋ kˈɔːld fɚðə pɹˈɛzɪdənt tə spˈɛnd ˈoʊnli wˈʌn dˈeɪ ɪnðə stˈeɪt , mˌeɪkɪŋ wˈɜːlwɪnd vˈɪzɪts tə dˈæləs , fˈɔːɹt wˈɜːθ , sˌæn æntˈoʊnɪˌoʊ , ænd hjˈuːstən .|0
100
+ LJ004-0045.wav|mˈɪstɚ . stˈɜːdʒᵻz bˈoːɹn , sˌɜː dʒˈeɪmz mˈækɪntˌɑːʃ , sˌɜː dʒˈeɪmz skˈɑːɹlɪt , ænd wˈɪljəm wˈɪlbɚfˌoːɹs .|0
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ LICENSE FOR STYLETTS2:
2
+
3
+ MIT License
4
+
5
+ Copyright (c) 2023 Aaron (Yinghao) Li
6
+
7
+ Permission is hereby granted, free of charge, to any person obtaining a copy
8
+ of this software and associated documentation files (the "Software"), to deal
9
+ in the Software without restriction, including without limitation the rights
10
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11
+ copies of the Software, and to permit persons to whom the Software is
12
+ furnished to do so, subject to the following conditions:
13
+
14
+ The above copyright notice and this permission notice shall be included in all
15
+ copies or substantial portions of the Software.
16
+
17
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23
+ SOFTWARE.
24
+
25
+
26
+ LICENSE FOR DEMO PAGE:
27
+
28
+ COPYRIGHT 2023 MRFAKENAME. ALL RIGHTS RESERVED.
Modules/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Modules/diffusion/diffusion.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import pi
2
+ from random import randint
3
+ from typing import Any, Optional, Sequence, Tuple, Union
4
+
5
+ import torch
6
+ from einops import rearrange
7
+ from torch import Tensor, nn
8
+ from tqdm import tqdm
9
+
10
+ from .utils import *
11
+ from .sampler import *
12
+
13
+ """
14
+ Diffusion Classes (generic for 1d data)
15
+ """
16
+
17
+
18
+ class Model1d(nn.Module):
19
+ def __init__(self, unet_type: str = "base", **kwargs):
20
+ super().__init__()
21
+ diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
22
+ self.unet = None
23
+ self.diffusion = None
24
+
25
+ def forward(self, x: Tensor, **kwargs) -> Tensor:
26
+ return self.diffusion(x, **kwargs)
27
+
28
+ def sample(self, *args, **kwargs) -> Tensor:
29
+ return self.diffusion.sample(*args, **kwargs)
30
+
31
+
32
+ """
33
+ Audio Diffusion Classes (specific for 1d audio data)
34
+ """
35
+
36
+
37
+ def get_default_model_kwargs():
38
+ return dict(
39
+ channels=128,
40
+ patch_size=16,
41
+ multipliers=[1, 2, 4, 4, 4, 4, 4],
42
+ factors=[4, 4, 4, 2, 2, 2],
43
+ num_blocks=[2, 2, 2, 2, 2, 2],
44
+ attentions=[0, 0, 0, 1, 1, 1, 1],
45
+ attention_heads=8,
46
+ attention_features=64,
47
+ attention_multiplier=2,
48
+ attention_use_rel_pos=False,
49
+ diffusion_type="v",
50
+ diffusion_sigma_distribution=UniformDistribution(),
51
+ )
52
+
53
+
54
+ def get_default_sampling_kwargs():
55
+ return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
56
+
57
+
58
+ class AudioDiffusionModel(Model1d):
59
+ def __init__(self, **kwargs):
60
+ super().__init__(**{**get_default_model_kwargs(), **kwargs})
61
+
62
+ def sample(self, *args, **kwargs):
63
+ return super().sample(*args, **{**get_default_sampling_kwargs(), **kwargs})
64
+
65
+
66
+ class AudioDiffusionConditional(Model1d):
67
+ def __init__(
68
+ self,
69
+ embedding_features: int,
70
+ embedding_max_length: int,
71
+ embedding_mask_proba: float = 0.1,
72
+ **kwargs,
73
+ ):
74
+ self.embedding_mask_proba = embedding_mask_proba
75
+ default_kwargs = dict(
76
+ **get_default_model_kwargs(),
77
+ unet_type="cfg",
78
+ context_embedding_features=embedding_features,
79
+ context_embedding_max_length=embedding_max_length,
80
+ )
81
+ super().__init__(**{**default_kwargs, **kwargs})
82
+
83
+ def forward(self, *args, **kwargs):
84
+ default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
85
+ return super().forward(*args, **{**default_kwargs, **kwargs})
86
+
87
+ def sample(self, *args, **kwargs):
88
+ default_kwargs = dict(
89
+ **get_default_sampling_kwargs(),
90
+ embedding_scale=5.0,
91
+ )
92
+ return super().sample(*args, **{**default_kwargs, **kwargs})
Modules/diffusion/modules.py ADDED
@@ -0,0 +1,700 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import floor, log, pi
2
+ from typing import Any, List, Optional, Sequence, Tuple, Union
3
+
4
+ from .utils import *
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange, reduce, repeat
9
+ from einops.layers.torch import Rearrange
10
+ from einops_exts import rearrange_many
11
+ from torch import Tensor, einsum
12
+
13
+
14
+ """
15
+ Utils
16
+ """
17
+
18
+
19
+ class AdaLayerNorm(nn.Module):
20
+ def __init__(self, style_dim, channels, eps=1e-5):
21
+ super().__init__()
22
+ self.channels = channels
23
+ self.eps = eps
24
+
25
+ self.fc = nn.Linear(style_dim, channels * 2)
26
+
27
+ def forward(self, x, s):
28
+ x = x.transpose(-1, -2)
29
+ x = x.transpose(1, -1)
30
+
31
+ h = self.fc(s)
32
+ h = h.view(h.size(0), h.size(1), 1)
33
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
34
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
35
+
36
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
37
+ x = (1 + gamma) * x + beta
38
+ return x.transpose(1, -1).transpose(-1, -2)
39
+
40
+
41
+ class StyleTransformer1d(nn.Module):
42
+ def __init__(
43
+ self,
44
+ num_layers: int,
45
+ channels: int,
46
+ num_heads: int,
47
+ head_features: int,
48
+ multiplier: int,
49
+ use_context_time: bool = True,
50
+ use_rel_pos: bool = False,
51
+ context_features_multiplier: int = 1,
52
+ rel_pos_num_buckets: Optional[int] = None,
53
+ rel_pos_max_distance: Optional[int] = None,
54
+ context_features: Optional[int] = None,
55
+ context_embedding_features: Optional[int] = None,
56
+ embedding_max_length: int = 512,
57
+ ):
58
+ super().__init__()
59
+
60
+ self.blocks = nn.ModuleList(
61
+ [
62
+ StyleTransformerBlock(
63
+ features=channels + context_embedding_features,
64
+ head_features=head_features,
65
+ num_heads=num_heads,
66
+ multiplier=multiplier,
67
+ style_dim=context_features,
68
+ use_rel_pos=use_rel_pos,
69
+ rel_pos_num_buckets=rel_pos_num_buckets,
70
+ rel_pos_max_distance=rel_pos_max_distance,
71
+ )
72
+ for i in range(num_layers)
73
+ ]
74
+ )
75
+
76
+ self.to_out = nn.Sequential(
77
+ Rearrange("b t c -> b c t"),
78
+ nn.Conv1d(
79
+ in_channels=channels + context_embedding_features,
80
+ out_channels=channels,
81
+ kernel_size=1,
82
+ ),
83
+ )
84
+
85
+ use_context_features = exists(context_features)
86
+ self.use_context_features = use_context_features
87
+ self.use_context_time = use_context_time
88
+
89
+ if use_context_time or use_context_features:
90
+ context_mapping_features = channels + context_embedding_features
91
+
92
+ self.to_mapping = nn.Sequential(
93
+ nn.Linear(context_mapping_features, context_mapping_features),
94
+ nn.GELU(),
95
+ nn.Linear(context_mapping_features, context_mapping_features),
96
+ nn.GELU(),
97
+ )
98
+
99
+ if use_context_time:
100
+ assert exists(context_mapping_features)
101
+ self.to_time = nn.Sequential(
102
+ TimePositionalEmbedding(
103
+ dim=channels, out_features=context_mapping_features
104
+ ),
105
+ nn.GELU(),
106
+ )
107
+
108
+ if use_context_features:
109
+ assert exists(context_features) and exists(context_mapping_features)
110
+ self.to_features = nn.Sequential(
111
+ nn.Linear(
112
+ in_features=context_features, out_features=context_mapping_features
113
+ ),
114
+ nn.GELU(),
115
+ )
116
+
117
+ self.fixed_embedding = FixedEmbedding(
118
+ max_length=embedding_max_length, features=context_embedding_features
119
+ )
120
+
121
+ def get_mapping(
122
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
123
+ ) -> Optional[Tensor]:
124
+ """Combines context time features and features into mapping"""
125
+ items, mapping = [], None
126
+ # Compute time features
127
+ if self.use_context_time:
128
+ assert_message = "use_context_time=True but no time features provided"
129
+ assert exists(time), assert_message
130
+ items += [self.to_time(time)]
131
+ # Compute features
132
+ if self.use_context_features:
133
+ assert_message = "context_features exists but no features provided"
134
+ assert exists(features), assert_message
135
+ items += [self.to_features(features)]
136
+
137
+ # Compute joint mapping
138
+ if self.use_context_time or self.use_context_features:
139
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
140
+ mapping = self.to_mapping(mapping)
141
+
142
+ return mapping
143
+
144
+ def run(self, x, time, embedding, features):
145
+ mapping = self.get_mapping(time, features)
146
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
147
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
148
+
149
+ for block in self.blocks:
150
+ x = x + mapping
151
+ x = block(x, features)
152
+
153
+ x = x.mean(axis=1).unsqueeze(1)
154
+ x = self.to_out(x)
155
+ x = x.transpose(-1, -2)
156
+
157
+ return x
158
+
159
+ def forward(
160
+ self,
161
+ x: Tensor,
162
+ time: Tensor,
163
+ embedding_mask_proba: float = 0.0,
164
+ embedding: Optional[Tensor] = None,
165
+ features: Optional[Tensor] = None,
166
+ embedding_scale: float = 1.0,
167
+ ) -> Tensor:
168
+ b, device = embedding.shape[0], embedding.device
169
+ fixed_embedding = self.fixed_embedding(embedding)
170
+ if embedding_mask_proba > 0.0:
171
+ # Randomly mask embedding
172
+ batch_mask = rand_bool(
173
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
174
+ )
175
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
176
+
177
+ if embedding_scale != 1.0:
178
+ # Compute both normal and fixed embedding outputs
179
+ out = self.run(x, time, embedding=embedding, features=features)
180
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
181
+ # Scale conditional output using classifier-free guidance
182
+ return out_masked + (out - out_masked) * embedding_scale
183
+ else:
184
+ return self.run(x, time, embedding=embedding, features=features)
185
+
186
+ return x
187
+
188
+
189
+ class StyleTransformerBlock(nn.Module):
190
+ def __init__(
191
+ self,
192
+ features: int,
193
+ num_heads: int,
194
+ head_features: int,
195
+ style_dim: int,
196
+ multiplier: int,
197
+ use_rel_pos: bool,
198
+ rel_pos_num_buckets: Optional[int] = None,
199
+ rel_pos_max_distance: Optional[int] = None,
200
+ context_features: Optional[int] = None,
201
+ ):
202
+ super().__init__()
203
+
204
+ self.use_cross_attention = exists(context_features) and context_features > 0
205
+
206
+ self.attention = StyleAttention(
207
+ features=features,
208
+ style_dim=style_dim,
209
+ num_heads=num_heads,
210
+ head_features=head_features,
211
+ use_rel_pos=use_rel_pos,
212
+ rel_pos_num_buckets=rel_pos_num_buckets,
213
+ rel_pos_max_distance=rel_pos_max_distance,
214
+ )
215
+
216
+ if self.use_cross_attention:
217
+ self.cross_attention = StyleAttention(
218
+ features=features,
219
+ style_dim=style_dim,
220
+ num_heads=num_heads,
221
+ head_features=head_features,
222
+ context_features=context_features,
223
+ use_rel_pos=use_rel_pos,
224
+ rel_pos_num_buckets=rel_pos_num_buckets,
225
+ rel_pos_max_distance=rel_pos_max_distance,
226
+ )
227
+
228
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
229
+
230
+ def forward(
231
+ self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
232
+ ) -> Tensor:
233
+ x = self.attention(x, s) + x
234
+ if self.use_cross_attention:
235
+ x = self.cross_attention(x, s, context=context) + x
236
+ x = self.feed_forward(x) + x
237
+ return x
238
+
239
+
240
+ class StyleAttention(nn.Module):
241
+ def __init__(
242
+ self,
243
+ features: int,
244
+ *,
245
+ style_dim: int,
246
+ head_features: int,
247
+ num_heads: int,
248
+ context_features: Optional[int] = None,
249
+ use_rel_pos: bool,
250
+ rel_pos_num_buckets: Optional[int] = None,
251
+ rel_pos_max_distance: Optional[int] = None,
252
+ ):
253
+ super().__init__()
254
+ self.context_features = context_features
255
+ mid_features = head_features * num_heads
256
+ context_features = default(context_features, features)
257
+
258
+ self.norm = AdaLayerNorm(style_dim, features)
259
+ self.norm_context = AdaLayerNorm(style_dim, context_features)
260
+ self.to_q = nn.Linear(
261
+ in_features=features, out_features=mid_features, bias=False
262
+ )
263
+ self.to_kv = nn.Linear(
264
+ in_features=context_features, out_features=mid_features * 2, bias=False
265
+ )
266
+ self.attention = AttentionBase(
267
+ features,
268
+ num_heads=num_heads,
269
+ head_features=head_features,
270
+ use_rel_pos=use_rel_pos,
271
+ rel_pos_num_buckets=rel_pos_num_buckets,
272
+ rel_pos_max_distance=rel_pos_max_distance,
273
+ )
274
+
275
+ def forward(
276
+ self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None
277
+ ) -> Tensor:
278
+ assert_message = "You must provide a context when using context_features"
279
+ assert not self.context_features or exists(context), assert_message
280
+ # Use context if provided
281
+ context = default(context, x)
282
+ # Normalize then compute q from input and k,v from context
283
+ x, context = self.norm(x, s), self.norm_context(context, s)
284
+
285
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
286
+ # Compute and return attention
287
+ return self.attention(q, k, v)
288
+
289
+
290
+ class Transformer1d(nn.Module):
291
+ def __init__(
292
+ self,
293
+ num_layers: int,
294
+ channels: int,
295
+ num_heads: int,
296
+ head_features: int,
297
+ multiplier: int,
298
+ use_context_time: bool = True,
299
+ use_rel_pos: bool = False,
300
+ context_features_multiplier: int = 1,
301
+ rel_pos_num_buckets: Optional[int] = None,
302
+ rel_pos_max_distance: Optional[int] = None,
303
+ context_features: Optional[int] = None,
304
+ context_embedding_features: Optional[int] = None,
305
+ embedding_max_length: int = 512,
306
+ ):
307
+ super().__init__()
308
+
309
+ self.blocks = nn.ModuleList(
310
+ [
311
+ TransformerBlock(
312
+ features=channels + context_embedding_features,
313
+ head_features=head_features,
314
+ num_heads=num_heads,
315
+ multiplier=multiplier,
316
+ use_rel_pos=use_rel_pos,
317
+ rel_pos_num_buckets=rel_pos_num_buckets,
318
+ rel_pos_max_distance=rel_pos_max_distance,
319
+ )
320
+ for i in range(num_layers)
321
+ ]
322
+ )
323
+
324
+ self.to_out = nn.Sequential(
325
+ Rearrange("b t c -> b c t"),
326
+ nn.Conv1d(
327
+ in_channels=channels + context_embedding_features,
328
+ out_channels=channels,
329
+ kernel_size=1,
330
+ ),
331
+ )
332
+
333
+ use_context_features = exists(context_features)
334
+ self.use_context_features = use_context_features
335
+ self.use_context_time = use_context_time
336
+
337
+ if use_context_time or use_context_features:
338
+ context_mapping_features = channels + context_embedding_features
339
+
340
+ self.to_mapping = nn.Sequential(
341
+ nn.Linear(context_mapping_features, context_mapping_features),
342
+ nn.GELU(),
343
+ nn.Linear(context_mapping_features, context_mapping_features),
344
+ nn.GELU(),
345
+ )
346
+
347
+ if use_context_time:
348
+ assert exists(context_mapping_features)
349
+ self.to_time = nn.Sequential(
350
+ TimePositionalEmbedding(
351
+ dim=channels, out_features=context_mapping_features
352
+ ),
353
+ nn.GELU(),
354
+ )
355
+
356
+ if use_context_features:
357
+ assert exists(context_features) and exists(context_mapping_features)
358
+ self.to_features = nn.Sequential(
359
+ nn.Linear(
360
+ in_features=context_features, out_features=context_mapping_features
361
+ ),
362
+ nn.GELU(),
363
+ )
364
+
365
+ self.fixed_embedding = FixedEmbedding(
366
+ max_length=embedding_max_length, features=context_embedding_features
367
+ )
368
+
369
+ def get_mapping(
370
+ self, time: Optional[Tensor] = None, features: Optional[Tensor] = None
371
+ ) -> Optional[Tensor]:
372
+ """Combines context time features and features into mapping"""
373
+ items, mapping = [], None
374
+ # Compute time features
375
+ if self.use_context_time:
376
+ assert_message = "use_context_time=True but no time features provided"
377
+ assert exists(time), assert_message
378
+ items += [self.to_time(time)]
379
+ # Compute features
380
+ if self.use_context_features:
381
+ assert_message = "context_features exists but no features provided"
382
+ assert exists(features), assert_message
383
+ items += [self.to_features(features)]
384
+
385
+ # Compute joint mapping
386
+ if self.use_context_time or self.use_context_features:
387
+ mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
388
+ mapping = self.to_mapping(mapping)
389
+
390
+ return mapping
391
+
392
+ def run(self, x, time, embedding, features):
393
+ mapping = self.get_mapping(time, features)
394
+ x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
395
+ mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
396
+
397
+ for block in self.blocks:
398
+ x = x + mapping
399
+ x = block(x)
400
+
401
+ x = x.mean(axis=1).unsqueeze(1)
402
+ x = self.to_out(x)
403
+ x = x.transpose(-1, -2)
404
+
405
+ return x
406
+
407
+ def forward(
408
+ self,
409
+ x: Tensor,
410
+ time: Tensor,
411
+ embedding_mask_proba: float = 0.0,
412
+ embedding: Optional[Tensor] = None,
413
+ features: Optional[Tensor] = None,
414
+ embedding_scale: float = 1.0,
415
+ ) -> Tensor:
416
+ b, device = embedding.shape[0], embedding.device
417
+ fixed_embedding = self.fixed_embedding(embedding)
418
+ if embedding_mask_proba > 0.0:
419
+ # Randomly mask embedding
420
+ batch_mask = rand_bool(
421
+ shape=(b, 1, 1), proba=embedding_mask_proba, device=device
422
+ )
423
+ embedding = torch.where(batch_mask, fixed_embedding, embedding)
424
+
425
+ if embedding_scale != 1.0:
426
+ # Compute both normal and fixed embedding outputs
427
+ out = self.run(x, time, embedding=embedding, features=features)
428
+ out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
429
+ # Scale conditional output using classifier-free guidance
430
+ return out_masked + (out - out_masked) * embedding_scale
431
+ else:
432
+ return self.run(x, time, embedding=embedding, features=features)
433
+
434
+ return x
435
+
436
+
437
+ """
438
+ Attention Components
439
+ """
440
+
441
+
442
+ class RelativePositionBias(nn.Module):
443
+ def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
444
+ super().__init__()
445
+ self.num_buckets = num_buckets
446
+ self.max_distance = max_distance
447
+ self.num_heads = num_heads
448
+ self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
449
+
450
+ @staticmethod
451
+ def _relative_position_bucket(
452
+ relative_position: Tensor, num_buckets: int, max_distance: int
453
+ ):
454
+ num_buckets //= 2
455
+ ret = (relative_position >= 0).to(torch.long) * num_buckets
456
+ n = torch.abs(relative_position)
457
+
458
+ max_exact = num_buckets // 2
459
+ is_small = n < max_exact
460
+
461
+ val_if_large = (
462
+ max_exact
463
+ + (
464
+ torch.log(n.float() / max_exact)
465
+ / log(max_distance / max_exact)
466
+ * (num_buckets - max_exact)
467
+ ).long()
468
+ )
469
+ val_if_large = torch.min(
470
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
471
+ )
472
+
473
+ ret += torch.where(is_small, n, val_if_large)
474
+ return ret
475
+
476
+ def forward(self, num_queries: int, num_keys: int) -> Tensor:
477
+ i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
478
+ q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
479
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
480
+ rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
481
+
482
+ relative_position_bucket = self._relative_position_bucket(
483
+ rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
484
+ )
485
+
486
+ bias = self.relative_attention_bias(relative_position_bucket)
487
+ bias = rearrange(bias, "m n h -> 1 h m n")
488
+ return bias
489
+
490
+
491
+ def FeedForward(features: int, multiplier: int) -> nn.Module:
492
+ mid_features = features * multiplier
493
+ return nn.Sequential(
494
+ nn.Linear(in_features=features, out_features=mid_features),
495
+ nn.GELU(),
496
+ nn.Linear(in_features=mid_features, out_features=features),
497
+ )
498
+
499
+
500
+ class AttentionBase(nn.Module):
501
+ def __init__(
502
+ self,
503
+ features: int,
504
+ *,
505
+ head_features: int,
506
+ num_heads: int,
507
+ use_rel_pos: bool,
508
+ out_features: Optional[int] = None,
509
+ rel_pos_num_buckets: Optional[int] = None,
510
+ rel_pos_max_distance: Optional[int] = None,
511
+ ):
512
+ super().__init__()
513
+ self.scale = head_features**-0.5
514
+ self.num_heads = num_heads
515
+ self.use_rel_pos = use_rel_pos
516
+ mid_features = head_features * num_heads
517
+
518
+ if use_rel_pos:
519
+ assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
520
+ self.rel_pos = RelativePositionBias(
521
+ num_buckets=rel_pos_num_buckets,
522
+ max_distance=rel_pos_max_distance,
523
+ num_heads=num_heads,
524
+ )
525
+ if out_features is None:
526
+ out_features = features
527
+
528
+ self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
529
+
530
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
531
+ # Split heads
532
+ q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
533
+ # Compute similarity matrix
534
+ sim = einsum("... n d, ... m d -> ... n m", q, k)
535
+ sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
536
+ sim = sim * self.scale
537
+ # Get attention matrix with softmax
538
+ attn = sim.softmax(dim=-1)
539
+ # Compute values
540
+ out = einsum("... n m, ... m d -> ... n d", attn, v)
541
+ out = rearrange(out, "b h n d -> b n (h d)")
542
+ return self.to_out(out)
543
+
544
+
545
+ class Attention(nn.Module):
546
+ def __init__(
547
+ self,
548
+ features: int,
549
+ *,
550
+ head_features: int,
551
+ num_heads: int,
552
+ out_features: Optional[int] = None,
553
+ context_features: Optional[int] = None,
554
+ use_rel_pos: bool,
555
+ rel_pos_num_buckets: Optional[int] = None,
556
+ rel_pos_max_distance: Optional[int] = None,
557
+ ):
558
+ super().__init__()
559
+ self.context_features = context_features
560
+ mid_features = head_features * num_heads
561
+ context_features = default(context_features, features)
562
+
563
+ self.norm = nn.LayerNorm(features)
564
+ self.norm_context = nn.LayerNorm(context_features)
565
+ self.to_q = nn.Linear(
566
+ in_features=features, out_features=mid_features, bias=False
567
+ )
568
+ self.to_kv = nn.Linear(
569
+ in_features=context_features, out_features=mid_features * 2, bias=False
570
+ )
571
+
572
+ self.attention = AttentionBase(
573
+ features,
574
+ out_features=out_features,
575
+ num_heads=num_heads,
576
+ head_features=head_features,
577
+ use_rel_pos=use_rel_pos,
578
+ rel_pos_num_buckets=rel_pos_num_buckets,
579
+ rel_pos_max_distance=rel_pos_max_distance,
580
+ )
581
+
582
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
583
+ assert_message = "You must provide a context when using context_features"
584
+ assert not self.context_features or exists(context), assert_message
585
+ # Use context if provided
586
+ context = default(context, x)
587
+ # Normalize then compute q from input and k,v from context
588
+ x, context = self.norm(x), self.norm_context(context)
589
+ q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
590
+ # Compute and return attention
591
+ return self.attention(q, k, v)
592
+
593
+
594
+ """
595
+ Transformer Blocks
596
+ """
597
+
598
+
599
+ class TransformerBlock(nn.Module):
600
+ def __init__(
601
+ self,
602
+ features: int,
603
+ num_heads: int,
604
+ head_features: int,
605
+ multiplier: int,
606
+ use_rel_pos: bool,
607
+ rel_pos_num_buckets: Optional[int] = None,
608
+ rel_pos_max_distance: Optional[int] = None,
609
+ context_features: Optional[int] = None,
610
+ ):
611
+ super().__init__()
612
+
613
+ self.use_cross_attention = exists(context_features) and context_features > 0
614
+
615
+ self.attention = Attention(
616
+ features=features,
617
+ num_heads=num_heads,
618
+ head_features=head_features,
619
+ use_rel_pos=use_rel_pos,
620
+ rel_pos_num_buckets=rel_pos_num_buckets,
621
+ rel_pos_max_distance=rel_pos_max_distance,
622
+ )
623
+
624
+ if self.use_cross_attention:
625
+ self.cross_attention = Attention(
626
+ features=features,
627
+ num_heads=num_heads,
628
+ head_features=head_features,
629
+ context_features=context_features,
630
+ use_rel_pos=use_rel_pos,
631
+ rel_pos_num_buckets=rel_pos_num_buckets,
632
+ rel_pos_max_distance=rel_pos_max_distance,
633
+ )
634
+
635
+ self.feed_forward = FeedForward(features=features, multiplier=multiplier)
636
+
637
+ def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
638
+ x = self.attention(x) + x
639
+ if self.use_cross_attention:
640
+ x = self.cross_attention(x, context=context) + x
641
+ x = self.feed_forward(x) + x
642
+ return x
643
+
644
+
645
+ """
646
+ Time Embeddings
647
+ """
648
+
649
+
650
+ class SinusoidalEmbedding(nn.Module):
651
+ def __init__(self, dim: int):
652
+ super().__init__()
653
+ self.dim = dim
654
+
655
+ def forward(self, x: Tensor) -> Tensor:
656
+ device, half_dim = x.device, self.dim // 2
657
+ emb = torch.tensor(log(10000) / (half_dim - 1), device=device)
658
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
659
+ emb = rearrange(x, "i -> i 1") * rearrange(emb, "j -> 1 j")
660
+ return torch.cat((emb.sin(), emb.cos()), dim=-1)
661
+
662
+
663
+ class LearnedPositionalEmbedding(nn.Module):
664
+ """Used for continuous time"""
665
+
666
+ def __init__(self, dim: int):
667
+ super().__init__()
668
+ assert (dim % 2) == 0
669
+ half_dim = dim // 2
670
+ self.weights = nn.Parameter(torch.randn(half_dim))
671
+
672
+ def forward(self, x: Tensor) -> Tensor:
673
+ x = rearrange(x, "b -> b 1")
674
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
675
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
676
+ fouriered = torch.cat((x, fouriered), dim=-1)
677
+ return fouriered
678
+
679
+
680
+ def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
681
+ return nn.Sequential(
682
+ LearnedPositionalEmbedding(dim),
683
+ nn.Linear(in_features=dim + 1, out_features=out_features),
684
+ )
685
+
686
+
687
+ class FixedEmbedding(nn.Module):
688
+ def __init__(self, max_length: int, features: int):
689
+ super().__init__()
690
+ self.max_length = max_length
691
+ self.embedding = nn.Embedding(max_length, features)
692
+
693
+ def forward(self, x: Tensor) -> Tensor:
694
+ batch_size, length, device = *x.shape[0:2], x.device
695
+ assert_message = "Input sequence length must be <= max_length"
696
+ assert length <= self.max_length, assert_message
697
+ position = torch.arange(length, device=device)
698
+ fixed_embedding = self.embedding(position)
699
+ fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
700
+ return fixed_embedding
Modules/diffusion/sampler.py ADDED
@@ -0,0 +1,685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import atan, cos, pi, sin, sqrt
2
+ from typing import Any, Callable, List, Optional, Tuple, Type
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from einops import rearrange, reduce
8
+ from torch import Tensor
9
+
10
+ from .utils import *
11
+
12
+ """
13
+ Diffusion Training
14
+ """
15
+
16
+ """ Distributions """
17
+
18
+
19
+ class Distribution:
20
+ def __call__(self, num_samples: int, device: torch.device):
21
+ raise NotImplementedError()
22
+
23
+
24
+ class LogNormalDistribution(Distribution):
25
+ def __init__(self, mean: float, std: float):
26
+ self.mean = mean
27
+ self.std = std
28
+
29
+ def __call__(
30
+ self, num_samples: int, device: torch.device = torch.device("cpu")
31
+ ) -> Tensor:
32
+ normal = self.mean + self.std * torch.randn((num_samples,), device=device)
33
+ return normal.exp()
34
+
35
+
36
+ class UniformDistribution(Distribution):
37
+ def __call__(self, num_samples: int, device: torch.device = torch.device("cpu")):
38
+ return torch.rand(num_samples, device=device)
39
+
40
+
41
+ class VKDistribution(Distribution):
42
+ def __init__(
43
+ self,
44
+ min_value: float = 0.0,
45
+ max_value: float = float("inf"),
46
+ sigma_data: float = 1.0,
47
+ ):
48
+ self.min_value = min_value
49
+ self.max_value = max_value
50
+ self.sigma_data = sigma_data
51
+
52
+ def __call__(
53
+ self, num_samples: int, device: torch.device = torch.device("cpu")
54
+ ) -> Tensor:
55
+ sigma_data = self.sigma_data
56
+ min_cdf = atan(self.min_value / sigma_data) * 2 / pi
57
+ max_cdf = atan(self.max_value / sigma_data) * 2 / pi
58
+ u = (max_cdf - min_cdf) * torch.randn((num_samples,), device=device) + min_cdf
59
+ return torch.tan(u * pi / 2) * sigma_data
60
+
61
+
62
+ """ Diffusion Classes """
63
+
64
+
65
+ def pad_dims(x: Tensor, ndim: int) -> Tensor:
66
+ # Pads additional ndims to the right of the tensor
67
+ return x.view(*x.shape, *((1,) * ndim))
68
+
69
+
70
+ def clip(x: Tensor, dynamic_threshold: float = 0.0):
71
+ if dynamic_threshold == 0.0:
72
+ return x.clamp(-1.0, 1.0)
73
+ else:
74
+ # Dynamic thresholding
75
+ # Find dynamic threshold quantile for each batch
76
+ x_flat = rearrange(x, "b ... -> b (...)")
77
+ scale = torch.quantile(x_flat.abs(), dynamic_threshold, dim=-1)
78
+ # Clamp to a min of 1.0
79
+ scale.clamp_(min=1.0)
80
+ # Clamp all values and scale
81
+ scale = pad_dims(scale, ndim=x.ndim - scale.ndim)
82
+ x = x.clamp(-scale, scale) / scale
83
+ return x
84
+
85
+
86
+ def to_batch(
87
+ batch_size: int,
88
+ device: torch.device,
89
+ x: Optional[float] = None,
90
+ xs: Optional[Tensor] = None,
91
+ ) -> Tensor:
92
+ assert exists(x) ^ exists(xs), "Either x or xs must be provided"
93
+ # If x provided use the same for all batch items
94
+ if exists(x):
95
+ xs = torch.full(size=(batch_size,), fill_value=x).to(device)
96
+ assert exists(xs)
97
+ return xs
98
+
99
+
100
+ class Diffusion(nn.Module):
101
+ alias: str = ""
102
+
103
+ """Base diffusion class"""
104
+
105
+ def denoise_fn(
106
+ self,
107
+ x_noisy: Tensor,
108
+ sigmas: Optional[Tensor] = None,
109
+ sigma: Optional[float] = None,
110
+ **kwargs,
111
+ ) -> Tensor:
112
+ raise NotImplementedError("Diffusion class missing denoise_fn")
113
+
114
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
115
+ raise NotImplementedError("Diffusion class missing forward function")
116
+
117
+
118
+ class VDiffusion(Diffusion):
119
+ alias = "v"
120
+
121
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
122
+ super().__init__()
123
+ self.net = net
124
+ self.sigma_distribution = sigma_distribution
125
+
126
+ def get_alpha_beta(self, sigmas: Tensor) -> Tuple[Tensor, Tensor]:
127
+ angle = sigmas * pi / 2
128
+ alpha = torch.cos(angle)
129
+ beta = torch.sin(angle)
130
+ return alpha, beta
131
+
132
+ def denoise_fn(
133
+ self,
134
+ x_noisy: Tensor,
135
+ sigmas: Optional[Tensor] = None,
136
+ sigma: Optional[float] = None,
137
+ **kwargs,
138
+ ) -> Tensor:
139
+ batch_size, device = x_noisy.shape[0], x_noisy.device
140
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
141
+ return self.net(x_noisy, sigmas, **kwargs)
142
+
143
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
144
+ batch_size, device = x.shape[0], x.device
145
+
146
+ # Sample amount of noise to add for each batch element
147
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
148
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
149
+
150
+ # Get noise
151
+ noise = default(noise, lambda: torch.randn_like(x))
152
+
153
+ # Combine input and noise weighted by half-circle
154
+ alpha, beta = self.get_alpha_beta(sigmas_padded)
155
+ x_noisy = x * alpha + noise * beta
156
+ x_target = noise * alpha - x * beta
157
+
158
+ # Denoise and return loss
159
+ x_denoised = self.denoise_fn(x_noisy, sigmas, **kwargs)
160
+ return F.mse_loss(x_denoised, x_target)
161
+
162
+
163
+ class KDiffusion(Diffusion):
164
+ """Elucidated Diffusion (Karras et al. 2022): https://arxiv.org/abs/2206.00364"""
165
+
166
+ alias = "k"
167
+
168
+ def __init__(
169
+ self,
170
+ net: nn.Module,
171
+ *,
172
+ sigma_distribution: Distribution,
173
+ sigma_data: float, # data distribution standard deviation
174
+ dynamic_threshold: float = 0.0,
175
+ ):
176
+ super().__init__()
177
+ self.net = net
178
+ self.sigma_data = sigma_data
179
+ self.sigma_distribution = sigma_distribution
180
+ self.dynamic_threshold = dynamic_threshold
181
+
182
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
183
+ sigma_data = self.sigma_data
184
+ c_noise = torch.log(sigmas) * 0.25
185
+ sigmas = rearrange(sigmas, "b -> b 1 1")
186
+ c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
187
+ c_out = sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
188
+ c_in = (sigmas**2 + sigma_data**2) ** -0.5
189
+ return c_skip, c_out, c_in, c_noise
190
+
191
+ def denoise_fn(
192
+ self,
193
+ x_noisy: Tensor,
194
+ sigmas: Optional[Tensor] = None,
195
+ sigma: Optional[float] = None,
196
+ **kwargs,
197
+ ) -> Tensor:
198
+ batch_size, device = x_noisy.shape[0], x_noisy.device
199
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
200
+
201
+ # Predict network output and add skip connection
202
+ c_skip, c_out, c_in, c_noise = self.get_scale_weights(sigmas)
203
+ x_pred = self.net(c_in * x_noisy, c_noise, **kwargs)
204
+ x_denoised = c_skip * x_noisy + c_out * x_pred
205
+
206
+ return x_denoised
207
+
208
+ def loss_weight(self, sigmas: Tensor) -> Tensor:
209
+ # Computes weight depending on data distribution
210
+ return (sigmas**2 + self.sigma_data**2) * (sigmas * self.sigma_data) ** -2
211
+
212
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
213
+ batch_size, device = x.shape[0], x.device
214
+ from einops import rearrange, reduce
215
+
216
+ # Sample amount of noise to add for each batch element
217
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
218
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
219
+
220
+ # Add noise to input
221
+ noise = default(noise, lambda: torch.randn_like(x))
222
+ x_noisy = x + sigmas_padded * noise
223
+
224
+ # Compute denoised values
225
+ x_denoised = self.denoise_fn(x_noisy, sigmas=sigmas, **kwargs)
226
+
227
+ # Compute weighted loss
228
+ losses = F.mse_loss(x_denoised, x, reduction="none")
229
+ losses = reduce(losses, "b ... -> b", "mean")
230
+ losses = losses * self.loss_weight(sigmas)
231
+ loss = losses.mean()
232
+ return loss
233
+
234
+
235
+ class VKDiffusion(Diffusion):
236
+ alias = "vk"
237
+
238
+ def __init__(self, net: nn.Module, *, sigma_distribution: Distribution):
239
+ super().__init__()
240
+ self.net = net
241
+ self.sigma_distribution = sigma_distribution
242
+
243
+ def get_scale_weights(self, sigmas: Tensor) -> Tuple[Tensor, ...]:
244
+ sigma_data = 1.0
245
+ sigmas = rearrange(sigmas, "b -> b 1 1")
246
+ c_skip = (sigma_data**2) / (sigmas**2 + sigma_data**2)
247
+ c_out = -sigmas * sigma_data * (sigma_data**2 + sigmas**2) ** -0.5
248
+ c_in = (sigmas**2 + sigma_data**2) ** -0.5
249
+ return c_skip, c_out, c_in
250
+
251
+ def sigma_to_t(self, sigmas: Tensor) -> Tensor:
252
+ return sigmas.atan() / pi * 2
253
+
254
+ def t_to_sigma(self, t: Tensor) -> Tensor:
255
+ return (t * pi / 2).tan()
256
+
257
+ def denoise_fn(
258
+ self,
259
+ x_noisy: Tensor,
260
+ sigmas: Optional[Tensor] = None,
261
+ sigma: Optional[float] = None,
262
+ **kwargs,
263
+ ) -> Tensor:
264
+ batch_size, device = x_noisy.shape[0], x_noisy.device
265
+ sigmas = to_batch(x=sigma, xs=sigmas, batch_size=batch_size, device=device)
266
+
267
+ # Predict network output and add skip connection
268
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
269
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
270
+ x_denoised = c_skip * x_noisy + c_out * x_pred
271
+ return x_denoised
272
+
273
+ def forward(self, x: Tensor, noise: Tensor = None, **kwargs) -> Tensor:
274
+ batch_size, device = x.shape[0], x.device
275
+
276
+ # Sample amount of noise to add for each batch element
277
+ sigmas = self.sigma_distribution(num_samples=batch_size, device=device)
278
+ sigmas_padded = rearrange(sigmas, "b -> b 1 1")
279
+
280
+ # Add noise to input
281
+ noise = default(noise, lambda: torch.randn_like(x))
282
+ x_noisy = x + sigmas_padded * noise
283
+
284
+ # Compute model output
285
+ c_skip, c_out, c_in = self.get_scale_weights(sigmas)
286
+ x_pred = self.net(c_in * x_noisy, self.sigma_to_t(sigmas), **kwargs)
287
+
288
+ # Compute v-objective target
289
+ v_target = (x - c_skip * x_noisy) / (c_out + 1e-7)
290
+
291
+ # Compute loss
292
+ loss = F.mse_loss(x_pred, v_target)
293
+ return loss
294
+
295
+
296
+ """
297
+ Diffusion Sampling
298
+ """
299
+
300
+ """ Schedules """
301
+
302
+
303
+ class Schedule(nn.Module):
304
+ """Interface used by different sampling schedules"""
305
+
306
+ def forward(self, num_steps: int, device: torch.device) -> Tensor:
307
+ raise NotImplementedError()
308
+
309
+
310
+ class LinearSchedule(Schedule):
311
+ def forward(self, num_steps: int, device: Any) -> Tensor:
312
+ sigmas = torch.linspace(1, 0, num_steps + 1)[:-1]
313
+ return sigmas
314
+
315
+
316
+ class KarrasSchedule(Schedule):
317
+ """https://arxiv.org/abs/2206.00364 equation 5"""
318
+
319
+ def __init__(self, sigma_min: float, sigma_max: float, rho: float = 7.0):
320
+ super().__init__()
321
+ self.sigma_min = sigma_min
322
+ self.sigma_max = sigma_max
323
+ self.rho = rho
324
+
325
+ def forward(self, num_steps: int, device: Any) -> Tensor:
326
+ rho_inv = 1.0 / self.rho
327
+ steps = torch.arange(num_steps, device=device, dtype=torch.float32)
328
+ sigmas = (
329
+ self.sigma_max**rho_inv
330
+ + (steps / (num_steps - 1))
331
+ * (self.sigma_min**rho_inv - self.sigma_max**rho_inv)
332
+ ) ** self.rho
333
+ sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
334
+ return sigmas
335
+
336
+
337
+ """ Samplers """
338
+
339
+
340
+ class Sampler(nn.Module):
341
+ diffusion_types: List[Type[Diffusion]] = []
342
+
343
+ def forward(
344
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
345
+ ) -> Tensor:
346
+ raise NotImplementedError()
347
+
348
+ def inpaint(
349
+ self,
350
+ source: Tensor,
351
+ mask: Tensor,
352
+ fn: Callable,
353
+ sigmas: Tensor,
354
+ num_steps: int,
355
+ num_resamples: int,
356
+ ) -> Tensor:
357
+ raise NotImplementedError("Inpainting not available with current sampler")
358
+
359
+
360
+ class VSampler(Sampler):
361
+ diffusion_types = [VDiffusion]
362
+
363
+ def get_alpha_beta(self, sigma: float) -> Tuple[float, float]:
364
+ angle = sigma * pi / 2
365
+ alpha = cos(angle)
366
+ beta = sin(angle)
367
+ return alpha, beta
368
+
369
+ def forward(
370
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
371
+ ) -> Tensor:
372
+ x = sigmas[0] * noise
373
+ alpha, beta = self.get_alpha_beta(sigmas[0].item())
374
+
375
+ for i in range(num_steps - 1):
376
+ is_last = i == num_steps - 1
377
+
378
+ x_denoised = fn(x, sigma=sigmas[i])
379
+ x_pred = x * alpha - x_denoised * beta
380
+ x_eps = x * beta + x_denoised * alpha
381
+
382
+ if not is_last:
383
+ alpha, beta = self.get_alpha_beta(sigmas[i + 1].item())
384
+ x = x_pred * alpha + x_eps * beta
385
+
386
+ return x_pred
387
+
388
+
389
+ class KarrasSampler(Sampler):
390
+ """https://arxiv.org/abs/2206.00364 algorithm 1"""
391
+
392
+ diffusion_types = [KDiffusion, VKDiffusion]
393
+
394
+ def __init__(
395
+ self,
396
+ s_tmin: float = 0,
397
+ s_tmax: float = float("inf"),
398
+ s_churn: float = 0.0,
399
+ s_noise: float = 1.0,
400
+ ):
401
+ super().__init__()
402
+ self.s_tmin = s_tmin
403
+ self.s_tmax = s_tmax
404
+ self.s_noise = s_noise
405
+ self.s_churn = s_churn
406
+
407
+ def step(
408
+ self, x: Tensor, fn: Callable, sigma: float, sigma_next: float, gamma: float
409
+ ) -> Tensor:
410
+ """Algorithm 2 (step)"""
411
+ # Select temporarily increased noise level
412
+ sigma_hat = sigma + gamma * sigma
413
+ # Add noise to move from sigma to sigma_hat
414
+ epsilon = self.s_noise * torch.randn_like(x)
415
+ x_hat = x + sqrt(sigma_hat**2 - sigma**2) * epsilon
416
+ # Evaluate ∂x/∂sigma at sigma_hat
417
+ d = (x_hat - fn(x_hat, sigma=sigma_hat)) / sigma_hat
418
+ # Take euler step from sigma_hat to sigma_next
419
+ x_next = x_hat + (sigma_next - sigma_hat) * d
420
+ # Second order correction
421
+ if sigma_next != 0:
422
+ model_out_next = fn(x_next, sigma=sigma_next)
423
+ d_prime = (x_next - model_out_next) / sigma_next
424
+ x_next = x_hat + 0.5 * (sigma - sigma_hat) * (d + d_prime)
425
+ return x_next
426
+
427
+ def forward(
428
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
429
+ ) -> Tensor:
430
+ x = sigmas[0] * noise
431
+ # Compute gammas
432
+ gammas = torch.where(
433
+ (sigmas >= self.s_tmin) & (sigmas <= self.s_tmax),
434
+ min(self.s_churn / num_steps, sqrt(2) - 1),
435
+ 0.0,
436
+ )
437
+ # Denoise to sample
438
+ for i in range(num_steps - 1):
439
+ x = self.step(
440
+ x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1], gamma=gammas[i] # type: ignore # noqa
441
+ )
442
+
443
+ return x
444
+
445
+
446
+ class AEulerSampler(Sampler):
447
+ diffusion_types = [KDiffusion, VKDiffusion]
448
+
449
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float]:
450
+ sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
451
+ sigma_down = sqrt(sigma_next**2 - sigma_up**2)
452
+ return sigma_up, sigma_down
453
+
454
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
455
+ # Sigma steps
456
+ sigma_up, sigma_down = self.get_sigmas(sigma, sigma_next)
457
+ # Derivative at sigma (∂x/∂sigma)
458
+ d = (x - fn(x, sigma=sigma)) / sigma
459
+ # Euler method
460
+ x_next = x + d * (sigma_down - sigma)
461
+ # Add randomness
462
+ x_next = x_next + torch.randn_like(x) * sigma_up
463
+ return x_next
464
+
465
+ def forward(
466
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
467
+ ) -> Tensor:
468
+ x = sigmas[0] * noise
469
+ # Denoise to sample
470
+ for i in range(num_steps - 1):
471
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
472
+ return x
473
+
474
+
475
+ class ADPM2Sampler(Sampler):
476
+ """https://www.desmos.com/calculator/jbxjlqd9mb"""
477
+
478
+ diffusion_types = [KDiffusion, VKDiffusion]
479
+
480
+ def __init__(self, rho: float = 1.0):
481
+ super().__init__()
482
+ self.rho = rho
483
+
484
+ def get_sigmas(self, sigma: float, sigma_next: float) -> Tuple[float, float, float]:
485
+ r = self.rho
486
+ sigma_up = sqrt(sigma_next**2 * (sigma**2 - sigma_next**2) / sigma**2)
487
+ sigma_down = sqrt(sigma_next**2 - sigma_up**2)
488
+ sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
489
+ return sigma_up, sigma_down, sigma_mid
490
+
491
+ def step(self, x: Tensor, fn: Callable, sigma: float, sigma_next: float) -> Tensor:
492
+ # Sigma steps
493
+ sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
494
+ # Derivative at sigma (∂x/∂sigma)
495
+ d = (x - fn(x, sigma=sigma)) / sigma
496
+ # Denoise to midpoint
497
+ x_mid = x + d * (sigma_mid - sigma)
498
+ # Derivative at sigma_mid (∂x_mid/∂sigma_mid)
499
+ d_mid = (x_mid - fn(x_mid, sigma=sigma_mid)) / sigma_mid
500
+ # Denoise to next
501
+ x = x + d_mid * (sigma_down - sigma)
502
+ # Add randomness
503
+ x_next = x + torch.randn_like(x) * sigma_up
504
+ return x_next
505
+
506
+ def forward(
507
+ self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
508
+ ) -> Tensor:
509
+ x = sigmas[0] * noise
510
+ # Denoise to sample
511
+ for i in range(num_steps - 1):
512
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
513
+ return x
514
+
515
+ def inpaint(
516
+ self,
517
+ source: Tensor,
518
+ mask: Tensor,
519
+ fn: Callable,
520
+ sigmas: Tensor,
521
+ num_steps: int,
522
+ num_resamples: int,
523
+ ) -> Tensor:
524
+ x = sigmas[0] * torch.randn_like(source)
525
+
526
+ for i in range(num_steps - 1):
527
+ # Noise source to current noise level
528
+ source_noisy = source + sigmas[i] * torch.randn_like(source)
529
+ for r in range(num_resamples):
530
+ # Merge noisy source and current then denoise
531
+ x = source_noisy * mask + x * ~mask
532
+ x = self.step(x, fn=fn, sigma=sigmas[i], sigma_next=sigmas[i + 1]) # type: ignore # noqa
533
+ # Renoise if not last resample step
534
+ if r < num_resamples - 1:
535
+ sigma = sqrt(sigmas[i] ** 2 - sigmas[i + 1] ** 2)
536
+ x = x + sigma * torch.randn_like(x)
537
+
538
+ return source * mask + x * ~mask
539
+
540
+
541
+ """ Main Classes """
542
+
543
+
544
+ class DiffusionSampler(nn.Module):
545
+ def __init__(
546
+ self,
547
+ diffusion: Diffusion,
548
+ *,
549
+ sampler: Sampler,
550
+ sigma_schedule: Schedule,
551
+ num_steps: Optional[int] = None,
552
+ clamp: bool = True,
553
+ ):
554
+ super().__init__()
555
+ self.denoise_fn = diffusion.denoise_fn
556
+ self.sampler = sampler
557
+ self.sigma_schedule = sigma_schedule
558
+ self.num_steps = num_steps
559
+ self.clamp = clamp
560
+
561
+ # Check sampler is compatible with diffusion type
562
+ sampler_class = sampler.__class__.__name__
563
+ diffusion_class = diffusion.__class__.__name__
564
+ message = f"{sampler_class} incompatible with {diffusion_class}"
565
+ assert diffusion.alias in [t.alias for t in sampler.diffusion_types], message
566
+
567
+ def forward(
568
+ self, noise: Tensor, num_steps: Optional[int] = None, **kwargs
569
+ ) -> Tensor:
570
+ device = noise.device
571
+ num_steps = default(num_steps, self.num_steps) # type: ignore
572
+ assert exists(num_steps), "Parameter `num_steps` must be provided"
573
+ # Compute sigmas using schedule
574
+ sigmas = self.sigma_schedule(num_steps, device)
575
+ # Append additional kwargs to denoise function (used e.g. for conditional unet)
576
+ fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
577
+ # Sample using sampler
578
+ x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
579
+ x = x.clamp(-1.0, 1.0) if self.clamp else x
580
+ return x
581
+
582
+
583
+ class DiffusionInpainter(nn.Module):
584
+ def __init__(
585
+ self,
586
+ diffusion: Diffusion,
587
+ *,
588
+ num_steps: int,
589
+ num_resamples: int,
590
+ sampler: Sampler,
591
+ sigma_schedule: Schedule,
592
+ ):
593
+ super().__init__()
594
+ self.denoise_fn = diffusion.denoise_fn
595
+ self.num_steps = num_steps
596
+ self.num_resamples = num_resamples
597
+ self.inpaint_fn = sampler.inpaint
598
+ self.sigma_schedule = sigma_schedule
599
+
600
+ @torch.no_grad()
601
+ def forward(self, inpaint: Tensor, inpaint_mask: Tensor) -> Tensor:
602
+ x = self.inpaint_fn(
603
+ source=inpaint,
604
+ mask=inpaint_mask,
605
+ fn=self.denoise_fn,
606
+ sigmas=self.sigma_schedule(self.num_steps, inpaint.device),
607
+ num_steps=self.num_steps,
608
+ num_resamples=self.num_resamples,
609
+ )
610
+ return x
611
+
612
+
613
+ def sequential_mask(like: Tensor, start: int) -> Tensor:
614
+ length, device = like.shape[2], like.device
615
+ mask = torch.ones_like(like, dtype=torch.bool)
616
+ mask[:, :, start:] = torch.zeros((length - start,), device=device)
617
+ return mask
618
+
619
+
620
+ class SpanBySpanComposer(nn.Module):
621
+ def __init__(
622
+ self,
623
+ inpainter: DiffusionInpainter,
624
+ *,
625
+ num_spans: int,
626
+ ):
627
+ super().__init__()
628
+ self.inpainter = inpainter
629
+ self.num_spans = num_spans
630
+
631
+ def forward(self, start: Tensor, keep_start: bool = False) -> Tensor:
632
+ half_length = start.shape[2] // 2
633
+
634
+ spans = list(start.chunk(chunks=2, dim=-1)) if keep_start else []
635
+ # Inpaint second half from first half
636
+ inpaint = torch.zeros_like(start)
637
+ inpaint[:, :, :half_length] = start[:, :, half_length:]
638
+ inpaint_mask = sequential_mask(like=start, start=half_length)
639
+
640
+ for i in range(self.num_spans):
641
+ # Inpaint second half
642
+ span = self.inpainter(inpaint=inpaint, inpaint_mask=inpaint_mask)
643
+ # Replace first half with generated second half
644
+ second_half = span[:, :, half_length:]
645
+ inpaint[:, :, :half_length] = second_half
646
+ # Save generated span
647
+ spans.append(second_half)
648
+
649
+ return torch.cat(spans, dim=2)
650
+
651
+
652
+ class XDiffusion(nn.Module):
653
+ def __init__(self, type: str, net: nn.Module, **kwargs):
654
+ super().__init__()
655
+
656
+ diffusion_classes = [VDiffusion, KDiffusion, VKDiffusion]
657
+ aliases = [t.alias for t in diffusion_classes] # type: ignore
658
+ message = f"type='{type}' must be one of {*aliases,}"
659
+ assert type in aliases, message
660
+ self.net = net
661
+
662
+ for XDiffusion in diffusion_classes:
663
+ if XDiffusion.alias == type: # type: ignore
664
+ self.diffusion = XDiffusion(net=net, **kwargs)
665
+
666
+ def forward(self, *args, **kwargs) -> Tensor:
667
+ return self.diffusion(*args, **kwargs)
668
+
669
+ def sample(
670
+ self,
671
+ noise: Tensor,
672
+ num_steps: int,
673
+ sigma_schedule: Schedule,
674
+ sampler: Sampler,
675
+ clamp: bool,
676
+ **kwargs,
677
+ ) -> Tensor:
678
+ diffusion_sampler = DiffusionSampler(
679
+ diffusion=self.diffusion,
680
+ sampler=sampler,
681
+ sigma_schedule=sigma_schedule,
682
+ num_steps=num_steps,
683
+ clamp=clamp,
684
+ )
685
+ return diffusion_sampler(noise, **kwargs)
Modules/diffusion/utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ from inspect import isfunction
3
+ from math import ceil, floor, log2, pi
4
+ from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from torch import Generator, Tensor
10
+ from typing_extensions import TypeGuard
11
+
12
+ T = TypeVar("T")
13
+
14
+
15
+ def exists(val: Optional[T]) -> TypeGuard[T]:
16
+ return val is not None
17
+
18
+
19
+ def iff(condition: bool, value: T) -> Optional[T]:
20
+ return value if condition else None
21
+
22
+
23
+ def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
24
+ return isinstance(obj, list) or isinstance(obj, tuple)
25
+
26
+
27
+ def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
28
+ if exists(val):
29
+ return val
30
+ return d() if isfunction(d) else d
31
+
32
+
33
+ def to_list(val: Union[T, Sequence[T]]) -> List[T]:
34
+ if isinstance(val, tuple):
35
+ return list(val)
36
+ if isinstance(val, list):
37
+ return val
38
+ return [val] # type: ignore
39
+
40
+
41
+ def prod(vals: Sequence[int]) -> int:
42
+ return reduce(lambda x, y: x * y, vals)
43
+
44
+
45
+ def closest_power_2(x: float) -> int:
46
+ exponent = log2(x)
47
+ distance_fn = lambda z: abs(x - 2**z) # noqa
48
+ exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
49
+ return 2 ** int(exponent_closest)
50
+
51
+
52
+ def rand_bool(shape, proba, device=None):
53
+ if proba == 1:
54
+ return torch.ones(shape, device=device, dtype=torch.bool)
55
+ elif proba == 0:
56
+ return torch.zeros(shape, device=device, dtype=torch.bool)
57
+ else:
58
+ return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
59
+
60
+
61
+ """
62
+ Kwargs Utils
63
+ """
64
+
65
+
66
+ def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
67
+ return_dicts: Tuple[Dict, Dict] = ({}, {})
68
+ for key in d.keys():
69
+ no_prefix = int(not key.startswith(prefix))
70
+ return_dicts[no_prefix][key] = d[key]
71
+ return return_dicts
72
+
73
+
74
+ def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
75
+ kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
76
+ if keep_prefix:
77
+ return kwargs_with_prefix, kwargs
78
+ kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
79
+ return kwargs_no_prefix, kwargs
80
+
81
+
82
+ def prefix_dict(prefix: str, d: Dict) -> Dict:
83
+ return {prefix + str(k): v for k, v in d.items()}
Modules/discriminators.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, spectral_norm
6
+
7
+ from .utils import get_padding
8
+
9
+ LRELU_SLOPE = 0.1
10
+
11
+
12
+ def stft(x, fft_size, hop_size, win_length, window):
13
+ """Perform STFT and convert to magnitude spectrogram.
14
+ Args:
15
+ x (Tensor): Input signal tensor (B, T).
16
+ fft_size (int): FFT size.
17
+ hop_size (int): Hop size.
18
+ win_length (int): Window length.
19
+ window (str): Window function type.
20
+ Returns:
21
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
22
+ """
23
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
24
+ real = x_stft[..., 0]
25
+ imag = x_stft[..., 1]
26
+
27
+ return torch.abs(x_stft).transpose(2, 1)
28
+
29
+
30
+ class SpecDiscriminator(nn.Module):
31
+ """docstring for Discriminator."""
32
+
33
+ def __init__(
34
+ self,
35
+ fft_size=1024,
36
+ shift_size=120,
37
+ win_length=600,
38
+ window="hann_window",
39
+ use_spectral_norm=False,
40
+ ):
41
+ super(SpecDiscriminator, self).__init__()
42
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
43
+ self.fft_size = fft_size
44
+ self.shift_size = shift_size
45
+ self.win_length = win_length
46
+ self.window = getattr(torch, window)(win_length)
47
+ self.discriminators = nn.ModuleList(
48
+ [
49
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
50
+ norm_f(
51
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
52
+ ),
53
+ norm_f(
54
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
55
+ ),
56
+ norm_f(
57
+ nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))
58
+ ),
59
+ norm_f(
60
+ nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
61
+ ),
62
+ ]
63
+ )
64
+
65
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
66
+
67
+ def forward(self, y):
68
+ fmap = []
69
+ y = y.squeeze(1)
70
+ y = stft(
71
+ y,
72
+ self.fft_size,
73
+ self.shift_size,
74
+ self.win_length,
75
+ self.window.to(y.get_device()),
76
+ )
77
+ y = y.unsqueeze(1)
78
+ for i, d in enumerate(self.discriminators):
79
+ y = d(y)
80
+ y = F.leaky_relu(y, LRELU_SLOPE)
81
+ fmap.append(y)
82
+
83
+ y = self.out(y)
84
+ fmap.append(y)
85
+
86
+ return torch.flatten(y, 1, -1), fmap
87
+
88
+
89
+ class MultiResSpecDiscriminator(torch.nn.Module):
90
+ def __init__(
91
+ self,
92
+ fft_sizes=[1024, 2048, 512],
93
+ hop_sizes=[120, 240, 50],
94
+ win_lengths=[600, 1200, 240],
95
+ window="hann_window",
96
+ ):
97
+ super(MultiResSpecDiscriminator, self).__init__()
98
+ self.discriminators = nn.ModuleList(
99
+ [
100
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
101
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
102
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window),
103
+ ]
104
+ )
105
+
106
+ def forward(self, y, y_hat):
107
+ y_d_rs = []
108
+ y_d_gs = []
109
+ fmap_rs = []
110
+ fmap_gs = []
111
+ for i, d in enumerate(self.discriminators):
112
+ y_d_r, fmap_r = d(y)
113
+ y_d_g, fmap_g = d(y_hat)
114
+ y_d_rs.append(y_d_r)
115
+ fmap_rs.append(fmap_r)
116
+ y_d_gs.append(y_d_g)
117
+ fmap_gs.append(fmap_g)
118
+
119
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
120
+
121
+
122
+ class DiscriminatorP(torch.nn.Module):
123
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
124
+ super(DiscriminatorP, self).__init__()
125
+ self.period = period
126
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
127
+ self.convs = nn.ModuleList(
128
+ [
129
+ norm_f(
130
+ Conv2d(
131
+ 1,
132
+ 32,
133
+ (kernel_size, 1),
134
+ (stride, 1),
135
+ padding=(get_padding(5, 1), 0),
136
+ )
137
+ ),
138
+ norm_f(
139
+ Conv2d(
140
+ 32,
141
+ 128,
142
+ (kernel_size, 1),
143
+ (stride, 1),
144
+ padding=(get_padding(5, 1), 0),
145
+ )
146
+ ),
147
+ norm_f(
148
+ Conv2d(
149
+ 128,
150
+ 512,
151
+ (kernel_size, 1),
152
+ (stride, 1),
153
+ padding=(get_padding(5, 1), 0),
154
+ )
155
+ ),
156
+ norm_f(
157
+ Conv2d(
158
+ 512,
159
+ 1024,
160
+ (kernel_size, 1),
161
+ (stride, 1),
162
+ padding=(get_padding(5, 1), 0),
163
+ )
164
+ ),
165
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
166
+ ]
167
+ )
168
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
169
+
170
+ def forward(self, x):
171
+ fmap = []
172
+
173
+ # 1d to 2d
174
+ b, c, t = x.shape
175
+ if t % self.period != 0: # pad first
176
+ n_pad = self.period - (t % self.period)
177
+ x = F.pad(x, (0, n_pad), "reflect")
178
+ t = t + n_pad
179
+ x = x.view(b, c, t // self.period, self.period)
180
+
181
+ for l in self.convs:
182
+ x = l(x)
183
+ x = F.leaky_relu(x, LRELU_SLOPE)
184
+ fmap.append(x)
185
+ x = self.conv_post(x)
186
+ fmap.append(x)
187
+ x = torch.flatten(x, 1, -1)
188
+
189
+ return x, fmap
190
+
191
+
192
+ class MultiPeriodDiscriminator(torch.nn.Module):
193
+ def __init__(self):
194
+ super(MultiPeriodDiscriminator, self).__init__()
195
+ self.discriminators = nn.ModuleList(
196
+ [
197
+ DiscriminatorP(2),
198
+ DiscriminatorP(3),
199
+ DiscriminatorP(5),
200
+ DiscriminatorP(7),
201
+ DiscriminatorP(11),
202
+ ]
203
+ )
204
+
205
+ def forward(self, y, y_hat):
206
+ y_d_rs = []
207
+ y_d_gs = []
208
+ fmap_rs = []
209
+ fmap_gs = []
210
+ for i, d in enumerate(self.discriminators):
211
+ y_d_r, fmap_r = d(y)
212
+ y_d_g, fmap_g = d(y_hat)
213
+ y_d_rs.append(y_d_r)
214
+ fmap_rs.append(fmap_r)
215
+ y_d_gs.append(y_d_g)
216
+ fmap_gs.append(fmap_g)
217
+
218
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
219
+
220
+
221
+ class WavLMDiscriminator(nn.Module):
222
+ """docstring for Discriminator."""
223
+
224
+ def __init__(
225
+ self, slm_hidden=768, slm_layers=13, initial_channel=64, use_spectral_norm=False
226
+ ):
227
+ super(WavLMDiscriminator, self).__init__()
228
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
229
+ self.pre = norm_f(
230
+ Conv1d(slm_hidden * slm_layers, initial_channel, 1, 1, padding=0)
231
+ )
232
+
233
+ self.convs = nn.ModuleList(
234
+ [
235
+ norm_f(
236
+ nn.Conv1d(
237
+ initial_channel, initial_channel * 2, kernel_size=5, padding=2
238
+ )
239
+ ),
240
+ norm_f(
241
+ nn.Conv1d(
242
+ initial_channel * 2,
243
+ initial_channel * 4,
244
+ kernel_size=5,
245
+ padding=2,
246
+ )
247
+ ),
248
+ norm_f(
249
+ nn.Conv1d(initial_channel * 4, initial_channel * 4, 5, 1, padding=2)
250
+ ),
251
+ ]
252
+ )
253
+
254
+ self.conv_post = norm_f(Conv1d(initial_channel * 4, 1, 3, 1, padding=1))
255
+
256
+ def forward(self, x):
257
+ x = self.pre(x)
258
+
259
+ fmap = []
260
+ for l in self.convs:
261
+ x = l(x)
262
+ x = F.leaky_relu(x, LRELU_SLOPE)
263
+ fmap.append(x)
264
+ x = self.conv_post(x)
265
+ x = torch.flatten(x, 1, -1)
266
+
267
+ return x
Modules/hifigan.py ADDED
@@ -0,0 +1,643 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+
12
+ LRELU_SLOPE = 0.1
13
+
14
+
15
+ class AdaIN1d(nn.Module):
16
+ def __init__(self, style_dim, num_features):
17
+ super().__init__()
18
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
19
+ self.fc = nn.Linear(style_dim, num_features * 2)
20
+
21
+ def forward(self, x, s):
22
+ h = self.fc(s)
23
+ h = h.view(h.size(0), h.size(1), 1)
24
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
25
+ return (1 + gamma) * self.norm(x) + beta
26
+
27
+
28
+ class AdaINResBlock1(torch.nn.Module):
29
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
30
+ super(AdaINResBlock1, self).__init__()
31
+ self.convs1 = nn.ModuleList(
32
+ [
33
+ weight_norm(
34
+ Conv1d(
35
+ channels,
36
+ channels,
37
+ kernel_size,
38
+ 1,
39
+ dilation=dilation[0],
40
+ padding=get_padding(kernel_size, dilation[0]),
41
+ )
42
+ ),
43
+ weight_norm(
44
+ Conv1d(
45
+ channels,
46
+ channels,
47
+ kernel_size,
48
+ 1,
49
+ dilation=dilation[1],
50
+ padding=get_padding(kernel_size, dilation[1]),
51
+ )
52
+ ),
53
+ weight_norm(
54
+ Conv1d(
55
+ channels,
56
+ channels,
57
+ kernel_size,
58
+ 1,
59
+ dilation=dilation[2],
60
+ padding=get_padding(kernel_size, dilation[2]),
61
+ )
62
+ ),
63
+ ]
64
+ )
65
+ self.convs1.apply(init_weights)
66
+
67
+ self.convs2 = nn.ModuleList(
68
+ [
69
+ weight_norm(
70
+ Conv1d(
71
+ channels,
72
+ channels,
73
+ kernel_size,
74
+ 1,
75
+ dilation=1,
76
+ padding=get_padding(kernel_size, 1),
77
+ )
78
+ ),
79
+ weight_norm(
80
+ Conv1d(
81
+ channels,
82
+ channels,
83
+ kernel_size,
84
+ 1,
85
+ dilation=1,
86
+ padding=get_padding(kernel_size, 1),
87
+ )
88
+ ),
89
+ weight_norm(
90
+ Conv1d(
91
+ channels,
92
+ channels,
93
+ kernel_size,
94
+ 1,
95
+ dilation=1,
96
+ padding=get_padding(kernel_size, 1),
97
+ )
98
+ ),
99
+ ]
100
+ )
101
+ self.convs2.apply(init_weights)
102
+
103
+ self.adain1 = nn.ModuleList(
104
+ [
105
+ AdaIN1d(style_dim, channels),
106
+ AdaIN1d(style_dim, channels),
107
+ AdaIN1d(style_dim, channels),
108
+ ]
109
+ )
110
+
111
+ self.adain2 = nn.ModuleList(
112
+ [
113
+ AdaIN1d(style_dim, channels),
114
+ AdaIN1d(style_dim, channels),
115
+ AdaIN1d(style_dim, channels),
116
+ ]
117
+ )
118
+
119
+ self.alpha1 = nn.ParameterList(
120
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
121
+ )
122
+ self.alpha2 = nn.ParameterList(
123
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
124
+ )
125
+
126
+ def forward(self, x, s):
127
+ for c1, c2, n1, n2, a1, a2 in zip(
128
+ self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
129
+ ):
130
+ xt = n1(x, s)
131
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
132
+ xt = c1(xt)
133
+ xt = n2(xt, s)
134
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
135
+ xt = c2(xt)
136
+ x = xt + x
137
+ return x
138
+
139
+ def remove_weight_norm(self):
140
+ for l in self.convs1:
141
+ remove_weight_norm(l)
142
+ for l in self.convs2:
143
+ remove_weight_norm(l)
144
+
145
+
146
+ class SineGen(torch.nn.Module):
147
+ """Definition of sine generator
148
+ SineGen(samp_rate, harmonic_num = 0,
149
+ sine_amp = 0.1, noise_std = 0.003,
150
+ voiced_threshold = 0,
151
+ flag_for_pulse=False)
152
+ samp_rate: sampling rate in Hz
153
+ harmonic_num: number of harmonic overtones (default 0)
154
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
155
+ noise_std: std of Gaussian noise (default 0.003)
156
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
157
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
158
+ Note: when flag_for_pulse is True, the first time step of a voiced
159
+ segment is always sin(np.pi) or cos(0)
160
+ """
161
+
162
+ def __init__(
163
+ self,
164
+ samp_rate,
165
+ upsample_scale,
166
+ harmonic_num=0,
167
+ sine_amp=0.1,
168
+ noise_std=0.003,
169
+ voiced_threshold=0,
170
+ flag_for_pulse=False,
171
+ ):
172
+ super(SineGen, self).__init__()
173
+ self.sine_amp = sine_amp
174
+ self.noise_std = noise_std
175
+ self.harmonic_num = harmonic_num
176
+ self.dim = self.harmonic_num + 1
177
+ self.sampling_rate = samp_rate
178
+ self.voiced_threshold = voiced_threshold
179
+ self.flag_for_pulse = flag_for_pulse
180
+ self.upsample_scale = upsample_scale
181
+
182
+ def _f02uv(self, f0):
183
+ # generate uv signal
184
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
185
+ return uv
186
+
187
+ def _f02sine(self, f0_values):
188
+ """f0_values: (batchsize, length, dim)
189
+ where dim indicates fundamental tone and overtones
190
+ """
191
+ # convert to F0 in rad. The interger part n can be ignored
192
+ # because 2 * np.pi * n doesn't affect phase
193
+ rad_values = (f0_values / self.sampling_rate) % 1
194
+
195
+ # initial phase noise (no noise for fundamental component)
196
+ rand_ini = torch.rand(
197
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
198
+ )
199
+ rand_ini[:, 0] = 0
200
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
201
+
202
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
203
+ if not self.flag_for_pulse:
204
+ # # for normal case
205
+
206
+ # # To prevent torch.cumsum numerical overflow,
207
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
208
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
209
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
210
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
211
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
212
+ # cumsum_shift = torch.zeros_like(rad_values)
213
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
214
+
215
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
216
+ rad_values = torch.nn.functional.interpolate(
217
+ rad_values.transpose(1, 2),
218
+ scale_factor=1 / self.upsample_scale,
219
+ mode="linear",
220
+ ).transpose(1, 2)
221
+
222
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
223
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
224
+ # cumsum_shift = torch.zeros_like(rad_values)
225
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
226
+
227
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
228
+ phase = torch.nn.functional.interpolate(
229
+ phase.transpose(1, 2) * self.upsample_scale,
230
+ scale_factor=self.upsample_scale,
231
+ mode="linear",
232
+ ).transpose(1, 2)
233
+ sines = torch.sin(phase)
234
+
235
+ else:
236
+ # If necessary, make sure that the first time step of every
237
+ # voiced segments is sin(pi) or cos(0)
238
+ # This is used for pulse-train generation
239
+
240
+ # identify the last time step in unvoiced segments
241
+ uv = self._f02uv(f0_values)
242
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
243
+ uv_1[:, -1, :] = 1
244
+ u_loc = (uv < 1) * (uv_1 > 0)
245
+
246
+ # get the instantanouse phase
247
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
248
+ # different batch needs to be processed differently
249
+ for idx in range(f0_values.shape[0]):
250
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
251
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
252
+ # stores the accumulation of i.phase within
253
+ # each voiced segments
254
+ tmp_cumsum[idx, :, :] = 0
255
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
256
+
257
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
258
+ # within the previous voiced segment.
259
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
260
+
261
+ # get the sines
262
+ sines = torch.cos(i_phase * 2 * np.pi)
263
+ return sines
264
+
265
+ def forward(self, f0):
266
+ """sine_tensor, uv = forward(f0)
267
+ input F0: tensor(batchsize=1, length, dim=1)
268
+ f0 for unvoiced steps should be 0
269
+ output sine_tensor: tensor(batchsize=1, length, dim)
270
+ output uv: tensor(batchsize=1, length, 1)
271
+ """
272
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
273
+ # fundamental component
274
+ fn = torch.multiply(
275
+ f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
276
+ )
277
+
278
+ # generate sine waveforms
279
+ sine_waves = self._f02sine(fn) * self.sine_amp
280
+
281
+ # generate uv signal
282
+ # uv = torch.ones(f0.shape)
283
+ # uv = uv * (f0 > self.voiced_threshold)
284
+ uv = self._f02uv(f0)
285
+
286
+ # noise: for unvoiced should be similar to sine_amp
287
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
288
+ # . for voiced regions is self.noise_std
289
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
290
+ noise = noise_amp * torch.randn_like(sine_waves)
291
+
292
+ # first: set the unvoiced part to 0 by uv
293
+ # then: additive noise
294
+ sine_waves = sine_waves * uv + noise
295
+ return sine_waves, uv, noise
296
+
297
+
298
+ class SourceModuleHnNSF(torch.nn.Module):
299
+ """SourceModule for hn-nsf
300
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
301
+ add_noise_std=0.003, voiced_threshod=0)
302
+ sampling_rate: sampling_rate in Hz
303
+ harmonic_num: number of harmonic above F0 (default: 0)
304
+ sine_amp: amplitude of sine source signal (default: 0.1)
305
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
306
+ note that amplitude of noise in unvoiced is decided
307
+ by sine_amp
308
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
309
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
310
+ F0_sampled (batchsize, length, 1)
311
+ Sine_source (batchsize, length, 1)
312
+ noise_source (batchsize, length 1)
313
+ uv (batchsize, length, 1)
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ sampling_rate,
319
+ upsample_scale,
320
+ harmonic_num=0,
321
+ sine_amp=0.1,
322
+ add_noise_std=0.003,
323
+ voiced_threshod=0,
324
+ ):
325
+ super(SourceModuleHnNSF, self).__init__()
326
+
327
+ self.sine_amp = sine_amp
328
+ self.noise_std = add_noise_std
329
+
330
+ # to produce sine waveforms
331
+ self.l_sin_gen = SineGen(
332
+ sampling_rate,
333
+ upsample_scale,
334
+ harmonic_num,
335
+ sine_amp,
336
+ add_noise_std,
337
+ voiced_threshod,
338
+ )
339
+
340
+ # to merge source harmonics into a single excitation
341
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
342
+ self.l_tanh = torch.nn.Tanh()
343
+
344
+ def forward(self, x):
345
+ """
346
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
347
+ F0_sampled (batchsize, length, 1)
348
+ Sine_source (batchsize, length, 1)
349
+ noise_source (batchsize, length 1)
350
+ """
351
+ # source for harmonic branch
352
+ with torch.no_grad():
353
+ sine_wavs, uv, _ = self.l_sin_gen(x)
354
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
355
+
356
+ # source for noise branch, in the same shape as uv
357
+ noise = torch.randn_like(uv) * self.sine_amp / 3
358
+ return sine_merge, noise, uv
359
+
360
+
361
+ def padDiff(x):
362
+ return F.pad(
363
+ F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
364
+ )
365
+
366
+
367
+ class Generator(torch.nn.Module):
368
+ def __init__(
369
+ self,
370
+ style_dim,
371
+ resblock_kernel_sizes,
372
+ upsample_rates,
373
+ upsample_initial_channel,
374
+ resblock_dilation_sizes,
375
+ upsample_kernel_sizes,
376
+ ):
377
+ super(Generator, self).__init__()
378
+ self.num_kernels = len(resblock_kernel_sizes)
379
+ self.num_upsamples = len(upsample_rates)
380
+ resblock = AdaINResBlock1
381
+
382
+ self.m_source = SourceModuleHnNSF(
383
+ sampling_rate=24000,
384
+ upsample_scale=np.prod(upsample_rates),
385
+ harmonic_num=8,
386
+ voiced_threshod=10,
387
+ )
388
+
389
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates))
390
+ self.noise_convs = nn.ModuleList()
391
+ self.ups = nn.ModuleList()
392
+ self.noise_res = nn.ModuleList()
393
+
394
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
395
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
396
+
397
+ self.ups.append(
398
+ weight_norm(
399
+ ConvTranspose1d(
400
+ upsample_initial_channel // (2**i),
401
+ upsample_initial_channel // (2 ** (i + 1)),
402
+ k,
403
+ u,
404
+ padding=(u // 2 + u % 2),
405
+ output_padding=u % 2,
406
+ )
407
+ )
408
+ )
409
+
410
+ if i + 1 < len(upsample_rates): #
411
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
412
+ self.noise_convs.append(
413
+ Conv1d(
414
+ 1,
415
+ c_cur,
416
+ kernel_size=stride_f0 * 2,
417
+ stride=stride_f0,
418
+ padding=(stride_f0 + 1) // 2,
419
+ )
420
+ )
421
+ self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
422
+ else:
423
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
424
+ self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
425
+
426
+ self.resblocks = nn.ModuleList()
427
+
428
+ self.alphas = nn.ParameterList()
429
+ self.alphas.append(nn.Parameter(torch.ones(1, upsample_initial_channel, 1)))
430
+
431
+ for i in range(len(self.ups)):
432
+ ch = upsample_initial_channel // (2 ** (i + 1))
433
+ self.alphas.append(nn.Parameter(torch.ones(1, ch, 1)))
434
+
435
+ for j, (k, d) in enumerate(
436
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
437
+ ):
438
+ self.resblocks.append(resblock(ch, k, d, style_dim))
439
+
440
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
441
+ self.ups.apply(init_weights)
442
+ self.conv_post.apply(init_weights)
443
+
444
+ def forward(self, x, s, f0):
445
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
446
+
447
+ har_source, noi_source, uv = self.m_source(f0)
448
+ har_source = har_source.transpose(1, 2)
449
+
450
+ for i in range(self.num_upsamples):
451
+ x = x + (1 / self.alphas[i]) * (torch.sin(self.alphas[i] * x) ** 2)
452
+ x_source = self.noise_convs[i](har_source)
453
+ x_source = self.noise_res[i](x_source, s)
454
+
455
+ x = self.ups[i](x)
456
+ x = x + x_source
457
+
458
+ xs = None
459
+ for j in range(self.num_kernels):
460
+ if xs is None:
461
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
462
+ else:
463
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
464
+ x = xs / self.num_kernels
465
+ x = x + (1 / self.alphas[i + 1]) * (torch.sin(self.alphas[i + 1] * x) ** 2)
466
+ x = self.conv_post(x)
467
+ x = torch.tanh(x)
468
+
469
+ return x
470
+
471
+ def remove_weight_norm(self):
472
+ print("Removing weight norm...")
473
+ for l in self.ups:
474
+ remove_weight_norm(l)
475
+ for l in self.resblocks:
476
+ l.remove_weight_norm()
477
+ remove_weight_norm(self.conv_pre)
478
+ remove_weight_norm(self.conv_post)
479
+
480
+
481
+ class AdainResBlk1d(nn.Module):
482
+ def __init__(
483
+ self,
484
+ dim_in,
485
+ dim_out,
486
+ style_dim=64,
487
+ actv=nn.LeakyReLU(0.2),
488
+ upsample="none",
489
+ dropout_p=0.0,
490
+ ):
491
+ super().__init__()
492
+ self.actv = actv
493
+ self.upsample_type = upsample
494
+ self.upsample = UpSample1d(upsample)
495
+ self.learned_sc = dim_in != dim_out
496
+ self._build_weights(dim_in, dim_out, style_dim)
497
+ self.dropout = nn.Dropout(dropout_p)
498
+
499
+ if upsample == "none":
500
+ self.pool = nn.Identity()
501
+ else:
502
+ self.pool = weight_norm(
503
+ nn.ConvTranspose1d(
504
+ dim_in,
505
+ dim_in,
506
+ kernel_size=3,
507
+ stride=2,
508
+ groups=dim_in,
509
+ padding=1,
510
+ output_padding=1,
511
+ )
512
+ )
513
+
514
+ def _build_weights(self, dim_in, dim_out, style_dim):
515
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
516
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
517
+ self.norm1 = AdaIN1d(style_dim, dim_in)
518
+ self.norm2 = AdaIN1d(style_dim, dim_out)
519
+ if self.learned_sc:
520
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
521
+
522
+ def _shortcut(self, x):
523
+ x = self.upsample(x)
524
+ if self.learned_sc:
525
+ x = self.conv1x1(x)
526
+ return x
527
+
528
+ def _residual(self, x, s):
529
+ x = self.norm1(x, s)
530
+ x = self.actv(x)
531
+ x = self.pool(x)
532
+ x = self.conv1(self.dropout(x))
533
+ x = self.norm2(x, s)
534
+ x = self.actv(x)
535
+ x = self.conv2(self.dropout(x))
536
+ return x
537
+
538
+ def forward(self, x, s):
539
+ out = self._residual(x, s)
540
+ out = (out + self._shortcut(x)) / math.sqrt(2)
541
+ return out
542
+
543
+
544
+ class UpSample1d(nn.Module):
545
+ def __init__(self, layer_type):
546
+ super().__init__()
547
+ self.layer_type = layer_type
548
+
549
+ def forward(self, x):
550
+ if self.layer_type == "none":
551
+ return x
552
+ else:
553
+ return F.interpolate(x, scale_factor=2, mode="nearest")
554
+
555
+
556
+ class Decoder(nn.Module):
557
+ def __init__(
558
+ self,
559
+ dim_in=512,
560
+ F0_channel=512,
561
+ style_dim=64,
562
+ dim_out=80,
563
+ resblock_kernel_sizes=[3, 7, 11],
564
+ upsample_rates=[10, 5, 3, 2],
565
+ upsample_initial_channel=512,
566
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
567
+ upsample_kernel_sizes=[20, 10, 6, 4],
568
+ ):
569
+ super().__init__()
570
+
571
+ self.decode = nn.ModuleList()
572
+
573
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
574
+
575
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
576
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
577
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
578
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
579
+
580
+ self.F0_conv = weight_norm(
581
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
582
+ )
583
+
584
+ self.N_conv = weight_norm(
585
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
586
+ )
587
+
588
+ self.asr_res = nn.Sequential(
589
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
590
+ )
591
+
592
+ self.generator = Generator(
593
+ style_dim,
594
+ resblock_kernel_sizes,
595
+ upsample_rates,
596
+ upsample_initial_channel,
597
+ resblock_dilation_sizes,
598
+ upsample_kernel_sizes,
599
+ )
600
+
601
+ def forward(self, asr, F0_curve, N, s):
602
+ if self.training:
603
+ downlist = [0, 3, 7]
604
+ F0_down = downlist[random.randint(0, 2)]
605
+ downlist = [0, 3, 7, 15]
606
+ N_down = downlist[random.randint(0, 3)]
607
+ if F0_down:
608
+ F0_curve = (
609
+ nn.functional.conv1d(
610
+ F0_curve.unsqueeze(1),
611
+ torch.ones(1, 1, F0_down).to("cuda"),
612
+ padding=F0_down // 2,
613
+ ).squeeze(1)
614
+ / F0_down
615
+ )
616
+ if N_down:
617
+ N = (
618
+ nn.functional.conv1d(
619
+ N.unsqueeze(1),
620
+ torch.ones(1, 1, N_down).to("cuda"),
621
+ padding=N_down // 2,
622
+ ).squeeze(1)
623
+ / N_down
624
+ )
625
+
626
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
627
+ N = self.N_conv(N.unsqueeze(1))
628
+
629
+ x = torch.cat([asr, F0, N], axis=1)
630
+ x = self.encode(x, s)
631
+
632
+ asr_res = self.asr_res(asr)
633
+
634
+ res = True
635
+ for block in self.decode:
636
+ if res:
637
+ x = torch.cat([x, asr_res, F0, N], axis=1)
638
+ x = block(x, s)
639
+ if block.upsample_type != "none":
640
+ res = False
641
+
642
+ x = self.generator(x, s, F0_curve)
643
+ return x
Modules/istftnet.py ADDED
@@ -0,0 +1,720 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+
8
+ import math
9
+ import random
10
+ import numpy as np
11
+ from scipy.signal import get_window
12
+
13
+ LRELU_SLOPE = 0.1
14
+
15
+
16
+ class AdaIN1d(nn.Module):
17
+ def __init__(self, style_dim, num_features):
18
+ super().__init__()
19
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
20
+ self.fc = nn.Linear(style_dim, num_features * 2)
21
+
22
+ def forward(self, x, s):
23
+ h = self.fc(s)
24
+ h = h.view(h.size(0), h.size(1), 1)
25
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
26
+ return (1 + gamma) * self.norm(x) + beta
27
+
28
+
29
+ class AdaINResBlock1(torch.nn.Module):
30
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
31
+ super(AdaINResBlock1, self).__init__()
32
+ self.convs1 = nn.ModuleList(
33
+ [
34
+ weight_norm(
35
+ Conv1d(
36
+ channels,
37
+ channels,
38
+ kernel_size,
39
+ 1,
40
+ dilation=dilation[0],
41
+ padding=get_padding(kernel_size, dilation[0]),
42
+ )
43
+ ),
44
+ weight_norm(
45
+ Conv1d(
46
+ channels,
47
+ channels,
48
+ kernel_size,
49
+ 1,
50
+ dilation=dilation[1],
51
+ padding=get_padding(kernel_size, dilation[1]),
52
+ )
53
+ ),
54
+ weight_norm(
55
+ Conv1d(
56
+ channels,
57
+ channels,
58
+ kernel_size,
59
+ 1,
60
+ dilation=dilation[2],
61
+ padding=get_padding(kernel_size, dilation[2]),
62
+ )
63
+ ),
64
+ ]
65
+ )
66
+ self.convs1.apply(init_weights)
67
+
68
+ self.convs2 = nn.ModuleList(
69
+ [
70
+ weight_norm(
71
+ Conv1d(
72
+ channels,
73
+ channels,
74
+ kernel_size,
75
+ 1,
76
+ dilation=1,
77
+ padding=get_padding(kernel_size, 1),
78
+ )
79
+ ),
80
+ weight_norm(
81
+ Conv1d(
82
+ channels,
83
+ channels,
84
+ kernel_size,
85
+ 1,
86
+ dilation=1,
87
+ padding=get_padding(kernel_size, 1),
88
+ )
89
+ ),
90
+ weight_norm(
91
+ Conv1d(
92
+ channels,
93
+ channels,
94
+ kernel_size,
95
+ 1,
96
+ dilation=1,
97
+ padding=get_padding(kernel_size, 1),
98
+ )
99
+ ),
100
+ ]
101
+ )
102
+ self.convs2.apply(init_weights)
103
+
104
+ self.adain1 = nn.ModuleList(
105
+ [
106
+ AdaIN1d(style_dim, channels),
107
+ AdaIN1d(style_dim, channels),
108
+ AdaIN1d(style_dim, channels),
109
+ ]
110
+ )
111
+
112
+ self.adain2 = nn.ModuleList(
113
+ [
114
+ AdaIN1d(style_dim, channels),
115
+ AdaIN1d(style_dim, channels),
116
+ AdaIN1d(style_dim, channels),
117
+ ]
118
+ )
119
+
120
+ self.alpha1 = nn.ParameterList(
121
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))]
122
+ )
123
+ self.alpha2 = nn.ParameterList(
124
+ [nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))]
125
+ )
126
+
127
+ def forward(self, x, s):
128
+ for c1, c2, n1, n2, a1, a2 in zip(
129
+ self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2
130
+ ):
131
+ xt = n1(x, s)
132
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
133
+ xt = c1(xt)
134
+ xt = n2(xt, s)
135
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
136
+ xt = c2(xt)
137
+ x = xt + x
138
+ return x
139
+
140
+ def remove_weight_norm(self):
141
+ for l in self.convs1:
142
+ remove_weight_norm(l)
143
+ for l in self.convs2:
144
+ remove_weight_norm(l)
145
+
146
+
147
+ class TorchSTFT(torch.nn.Module):
148
+ def __init__(
149
+ self, filter_length=800, hop_length=200, win_length=800, window="hann"
150
+ ):
151
+ super().__init__()
152
+ self.filter_length = filter_length
153
+ self.hop_length = hop_length
154
+ self.win_length = win_length
155
+ self.window = torch.from_numpy(
156
+ get_window(window, win_length, fftbins=True).astype(np.float32)
157
+ )
158
+
159
+ def transform(self, input_data):
160
+ forward_transform = torch.stft(
161
+ input_data,
162
+ self.filter_length,
163
+ self.hop_length,
164
+ self.win_length,
165
+ window=self.window.to(input_data.device),
166
+ return_complex=True,
167
+ )
168
+
169
+ return torch.abs(forward_transform), torch.angle(forward_transform)
170
+
171
+ def inverse(self, magnitude, phase):
172
+ inverse_transform = torch.istft(
173
+ magnitude * torch.exp(phase * 1j),
174
+ self.filter_length,
175
+ self.hop_length,
176
+ self.win_length,
177
+ window=self.window.to(magnitude.device),
178
+ )
179
+
180
+ return inverse_transform.unsqueeze(
181
+ -2
182
+ ) # unsqueeze to stay consistent with conv_transpose1d implementation
183
+
184
+ def forward(self, input_data):
185
+ self.magnitude, self.phase = self.transform(input_data)
186
+ reconstruction = self.inverse(self.magnitude, self.phase)
187
+ return reconstruction
188
+
189
+
190
+ class SineGen(torch.nn.Module):
191
+ """Definition of sine generator
192
+ SineGen(samp_rate, harmonic_num = 0,
193
+ sine_amp = 0.1, noise_std = 0.003,
194
+ voiced_threshold = 0,
195
+ flag_for_pulse=False)
196
+ samp_rate: sampling rate in Hz
197
+ harmonic_num: number of harmonic overtones (default 0)
198
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
199
+ noise_std: std of Gaussian noise (default 0.003)
200
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
201
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
202
+ Note: when flag_for_pulse is True, the first time step of a voiced
203
+ segment is always sin(np.pi) or cos(0)
204
+ """
205
+
206
+ def __init__(
207
+ self,
208
+ samp_rate,
209
+ upsample_scale,
210
+ harmonic_num=0,
211
+ sine_amp=0.1,
212
+ noise_std=0.003,
213
+ voiced_threshold=0,
214
+ flag_for_pulse=False,
215
+ ):
216
+ super(SineGen, self).__init__()
217
+ self.sine_amp = sine_amp
218
+ self.noise_std = noise_std
219
+ self.harmonic_num = harmonic_num
220
+ self.dim = self.harmonic_num + 1
221
+ self.sampling_rate = samp_rate
222
+ self.voiced_threshold = voiced_threshold
223
+ self.flag_for_pulse = flag_for_pulse
224
+ self.upsample_scale = upsample_scale
225
+
226
+ def _f02uv(self, f0):
227
+ # generate uv signal
228
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
229
+ return uv
230
+
231
+ def _f02sine(self, f0_values):
232
+ """f0_values: (batchsize, length, dim)
233
+ where dim indicates fundamental tone and overtones
234
+ """
235
+ # convert to F0 in rad. The interger part n can be ignored
236
+ # because 2 * np.pi * n doesn't affect phase
237
+ rad_values = (f0_values / self.sampling_rate) % 1
238
+
239
+ # initial phase noise (no noise for fundamental component)
240
+ rand_ini = torch.rand(
241
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
242
+ )
243
+ rand_ini[:, 0] = 0
244
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
245
+
246
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
247
+ if not self.flag_for_pulse:
248
+ # # for normal case
249
+
250
+ # # To prevent torch.cumsum numerical overflow,
251
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
252
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
253
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
254
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
255
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
256
+ # cumsum_shift = torch.zeros_like(rad_values)
257
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
258
+
259
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
260
+ rad_values = torch.nn.functional.interpolate(
261
+ rad_values.transpose(1, 2),
262
+ scale_factor=1 / self.upsample_scale,
263
+ mode="linear",
264
+ ).transpose(1, 2)
265
+
266
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
267
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
268
+ # cumsum_shift = torch.zeros_like(rad_values)
269
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
270
+
271
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
272
+ phase = torch.nn.functional.interpolate(
273
+ phase.transpose(1, 2) * self.upsample_scale,
274
+ scale_factor=self.upsample_scale,
275
+ mode="linear",
276
+ ).transpose(1, 2)
277
+ sines = torch.sin(phase)
278
+
279
+ else:
280
+ # If necessary, make sure that the first time step of every
281
+ # voiced segments is sin(pi) or cos(0)
282
+ # This is used for pulse-train generation
283
+
284
+ # identify the last time step in unvoiced segments
285
+ uv = self._f02uv(f0_values)
286
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
287
+ uv_1[:, -1, :] = 1
288
+ u_loc = (uv < 1) * (uv_1 > 0)
289
+
290
+ # get the instantanouse phase
291
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
292
+ # different batch needs to be processed differently
293
+ for idx in range(f0_values.shape[0]):
294
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
295
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
296
+ # stores the accumulation of i.phase within
297
+ # each voiced segments
298
+ tmp_cumsum[idx, :, :] = 0
299
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
300
+
301
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
302
+ # within the previous voiced segment.
303
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
304
+
305
+ # get the sines
306
+ sines = torch.cos(i_phase * 2 * np.pi)
307
+ return sines
308
+
309
+ def forward(self, f0):
310
+ """sine_tensor, uv = forward(f0)
311
+ input F0: tensor(batchsize=1, length, dim=1)
312
+ f0 for unvoiced steps should be 0
313
+ output sine_tensor: tensor(batchsize=1, length, dim)
314
+ output uv: tensor(batchsize=1, length, 1)
315
+ """
316
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
317
+ # fundamental component
318
+ fn = torch.multiply(
319
+ f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device)
320
+ )
321
+
322
+ # generate sine waveforms
323
+ sine_waves = self._f02sine(fn) * self.sine_amp
324
+
325
+ # generate uv signal
326
+ # uv = torch.ones(f0.shape)
327
+ # uv = uv * (f0 > self.voiced_threshold)
328
+ uv = self._f02uv(f0)
329
+
330
+ # noise: for unvoiced should be similar to sine_amp
331
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
332
+ # . for voiced regions is self.noise_std
333
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
334
+ noise = noise_amp * torch.randn_like(sine_waves)
335
+
336
+ # first: set the unvoiced part to 0 by uv
337
+ # then: additive noise
338
+ sine_waves = sine_waves * uv + noise
339
+ return sine_waves, uv, noise
340
+
341
+
342
+ class SourceModuleHnNSF(torch.nn.Module):
343
+ """SourceModule for hn-nsf
344
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
345
+ add_noise_std=0.003, voiced_threshod=0)
346
+ sampling_rate: sampling_rate in Hz
347
+ harmonic_num: number of harmonic above F0 (default: 0)
348
+ sine_amp: amplitude of sine source signal (default: 0.1)
349
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
350
+ note that amplitude of noise in unvoiced is decided
351
+ by sine_amp
352
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
353
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
354
+ F0_sampled (batchsize, length, 1)
355
+ Sine_source (batchsize, length, 1)
356
+ noise_source (batchsize, length 1)
357
+ uv (batchsize, length, 1)
358
+ """
359
+
360
+ def __init__(
361
+ self,
362
+ sampling_rate,
363
+ upsample_scale,
364
+ harmonic_num=0,
365
+ sine_amp=0.1,
366
+ add_noise_std=0.003,
367
+ voiced_threshod=0,
368
+ ):
369
+ super(SourceModuleHnNSF, self).__init__()
370
+
371
+ self.sine_amp = sine_amp
372
+ self.noise_std = add_noise_std
373
+
374
+ # to produce sine waveforms
375
+ self.l_sin_gen = SineGen(
376
+ sampling_rate,
377
+ upsample_scale,
378
+ harmonic_num,
379
+ sine_amp,
380
+ add_noise_std,
381
+ voiced_threshod,
382
+ )
383
+
384
+ # to merge source harmonics into a single excitation
385
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
386
+ self.l_tanh = torch.nn.Tanh()
387
+
388
+ def forward(self, x):
389
+ """
390
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
391
+ F0_sampled (batchsize, length, 1)
392
+ Sine_source (batchsize, length, 1)
393
+ noise_source (batchsize, length 1)
394
+ """
395
+ # source for harmonic branch
396
+ with torch.no_grad():
397
+ sine_wavs, uv, _ = self.l_sin_gen(x)
398
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
399
+
400
+ # source for noise branch, in the same shape as uv
401
+ noise = torch.randn_like(uv) * self.sine_amp / 3
402
+ return sine_merge, noise, uv
403
+
404
+
405
+ def padDiff(x):
406
+ return F.pad(
407
+ F.pad(x, (0, 0, -1, 1), "constant", 0) - x, (0, 0, 0, -1), "constant", 0
408
+ )
409
+
410
+
411
+ class Generator(torch.nn.Module):
412
+ def __init__(
413
+ self,
414
+ style_dim,
415
+ resblock_kernel_sizes,
416
+ upsample_rates,
417
+ upsample_initial_channel,
418
+ resblock_dilation_sizes,
419
+ upsample_kernel_sizes,
420
+ gen_istft_n_fft,
421
+ gen_istft_hop_size,
422
+ ):
423
+ super(Generator, self).__init__()
424
+
425
+ self.num_kernels = len(resblock_kernel_sizes)
426
+ self.num_upsamples = len(upsample_rates)
427
+ resblock = AdaINResBlock1
428
+
429
+ self.m_source = SourceModuleHnNSF(
430
+ sampling_rate=24000,
431
+ upsample_scale=np.prod(upsample_rates) * gen_istft_hop_size,
432
+ harmonic_num=8,
433
+ voiced_threshod=10,
434
+ )
435
+ self.f0_upsamp = torch.nn.Upsample(
436
+ scale_factor=np.prod(upsample_rates) * gen_istft_hop_size
437
+ )
438
+ self.noise_convs = nn.ModuleList()
439
+ self.noise_res = nn.ModuleList()
440
+
441
+ self.ups = nn.ModuleList()
442
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
443
+ self.ups.append(
444
+ weight_norm(
445
+ ConvTranspose1d(
446
+ upsample_initial_channel // (2**i),
447
+ upsample_initial_channel // (2 ** (i + 1)),
448
+ k,
449
+ u,
450
+ padding=(k - u) // 2,
451
+ )
452
+ )
453
+ )
454
+
455
+ self.resblocks = nn.ModuleList()
456
+ for i in range(len(self.ups)):
457
+ ch = upsample_initial_channel // (2 ** (i + 1))
458
+ for j, (k, d) in enumerate(
459
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
460
+ ):
461
+ self.resblocks.append(resblock(ch, k, d, style_dim))
462
+
463
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
464
+
465
+ if i + 1 < len(upsample_rates): #
466
+ stride_f0 = np.prod(upsample_rates[i + 1 :])
467
+ self.noise_convs.append(
468
+ Conv1d(
469
+ gen_istft_n_fft + 2,
470
+ c_cur,
471
+ kernel_size=stride_f0 * 2,
472
+ stride=stride_f0,
473
+ padding=(stride_f0 + 1) // 2,
474
+ )
475
+ )
476
+ self.noise_res.append(resblock(c_cur, 7, [1, 3, 5], style_dim))
477
+ else:
478
+ self.noise_convs.append(
479
+ Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1)
480
+ )
481
+ self.noise_res.append(resblock(c_cur, 11, [1, 3, 5], style_dim))
482
+
483
+ self.post_n_fft = gen_istft_n_fft
484
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
485
+ self.ups.apply(init_weights)
486
+ self.conv_post.apply(init_weights)
487
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
488
+ self.stft = TorchSTFT(
489
+ filter_length=gen_istft_n_fft,
490
+ hop_length=gen_istft_hop_size,
491
+ win_length=gen_istft_n_fft,
492
+ )
493
+
494
+ def forward(self, x, s, f0):
495
+ with torch.no_grad():
496
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
497
+
498
+ har_source, noi_source, uv = self.m_source(f0)
499
+ har_source = har_source.transpose(1, 2).squeeze(1)
500
+ har_spec, har_phase = self.stft.transform(har_source)
501
+ har = torch.cat([har_spec, har_phase], dim=1)
502
+
503
+ for i in range(self.num_upsamples):
504
+ x = F.leaky_relu(x, LRELU_SLOPE)
505
+ x_source = self.noise_convs[i](har)
506
+ x_source = self.noise_res[i](x_source, s)
507
+
508
+ x = self.ups[i](x)
509
+ if i == self.num_upsamples - 1:
510
+ x = self.reflection_pad(x)
511
+
512
+ x = x + x_source
513
+ xs = None
514
+ for j in range(self.num_kernels):
515
+ if xs is None:
516
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
517
+ else:
518
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
519
+ x = xs / self.num_kernels
520
+ x = F.leaky_relu(x)
521
+ x = self.conv_post(x)
522
+ spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
523
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
524
+ return self.stft.inverse(spec, phase)
525
+
526
+ def fw_phase(self, x, s):
527
+ for i in range(self.num_upsamples):
528
+ x = F.leaky_relu(x, LRELU_SLOPE)
529
+ x = self.ups[i](x)
530
+ xs = None
531
+ for j in range(self.num_kernels):
532
+ if xs is None:
533
+ xs = self.resblocks[i * self.num_kernels + j](x, s)
534
+ else:
535
+ xs += self.resblocks[i * self.num_kernels + j](x, s)
536
+ x = xs / self.num_kernels
537
+ x = F.leaky_relu(x)
538
+ x = self.reflection_pad(x)
539
+ x = self.conv_post(x)
540
+ spec = torch.exp(x[:, : self.post_n_fft // 2 + 1, :])
541
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1 :, :])
542
+ return spec, phase
543
+
544
+ def remove_weight_norm(self):
545
+ print("Removing weight norm...")
546
+ for l in self.ups:
547
+ remove_weight_norm(l)
548
+ for l in self.resblocks:
549
+ l.remove_weight_norm()
550
+ remove_weight_norm(self.conv_pre)
551
+ remove_weight_norm(self.conv_post)
552
+
553
+
554
+ class AdainResBlk1d(nn.Module):
555
+ def __init__(
556
+ self,
557
+ dim_in,
558
+ dim_out,
559
+ style_dim=64,
560
+ actv=nn.LeakyReLU(0.2),
561
+ upsample="none",
562
+ dropout_p=0.0,
563
+ ):
564
+ super().__init__()
565
+ self.actv = actv
566
+ self.upsample_type = upsample
567
+ self.upsample = UpSample1d(upsample)
568
+ self.learned_sc = dim_in != dim_out
569
+ self._build_weights(dim_in, dim_out, style_dim)
570
+ self.dropout = nn.Dropout(dropout_p)
571
+
572
+ if upsample == "none":
573
+ self.pool = nn.Identity()
574
+ else:
575
+ self.pool = weight_norm(
576
+ nn.ConvTranspose1d(
577
+ dim_in,
578
+ dim_in,
579
+ kernel_size=3,
580
+ stride=2,
581
+ groups=dim_in,
582
+ padding=1,
583
+ output_padding=1,
584
+ )
585
+ )
586
+
587
+ def _build_weights(self, dim_in, dim_out, style_dim):
588
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
589
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
590
+ self.norm1 = AdaIN1d(style_dim, dim_in)
591
+ self.norm2 = AdaIN1d(style_dim, dim_out)
592
+ if self.learned_sc:
593
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
594
+
595
+ def _shortcut(self, x):
596
+ x = self.upsample(x)
597
+ if self.learned_sc:
598
+ x = self.conv1x1(x)
599
+ return x
600
+
601
+ def _residual(self, x, s):
602
+ x = self.norm1(x, s)
603
+ x = self.actv(x)
604
+ x = self.pool(x)
605
+ x = self.conv1(self.dropout(x))
606
+ x = self.norm2(x, s)
607
+ x = self.actv(x)
608
+ x = self.conv2(self.dropout(x))
609
+ return x
610
+
611
+ def forward(self, x, s):
612
+ out = self._residual(x, s)
613
+ out = (out + self._shortcut(x)) / math.sqrt(2)
614
+ return out
615
+
616
+
617
+ class UpSample1d(nn.Module):
618
+ def __init__(self, layer_type):
619
+ super().__init__()
620
+ self.layer_type = layer_type
621
+
622
+ def forward(self, x):
623
+ if self.layer_type == "none":
624
+ return x
625
+ else:
626
+ return F.interpolate(x, scale_factor=2, mode="nearest")
627
+
628
+
629
+ class Decoder(nn.Module):
630
+ def __init__(
631
+ self,
632
+ dim_in=512,
633
+ F0_channel=512,
634
+ style_dim=64,
635
+ dim_out=80,
636
+ resblock_kernel_sizes=[3, 7, 11],
637
+ upsample_rates=[10, 6],
638
+ upsample_initial_channel=512,
639
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
640
+ upsample_kernel_sizes=[20, 12],
641
+ gen_istft_n_fft=20,
642
+ gen_istft_hop_size=5,
643
+ ):
644
+ super().__init__()
645
+
646
+ self.decode = nn.ModuleList()
647
+
648
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
649
+
650
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
651
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
652
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
653
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
654
+
655
+ self.F0_conv = weight_norm(
656
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
657
+ )
658
+
659
+ self.N_conv = weight_norm(
660
+ nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1)
661
+ )
662
+
663
+ self.asr_res = nn.Sequential(
664
+ weight_norm(nn.Conv1d(512, 64, kernel_size=1)),
665
+ )
666
+
667
+ self.generator = Generator(
668
+ style_dim,
669
+ resblock_kernel_sizes,
670
+ upsample_rates,
671
+ upsample_initial_channel,
672
+ resblock_dilation_sizes,
673
+ upsample_kernel_sizes,
674
+ gen_istft_n_fft,
675
+ gen_istft_hop_size,
676
+ )
677
+
678
+ def forward(self, asr, F0_curve, N, s):
679
+ if self.training:
680
+ downlist = [0, 3, 7]
681
+ F0_down = downlist[random.randint(0, 2)]
682
+ downlist = [0, 3, 7, 15]
683
+ N_down = downlist[random.randint(0, 3)]
684
+ if F0_down:
685
+ F0_curve = (
686
+ nn.functional.conv1d(
687
+ F0_curve.unsqueeze(1),
688
+ torch.ones(1, 1, F0_down).to("cuda"),
689
+ padding=F0_down // 2,
690
+ ).squeeze(1)
691
+ / F0_down
692
+ )
693
+ if N_down:
694
+ N = (
695
+ nn.functional.conv1d(
696
+ N.unsqueeze(1),
697
+ torch.ones(1, 1, N_down).to("cuda"),
698
+ padding=N_down // 2,
699
+ ).squeeze(1)
700
+ / N_down
701
+ )
702
+
703
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
704
+ N = self.N_conv(N.unsqueeze(1))
705
+
706
+ x = torch.cat([asr, F0, N], axis=1)
707
+ x = self.encode(x, s)
708
+
709
+ asr_res = self.asr_res(asr)
710
+
711
+ res = True
712
+ for block in self.decode:
713
+ if res:
714
+ x = torch.cat([x, asr_res, F0, N], axis=1)
715
+ x = block(x, s)
716
+ if block.upsample_type != "none":
717
+ res = False
718
+
719
+ x = self.generator(x, s, F0_curve)
720
+ return x
Modules/slmadv.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class SLMAdversarialLoss(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ model,
10
+ wl,
11
+ sampler,
12
+ min_len,
13
+ max_len,
14
+ batch_percentage=0.5,
15
+ skip_update=10,
16
+ sig=1.5,
17
+ ):
18
+ super(SLMAdversarialLoss, self).__init__()
19
+ self.model = model
20
+ self.wl = wl
21
+ self.sampler = sampler
22
+
23
+ self.min_len = min_len
24
+ self.max_len = max_len
25
+ self.batch_percentage = batch_percentage
26
+
27
+ self.sig = sig
28
+ self.skip_update = skip_update
29
+
30
+ def forward(
31
+ self,
32
+ iters,
33
+ y_rec_gt,
34
+ y_rec_gt_pred,
35
+ waves,
36
+ mel_input_length,
37
+ ref_text,
38
+ ref_lengths,
39
+ use_ind,
40
+ s_trg,
41
+ ref_s=None,
42
+ ):
43
+ text_mask = length_to_mask(ref_lengths).to(ref_text.device)
44
+ bert_dur = self.model.bert(ref_text, attention_mask=(~text_mask).int())
45
+ d_en = self.model.bert_encoder(bert_dur).transpose(-1, -2)
46
+
47
+ if use_ind and np.random.rand() < 0.5:
48
+ s_preds = s_trg
49
+ else:
50
+ num_steps = np.random.randint(3, 5)
51
+ if ref_s is not None:
52
+ s_preds = self.sampler(
53
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
54
+ embedding=bert_dur,
55
+ embedding_scale=1,
56
+ features=ref_s, # reference from the same speaker as the embedding
57
+ embedding_mask_proba=0.1,
58
+ num_steps=num_steps,
59
+ ).squeeze(1)
60
+ else:
61
+ s_preds = self.sampler(
62
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(ref_text.device),
63
+ embedding=bert_dur,
64
+ embedding_scale=1,
65
+ embedding_mask_proba=0.1,
66
+ num_steps=num_steps,
67
+ ).squeeze(1)
68
+
69
+ s_dur = s_preds[:, 128:]
70
+ s = s_preds[:, :128]
71
+
72
+ d, _ = self.model.predictor(
73
+ d_en,
74
+ s_dur,
75
+ ref_lengths,
76
+ torch.randn(ref_lengths.shape[0], ref_lengths.max(), 2).to(ref_text.device),
77
+ text_mask,
78
+ )
79
+
80
+ bib = 0
81
+
82
+ output_lengths = []
83
+ attn_preds = []
84
+
85
+ # differentiable duration modeling
86
+ for _s2s_pred, _text_length in zip(d, ref_lengths):
87
+ _s2s_pred_org = _s2s_pred[:_text_length, :]
88
+
89
+ _s2s_pred = torch.sigmoid(_s2s_pred_org)
90
+ _dur_pred = _s2s_pred.sum(axis=-1)
91
+
92
+ l = int(torch.round(_s2s_pred.sum()).item())
93
+ t = torch.arange(0, l).expand(l)
94
+
95
+ t = (
96
+ torch.arange(0, l)
97
+ .unsqueeze(0)
98
+ .expand((len(_s2s_pred), l))
99
+ .to(ref_text.device)
100
+ )
101
+ loc = torch.cumsum(_dur_pred, dim=0) - _dur_pred / 2
102
+
103
+ h = torch.exp(
104
+ -0.5 * torch.square(t - (l - loc.unsqueeze(-1))) / (self.sig) ** 2
105
+ )
106
+
107
+ out = torch.nn.functional.conv1d(
108
+ _s2s_pred_org.unsqueeze(0),
109
+ h.unsqueeze(1),
110
+ padding=h.shape[-1] - 1,
111
+ groups=int(_text_length),
112
+ )[..., :l]
113
+ attn_preds.append(F.softmax(out.squeeze(), dim=0))
114
+
115
+ output_lengths.append(l)
116
+
117
+ max_len = max(output_lengths)
118
+
119
+ with torch.no_grad():
120
+ t_en = self.model.text_encoder(ref_text, ref_lengths, text_mask)
121
+
122
+ s2s_attn = torch.zeros(len(ref_lengths), int(ref_lengths.max()), max_len).to(
123
+ ref_text.device
124
+ )
125
+ for bib in range(len(output_lengths)):
126
+ s2s_attn[bib, : ref_lengths[bib], : output_lengths[bib]] = attn_preds[bib]
127
+
128
+ asr_pred = t_en @ s2s_attn
129
+
130
+ _, p_pred = self.model.predictor(d_en, s_dur, ref_lengths, s2s_attn, text_mask)
131
+
132
+ mel_len = max(int(min(output_lengths) / 2 - 1), self.min_len // 2)
133
+ mel_len = min(mel_len, self.max_len // 2)
134
+
135
+ # get clips
136
+
137
+ en = []
138
+ p_en = []
139
+ sp = []
140
+
141
+ F0_fakes = []
142
+ N_fakes = []
143
+
144
+ wav = []
145
+
146
+ for bib in range(len(output_lengths)):
147
+ mel_length_pred = output_lengths[bib]
148
+ mel_length_gt = int(mel_input_length[bib].item() / 2)
149
+ if mel_length_gt <= mel_len or mel_length_pred <= mel_len:
150
+ continue
151
+
152
+ sp.append(s_preds[bib])
153
+
154
+ random_start = np.random.randint(0, mel_length_pred - mel_len)
155
+ en.append(asr_pred[bib, :, random_start : random_start + mel_len])
156
+ p_en.append(p_pred[bib, :, random_start : random_start + mel_len])
157
+
158
+ # get ground truth clips
159
+ random_start = np.random.randint(0, mel_length_gt - mel_len)
160
+ y = waves[bib][
161
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
162
+ ]
163
+ wav.append(torch.from_numpy(y).to(ref_text.device))
164
+
165
+ if len(wav) >= self.batch_percentage * len(
166
+ waves
167
+ ): # prevent OOM due to longer lengths
168
+ break
169
+
170
+ if len(sp) <= 1:
171
+ return None
172
+
173
+ sp = torch.stack(sp)
174
+ wav = torch.stack(wav).float()
175
+ en = torch.stack(en)
176
+ p_en = torch.stack(p_en)
177
+
178
+ F0_fake, N_fake = self.model.predictor.F0Ntrain(p_en, sp[:, 128:])
179
+ y_pred = self.model.decoder(en, F0_fake, N_fake, sp[:, :128])
180
+
181
+ # discriminator loss
182
+ if (iters + 1) % self.skip_update == 0:
183
+ if np.random.randint(0, 2) == 0:
184
+ wav = y_rec_gt_pred
185
+ use_rec = True
186
+ else:
187
+ use_rec = False
188
+
189
+ crop_size = min(wav.size(-1), y_pred.size(-1))
190
+ if (
191
+ use_rec
192
+ ): # use reconstructed (shorter lengths), do length invariant regularization
193
+ if wav.size(-1) > y_pred.size(-1):
194
+ real_GP = wav[:, :, :crop_size]
195
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
196
+ out_org = self.wl.discriminator_forward(wav.detach().squeeze())
197
+ loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
198
+
199
+ if np.random.randint(0, 2) == 0:
200
+ d_loss = self.wl.discriminator(
201
+ real_GP.detach().squeeze(), y_pred.detach().squeeze()
202
+ ).mean()
203
+ else:
204
+ d_loss = self.wl.discriminator(
205
+ wav.detach().squeeze(), y_pred.detach().squeeze()
206
+ ).mean()
207
+ else:
208
+ real_GP = y_pred[:, :, :crop_size]
209
+ out_crop = self.wl.discriminator_forward(real_GP.detach().squeeze())
210
+ out_org = self.wl.discriminator_forward(y_pred.detach().squeeze())
211
+ loss_reg = F.l1_loss(out_crop, out_org[..., : out_crop.size(-1)])
212
+
213
+ if np.random.randint(0, 2) == 0:
214
+ d_loss = self.wl.discriminator(
215
+ wav.detach().squeeze(), real_GP.detach().squeeze()
216
+ ).mean()
217
+ else:
218
+ d_loss = self.wl.discriminator(
219
+ wav.detach().squeeze(), y_pred.detach().squeeze()
220
+ ).mean()
221
+
222
+ # regularization (ignore length variation)
223
+ d_loss += loss_reg
224
+
225
+ out_gt = self.wl.discriminator_forward(y_rec_gt.detach().squeeze())
226
+ out_rec = self.wl.discriminator_forward(
227
+ y_rec_gt_pred.detach().squeeze()
228
+ )
229
+
230
+ # regularization (ignore reconstruction artifacts)
231
+ d_loss += F.l1_loss(out_gt, out_rec)
232
+
233
+ else:
234
+ d_loss = self.wl.discriminator(
235
+ wav.detach().squeeze(), y_pred.detach().squeeze()
236
+ ).mean()
237
+ else:
238
+ d_loss = 0
239
+
240
+ # generator loss
241
+ gen_loss = self.wl.generator(y_pred.squeeze())
242
+
243
+ gen_loss = gen_loss.mean()
244
+
245
+ return d_loss, gen_loss, y_pred.detach().cpu().numpy()
246
+
247
+
248
+ def length_to_mask(lengths):
249
+ mask = (
250
+ torch.arange(lengths.max())
251
+ .unsqueeze(0)
252
+ .expand(lengths.shape[0], -1)
253
+ .type_as(lengths)
254
+ )
255
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
256
+ return mask
Modules/utils.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def init_weights(m, mean=0.0, std=0.01):
2
+ classname = m.__class__.__name__
3
+ if classname.find("Conv") != -1:
4
+ m.weight.data.normal_(mean, std)
5
+
6
+
7
+ def apply_weight_norm(m):
8
+ classname = m.__class__.__name__
9
+ if classname.find("Conv") != -1:
10
+ weight_norm(m)
11
+
12
+
13
+ def get_padding(kernel_size, dilation=1):
14
+ return int((kernel_size * dilation - dilation) / 2)
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: StyleTTS 2
3
+ emoji: 🗣️
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 4.5.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ ---
12
+
13
+ LICENSE FOR STYLETTS2: MIT LICENSE
14
+
15
+ LICENSE FOR STYLETTS2 DEMO PAGE: © 2023 MRFAKENAME. ALL RIGHTS RESERVED.
Utils/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
Utils/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils/ASR/layers.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+
12
+ random.seed(0)
13
+
14
+
15
+ def _get_activation_fn(activ):
16
+ if activ == "relu":
17
+ return nn.ReLU()
18
+ elif activ == "lrelu":
19
+ return nn.LeakyReLU(0.2)
20
+ elif activ == "swish":
21
+ return lambda x: x * torch.sigmoid(x)
22
+ else:
23
+ raise RuntimeError(
24
+ "Unexpected activ type %s, expected [relu, lrelu, swish]" % activ
25
+ )
26
+
27
+
28
+ class LinearNorm(torch.nn.Module):
29
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
30
+ super(LinearNorm, self).__init__()
31
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
32
+
33
+ torch.nn.init.xavier_uniform_(
34
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
35
+ )
36
+
37
+ def forward(self, x):
38
+ return self.linear_layer(x)
39
+
40
+
41
+ class ConvNorm(torch.nn.Module):
42
+ def __init__(
43
+ self,
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=1,
47
+ stride=1,
48
+ padding=None,
49
+ dilation=1,
50
+ bias=True,
51
+ w_init_gain="linear",
52
+ param=None,
53
+ ):
54
+ super(ConvNorm, self).__init__()
55
+ if padding is None:
56
+ assert kernel_size % 2 == 1
57
+ padding = int(dilation * (kernel_size - 1) / 2)
58
+
59
+ self.conv = torch.nn.Conv1d(
60
+ in_channels,
61
+ out_channels,
62
+ kernel_size=kernel_size,
63
+ stride=stride,
64
+ padding=padding,
65
+ dilation=dilation,
66
+ bias=bias,
67
+ )
68
+
69
+ torch.nn.init.xavier_uniform_(
70
+ self.conv.weight,
71
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
72
+ )
73
+
74
+ def forward(self, signal):
75
+ conv_signal = self.conv(signal)
76
+ return conv_signal
77
+
78
+
79
+ class CausualConv(nn.Module):
80
+ def __init__(
81
+ self,
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size=1,
85
+ stride=1,
86
+ padding=1,
87
+ dilation=1,
88
+ bias=True,
89
+ w_init_gain="linear",
90
+ param=None,
91
+ ):
92
+ super(CausualConv, self).__init__()
93
+ if padding is None:
94
+ assert kernel_size % 2 == 1
95
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
96
+ else:
97
+ self.padding = padding * 2
98
+ self.conv = nn.Conv1d(
99
+ in_channels,
100
+ out_channels,
101
+ kernel_size=kernel_size,
102
+ stride=stride,
103
+ padding=self.padding,
104
+ dilation=dilation,
105
+ bias=bias,
106
+ )
107
+
108
+ torch.nn.init.xavier_uniform_(
109
+ self.conv.weight,
110
+ gain=torch.nn.init.calculate_gain(w_init_gain, param=param),
111
+ )
112
+
113
+ def forward(self, x):
114
+ x = self.conv(x)
115
+ x = x[:, :, : -self.padding]
116
+ return x
117
+
118
+
119
+ class CausualBlock(nn.Module):
120
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="lrelu"):
121
+ super(CausualBlock, self).__init__()
122
+ self.blocks = nn.ModuleList(
123
+ [
124
+ self._get_conv(
125
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
126
+ )
127
+ for i in range(n_conv)
128
+ ]
129
+ )
130
+
131
+ def forward(self, x):
132
+ for block in self.blocks:
133
+ res = x
134
+ x = block(x)
135
+ x += res
136
+ return x
137
+
138
+ def _get_conv(self, hidden_dim, dilation, activ="lrelu", dropout_p=0.2):
139
+ layers = [
140
+ CausualConv(
141
+ hidden_dim,
142
+ hidden_dim,
143
+ kernel_size=3,
144
+ padding=dilation,
145
+ dilation=dilation,
146
+ ),
147
+ _get_activation_fn(activ),
148
+ nn.BatchNorm1d(hidden_dim),
149
+ nn.Dropout(p=dropout_p),
150
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
151
+ _get_activation_fn(activ),
152
+ nn.Dropout(p=dropout_p),
153
+ ]
154
+ return nn.Sequential(*layers)
155
+
156
+
157
+ class ConvBlock(nn.Module):
158
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ="relu"):
159
+ super().__init__()
160
+ self._n_groups = 8
161
+ self.blocks = nn.ModuleList(
162
+ [
163
+ self._get_conv(
164
+ hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p
165
+ )
166
+ for i in range(n_conv)
167
+ ]
168
+ )
169
+
170
+ def forward(self, x):
171
+ for block in self.blocks:
172
+ res = x
173
+ x = block(x)
174
+ x += res
175
+ return x
176
+
177
+ def _get_conv(self, hidden_dim, dilation, activ="relu", dropout_p=0.2):
178
+ layers = [
179
+ ConvNorm(
180
+ hidden_dim,
181
+ hidden_dim,
182
+ kernel_size=3,
183
+ padding=dilation,
184
+ dilation=dilation,
185
+ ),
186
+ _get_activation_fn(activ),
187
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
188
+ nn.Dropout(p=dropout_p),
189
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
190
+ _get_activation_fn(activ),
191
+ nn.Dropout(p=dropout_p),
192
+ ]
193
+ return nn.Sequential(*layers)
194
+
195
+
196
+ class LocationLayer(nn.Module):
197
+ def __init__(self, attention_n_filters, attention_kernel_size, attention_dim):
198
+ super(LocationLayer, self).__init__()
199
+ padding = int((attention_kernel_size - 1) / 2)
200
+ self.location_conv = ConvNorm(
201
+ 2,
202
+ attention_n_filters,
203
+ kernel_size=attention_kernel_size,
204
+ padding=padding,
205
+ bias=False,
206
+ stride=1,
207
+ dilation=1,
208
+ )
209
+ self.location_dense = LinearNorm(
210
+ attention_n_filters, attention_dim, bias=False, w_init_gain="tanh"
211
+ )
212
+
213
+ def forward(self, attention_weights_cat):
214
+ processed_attention = self.location_conv(attention_weights_cat)
215
+ processed_attention = processed_attention.transpose(1, 2)
216
+ processed_attention = self.location_dense(processed_attention)
217
+ return processed_attention
218
+
219
+
220
+ class Attention(nn.Module):
221
+ def __init__(
222
+ self,
223
+ attention_rnn_dim,
224
+ embedding_dim,
225
+ attention_dim,
226
+ attention_location_n_filters,
227
+ attention_location_kernel_size,
228
+ ):
229
+ super(Attention, self).__init__()
230
+ self.query_layer = LinearNorm(
231
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
232
+ )
233
+ self.memory_layer = LinearNorm(
234
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
235
+ )
236
+ self.v = LinearNorm(attention_dim, 1, bias=False)
237
+ self.location_layer = LocationLayer(
238
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
239
+ )
240
+ self.score_mask_value = -float("inf")
241
+
242
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
243
+ """
244
+ PARAMS
245
+ ------
246
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
247
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
248
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
249
+ RETURNS
250
+ -------
251
+ alignment (batch, max_time)
252
+ """
253
+
254
+ processed_query = self.query_layer(query.unsqueeze(1))
255
+ processed_attention_weights = self.location_layer(attention_weights_cat)
256
+ energies = self.v(
257
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
258
+ )
259
+
260
+ energies = energies.squeeze(-1)
261
+ return energies
262
+
263
+ def forward(
264
+ self,
265
+ attention_hidden_state,
266
+ memory,
267
+ processed_memory,
268
+ attention_weights_cat,
269
+ mask,
270
+ ):
271
+ """
272
+ PARAMS
273
+ ------
274
+ attention_hidden_state: attention rnn last output
275
+ memory: encoder outputs
276
+ processed_memory: processed encoder outputs
277
+ attention_weights_cat: previous and cummulative attention weights
278
+ mask: binary mask for padded data
279
+ """
280
+ alignment = self.get_alignment_energies(
281
+ attention_hidden_state, processed_memory, attention_weights_cat
282
+ )
283
+
284
+ if mask is not None:
285
+ alignment.data.masked_fill_(mask, self.score_mask_value)
286
+
287
+ attention_weights = F.softmax(alignment, dim=1)
288
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
289
+ attention_context = attention_context.squeeze(1)
290
+
291
+ return attention_context, attention_weights
292
+
293
+
294
+ class ForwardAttentionV2(nn.Module):
295
+ def __init__(
296
+ self,
297
+ attention_rnn_dim,
298
+ embedding_dim,
299
+ attention_dim,
300
+ attention_location_n_filters,
301
+ attention_location_kernel_size,
302
+ ):
303
+ super(ForwardAttentionV2, self).__init__()
304
+ self.query_layer = LinearNorm(
305
+ attention_rnn_dim, attention_dim, bias=False, w_init_gain="tanh"
306
+ )
307
+ self.memory_layer = LinearNorm(
308
+ embedding_dim, attention_dim, bias=False, w_init_gain="tanh"
309
+ )
310
+ self.v = LinearNorm(attention_dim, 1, bias=False)
311
+ self.location_layer = LocationLayer(
312
+ attention_location_n_filters, attention_location_kernel_size, attention_dim
313
+ )
314
+ self.score_mask_value = -float(1e20)
315
+
316
+ def get_alignment_energies(self, query, processed_memory, attention_weights_cat):
317
+ """
318
+ PARAMS
319
+ ------
320
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
321
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
322
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
323
+ RETURNS
324
+ -------
325
+ alignment (batch, max_time)
326
+ """
327
+
328
+ processed_query = self.query_layer(query.unsqueeze(1))
329
+ processed_attention_weights = self.location_layer(attention_weights_cat)
330
+ energies = self.v(
331
+ torch.tanh(processed_query + processed_attention_weights + processed_memory)
332
+ )
333
+
334
+ energies = energies.squeeze(-1)
335
+ return energies
336
+
337
+ def forward(
338
+ self,
339
+ attention_hidden_state,
340
+ memory,
341
+ processed_memory,
342
+ attention_weights_cat,
343
+ mask,
344
+ log_alpha,
345
+ ):
346
+ """
347
+ PARAMS
348
+ ------
349
+ attention_hidden_state: attention rnn last output
350
+ memory: encoder outputs
351
+ processed_memory: processed encoder outputs
352
+ attention_weights_cat: previous and cummulative attention weights
353
+ mask: binary mask for padded data
354
+ """
355
+ log_energy = self.get_alignment_energies(
356
+ attention_hidden_state, processed_memory, attention_weights_cat
357
+ )
358
+
359
+ # log_energy =
360
+
361
+ if mask is not None:
362
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
363
+
364
+ # attention_weights = F.softmax(alignment, dim=1)
365
+
366
+ # content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
367
+ # log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
368
+
369
+ # log_total_score = log_alpha + content_score
370
+
371
+ # previous_attention_weights = attention_weights_cat[:,0,:]
372
+
373
+ log_alpha_shift_padded = []
374
+ max_time = log_energy.size(1)
375
+ for sft in range(2):
376
+ shifted = log_alpha[:, : max_time - sft]
377
+ shift_padded = F.pad(shifted, (sft, 0), "constant", self.score_mask_value)
378
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
379
+
380
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded, 2), 2)
381
+
382
+ log_alpha_new = biased + log_energy
383
+
384
+ attention_weights = F.softmax(log_alpha_new, dim=1)
385
+
386
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
387
+ attention_context = attention_context.squeeze(1)
388
+
389
+ return attention_context, attention_weights, log_alpha_new
390
+
391
+
392
+ class PhaseShuffle2d(nn.Module):
393
+ def __init__(self, n=2):
394
+ super(PhaseShuffle2d, self).__init__()
395
+ self.n = n
396
+ self.random = random.Random(1)
397
+
398
+ def forward(self, x, move=None):
399
+ # x.size = (B, C, M, L)
400
+ if move is None:
401
+ move = self.random.randint(-self.n, self.n)
402
+
403
+ if move == 0:
404
+ return x
405
+ else:
406
+ left = x[:, :, :, :move]
407
+ right = x[:, :, :, move:]
408
+ shuffled = torch.cat([right, left], dim=3)
409
+ return shuffled
410
+
411
+
412
+ class PhaseShuffle1d(nn.Module):
413
+ def __init__(self, n=2):
414
+ super(PhaseShuffle1d, self).__init__()
415
+ self.n = n
416
+ self.random = random.Random(1)
417
+
418
+ def forward(self, x, move=None):
419
+ # x.size = (B, C, M, L)
420
+ if move is None:
421
+ move = self.random.randint(-self.n, self.n)
422
+
423
+ if move == 0:
424
+ return x
425
+ else:
426
+ left = x[:, :, :move]
427
+ right = x[:, :, move:]
428
+ shuffled = torch.cat([right, left], dim=2)
429
+
430
+ return shuffled
431
+
432
+
433
+ class MFCC(nn.Module):
434
+ def __init__(self, n_mfcc=40, n_mels=80):
435
+ super(MFCC, self).__init__()
436
+ self.n_mfcc = n_mfcc
437
+ self.n_mels = n_mels
438
+ self.norm = "ortho"
439
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
440
+ self.register_buffer("dct_mat", dct_mat)
441
+
442
+ def forward(self, mel_specgram):
443
+ if len(mel_specgram.shape) == 2:
444
+ mel_specgram = mel_specgram.unsqueeze(0)
445
+ unsqueezed = True
446
+ else:
447
+ unsqueezed = False
448
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
449
+ # -> (channel, time, n_mfcc).tranpose(...)
450
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
451
+
452
+ # unpack batch
453
+ if unsqueezed:
454
+ mfcc = mfcc.squeeze(0)
455
+ return mfcc
Utils/ASR/models.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+
9
+ class ASRCNN(nn.Module):
10
+ def __init__(
11
+ self,
12
+ input_dim=80,
13
+ hidden_dim=256,
14
+ n_token=35,
15
+ n_layers=6,
16
+ token_embedding_dim=256,
17
+ ):
18
+ super().__init__()
19
+ self.n_token = n_token
20
+ self.n_down = 1
21
+ self.to_mfcc = MFCC()
22
+ self.init_cnn = ConvNorm(
23
+ input_dim // 2, hidden_dim, kernel_size=7, padding=3, stride=2
24
+ )
25
+ self.cnns = nn.Sequential(
26
+ *[
27
+ nn.Sequential(
28
+ ConvBlock(hidden_dim),
29
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim),
30
+ )
31
+ for n in range(n_layers)
32
+ ]
33
+ )
34
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
35
+ self.ctc_linear = nn.Sequential(
36
+ LinearNorm(hidden_dim // 2, hidden_dim),
37
+ nn.ReLU(),
38
+ LinearNorm(hidden_dim, n_token),
39
+ )
40
+ self.asr_s2s = ASRS2S(
41
+ embedding_dim=token_embedding_dim,
42
+ hidden_dim=hidden_dim // 2,
43
+ n_token=n_token,
44
+ )
45
+
46
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
47
+ x = self.to_mfcc(x)
48
+ x = self.init_cnn(x)
49
+ x = self.cnns(x)
50
+ x = self.projection(x)
51
+ x = x.transpose(1, 2)
52
+ ctc_logit = self.ctc_linear(x)
53
+ if text_input is not None:
54
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
55
+ return ctc_logit, s2s_logit, s2s_attn
56
+ else:
57
+ return ctc_logit
58
+
59
+ def get_feature(self, x):
60
+ x = self.to_mfcc(x.squeeze(1))
61
+ x = self.init_cnn(x)
62
+ x = self.cnns(x)
63
+ x = self.projection(x)
64
+ return x
65
+
66
+ def length_to_mask(self, lengths):
67
+ mask = (
68
+ torch.arange(lengths.max())
69
+ .unsqueeze(0)
70
+ .expand(lengths.shape[0], -1)
71
+ .type_as(lengths)
72
+ )
73
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1)).to(lengths.device)
74
+ return mask
75
+
76
+ def get_future_mask(self, out_length, unmask_future_steps=0):
77
+ """
78
+ Args:
79
+ out_length (int): returned mask shape is (out_length, out_length).
80
+ unmask_futre_steps (int): unmasking future step size.
81
+ Return:
82
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
83
+ """
84
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
85
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
86
+ return mask
87
+
88
+
89
+ class ASRS2S(nn.Module):
90
+ def __init__(
91
+ self,
92
+ embedding_dim=256,
93
+ hidden_dim=512,
94
+ n_location_filters=32,
95
+ location_kernel_size=63,
96
+ n_token=40,
97
+ ):
98
+ super(ASRS2S, self).__init__()
99
+ self.embedding = nn.Embedding(n_token, embedding_dim)
100
+ val_range = math.sqrt(6 / hidden_dim)
101
+ self.embedding.weight.data.uniform_(-val_range, val_range)
102
+
103
+ self.decoder_rnn_dim = hidden_dim
104
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
105
+ self.attention_layer = Attention(
106
+ self.decoder_rnn_dim,
107
+ hidden_dim,
108
+ hidden_dim,
109
+ n_location_filters,
110
+ location_kernel_size,
111
+ )
112
+ self.decoder_rnn = nn.LSTMCell(
113
+ self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim
114
+ )
115
+ self.project_to_hidden = nn.Sequential(
116
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim), nn.Tanh()
117
+ )
118
+ self.sos = 1
119
+ self.eos = 2
120
+
121
+ def initialize_decoder_states(self, memory, mask):
122
+ """
123
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
124
+ """
125
+ B, L, H = memory.shape
126
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
127
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
128
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
129
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
130
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
131
+ self.memory = memory
132
+ self.processed_memory = self.attention_layer.memory_layer(memory)
133
+ self.mask = mask
134
+ self.unk_index = 3
135
+ self.random_mask = 0.1
136
+
137
+ def forward(self, memory, memory_mask, text_input):
138
+ """
139
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
140
+ moemory_mask.shape = (B, L, )
141
+ texts_input.shape = (B, T)
142
+ """
143
+ self.initialize_decoder_states(memory, memory_mask)
144
+ # text random mask
145
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(
146
+ text_input.device
147
+ )
148
+ _text_input = text_input.clone()
149
+ _text_input.masked_fill_(random_mask, self.unk_index)
150
+ decoder_inputs = self.embedding(_text_input).transpose(
151
+ 0, 1
152
+ ) # -> [T, B, channel]
153
+ start_embedding = self.embedding(
154
+ torch.LongTensor([self.sos] * decoder_inputs.size(1)).to(
155
+ decoder_inputs.device
156
+ )
157
+ )
158
+ decoder_inputs = torch.cat(
159
+ (start_embedding.unsqueeze(0), decoder_inputs), dim=0
160
+ )
161
+
162
+ hidden_outputs, logit_outputs, alignments = [], [], []
163
+ while len(hidden_outputs) < decoder_inputs.size(0):
164
+ decoder_input = decoder_inputs[len(hidden_outputs)]
165
+ hidden, logit, attention_weights = self.decode(decoder_input)
166
+ hidden_outputs += [hidden]
167
+ logit_outputs += [logit]
168
+ alignments += [attention_weights]
169
+
170
+ hidden_outputs, logit_outputs, alignments = self.parse_decoder_outputs(
171
+ hidden_outputs, logit_outputs, alignments
172
+ )
173
+
174
+ return hidden_outputs, logit_outputs, alignments
175
+
176
+ def decode(self, decoder_input):
177
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
178
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
179
+ cell_input, (self.decoder_hidden, self.decoder_cell)
180
+ )
181
+
182
+ attention_weights_cat = torch.cat(
183
+ (
184
+ self.attention_weights.unsqueeze(1),
185
+ self.attention_weights_cum.unsqueeze(1),
186
+ ),
187
+ dim=1,
188
+ )
189
+
190
+ self.attention_context, self.attention_weights = self.attention_layer(
191
+ self.decoder_hidden,
192
+ self.memory,
193
+ self.processed_memory,
194
+ attention_weights_cat,
195
+ self.mask,
196
+ )
197
+
198
+ self.attention_weights_cum += self.attention_weights
199
+
200
+ hidden_and_context = torch.cat(
201
+ (self.decoder_hidden, self.attention_context), -1
202
+ )
203
+ hidden = self.project_to_hidden(hidden_and_context)
204
+
205
+ # dropout to increasing g
206
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
207
+
208
+ return hidden, logit, self.attention_weights
209
+
210
+ def parse_decoder_outputs(self, hidden, logit, alignments):
211
+ # -> [B, T_out + 1, max_time]
212
+ alignments = torch.stack(alignments).transpose(0, 1)
213
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
214
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
215
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
216
+
217
+ return hidden, logit, alignments
Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/JDC/bst.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:54dc94364b97e18ac1dfa6287714ed121248cfaac4cfd39d061c6e0a089ef169
3
+ size 21029926
Utils/JDC/model.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class JDCNet(nn.Module):
12
+ """
13
+ Joint Detection and Classification Network model for singing voice melody.
14
+ """
15
+
16
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
17
+ super().__init__()
18
+ self.num_class = num_class
19
+
20
+ # input = (b, 1, 31, 513), b = batch size
21
+ self.conv_block = nn.Sequential(
22
+ nn.Conv2d(
23
+ in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False
24
+ ), # out: (b, 64, 31, 513)
25
+ nn.BatchNorm2d(num_features=64),
26
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
27
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
28
+ )
29
+
30
+ # res blocks
31
+ self.res_block1 = ResBlock(
32
+ in_channels=64, out_channels=128
33
+ ) # (b, 128, 31, 128)
34
+ self.res_block2 = ResBlock(
35
+ in_channels=128, out_channels=192
36
+ ) # (b, 192, 31, 32)
37
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
38
+
39
+ # pool block
40
+ self.pool_block = nn.Sequential(
41
+ nn.BatchNorm2d(num_features=256),
42
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
43
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
44
+ nn.Dropout(p=0.2),
45
+ )
46
+
47
+ # maxpool layers (for auxiliary network inputs)
48
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
49
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
50
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
51
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
52
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
53
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
54
+
55
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
56
+ self.detector_conv = nn.Sequential(
57
+ nn.Conv2d(640, 256, 1, bias=False),
58
+ nn.BatchNorm2d(256),
59
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
60
+ nn.Dropout(p=0.2),
61
+ )
62
+
63
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
64
+ self.bilstm_classifier = nn.LSTM(
65
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
66
+ ) # (b, 31, 512)
67
+
68
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
69
+ self.bilstm_detector = nn.LSTM(
70
+ input_size=512, hidden_size=256, batch_first=True, bidirectional=True
71
+ ) # (b, 31, 512)
72
+
73
+ # input: (b * 31, 512)
74
+ self.classifier = nn.Linear(
75
+ in_features=512, out_features=self.num_class
76
+ ) # (b * 31, num_class)
77
+
78
+ # input: (b * 31, 512)
79
+ self.detector = nn.Linear(
80
+ in_features=512, out_features=2
81
+ ) # (b * 31, 2) - binary classifier
82
+
83
+ # initialize weights
84
+ self.apply(self.init_weights)
85
+
86
+ def get_feature_GAN(self, x):
87
+ seq_len = x.shape[-2]
88
+ x = x.float().transpose(-1, -2)
89
+
90
+ convblock_out = self.conv_block(x)
91
+
92
+ resblock1_out = self.res_block1(convblock_out)
93
+ resblock2_out = self.res_block2(resblock1_out)
94
+ resblock3_out = self.res_block3(resblock2_out)
95
+ poolblock_out = self.pool_block[0](resblock3_out)
96
+ poolblock_out = self.pool_block[1](poolblock_out)
97
+
98
+ return poolblock_out.transpose(-1, -2)
99
+
100
+ def get_feature(self, x):
101
+ seq_len = x.shape[-2]
102
+ x = x.float().transpose(-1, -2)
103
+
104
+ convblock_out = self.conv_block(x)
105
+
106
+ resblock1_out = self.res_block1(convblock_out)
107
+ resblock2_out = self.res_block2(resblock1_out)
108
+ resblock3_out = self.res_block3(resblock2_out)
109
+ poolblock_out = self.pool_block[0](resblock3_out)
110
+ poolblock_out = self.pool_block[1](poolblock_out)
111
+
112
+ return self.pool_block[2](poolblock_out)
113
+
114
+ def forward(self, x):
115
+ """
116
+ Returns:
117
+ classification_prediction, detection_prediction
118
+ sizes: (b, 31, 722), (b, 31, 2)
119
+ """
120
+ ###############################
121
+ # forward pass for classifier #
122
+ ###############################
123
+ seq_len = x.shape[-1]
124
+ x = x.float().transpose(-1, -2)
125
+
126
+ convblock_out = self.conv_block(x)
127
+
128
+ resblock1_out = self.res_block1(convblock_out)
129
+ resblock2_out = self.res_block2(resblock1_out)
130
+ resblock3_out = self.res_block3(resblock2_out)
131
+
132
+ poolblock_out = self.pool_block[0](resblock3_out)
133
+ poolblock_out = self.pool_block[1](poolblock_out)
134
+ GAN_feature = poolblock_out.transpose(-1, -2)
135
+ poolblock_out = self.pool_block[2](poolblock_out)
136
+
137
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
138
+ classifier_out = (
139
+ poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
140
+ )
141
+ classifier_out, _ = self.bilstm_classifier(
142
+ classifier_out
143
+ ) # ignore the hidden states
144
+
145
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
146
+ classifier_out = self.classifier(classifier_out)
147
+ classifier_out = classifier_out.view(
148
+ (-1, seq_len, self.num_class)
149
+ ) # (b, 31, num_class)
150
+
151
+ # sizes: (b, 31, 722), (b, 31, 2)
152
+ # classifier output consists of predicted pitch classes per frame
153
+ # detector output consists of: (isvoice, notvoice) estimates per frame
154
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
155
+
156
+ @staticmethod
157
+ def init_weights(m):
158
+ if isinstance(m, nn.Linear):
159
+ nn.init.kaiming_uniform_(m.weight)
160
+ if m.bias is not None:
161
+ nn.init.constant_(m.bias, 0)
162
+ elif isinstance(m, nn.Conv2d):
163
+ nn.init.xavier_normal_(m.weight)
164
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
165
+ for p in m.parameters():
166
+ if p.data is None:
167
+ continue
168
+
169
+ if len(p.shape) >= 2:
170
+ nn.init.orthogonal_(p.data)
171
+ else:
172
+ nn.init.normal_(p.data)
173
+
174
+
175
+ class ResBlock(nn.Module):
176
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
177
+ super().__init__()
178
+ self.downsample = in_channels != out_channels
179
+
180
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
181
+ self.pre_conv = nn.Sequential(
182
+ nn.BatchNorm2d(num_features=in_channels),
183
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
184
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
185
+ )
186
+
187
+ # conv layers
188
+ self.conv = nn.Sequential(
189
+ nn.Conv2d(
190
+ in_channels=in_channels,
191
+ out_channels=out_channels,
192
+ kernel_size=3,
193
+ padding=1,
194
+ bias=False,
195
+ ),
196
+ nn.BatchNorm2d(out_channels),
197
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
198
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
199
+ )
200
+
201
+ # 1 x 1 convolution layer to match the feature dimensions
202
+ self.conv1by1 = None
203
+ if self.downsample:
204
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
205
+
206
+ def forward(self, x):
207
+ x = self.pre_conv(x)
208
+ if self.downsample:
209
+ x = self.conv(x) + self.conv1by1(x)
210
+ else:
211
+ x = self.conv(x) + x
212
+ return x
Utils/PLBERT/config.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "Checkpoint"
2
+ mixed_precision: "fp16"
3
+ data_folder: "wikipedia_20220301.en.processed"
4
+ batch_size: 192
5
+ save_interval: 5000
6
+ log_interval: 10
7
+ num_process: 1 # number of GPUs
8
+ num_steps: 1000000
9
+
10
+ dataset_params:
11
+ tokenizer: "transfo-xl-wt103"
12
+ token_separator: " " # token used for phoneme separator (space)
13
+ token_mask: "M" # token used for phoneme mask (M)
14
+ word_separator: 3039 # token used for word separator (<formula>)
15
+ token_maps: "token_maps.pkl" # token map path
16
+
17
+ max_mel_length: 512 # max phoneme length
18
+
19
+ word_mask_prob: 0.15 # probability to mask the entire word
20
+ phoneme_mask_prob: 0.1 # probability to mask each phoneme
21
+ replace_prob: 0.2 # probablity to replace phonemes
22
+
23
+ model_params:
24
+ vocab_size: 178
25
+ hidden_size: 768
26
+ num_attention_heads: 12
27
+ intermediate_size: 2048
28
+ max_position_embeddings: 512
29
+ num_hidden_layers: 12
30
+ dropout: 0.1
Utils/PLBERT/step_1000000.t7 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0714ff85804db43e06b3b0ac5749bf90cf206257c6c5916e8a98c5933b4c21e0
3
+ size 25185187
Utils/PLBERT/util.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+
7
+ class CustomAlbert(AlbertModel):
8
+ def forward(self, *args, **kwargs):
9
+ # Call the original forward method
10
+ outputs = super().forward(*args, **kwargs)
11
+
12
+ # Only return the last_hidden_state
13
+ return outputs.last_hidden_state
14
+
15
+
16
+ def load_plbert(log_dir):
17
+ config_path = os.path.join(log_dir, "config.yml")
18
+ plbert_config = yaml.safe_load(open(config_path))
19
+
20
+ albert_base_configuration = AlbertConfig(**plbert_config["model_params"])
21
+ bert = CustomAlbert(albert_base_configuration)
22
+
23
+ files = os.listdir(log_dir)
24
+ ckpts = []
25
+ for f in os.listdir(log_dir):
26
+ if f.startswith("step_"):
27
+ ckpts.append(f)
28
+
29
+ iters = [
30
+ int(f.split("_")[-1].split(".")[0])
31
+ for f in ckpts
32
+ if os.path.isfile(os.path.join(log_dir, f))
33
+ ]
34
+ iters = sorted(iters)[-1]
35
+
36
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location="cpu")
37
+ state_dict = checkpoint["net"]
38
+ from collections import OrderedDict
39
+
40
+ new_state_dict = OrderedDict()
41
+ for k, v in state_dict.items():
42
+ name = k[7:] # remove `module.`
43
+ if name.startswith("encoder."):
44
+ name = name[8:] # remove `encoder.`
45
+ new_state_dict[name] = v
46
+ del new_state_dict["embeddings.position_ids"]
47
+ bert.load_state_dict(new_state_dict, strict=False)
48
+
49
+ return bert
Utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
_run.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cached_path import cached_path
2
+
3
+ from dp.phonemizer import Phonemizer
4
+ print("NLTK")
5
+ import nltk
6
+ nltk.download('punkt')
7
+ print("SCIPY")
8
+ from scipy.io.wavfile import write
9
+ print("TORCH STUFF")
10
+ import torch
11
+ print("START")
12
+ torch.manual_seed(0)
13
+ torch.backends.cudnn.benchmark = False
14
+ torch.backends.cudnn.deterministic = True
15
+
16
+ import random
17
+ random.seed(0)
18
+
19
+ import numpy as np
20
+ np.random.seed(0)
21
+
22
+ # load packages
23
+ import time
24
+ import random
25
+ import yaml
26
+ from munch import Munch
27
+ import numpy as np
28
+ import torch
29
+ from torch import nn
30
+ import torch.nn.functional as F
31
+ import torchaudio
32
+ import librosa
33
+ from nltk.tokenize import word_tokenize
34
+
35
+ from models import *
36
+ from utils import *
37
+ from text_utils import TextCleaner
38
+ textclenaer = TextCleaner()
39
+
40
+
41
+ to_mel = torchaudio.transforms.MelSpectrogram(
42
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
43
+ mean, std = -4, 4
44
+
45
+ def length_to_mask(lengths):
46
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
47
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
48
+ return mask
49
+
50
+ def preprocess(wave):
51
+ wave_tensor = torch.from_numpy(wave).float()
52
+ mel_tensor = to_mel(wave_tensor)
53
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
54
+ return mel_tensor
55
+
56
+ def compute_style(path):
57
+ wave, sr = librosa.load(path, sr=24000)
58
+ audio, index = librosa.effects.trim(wave, top_db=30)
59
+ if sr != 24000:
60
+ audio = librosa.resample(audio, sr, 24000)
61
+ mel_tensor = preprocess(audio).to(device)
62
+
63
+ with torch.no_grad():
64
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
65
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
66
+
67
+ return torch.cat([ref_s, ref_p], dim=1)
68
+
69
+ device = 'cpu'
70
+ if torch.cuda.is_available():
71
+ device = 'cuda'
72
+ elif torch.backends.mps.is_available():
73
+ print("MPS would be available but cannot be used rn")
74
+ # device = 'mps'
75
+
76
+
77
+ # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
78
+ phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
79
+
80
+
81
+ config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
82
+
83
+ # load pretrained ASR model
84
+ ASR_config = config.get('ASR_config', False)
85
+ ASR_path = config.get('ASR_path', False)
86
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
87
+
88
+ # load pretrained F0 model
89
+ F0_path = config.get('F0_path', False)
90
+ pitch_extractor = load_F0_models(F0_path)
91
+
92
+ # load BERT model
93
+ from Utils.PLBERT.util import load_plbert
94
+ BERT_path = config.get('PLBERT_dir', False)
95
+ plbert = load_plbert(BERT_path)
96
+
97
+ model_params = recursive_munch(config['model_params'])
98
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
99
+ _ = [model[key].eval() for key in model]
100
+ _ = [model[key].to(device) for key in model]
101
+
102
+ params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
103
+ params = params_whole['net']
104
+
105
+ for key in model:
106
+ if key in params:
107
+ print('%s loaded' % key)
108
+ try:
109
+ model[key].load_state_dict(params[key])
110
+ except:
111
+ from collections import OrderedDict
112
+ state_dict = params[key]
113
+ new_state_dict = OrderedDict()
114
+ for k, v in state_dict.items():
115
+ name = k[7:] # remove `module.`
116
+ new_state_dict[name] = v
117
+ # load params
118
+ model[key].load_state_dict(new_state_dict, strict=False)
119
+ # except:
120
+ # _load(params[key], model[key])
121
+ _ = [model[key].eval() for key in model]
122
+
123
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
124
+
125
+ sampler = DiffusionSampler(
126
+ model.diffusion.diffusion,
127
+ sampler=ADPM2Sampler(),
128
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
129
+ clamp=False
130
+ )
131
+
132
+ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
133
+ text = text.strip()
134
+ ps = phonemizer([text], lang='en_us')
135
+ ps = word_tokenize(ps[0])
136
+ ps = ' '.join(ps)
137
+ tokens = textclenaer(ps)
138
+ tokens.insert(0, 0)
139
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
140
+
141
+ with torch.no_grad():
142
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
143
+ text_mask = length_to_mask(input_lengths).to(device)
144
+
145
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
146
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
147
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
148
+
149
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
150
+ embedding=bert_dur,
151
+ embedding_scale=embedding_scale,
152
+ features=ref_s, # reference from the same speaker as the embedding
153
+ num_steps=diffusion_steps).squeeze(1)
154
+
155
+
156
+ s = s_pred[:, 128:]
157
+ ref = s_pred[:, :128]
158
+
159
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
160
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
161
+
162
+ d = model.predictor.text_encoder(d_en,
163
+ s, input_lengths, text_mask)
164
+
165
+ x, _ = model.predictor.lstm(d)
166
+ duration = model.predictor.duration_proj(x)
167
+
168
+ duration = torch.sigmoid(duration).sum(axis=-1)
169
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
170
+
171
+
172
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
173
+ c_frame = 0
174
+ for i in range(pred_aln_trg.size(0)):
175
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
176
+ c_frame += int(pred_dur[i].data)
177
+
178
+ # encode prosody
179
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
180
+ if model_params.decoder.type == "hifigan":
181
+ asr_new = torch.zeros_like(en)
182
+ asr_new[:, :, 0] = en[:, :, 0]
183
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
184
+ en = asr_new
185
+
186
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
187
+
188
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
189
+ if model_params.decoder.type == "hifigan":
190
+ asr_new = torch.zeros_like(asr)
191
+ asr_new[:, :, 0] = asr[:, :, 0]
192
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
193
+ asr = asr_new
194
+
195
+ out = model.decoder(asr,
196
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
197
+
198
+
199
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
200
+
201
+ def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
202
+ text = text.strip()
203
+ ps = phonemizer([text], lang='en_us')
204
+ ps = word_tokenize(ps[0])
205
+ ps = ' '.join(ps)
206
+ ps = ps.replace('``', '"')
207
+ ps = ps.replace("''", '"')
208
+
209
+ tokens = textclenaer(ps)
210
+ tokens.insert(0, 0)
211
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
212
+
213
+ with torch.no_grad():
214
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
215
+ text_mask = length_to_mask(input_lengths).to(device)
216
+
217
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
218
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
219
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
220
+
221
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
222
+ embedding=bert_dur,
223
+ embedding_scale=embedding_scale,
224
+ features=ref_s, # reference from the same speaker as the embedding
225
+ num_steps=diffusion_steps).squeeze(1)
226
+
227
+ if s_prev is not None:
228
+ # convex combination of previous and current style
229
+ s_pred = t * s_prev + (1 - t) * s_pred
230
+
231
+ s = s_pred[:, 128:]
232
+ ref = s_pred[:, :128]
233
+
234
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
235
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
236
+
237
+ s_pred = torch.cat([ref, s], dim=-1)
238
+
239
+ d = model.predictor.text_encoder(d_en,
240
+ s, input_lengths, text_mask)
241
+
242
+ x, _ = model.predictor.lstm(d)
243
+ duration = model.predictor.duration_proj(x)
244
+
245
+ duration = torch.sigmoid(duration).sum(axis=-1)
246
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
247
+
248
+
249
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
250
+ c_frame = 0
251
+ for i in range(pred_aln_trg.size(0)):
252
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
253
+ c_frame += int(pred_dur[i].data)
254
+
255
+ # encode prosody
256
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
257
+ if model_params.decoder.type == "hifigan":
258
+ asr_new = torch.zeros_like(en)
259
+ asr_new[:, :, 0] = en[:, :, 0]
260
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
261
+ en = asr_new
262
+
263
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
264
+
265
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
266
+ if model_params.decoder.type == "hifigan":
267
+ asr_new = torch.zeros_like(asr)
268
+ asr_new[:, :, 0] = asr[:, :, 0]
269
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
270
+ asr = asr_new
271
+
272
+ out = model.decoder(asr,
273
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
274
+
275
+
276
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
277
+
278
+ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
279
+ text = text.strip()
280
+ ps = phonemizer([text], lang='en_us')
281
+ ps = word_tokenize(ps[0])
282
+ ps = ' '.join(ps)
283
+
284
+ tokens = textclenaer(ps)
285
+ tokens.insert(0, 0)
286
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
287
+
288
+ ref_text = ref_text.strip()
289
+ ps = phonemizer([ref_text], lang='en_us')
290
+ ps = word_tokenize(ps[0])
291
+ ps = ' '.join(ps)
292
+
293
+ ref_tokens = textclenaer(ps)
294
+ ref_tokens.insert(0, 0)
295
+ ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
296
+
297
+
298
+ with torch.no_grad():
299
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
300
+ text_mask = length_to_mask(input_lengths).to(device)
301
+
302
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
303
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
304
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
305
+
306
+ ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
307
+ ref_text_mask = length_to_mask(ref_input_lengths).to(device)
308
+ ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
309
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
310
+ embedding=bert_dur,
311
+ embedding_scale=embedding_scale,
312
+ features=ref_s, # reference from the same speaker as the embedding
313
+ num_steps=diffusion_steps).squeeze(1)
314
+
315
+
316
+ s = s_pred[:, 128:]
317
+ ref = s_pred[:, :128]
318
+
319
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
320
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
321
+
322
+ d = model.predictor.text_encoder(d_en,
323
+ s, input_lengths, text_mask)
324
+
325
+ x, _ = model.predictor.lstm(d)
326
+ duration = model.predictor.duration_proj(x)
327
+
328
+ duration = torch.sigmoid(duration).sum(axis=-1)
329
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
330
+
331
+
332
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
333
+ c_frame = 0
334
+ for i in range(pred_aln_trg.size(0)):
335
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
336
+ c_frame += int(pred_dur[i].data)
337
+
338
+ # encode prosody
339
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
340
+ if model_params.decoder.type == "hifigan":
341
+ asr_new = torch.zeros_like(en)
342
+ asr_new[:, :, 0] = en[:, :, 0]
343
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
344
+ en = asr_new
345
+
346
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
347
+
348
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
349
+ if model_params.decoder.type == "hifigan":
350
+ asr_new = torch.zeros_like(asr)
351
+ asr_new[:, :, 0] = asr[:, :, 0]
352
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
353
+ asr = asr_new
354
+
355
+ out = model.decoder(asr,
356
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
357
+
358
+
359
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
360
+ print("Time to synthesize!")
361
+ ref_s = compute_style('./voice/voice.wav')
362
+ while True:
363
+ text = input("What to say? > ")
364
+ start = time.time()
365
+ wav = inference(text, ref_s, alpha=0.3, beta=0.7, diffusion_steps=15, embedding_scale=1)
366
+ rtf = (time.time() - start) / (len(wav) / 24000)
367
+ print(f"RTF = {rtf:5f}")
368
+ print(k + ' Synthesized:')
369
+ # display(ipd.Audio(wav, rate=24000, normalize=False))
370
+ write('result.wav', 24000, wav)
371
+ print("Saved to result.wav")
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import styletts2importable
3
+ theme = gr.themes.Base(
4
+ font=[gr.themes.GoogleFont('Libre Franklin'), gr.themes.GoogleFont('Public Sans'), 'system-ui', 'sans-serif'],
5
+ )
6
+ voices = {
7
+ 'angie': styletts2importable.compute_style('voices/angie.wav'),
8
+ 'daniel': styletts2importable.compute_style('voices/daniel.wav'),
9
+ 'dotrice': styletts2importable.compute_style('voices/dotrice.wav'),
10
+ 'lj': styletts2importable.compute_style('voices/lj.wav'),
11
+ 'mouse': styletts2importable.compute_style('voices/mouse.wav'),
12
+ 'pat': styletts2importable.compute_style('voices/pat.wav'),
13
+ 'tom': styletts2importable.compute_style('voices/tom.wav'),
14
+ 'william': styletts2importable.compute_style('voices/william.wav'),
15
+ }
16
+ def synthesize(text, voice):
17
+ if text.strip() == "":
18
+ raise gr.Error("You must enter some text")
19
+ v = voice.lower()
20
+ return (24000, styletts2importable.inference(text, voices[v], alpha=0.3, beta=0.7, diffusion_steps=15, embedding_scale=1))
21
+
22
+ with gr.Blocks(title="StyleTTS 2", css="footer{display:none !important}", theme=theme) as demo:
23
+ gr.Markdown("""# StyleTTS 2
24
+
25
+ [Paper](https://arxiv.org/abs/2306.07691) - [Samples](https://styletts2.github.io/) - [Code](https://github.com/yl4579/StyleTTS2)
26
+
27
+ A free demo of StyleTTS 2. Not affiliated with the StyleTTS 2 Authors.
28
+
29
+ **Before using this demo, you agree to inform the listeners that the speech samples are synthesized by the pre-trained models, unless you have the permission to use the voice you synthesize. That is, you agree to only use voices whose speakers grant the permission to have their voice cloned, either directly or by license before making synthesized voices public, or you have to publicly announce that these voices are synthesized if you do not have the permission to use these voices.**
30
+
31
+ This space does NOT allow voice cloning. We use some default voice from Tortoise TTS instead.
32
+
33
+ Is there a long queue on this space? Duplicate it and add a GPU to skip the wait!""")
34
+ gr.DuplicateButton("Duplicate Space")
35
+ with gr.Row():
36
+ with gr.Column(scale=1):
37
+ inp = gr.Textbox(label="Text", info="What would you like StyleTTS 2 to read? It works better on full sentences.", interactive=True)
38
+ voice = gr.Dropdown(['Angie', 'Daniel', 'Tom', 'LJ', 'Pat', 'Tom', 'Dotrice', 'Mouse', 'William'], label="Voice", info="Select a voice. We use some voices from Tortoise TTS.", value='Tom', interactive=True)
39
+ with gr.Column(scale=1):
40
+ btn = gr.Button("Synthesize")
41
+ audio = gr.Audio(interactive=False, label="Synthesized Audio")
42
+ btn.click(synthesize, inputs=[inp, voice], outputs=[audio])
43
+
44
+ if __name__ == "__main__":
45
+ demo.launch(show_api=False)
46
+
losses.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from transformers import AutoModel
6
+
7
+
8
+ class SpectralConvergengeLoss(torch.nn.Module):
9
+ """Spectral convergence loss module."""
10
+
11
+ def __init__(self):
12
+ """Initilize spectral convergence loss module."""
13
+ super(SpectralConvergengeLoss, self).__init__()
14
+
15
+ def forward(self, x_mag, y_mag):
16
+ """Calculate forward propagation.
17
+ Args:
18
+ x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
19
+ y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
20
+ Returns:
21
+ Tensor: Spectral convergence loss value.
22
+ """
23
+ return torch.norm(y_mag - x_mag, p=1) / torch.norm(y_mag, p=1)
24
+
25
+
26
+ class STFTLoss(torch.nn.Module):
27
+ """STFT loss module."""
28
+
29
+ def __init__(
30
+ self, fft_size=1024, shift_size=120, win_length=600, window=torch.hann_window
31
+ ):
32
+ """Initialize STFT loss module."""
33
+ super(STFTLoss, self).__init__()
34
+ self.fft_size = fft_size
35
+ self.shift_size = shift_size
36
+ self.win_length = win_length
37
+ self.to_mel = torchaudio.transforms.MelSpectrogram(
38
+ sample_rate=24000,
39
+ n_fft=fft_size,
40
+ win_length=win_length,
41
+ hop_length=shift_size,
42
+ window_fn=window,
43
+ )
44
+
45
+ self.spectral_convergenge_loss = SpectralConvergengeLoss()
46
+
47
+ def forward(self, x, y):
48
+ """Calculate forward propagation.
49
+ Args:
50
+ x (Tensor): Predicted signal (B, T).
51
+ y (Tensor): Groundtruth signal (B, T).
52
+ Returns:
53
+ Tensor: Spectral convergence loss value.
54
+ Tensor: Log STFT magnitude loss value.
55
+ """
56
+ x_mag = self.to_mel(x)
57
+ mean, std = -4, 4
58
+ x_mag = (torch.log(1e-5 + x_mag) - mean) / std
59
+
60
+ y_mag = self.to_mel(y)
61
+ mean, std = -4, 4
62
+ y_mag = (torch.log(1e-5 + y_mag) - mean) / std
63
+
64
+ sc_loss = self.spectral_convergenge_loss(x_mag, y_mag)
65
+ return sc_loss
66
+
67
+
68
+ class MultiResolutionSTFTLoss(torch.nn.Module):
69
+ """Multi resolution STFT loss module."""
70
+
71
+ def __init__(
72
+ self,
73
+ fft_sizes=[1024, 2048, 512],
74
+ hop_sizes=[120, 240, 50],
75
+ win_lengths=[600, 1200, 240],
76
+ window=torch.hann_window,
77
+ ):
78
+ """Initialize Multi resolution STFT loss module.
79
+ Args:
80
+ fft_sizes (list): List of FFT sizes.
81
+ hop_sizes (list): List of hop sizes.
82
+ win_lengths (list): List of window lengths.
83
+ window (str): Window function type.
84
+ """
85
+ super(MultiResolutionSTFTLoss, self).__init__()
86
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
87
+ self.stft_losses = torch.nn.ModuleList()
88
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
89
+ self.stft_losses += [STFTLoss(fs, ss, wl, window)]
90
+
91
+ def forward(self, x, y):
92
+ """Calculate forward propagation.
93
+ Args:
94
+ x (Tensor): Predicted signal (B, T).
95
+ y (Tensor): Groundtruth signal (B, T).
96
+ Returns:
97
+ Tensor: Multi resolution spectral convergence loss value.
98
+ Tensor: Multi resolution log STFT magnitude loss value.
99
+ """
100
+ sc_loss = 0.0
101
+ for f in self.stft_losses:
102
+ sc_l = f(x, y)
103
+ sc_loss += sc_l
104
+ sc_loss /= len(self.stft_losses)
105
+
106
+ return sc_loss
107
+
108
+
109
+ def feature_loss(fmap_r, fmap_g):
110
+ loss = 0
111
+ for dr, dg in zip(fmap_r, fmap_g):
112
+ for rl, gl in zip(dr, dg):
113
+ loss += torch.mean(torch.abs(rl - gl))
114
+
115
+ return loss * 2
116
+
117
+
118
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
119
+ loss = 0
120
+ r_losses = []
121
+ g_losses = []
122
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
123
+ r_loss = torch.mean((1 - dr) ** 2)
124
+ g_loss = torch.mean(dg**2)
125
+ loss += r_loss + g_loss
126
+ r_losses.append(r_loss.item())
127
+ g_losses.append(g_loss.item())
128
+
129
+ return loss, r_losses, g_losses
130
+
131
+
132
+ def generator_loss(disc_outputs):
133
+ loss = 0
134
+ gen_losses = []
135
+ for dg in disc_outputs:
136
+ l = torch.mean((1 - dg) ** 2)
137
+ gen_losses.append(l)
138
+ loss += l
139
+
140
+ return loss, gen_losses
141
+
142
+
143
+ """ https://dl.acm.org/doi/abs/10.1145/3573834.3574506 """
144
+
145
+
146
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
147
+ loss = 0
148
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
149
+ tau = 0.04
150
+ m_DG = torch.median((dr - dg))
151
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
152
+ loss += tau - F.relu(tau - L_rel)
153
+ return loss
154
+
155
+
156
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
157
+ loss = 0
158
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
159
+ tau = 0.04
160
+ m_DG = torch.median((dr - dg))
161
+ L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
162
+ loss += tau - F.relu(tau - L_rel)
163
+ return loss
164
+
165
+
166
+ class GeneratorLoss(torch.nn.Module):
167
+ def __init__(self, mpd, msd):
168
+ super(GeneratorLoss, self).__init__()
169
+ self.mpd = mpd
170
+ self.msd = msd
171
+
172
+ def forward(self, y, y_hat):
173
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = self.mpd(y, y_hat)
174
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = self.msd(y, y_hat)
175
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
176
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
177
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
178
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
179
+
180
+ loss_rel = generator_TPRLS_loss(y_df_hat_r, y_df_hat_g) + generator_TPRLS_loss(
181
+ y_ds_hat_r, y_ds_hat_g
182
+ )
183
+
184
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_rel
185
+
186
+ return loss_gen_all.mean()
187
+
188
+
189
+ class DiscriminatorLoss(torch.nn.Module):
190
+ def __init__(self, mpd, msd):
191
+ super(DiscriminatorLoss, self).__init__()
192
+ self.mpd = mpd
193
+ self.msd = msd
194
+
195
+ def forward(self, y, y_hat):
196
+ # MPD
197
+ y_df_hat_r, y_df_hat_g, _, _ = self.mpd(y, y_hat)
198
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
199
+ y_df_hat_r, y_df_hat_g
200
+ )
201
+ # MSD
202
+ y_ds_hat_r, y_ds_hat_g, _, _ = self.msd(y, y_hat)
203
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
204
+ y_ds_hat_r, y_ds_hat_g
205
+ )
206
+
207
+ loss_rel = discriminator_TPRLS_loss(
208
+ y_df_hat_r, y_df_hat_g
209
+ ) + discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
210
+
211
+ d_loss = loss_disc_s + loss_disc_f + loss_rel
212
+
213
+ return d_loss.mean()
214
+
215
+
216
+ class WavLMLoss(torch.nn.Module):
217
+ def __init__(self, model, wd, model_sr, slm_sr=16000):
218
+ super(WavLMLoss, self).__init__()
219
+ self.wavlm = AutoModel.from_pretrained(model)
220
+ self.wd = wd
221
+ self.resample = torchaudio.transforms.Resample(model_sr, slm_sr)
222
+
223
+ def forward(self, wav, y_rec):
224
+ with torch.no_grad():
225
+ wav_16 = self.resample(wav)
226
+ wav_embeddings = self.wavlm(
227
+ input_values=wav_16, output_hidden_states=True
228
+ ).hidden_states
229
+ y_rec_16 = self.resample(y_rec)
230
+ y_rec_embeddings = self.wavlm(
231
+ input_values=y_rec_16.squeeze(), output_hidden_states=True
232
+ ).hidden_states
233
+
234
+ floss = 0
235
+ for er, eg in zip(wav_embeddings, y_rec_embeddings):
236
+ floss += torch.mean(torch.abs(er - eg))
237
+
238
+ return floss.mean()
239
+
240
+ def generator(self, y_rec):
241
+ y_rec_16 = self.resample(y_rec)
242
+ y_rec_embeddings = self.wavlm(
243
+ input_values=y_rec_16, output_hidden_states=True
244
+ ).hidden_states
245
+ y_rec_embeddings = (
246
+ torch.stack(y_rec_embeddings, dim=1)
247
+ .transpose(-1, -2)
248
+ .flatten(start_dim=1, end_dim=2)
249
+ )
250
+ y_df_hat_g = self.wd(y_rec_embeddings)
251
+ loss_gen = torch.mean((1 - y_df_hat_g) ** 2)
252
+
253
+ return loss_gen
254
+
255
+ def discriminator(self, wav, y_rec):
256
+ with torch.no_grad():
257
+ wav_16 = self.resample(wav)
258
+ wav_embeddings = self.wavlm(
259
+ input_values=wav_16, output_hidden_states=True
260
+ ).hidden_states
261
+ y_rec_16 = self.resample(y_rec)
262
+ y_rec_embeddings = self.wavlm(
263
+ input_values=y_rec_16, output_hidden_states=True
264
+ ).hidden_states
265
+
266
+ y_embeddings = (
267
+ torch.stack(wav_embeddings, dim=1)
268
+ .transpose(-1, -2)
269
+ .flatten(start_dim=1, end_dim=2)
270
+ )
271
+ y_rec_embeddings = (
272
+ torch.stack(y_rec_embeddings, dim=1)
273
+ .transpose(-1, -2)
274
+ .flatten(start_dim=1, end_dim=2)
275
+ )
276
+
277
+ y_d_rs = self.wd(y_embeddings)
278
+ y_d_gs = self.wd(y_rec_embeddings)
279
+
280
+ y_df_hat_r, y_df_hat_g = y_d_rs, y_d_gs
281
+
282
+ r_loss = torch.mean((1 - y_df_hat_r) ** 2)
283
+ g_loss = torch.mean((y_df_hat_g) ** 2)
284
+
285
+ loss_disc_f = r_loss + g_loss
286
+
287
+ return loss_disc_f.mean()
288
+
289
+ def discriminator_forward(self, wav):
290
+ with torch.no_grad():
291
+ wav_16 = self.resample(wav)
292
+ wav_embeddings = self.wavlm(
293
+ input_values=wav_16, output_hidden_states=True
294
+ ).hidden_states
295
+ y_embeddings = (
296
+ torch.stack(wav_embeddings, dim=1)
297
+ .transpose(-1, -2)
298
+ .flatten(start_dim=1, end_dim=2)
299
+ )
300
+
301
+ y_d_rs = self.wd(y_embeddings)
302
+
303
+ return y_d_rs
meldataset.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import random
6
+ import numpy as np
7
+ import random
8
+ import soundfile as sf
9
+ import librosa
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ import torchaudio
15
+ from torch.utils.data import DataLoader
16
+
17
+ import logging
18
+
19
+ logger = logging.getLogger(__name__)
20
+ logger.setLevel(logging.DEBUG)
21
+
22
+ import pandas as pd
23
+
24
+ _pad = "$"
25
+ _punctuation = ';:,.!?¡¿—…"«»“” '
26
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
27
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
28
+
29
+ # Export all symbols:
30
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
31
+
32
+ dicts = {}
33
+ for i in range(len((symbols))):
34
+ dicts[symbols[i]] = i
35
+
36
+
37
+ class TextCleaner:
38
+ def __init__(self, dummy=None):
39
+ self.word_index_dictionary = dicts
40
+
41
+ def __call__(self, text):
42
+ indexes = []
43
+ for char in text:
44
+ try:
45
+ indexes.append(self.word_index_dictionary[char])
46
+ except KeyError:
47
+ print(text)
48
+ return indexes
49
+
50
+
51
+ np.random.seed(1)
52
+ random.seed(1)
53
+ SPECT_PARAMS = {"n_fft": 2048, "win_length": 1200, "hop_length": 300}
54
+ MEL_PARAMS = {
55
+ "n_mels": 80,
56
+ }
57
+
58
+ to_mel = torchaudio.transforms.MelSpectrogram(
59
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300
60
+ )
61
+ mean, std = -4, 4
62
+
63
+
64
+ def preprocess(wave):
65
+ wave_tensor = torch.from_numpy(wave).float()
66
+ mel_tensor = to_mel(wave_tensor)
67
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
68
+ return mel_tensor
69
+
70
+
71
+ class FilePathDataset(torch.utils.data.Dataset):
72
+ def __init__(
73
+ self,
74
+ data_list,
75
+ root_path,
76
+ sr=24000,
77
+ data_augmentation=False,
78
+ validation=False,
79
+ OOD_data="Data/OOD_texts.txt",
80
+ min_length=50,
81
+ ):
82
+ spect_params = SPECT_PARAMS
83
+ mel_params = MEL_PARAMS
84
+
85
+ _data_list = [l[:-1].split("|") for l in data_list]
86
+ self.data_list = [data if len(data) == 3 else (*data, 0) for data in _data_list]
87
+ self.text_cleaner = TextCleaner()
88
+ self.sr = sr
89
+
90
+ self.df = pd.DataFrame(self.data_list)
91
+
92
+ self.to_melspec = torchaudio.transforms.MelSpectrogram(**MEL_PARAMS)
93
+
94
+ self.mean, self.std = -4, 4
95
+ self.data_augmentation = data_augmentation and (not validation)
96
+ self.max_mel_length = 192
97
+
98
+ self.min_length = min_length
99
+ with open(OOD_data, "r") as f:
100
+ tl = f.readlines()
101
+ idx = 1 if ".wav" in tl[0].split("|")[0] else 0
102
+ self.ptexts = [t.split("|")[idx] for t in tl]
103
+
104
+ self.root_path = root_path
105
+
106
+ def __len__(self):
107
+ return len(self.data_list)
108
+
109
+ def __getitem__(self, idx):
110
+ data = self.data_list[idx]
111
+ path = data[0]
112
+
113
+ wave, text_tensor, speaker_id = self._load_tensor(data)
114
+
115
+ mel_tensor = preprocess(wave).squeeze()
116
+
117
+ acoustic_feature = mel_tensor.squeeze()
118
+ length_feature = acoustic_feature.size(1)
119
+ acoustic_feature = acoustic_feature[:, : (length_feature - length_feature % 2)]
120
+
121
+ # get reference sample
122
+ ref_data = (self.df[self.df[2] == str(speaker_id)]).sample(n=1).iloc[0].tolist()
123
+ ref_mel_tensor, ref_label = self._load_data(ref_data[:3])
124
+
125
+ # get OOD text
126
+
127
+ ps = ""
128
+
129
+ while len(ps) < self.min_length:
130
+ rand_idx = np.random.randint(0, len(self.ptexts) - 1)
131
+ ps = self.ptexts[rand_idx]
132
+
133
+ text = self.text_cleaner(ps)
134
+ text.insert(0, 0)
135
+ text.append(0)
136
+
137
+ ref_text = torch.LongTensor(text)
138
+
139
+ return (
140
+ speaker_id,
141
+ acoustic_feature,
142
+ text_tensor,
143
+ ref_text,
144
+ ref_mel_tensor,
145
+ ref_label,
146
+ path,
147
+ wave,
148
+ )
149
+
150
+ def _load_tensor(self, data):
151
+ wave_path, text, speaker_id = data
152
+ speaker_id = int(speaker_id)
153
+ wave, sr = sf.read(osp.join(self.root_path, wave_path))
154
+ if wave.shape[-1] == 2:
155
+ wave = wave[:, 0].squeeze()
156
+ if sr != 24000:
157
+ wave = librosa.resample(wave, orig_sr=sr, target_sr=24000)
158
+ print(wave_path, sr)
159
+
160
+ wave = np.concatenate([np.zeros([5000]), wave, np.zeros([5000])], axis=0)
161
+
162
+ text = self.text_cleaner(text)
163
+
164
+ text.insert(0, 0)
165
+ text.append(0)
166
+
167
+ text = torch.LongTensor(text)
168
+
169
+ return wave, text, speaker_id
170
+
171
+ def _load_data(self, data):
172
+ wave, text_tensor, speaker_id = self._load_tensor(data)
173
+ mel_tensor = preprocess(wave).squeeze()
174
+
175
+ mel_length = mel_tensor.size(1)
176
+ if mel_length > self.max_mel_length:
177
+ random_start = np.random.randint(0, mel_length - self.max_mel_length)
178
+ mel_tensor = mel_tensor[
179
+ :, random_start : random_start + self.max_mel_length
180
+ ]
181
+
182
+ return mel_tensor, speaker_id
183
+
184
+
185
+ class Collater(object):
186
+ """
187
+ Args:
188
+ adaptive_batch_size (bool): if true, decrease batch size when long data comes.
189
+ """
190
+
191
+ def __init__(self, return_wave=False):
192
+ self.text_pad_index = 0
193
+ self.min_mel_length = 192
194
+ self.max_mel_length = 192
195
+ self.return_wave = return_wave
196
+
197
+ def __call__(self, batch):
198
+ # batch[0] = wave, mel, text, f0, speakerid
199
+ batch_size = len(batch)
200
+
201
+ # sort by mel length
202
+ lengths = [b[1].shape[1] for b in batch]
203
+ batch_indexes = np.argsort(lengths)[::-1]
204
+ batch = [batch[bid] for bid in batch_indexes]
205
+
206
+ nmels = batch[0][1].size(0)
207
+ max_mel_length = max([b[1].shape[1] for b in batch])
208
+ max_text_length = max([b[2].shape[0] for b in batch])
209
+ max_rtext_length = max([b[3].shape[0] for b in batch])
210
+
211
+ labels = torch.zeros((batch_size)).long()
212
+ mels = torch.zeros((batch_size, nmels, max_mel_length)).float()
213
+ texts = torch.zeros((batch_size, max_text_length)).long()
214
+ ref_texts = torch.zeros((batch_size, max_rtext_length)).long()
215
+
216
+ input_lengths = torch.zeros(batch_size).long()
217
+ ref_lengths = torch.zeros(batch_size).long()
218
+ output_lengths = torch.zeros(batch_size).long()
219
+ ref_mels = torch.zeros((batch_size, nmels, self.max_mel_length)).float()
220
+ ref_labels = torch.zeros((batch_size)).long()
221
+ paths = ["" for _ in range(batch_size)]
222
+ waves = [None for _ in range(batch_size)]
223
+
224
+ for bid, (
225
+ label,
226
+ mel,
227
+ text,
228
+ ref_text,
229
+ ref_mel,
230
+ ref_label,
231
+ path,
232
+ wave,
233
+ ) in enumerate(batch):
234
+ mel_size = mel.size(1)
235
+ text_size = text.size(0)
236
+ rtext_size = ref_text.size(0)
237
+ labels[bid] = label
238
+ mels[bid, :, :mel_size] = mel
239
+ texts[bid, :text_size] = text
240
+ ref_texts[bid, :rtext_size] = ref_text
241
+ input_lengths[bid] = text_size
242
+ ref_lengths[bid] = rtext_size
243
+ output_lengths[bid] = mel_size
244
+ paths[bid] = path
245
+ ref_mel_size = ref_mel.size(1)
246
+ ref_mels[bid, :, :ref_mel_size] = ref_mel
247
+
248
+ ref_labels[bid] = ref_label
249
+ waves[bid] = wave
250
+
251
+ return (
252
+ waves,
253
+ texts,
254
+ input_lengths,
255
+ ref_texts,
256
+ ref_lengths,
257
+ mels,
258
+ output_lengths,
259
+ ref_mels,
260
+ )
261
+
262
+
263
+ def build_dataloader(
264
+ path_list,
265
+ root_path,
266
+ validation=False,
267
+ OOD_data="Data/OOD_texts.txt",
268
+ min_length=50,
269
+ batch_size=4,
270
+ num_workers=1,
271
+ device="cpu",
272
+ collate_config={},
273
+ dataset_config={},
274
+ ):
275
+ dataset = FilePathDataset(
276
+ path_list,
277
+ root_path,
278
+ OOD_data=OOD_data,
279
+ min_length=min_length,
280
+ validation=validation,
281
+ **dataset_config
282
+ )
283
+ collate_fn = Collater(**collate_config)
284
+ data_loader = DataLoader(
285
+ dataset,
286
+ batch_size=batch_size,
287
+ shuffle=(not validation),
288
+ num_workers=num_workers,
289
+ drop_last=(not validation),
290
+ collate_fn=collate_fn,
291
+ pin_memory=(device != "cpu"),
292
+ )
293
+
294
+ return data_loader
models.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+
3
+ import os
4
+ import os.path as osp
5
+
6
+ import copy
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
14
+
15
+ from Utils.ASR.models import ASRCNN
16
+ from Utils.JDC.model import JDCNet
17
+
18
+ from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
19
+ from Modules.diffusion.modules import Transformer1d, StyleTransformer1d
20
+ from Modules.diffusion.diffusion import AudioDiffusionConditional
21
+
22
+ from Modules.discriminators import (
23
+ MultiPeriodDiscriminator,
24
+ MultiResSpecDiscriminator,
25
+ WavLMDiscriminator,
26
+ )
27
+
28
+ from munch import Munch
29
+ import yaml
30
+
31
+
32
+ class LearnedDownSample(nn.Module):
33
+ def __init__(self, layer_type, dim_in):
34
+ super().__init__()
35
+ self.layer_type = layer_type
36
+
37
+ if self.layer_type == "none":
38
+ self.conv = nn.Identity()
39
+ elif self.layer_type == "timepreserve":
40
+ self.conv = spectral_norm(
41
+ nn.Conv2d(
42
+ dim_in,
43
+ dim_in,
44
+ kernel_size=(3, 1),
45
+ stride=(2, 1),
46
+ groups=dim_in,
47
+ padding=(1, 0),
48
+ )
49
+ )
50
+ elif self.layer_type == "half":
51
+ self.conv = spectral_norm(
52
+ nn.Conv2d(
53
+ dim_in,
54
+ dim_in,
55
+ kernel_size=(3, 3),
56
+ stride=(2, 2),
57
+ groups=dim_in,
58
+ padding=1,
59
+ )
60
+ )
61
+ else:
62
+ raise RuntimeError(
63
+ "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
64
+ % self.layer_type
65
+ )
66
+
67
+ def forward(self, x):
68
+ return self.conv(x)
69
+
70
+
71
+ class LearnedUpSample(nn.Module):
72
+ def __init__(self, layer_type, dim_in):
73
+ super().__init__()
74
+ self.layer_type = layer_type
75
+
76
+ if self.layer_type == "none":
77
+ self.conv = nn.Identity()
78
+ elif self.layer_type == "timepreserve":
79
+ self.conv = nn.ConvTranspose2d(
80
+ dim_in,
81
+ dim_in,
82
+ kernel_size=(3, 1),
83
+ stride=(2, 1),
84
+ groups=dim_in,
85
+ output_padding=(1, 0),
86
+ padding=(1, 0),
87
+ )
88
+ elif self.layer_type == "half":
89
+ self.conv = nn.ConvTranspose2d(
90
+ dim_in,
91
+ dim_in,
92
+ kernel_size=(3, 3),
93
+ stride=(2, 2),
94
+ groups=dim_in,
95
+ output_padding=1,
96
+ padding=1,
97
+ )
98
+ else:
99
+ raise RuntimeError(
100
+ "Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
101
+ % self.layer_type
102
+ )
103
+
104
+ def forward(self, x):
105
+ return self.conv(x)
106
+
107
+
108
+ class DownSample(nn.Module):
109
+ def __init__(self, layer_type):
110
+ super().__init__()
111
+ self.layer_type = layer_type
112
+
113
+ def forward(self, x):
114
+ if self.layer_type == "none":
115
+ return x
116
+ elif self.layer_type == "timepreserve":
117
+ return F.avg_pool2d(x, (2, 1))
118
+ elif self.layer_type == "half":
119
+ if x.shape[-1] % 2 != 0:
120
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
121
+ return F.avg_pool2d(x, 2)
122
+ else:
123
+ raise RuntimeError(
124
+ "Got unexpected donwsampletype %s, expected is [none, timepreserve, half]"
125
+ % self.layer_type
126
+ )
127
+
128
+
129
+ class UpSample(nn.Module):
130
+ def __init__(self, layer_type):
131
+ super().__init__()
132
+ self.layer_type = layer_type
133
+
134
+ def forward(self, x):
135
+ if self.layer_type == "none":
136
+ return x
137
+ elif self.layer_type == "timepreserve":
138
+ return F.interpolate(x, scale_factor=(2, 1), mode="nearest")
139
+ elif self.layer_type == "half":
140
+ return F.interpolate(x, scale_factor=2, mode="nearest")
141
+ else:
142
+ raise RuntimeError(
143
+ "Got unexpected upsampletype %s, expected is [none, timepreserve, half]"
144
+ % self.layer_type
145
+ )
146
+
147
+
148
+ class ResBlk(nn.Module):
149
+ def __init__(
150
+ self,
151
+ dim_in,
152
+ dim_out,
153
+ actv=nn.LeakyReLU(0.2),
154
+ normalize=False,
155
+ downsample="none",
156
+ ):
157
+ super().__init__()
158
+ self.actv = actv
159
+ self.normalize = normalize
160
+ self.downsample = DownSample(downsample)
161
+ self.downsample_res = LearnedDownSample(downsample, dim_in)
162
+ self.learned_sc = dim_in != dim_out
163
+ self._build_weights(dim_in, dim_out)
164
+
165
+ def _build_weights(self, dim_in, dim_out):
166
+ self.conv1 = spectral_norm(nn.Conv2d(dim_in, dim_in, 3, 1, 1))
167
+ self.conv2 = spectral_norm(nn.Conv2d(dim_in, dim_out, 3, 1, 1))
168
+ if self.normalize:
169
+ self.norm1 = nn.InstanceNorm2d(dim_in, affine=True)
170
+ self.norm2 = nn.InstanceNorm2d(dim_in, affine=True)
171
+ if self.learned_sc:
172
+ self.conv1x1 = spectral_norm(
173
+ nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=False)
174
+ )
175
+
176
+ def _shortcut(self, x):
177
+ if self.learned_sc:
178
+ x = self.conv1x1(x)
179
+ if self.downsample:
180
+ x = self.downsample(x)
181
+ return x
182
+
183
+ def _residual(self, x):
184
+ if self.normalize:
185
+ x = self.norm1(x)
186
+ x = self.actv(x)
187
+ x = self.conv1(x)
188
+ x = self.downsample_res(x)
189
+ if self.normalize:
190
+ x = self.norm2(x)
191
+ x = self.actv(x)
192
+ x = self.conv2(x)
193
+ return x
194
+
195
+ def forward(self, x):
196
+ x = self._shortcut(x) + self._residual(x)
197
+ return x / math.sqrt(2) # unit variance
198
+
199
+
200
+ class StyleEncoder(nn.Module):
201
+ def __init__(self, dim_in=48, style_dim=48, max_conv_dim=384):
202
+ super().__init__()
203
+ blocks = []
204
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
205
+
206
+ repeat_num = 4
207
+ for _ in range(repeat_num):
208
+ dim_out = min(dim_in * 2, max_conv_dim)
209
+ blocks += [ResBlk(dim_in, dim_out, downsample="half")]
210
+ dim_in = dim_out
211
+
212
+ blocks += [nn.LeakyReLU(0.2)]
213
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
214
+ blocks += [nn.AdaptiveAvgPool2d(1)]
215
+ blocks += [nn.LeakyReLU(0.2)]
216
+ self.shared = nn.Sequential(*blocks)
217
+
218
+ self.unshared = nn.Linear(dim_out, style_dim)
219
+
220
+ def forward(self, x):
221
+ h = self.shared(x)
222
+ h = h.view(h.size(0), -1)
223
+ s = self.unshared(h)
224
+
225
+ return s
226
+
227
+
228
+ class LinearNorm(torch.nn.Module):
229
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain="linear"):
230
+ super(LinearNorm, self).__init__()
231
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
232
+
233
+ torch.nn.init.xavier_uniform_(
234
+ self.linear_layer.weight, gain=torch.nn.init.calculate_gain(w_init_gain)
235
+ )
236
+
237
+ def forward(self, x):
238
+ return self.linear_layer(x)
239
+
240
+
241
+ class Discriminator2d(nn.Module):
242
+ def __init__(self, dim_in=48, num_domains=1, max_conv_dim=384, repeat_num=4):
243
+ super().__init__()
244
+ blocks = []
245
+ blocks += [spectral_norm(nn.Conv2d(1, dim_in, 3, 1, 1))]
246
+
247
+ for lid in range(repeat_num):
248
+ dim_out = min(dim_in * 2, max_conv_dim)
249
+ blocks += [ResBlk(dim_in, dim_out, downsample="half")]
250
+ dim_in = dim_out
251
+
252
+ blocks += [nn.LeakyReLU(0.2)]
253
+ blocks += [spectral_norm(nn.Conv2d(dim_out, dim_out, 5, 1, 0))]
254
+ blocks += [nn.LeakyReLU(0.2)]
255
+ blocks += [nn.AdaptiveAvgPool2d(1)]
256
+ blocks += [spectral_norm(nn.Conv2d(dim_out, num_domains, 1, 1, 0))]
257
+ self.main = nn.Sequential(*blocks)
258
+
259
+ def get_feature(self, x):
260
+ features = []
261
+ for l in self.main:
262
+ x = l(x)
263
+ features.append(x)
264
+ out = features[-1]
265
+ out = out.view(out.size(0), -1) # (batch, num_domains)
266
+ return out, features
267
+
268
+ def forward(self, x):
269
+ out, features = self.get_feature(x)
270
+ out = out.squeeze() # (batch)
271
+ return out, features
272
+
273
+
274
+ class ResBlk1d(nn.Module):
275
+ def __init__(
276
+ self,
277
+ dim_in,
278
+ dim_out,
279
+ actv=nn.LeakyReLU(0.2),
280
+ normalize=False,
281
+ downsample="none",
282
+ dropout_p=0.2,
283
+ ):
284
+ super().__init__()
285
+ self.actv = actv
286
+ self.normalize = normalize
287
+ self.downsample_type = downsample
288
+ self.learned_sc = dim_in != dim_out
289
+ self._build_weights(dim_in, dim_out)
290
+ self.dropout_p = dropout_p
291
+
292
+ if self.downsample_type == "none":
293
+ self.pool = nn.Identity()
294
+ else:
295
+ self.pool = weight_norm(
296
+ nn.Conv1d(
297
+ dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1
298
+ )
299
+ )
300
+
301
+ def _build_weights(self, dim_in, dim_out):
302
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_in, 3, 1, 1))
303
+ self.conv2 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
304
+ if self.normalize:
305
+ self.norm1 = nn.InstanceNorm1d(dim_in, affine=True)
306
+ self.norm2 = nn.InstanceNorm1d(dim_in, affine=True)
307
+ if self.learned_sc:
308
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
309
+
310
+ def downsample(self, x):
311
+ if self.downsample_type == "none":
312
+ return x
313
+ else:
314
+ if x.shape[-1] % 2 != 0:
315
+ x = torch.cat([x, x[..., -1].unsqueeze(-1)], dim=-1)
316
+ return F.avg_pool1d(x, 2)
317
+
318
+ def _shortcut(self, x):
319
+ if self.learned_sc:
320
+ x = self.conv1x1(x)
321
+ x = self.downsample(x)
322
+ return x
323
+
324
+ def _residual(self, x):
325
+ if self.normalize:
326
+ x = self.norm1(x)
327
+ x = self.actv(x)
328
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
329
+
330
+ x = self.conv1(x)
331
+ x = self.pool(x)
332
+ if self.normalize:
333
+ x = self.norm2(x)
334
+
335
+ x = self.actv(x)
336
+ x = F.dropout(x, p=self.dropout_p, training=self.training)
337
+
338
+ x = self.conv2(x)
339
+ return x
340
+
341
+ def forward(self, x):
342
+ x = self._shortcut(x) + self._residual(x)
343
+ return x / math.sqrt(2) # unit variance
344
+
345
+
346
+ class LayerNorm(nn.Module):
347
+ def __init__(self, channels, eps=1e-5):
348
+ super().__init__()
349
+ self.channels = channels
350
+ self.eps = eps
351
+
352
+ self.gamma = nn.Parameter(torch.ones(channels))
353
+ self.beta = nn.Parameter(torch.zeros(channels))
354
+
355
+ def forward(self, x):
356
+ x = x.transpose(1, -1)
357
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
358
+ return x.transpose(1, -1)
359
+
360
+
361
+ class TextEncoder(nn.Module):
362
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
363
+ super().__init__()
364
+ self.embedding = nn.Embedding(n_symbols, channels)
365
+
366
+ padding = (kernel_size - 1) // 2
367
+ self.cnn = nn.ModuleList()
368
+ for _ in range(depth):
369
+ self.cnn.append(
370
+ nn.Sequential(
371
+ weight_norm(
372
+ nn.Conv1d(
373
+ channels, channels, kernel_size=kernel_size, padding=padding
374
+ )
375
+ ),
376
+ LayerNorm(channels),
377
+ actv,
378
+ nn.Dropout(0.2),
379
+ )
380
+ )
381
+ # self.cnn = nn.Sequential(*self.cnn)
382
+
383
+ self.lstm = nn.LSTM(
384
+ channels, channels // 2, 1, batch_first=True, bidirectional=True
385
+ )
386
+
387
+ def forward(self, x, input_lengths, m):
388
+ x = self.embedding(x) # [B, T, emb]
389
+ x = x.transpose(1, 2) # [B, emb, T]
390
+ m = m.to(input_lengths.device).unsqueeze(1)
391
+ x.masked_fill_(m, 0.0)
392
+
393
+ for c in self.cnn:
394
+ x = c(x)
395
+ x.masked_fill_(m, 0.0)
396
+
397
+ x = x.transpose(1, 2) # [B, T, chn]
398
+
399
+ input_lengths = input_lengths.cpu().numpy()
400
+ x = nn.utils.rnn.pack_padded_sequence(
401
+ x, input_lengths, batch_first=True, enforce_sorted=False
402
+ )
403
+
404
+ self.lstm.flatten_parameters()
405
+ x, _ = self.lstm(x)
406
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
407
+
408
+ x = x.transpose(-1, -2)
409
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
410
+
411
+ x_pad[:, :, : x.shape[-1]] = x
412
+ x = x_pad.to(x.device)
413
+
414
+ x.masked_fill_(m, 0.0)
415
+
416
+ return x
417
+
418
+ def inference(self, x):
419
+ x = self.embedding(x)
420
+ x = x.transpose(1, 2)
421
+ x = self.cnn(x)
422
+ x = x.transpose(1, 2)
423
+ self.lstm.flatten_parameters()
424
+ x, _ = self.lstm(x)
425
+ return x
426
+
427
+ def length_to_mask(self, lengths):
428
+ mask = (
429
+ torch.arange(lengths.max())
430
+ .unsqueeze(0)
431
+ .expand(lengths.shape[0], -1)
432
+ .type_as(lengths)
433
+ )
434
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
435
+ return mask
436
+
437
+
438
+ class AdaIN1d(nn.Module):
439
+ def __init__(self, style_dim, num_features):
440
+ super().__init__()
441
+ self.norm = nn.InstanceNorm1d(num_features, affine=False)
442
+ self.fc = nn.Linear(style_dim, num_features * 2)
443
+
444
+ def forward(self, x, s):
445
+ h = self.fc(s)
446
+ h = h.view(h.size(0), h.size(1), 1)
447
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
448
+ return (1 + gamma) * self.norm(x) + beta
449
+
450
+
451
+ class UpSample1d(nn.Module):
452
+ def __init__(self, layer_type):
453
+ super().__init__()
454
+ self.layer_type = layer_type
455
+
456
+ def forward(self, x):
457
+ if self.layer_type == "none":
458
+ return x
459
+ else:
460
+ return F.interpolate(x, scale_factor=2, mode="nearest")
461
+
462
+
463
+ class AdainResBlk1d(nn.Module):
464
+ def __init__(
465
+ self,
466
+ dim_in,
467
+ dim_out,
468
+ style_dim=64,
469
+ actv=nn.LeakyReLU(0.2),
470
+ upsample="none",
471
+ dropout_p=0.0,
472
+ ):
473
+ super().__init__()
474
+ self.actv = actv
475
+ self.upsample_type = upsample
476
+ self.upsample = UpSample1d(upsample)
477
+ self.learned_sc = dim_in != dim_out
478
+ self._build_weights(dim_in, dim_out, style_dim)
479
+ self.dropout = nn.Dropout(dropout_p)
480
+
481
+ if upsample == "none":
482
+ self.pool = nn.Identity()
483
+ else:
484
+ self.pool = weight_norm(
485
+ nn.ConvTranspose1d(
486
+ dim_in,
487
+ dim_in,
488
+ kernel_size=3,
489
+ stride=2,
490
+ groups=dim_in,
491
+ padding=1,
492
+ output_padding=1,
493
+ )
494
+ )
495
+
496
+ def _build_weights(self, dim_in, dim_out, style_dim):
497
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
498
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
499
+ self.norm1 = AdaIN1d(style_dim, dim_in)
500
+ self.norm2 = AdaIN1d(style_dim, dim_out)
501
+ if self.learned_sc:
502
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
503
+
504
+ def _shortcut(self, x):
505
+ x = self.upsample(x)
506
+ if self.learned_sc:
507
+ x = self.conv1x1(x)
508
+ return x
509
+
510
+ def _residual(self, x, s):
511
+ x = self.norm1(x, s)
512
+ x = self.actv(x)
513
+ x = self.pool(x)
514
+ x = self.conv1(self.dropout(x))
515
+ x = self.norm2(x, s)
516
+ x = self.actv(x)
517
+ x = self.conv2(self.dropout(x))
518
+ return x
519
+
520
+ def forward(self, x, s):
521
+ out = self._residual(x, s)
522
+ out = (out + self._shortcut(x)) / math.sqrt(2)
523
+ return out
524
+
525
+
526
+ class AdaLayerNorm(nn.Module):
527
+ def __init__(self, style_dim, channels, eps=1e-5):
528
+ super().__init__()
529
+ self.channels = channels
530
+ self.eps = eps
531
+
532
+ self.fc = nn.Linear(style_dim, channels * 2)
533
+
534
+ def forward(self, x, s):
535
+ x = x.transpose(-1, -2)
536
+ x = x.transpose(1, -1)
537
+
538
+ h = self.fc(s)
539
+ h = h.view(h.size(0), h.size(1), 1)
540
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
541
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
542
+
543
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
544
+ x = (1 + gamma) * x + beta
545
+ return x.transpose(1, -1).transpose(-1, -2)
546
+
547
+
548
+ class ProsodyPredictor(nn.Module):
549
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
550
+ super().__init__()
551
+
552
+ self.text_encoder = DurationEncoder(
553
+ sty_dim=style_dim, d_model=d_hid, nlayers=nlayers, dropout=dropout
554
+ )
555
+
556
+ self.lstm = nn.LSTM(
557
+ d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
558
+ )
559
+ self.duration_proj = LinearNorm(d_hid, max_dur)
560
+
561
+ self.shared = nn.LSTM(
562
+ d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True
563
+ )
564
+ self.F0 = nn.ModuleList()
565
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
566
+ self.F0.append(
567
+ AdainResBlk1d(
568
+ d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
569
+ )
570
+ )
571
+ self.F0.append(
572
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
573
+ )
574
+
575
+ self.N = nn.ModuleList()
576
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
577
+ self.N.append(
578
+ AdainResBlk1d(
579
+ d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout
580
+ )
581
+ )
582
+ self.N.append(
583
+ AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout)
584
+ )
585
+
586
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
587
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
588
+
589
+ def forward(self, texts, style, text_lengths, alignment, m):
590
+ d = self.text_encoder(texts, style, text_lengths, m)
591
+
592
+ batch_size = d.shape[0]
593
+ text_size = d.shape[1]
594
+
595
+ # predict duration
596
+ input_lengths = text_lengths.cpu().numpy()
597
+ x = nn.utils.rnn.pack_padded_sequence(
598
+ d, input_lengths, batch_first=True, enforce_sorted=False
599
+ )
600
+
601
+ m = m.to(text_lengths.device).unsqueeze(1)
602
+
603
+ self.lstm.flatten_parameters()
604
+ x, _ = self.lstm(x)
605
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
606
+
607
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]])
608
+
609
+ x_pad[:, : x.shape[1], :] = x
610
+ x = x_pad.to(x.device)
611
+
612
+ duration = self.duration_proj(
613
+ nn.functional.dropout(x, 0.5, training=self.training)
614
+ )
615
+
616
+ en = d.transpose(-1, -2) @ alignment
617
+
618
+ return duration.squeeze(-1), en
619
+
620
+ def F0Ntrain(self, x, s):
621
+ x, _ = self.shared(x.transpose(-1, -2))
622
+
623
+ F0 = x.transpose(-1, -2)
624
+ for block in self.F0:
625
+ F0 = block(F0, s)
626
+ F0 = self.F0_proj(F0)
627
+
628
+ N = x.transpose(-1, -2)
629
+ for block in self.N:
630
+ N = block(N, s)
631
+ N = self.N_proj(N)
632
+
633
+ return F0.squeeze(1), N.squeeze(1)
634
+
635
+ def length_to_mask(self, lengths):
636
+ mask = (
637
+ torch.arange(lengths.max())
638
+ .unsqueeze(0)
639
+ .expand(lengths.shape[0], -1)
640
+ .type_as(lengths)
641
+ )
642
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
643
+ return mask
644
+
645
+
646
+ class DurationEncoder(nn.Module):
647
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
648
+ super().__init__()
649
+ self.lstms = nn.ModuleList()
650
+ for _ in range(nlayers):
651
+ self.lstms.append(
652
+ nn.LSTM(
653
+ d_model + sty_dim,
654
+ d_model // 2,
655
+ num_layers=1,
656
+ batch_first=True,
657
+ bidirectional=True,
658
+ dropout=dropout,
659
+ )
660
+ )
661
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
662
+
663
+ self.dropout = dropout
664
+ self.d_model = d_model
665
+ self.sty_dim = sty_dim
666
+
667
+ def forward(self, x, style, text_lengths, m):
668
+ masks = m.to(text_lengths.device)
669
+
670
+ x = x.permute(2, 0, 1)
671
+ s = style.expand(x.shape[0], x.shape[1], -1)
672
+ x = torch.cat([x, s], axis=-1)
673
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
674
+
675
+ x = x.transpose(0, 1)
676
+ input_lengths = text_lengths.cpu().numpy()
677
+ x = x.transpose(-1, -2)
678
+
679
+ for block in self.lstms:
680
+ if isinstance(block, AdaLayerNorm):
681
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
682
+ x = torch.cat([x, s.permute(1, -1, 0)], axis=1)
683
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
684
+ else:
685
+ x = x.transpose(-1, -2)
686
+ x = nn.utils.rnn.pack_padded_sequence(
687
+ x, input_lengths, batch_first=True, enforce_sorted=False
688
+ )
689
+ block.flatten_parameters()
690
+ x, _ = block(x)
691
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
692
+ x = F.dropout(x, p=self.dropout, training=self.training)
693
+ x = x.transpose(-1, -2)
694
+
695
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]])
696
+
697
+ x_pad[:, :, : x.shape[-1]] = x
698
+ x = x_pad.to(x.device)
699
+
700
+ return x.transpose(-1, -2)
701
+
702
+ def inference(self, x, style):
703
+ x = self.embedding(x.transpose(-1, -2)) * math.sqrt(self.d_model)
704
+ style = style.expand(x.shape[0], x.shape[1], -1)
705
+ x = torch.cat([x, style], axis=-1)
706
+ src = self.pos_encoder(x)
707
+ output = self.transformer_encoder(src).transpose(0, 1)
708
+ return output
709
+
710
+ def length_to_mask(self, lengths):
711
+ mask = (
712
+ torch.arange(lengths.max())
713
+ .unsqueeze(0)
714
+ .expand(lengths.shape[0], -1)
715
+ .type_as(lengths)
716
+ )
717
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
718
+ return mask
719
+
720
+
721
+ def load_F0_models(path):
722
+ # load F0 model
723
+
724
+ F0_model = JDCNet(num_class=1, seq_len=192)
725
+ params = torch.load(path, map_location="cpu")["net"]
726
+ F0_model.load_state_dict(params)
727
+ _ = F0_model.train()
728
+
729
+ return F0_model
730
+
731
+
732
+ def load_ASR_models(ASR_MODEL_PATH, ASR_MODEL_CONFIG):
733
+ # load ASR model
734
+ def _load_config(path):
735
+ with open(path) as f:
736
+ config = yaml.safe_load(f)
737
+ model_config = config["model_params"]
738
+ return model_config
739
+
740
+ def _load_model(model_config, model_path):
741
+ model = ASRCNN(**model_config)
742
+ params = torch.load(model_path, map_location="cpu")["model"]
743
+ model.load_state_dict(params)
744
+ return model
745
+
746
+ asr_model_config = _load_config(ASR_MODEL_CONFIG)
747
+ asr_model = _load_model(asr_model_config, ASR_MODEL_PATH)
748
+ _ = asr_model.train()
749
+
750
+ return asr_model
751
+
752
+
753
+ def build_model(args, text_aligner, pitch_extractor, bert):
754
+ assert args.decoder.type in ["istftnet", "hifigan"], "Decoder type unknown"
755
+
756
+ if args.decoder.type == "istftnet":
757
+ from Modules.istftnet import Decoder
758
+
759
+ decoder = Decoder(
760
+ dim_in=args.hidden_dim,
761
+ style_dim=args.style_dim,
762
+ dim_out=args.n_mels,
763
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
764
+ upsample_rates=args.decoder.upsample_rates,
765
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
766
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
767
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
768
+ gen_istft_n_fft=args.decoder.gen_istft_n_fft,
769
+ gen_istft_hop_size=args.decoder.gen_istft_hop_size,
770
+ )
771
+ else:
772
+ from Modules.hifigan import Decoder
773
+
774
+ decoder = Decoder(
775
+ dim_in=args.hidden_dim,
776
+ style_dim=args.style_dim,
777
+ dim_out=args.n_mels,
778
+ resblock_kernel_sizes=args.decoder.resblock_kernel_sizes,
779
+ upsample_rates=args.decoder.upsample_rates,
780
+ upsample_initial_channel=args.decoder.upsample_initial_channel,
781
+ resblock_dilation_sizes=args.decoder.resblock_dilation_sizes,
782
+ upsample_kernel_sizes=args.decoder.upsample_kernel_sizes,
783
+ )
784
+
785
+ text_encoder = TextEncoder(
786
+ channels=args.hidden_dim,
787
+ kernel_size=5,
788
+ depth=args.n_layer,
789
+ n_symbols=args.n_token,
790
+ )
791
+
792
+ predictor = ProsodyPredictor(
793
+ style_dim=args.style_dim,
794
+ d_hid=args.hidden_dim,
795
+ nlayers=args.n_layer,
796
+ max_dur=args.max_dur,
797
+ dropout=args.dropout,
798
+ )
799
+
800
+ style_encoder = StyleEncoder(
801
+ dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
802
+ ) # acoustic style encoder
803
+ predictor_encoder = StyleEncoder(
804
+ dim_in=args.dim_in, style_dim=args.style_dim, max_conv_dim=args.hidden_dim
805
+ ) # prosodic style encoder
806
+
807
+ # define diffusion model
808
+ if args.multispeaker:
809
+ transformer = StyleTransformer1d(
810
+ channels=args.style_dim * 2,
811
+ context_embedding_features=bert.config.hidden_size,
812
+ context_features=args.style_dim * 2,
813
+ **args.diffusion.transformer
814
+ )
815
+ else:
816
+ transformer = Transformer1d(
817
+ channels=args.style_dim * 2,
818
+ context_embedding_features=bert.config.hidden_size,
819
+ **args.diffusion.transformer
820
+ )
821
+
822
+ diffusion = AudioDiffusionConditional(
823
+ in_channels=1,
824
+ embedding_max_length=bert.config.max_position_embeddings,
825
+ embedding_features=bert.config.hidden_size,
826
+ embedding_mask_proba=args.diffusion.embedding_mask_proba, # Conditional dropout of batch elements,
827
+ channels=args.style_dim * 2,
828
+ context_features=args.style_dim * 2,
829
+ )
830
+
831
+ diffusion.diffusion = KDiffusion(
832
+ net=diffusion.unet,
833
+ sigma_distribution=LogNormalDistribution(
834
+ mean=args.diffusion.dist.mean, std=args.diffusion.dist.std
835
+ ),
836
+ sigma_data=args.diffusion.dist.sigma_data, # a placeholder, will be changed dynamically when start training diffusion model
837
+ dynamic_threshold=0.0,
838
+ )
839
+ diffusion.diffusion.net = transformer
840
+ diffusion.unet = transformer
841
+
842
+ nets = Munch(
843
+ bert=bert,
844
+ bert_encoder=nn.Linear(bert.config.hidden_size, args.hidden_dim),
845
+ predictor=predictor,
846
+ decoder=decoder,
847
+ text_encoder=text_encoder,
848
+ predictor_encoder=predictor_encoder,
849
+ style_encoder=style_encoder,
850
+ diffusion=diffusion,
851
+ text_aligner=text_aligner,
852
+ pitch_extractor=pitch_extractor,
853
+ mpd=MultiPeriodDiscriminator(),
854
+ msd=MultiResSpecDiscriminator(),
855
+ # slm discriminator head
856
+ wd=WavLMDiscriminator(
857
+ args.slm.hidden, args.slm.nlayers, args.slm.initial_channel
858
+ ),
859
+ )
860
+
861
+ return nets
862
+
863
+
864
+ def load_checkpoint(model, optimizer, path, load_only_params=True, ignore_modules=[]):
865
+ state = torch.load(path, map_location="cpu")
866
+ params = state["net"]
867
+ for key in model:
868
+ if key in params and key not in ignore_modules:
869
+ print("%s loaded" % key)
870
+ model[key].load_state_dict(params[key], strict=False)
871
+ _ = [model[key].eval() for key in model]
872
+
873
+ if not load_only_params:
874
+ epoch = state["epoch"]
875
+ iters = state["iters"]
876
+ optimizer.load_state_dict(state["optimizer"])
877
+ else:
878
+ epoch = 0
879
+ iters = 0
880
+
881
+ return model, optimizer, epoch, iters
optimizers.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding:utf-8
2
+ import os, sys
3
+ import os.path as osp
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ from torch.optim import Optimizer
8
+ from functools import reduce
9
+ from torch.optim import AdamW
10
+
11
+
12
+ class MultiOptimizer:
13
+ def __init__(self, optimizers={}, schedulers={}):
14
+ self.optimizers = optimizers
15
+ self.schedulers = schedulers
16
+ self.keys = list(optimizers.keys())
17
+ self.param_groups = reduce(
18
+ lambda x, y: x + y, [v.param_groups for v in self.optimizers.values()]
19
+ )
20
+
21
+ def state_dict(self):
22
+ state_dicts = [(key, self.optimizers[key].state_dict()) for key in self.keys]
23
+ return state_dicts
24
+
25
+ def load_state_dict(self, state_dict):
26
+ for key, val in state_dict:
27
+ try:
28
+ self.optimizers[key].load_state_dict(val)
29
+ except:
30
+ print("Unloaded %s" % key)
31
+
32
+ def step(self, key=None, scaler=None):
33
+ keys = [key] if key is not None else self.keys
34
+ _ = [self._step(key, scaler) for key in keys]
35
+
36
+ def _step(self, key, scaler=None):
37
+ if scaler is not None:
38
+ scaler.step(self.optimizers[key])
39
+ scaler.update()
40
+ else:
41
+ self.optimizers[key].step()
42
+
43
+ def zero_grad(self, key=None):
44
+ if key is not None:
45
+ self.optimizers[key].zero_grad()
46
+ else:
47
+ _ = [self.optimizers[key].zero_grad() for key in self.keys]
48
+
49
+ def scheduler(self, *args, key=None):
50
+ if key is not None:
51
+ self.schedulers[key].step(*args)
52
+ else:
53
+ _ = [self.schedulers[key].step(*args) for key in self.keys]
54
+
55
+
56
+ def define_scheduler(optimizer, params):
57
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
58
+ optimizer,
59
+ max_lr=params.get("max_lr", 2e-4),
60
+ epochs=params.get("epochs", 200),
61
+ steps_per_epoch=params.get("steps_per_epoch", 1000),
62
+ pct_start=params.get("pct_start", 0.0),
63
+ div_factor=1,
64
+ final_div_factor=1,
65
+ )
66
+
67
+ return scheduler
68
+
69
+
70
+ def build_optimizer(parameters_dict, scheduler_params_dict, lr):
71
+ optim = dict(
72
+ [
73
+ (key, AdamW(params, lr=lr, weight_decay=1e-4, betas=(0.0, 0.99), eps=1e-9))
74
+ for key, params in parameters_dict.items()
75
+ ]
76
+ )
77
+
78
+ schedulers = dict(
79
+ [
80
+ (key, define_scheduler(opt, scheduler_params_dict[key]))
81
+ for key, opt in optim.items()
82
+ ]
83
+ )
84
+
85
+ multi_optim = MultiOptimizer(optim, schedulers)
86
+ return multi_optim
reference_audio.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d25b4950ec39cec5a00f5061491ad0b3606edc6618a54adc59663bfd6e6ab55e
3
+ size 2917622
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SoundFile
2
+ torchaudio
3
+ munch
4
+ torch
5
+ pydub
6
+ pyyaml
7
+ librosa
8
+ nltk
9
+ matplotlib
10
+ accelerate
11
+ transformers
12
+ einops
13
+ einops-exts
14
+ tqdm
15
+ typing
16
+ typing-extensions
17
+ git+https://github.com/resemble-ai/monotonic_align.git
18
+ scipy
19
+ deep-phonemizer
20
+ cached-path
21
+ gradio
styletts2importable.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cached_path import cached_path
2
+
3
+ from dp.phonemizer import Phonemizer
4
+ print("NLTK")
5
+ import nltk
6
+ nltk.download('punkt')
7
+ print("SCIPY")
8
+ from scipy.io.wavfile import write
9
+ print("TORCH STUFF")
10
+ import torch
11
+ print("START")
12
+ torch.manual_seed(0)
13
+ torch.backends.cudnn.benchmark = False
14
+ torch.backends.cudnn.deterministic = True
15
+
16
+ import random
17
+ random.seed(0)
18
+
19
+ import numpy as np
20
+ np.random.seed(0)
21
+
22
+ # load packages
23
+ import time
24
+ import random
25
+ import yaml
26
+ from munch import Munch
27
+ import numpy as np
28
+ import torch
29
+ from torch import nn
30
+ import torch.nn.functional as F
31
+ import torchaudio
32
+ import librosa
33
+ from nltk.tokenize import word_tokenize
34
+
35
+ from models import *
36
+ from utils import *
37
+ from text_utils import TextCleaner
38
+ textclenaer = TextCleaner()
39
+
40
+
41
+ to_mel = torchaudio.transforms.MelSpectrogram(
42
+ n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
43
+ mean, std = -4, 4
44
+
45
+ def length_to_mask(lengths):
46
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
47
+ mask = torch.gt(mask+1, lengths.unsqueeze(1))
48
+ return mask
49
+
50
+ def preprocess(wave):
51
+ wave_tensor = torch.from_numpy(wave).float()
52
+ mel_tensor = to_mel(wave_tensor)
53
+ mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
54
+ return mel_tensor
55
+
56
+ def compute_style(path):
57
+ wave, sr = librosa.load(path, sr=24000)
58
+ audio, index = librosa.effects.trim(wave, top_db=30)
59
+ if sr != 24000:
60
+ audio = librosa.resample(audio, sr, 24000)
61
+ mel_tensor = preprocess(audio).to(device)
62
+
63
+ with torch.no_grad():
64
+ ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
65
+ ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
66
+
67
+ return torch.cat([ref_s, ref_p], dim=1)
68
+
69
+ device = 'cpu'
70
+ if torch.cuda.is_available():
71
+ device = 'cuda'
72
+ elif torch.backends.mps.is_available():
73
+ print("MPS would be available but cannot be used rn")
74
+ # device = 'mps'
75
+
76
+
77
+ # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
78
+ phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
79
+
80
+
81
+ # config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
82
+ config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml"))))
83
+
84
+ # load pretrained ASR model
85
+ ASR_config = config.get('ASR_config', False)
86
+ ASR_path = config.get('ASR_path', False)
87
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
88
+
89
+ # load pretrained F0 model
90
+ F0_path = config.get('F0_path', False)
91
+ pitch_extractor = load_F0_models(F0_path)
92
+
93
+ # load BERT model
94
+ from Utils.PLBERT.util import load_plbert
95
+ BERT_path = config.get('PLBERT_dir', False)
96
+ plbert = load_plbert(BERT_path)
97
+
98
+ model_params = recursive_munch(config['model_params'])
99
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
100
+ _ = [model[key].eval() for key in model]
101
+ _ = [model[key].to(device) for key in model]
102
+
103
+ # params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
104
+ params_whole = torch.load(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/epochs_2nd_00020.pth")), map_location='cpu')
105
+ params = params_whole['net']
106
+
107
+ for key in model:
108
+ if key in params:
109
+ print('%s loaded' % key)
110
+ try:
111
+ model[key].load_state_dict(params[key])
112
+ except:
113
+ from collections import OrderedDict
114
+ state_dict = params[key]
115
+ new_state_dict = OrderedDict()
116
+ for k, v in state_dict.items():
117
+ name = k[7:] # remove `module.`
118
+ new_state_dict[name] = v
119
+ # load params
120
+ model[key].load_state_dict(new_state_dict, strict=False)
121
+ # except:
122
+ # _load(params[key], model[key])
123
+ _ = [model[key].eval() for key in model]
124
+
125
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
126
+
127
+ sampler = DiffusionSampler(
128
+ model.diffusion.diffusion,
129
+ sampler=ADPM2Sampler(),
130
+ sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
131
+ clamp=False
132
+ )
133
+
134
+ def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
135
+ text = text.strip()
136
+ ps = phonemizer([text], lang='en_us')
137
+ ps = word_tokenize(ps[0])
138
+ ps = ' '.join(ps)
139
+ tokens = textclenaer(ps)
140
+ tokens.insert(0, 0)
141
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
142
+
143
+ with torch.no_grad():
144
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
145
+ text_mask = length_to_mask(input_lengths).to(device)
146
+
147
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
148
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
149
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
150
+
151
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
152
+ embedding=bert_dur,
153
+ embedding_scale=embedding_scale,
154
+ features=ref_s, # reference from the same speaker as the embedding
155
+ num_steps=diffusion_steps).squeeze(1)
156
+
157
+
158
+ s = s_pred[:, 128:]
159
+ ref = s_pred[:, :128]
160
+
161
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
162
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
163
+
164
+ d = model.predictor.text_encoder(d_en,
165
+ s, input_lengths, text_mask)
166
+
167
+ x, _ = model.predictor.lstm(d)
168
+ duration = model.predictor.duration_proj(x)
169
+
170
+ duration = torch.sigmoid(duration).sum(axis=-1)
171
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
172
+
173
+
174
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
175
+ c_frame = 0
176
+ for i in range(pred_aln_trg.size(0)):
177
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
178
+ c_frame += int(pred_dur[i].data)
179
+
180
+ # encode prosody
181
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
182
+ if model_params.decoder.type == "hifigan":
183
+ asr_new = torch.zeros_like(en)
184
+ asr_new[:, :, 0] = en[:, :, 0]
185
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
186
+ en = asr_new
187
+
188
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
189
+
190
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
191
+ if model_params.decoder.type == "hifigan":
192
+ asr_new = torch.zeros_like(asr)
193
+ asr_new[:, :, 0] = asr[:, :, 0]
194
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
195
+ asr = asr_new
196
+
197
+ out = model.decoder(asr,
198
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
199
+
200
+
201
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
202
+
203
+ def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1):
204
+ text = text.strip()
205
+ ps = phonemizer([text], lang='en_us')
206
+ ps = word_tokenize(ps[0])
207
+ ps = ' '.join(ps)
208
+ ps = ps.replace('``', '"')
209
+ ps = ps.replace("''", '"')
210
+
211
+ tokens = textclenaer(ps)
212
+ tokens.insert(0, 0)
213
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
214
+
215
+ with torch.no_grad():
216
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
217
+ text_mask = length_to_mask(input_lengths).to(device)
218
+
219
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
220
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
221
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
222
+
223
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
224
+ embedding=bert_dur,
225
+ embedding_scale=embedding_scale,
226
+ features=ref_s, # reference from the same speaker as the embedding
227
+ num_steps=diffusion_steps).squeeze(1)
228
+
229
+ if s_prev is not None:
230
+ # convex combination of previous and current style
231
+ s_pred = t * s_prev + (1 - t) * s_pred
232
+
233
+ s = s_pred[:, 128:]
234
+ ref = s_pred[:, :128]
235
+
236
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
237
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
238
+
239
+ s_pred = torch.cat([ref, s], dim=-1)
240
+
241
+ d = model.predictor.text_encoder(d_en,
242
+ s, input_lengths, text_mask)
243
+
244
+ x, _ = model.predictor.lstm(d)
245
+ duration = model.predictor.duration_proj(x)
246
+
247
+ duration = torch.sigmoid(duration).sum(axis=-1)
248
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
249
+
250
+
251
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
252
+ c_frame = 0
253
+ for i in range(pred_aln_trg.size(0)):
254
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
255
+ c_frame += int(pred_dur[i].data)
256
+
257
+ # encode prosody
258
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
259
+ if model_params.decoder.type == "hifigan":
260
+ asr_new = torch.zeros_like(en)
261
+ asr_new[:, :, 0] = en[:, :, 0]
262
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
263
+ en = asr_new
264
+
265
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
266
+
267
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
268
+ if model_params.decoder.type == "hifigan":
269
+ asr_new = torch.zeros_like(asr)
270
+ asr_new[:, :, 0] = asr[:, :, 0]
271
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
272
+ asr = asr_new
273
+
274
+ out = model.decoder(asr,
275
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
276
+
277
+
278
+ return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
279
+
280
+ def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1):
281
+ text = text.strip()
282
+ ps = phonemizer([text], lang='en_us')
283
+ ps = word_tokenize(ps[0])
284
+ ps = ' '.join(ps)
285
+
286
+ tokens = textclenaer(ps)
287
+ tokens.insert(0, 0)
288
+ tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
289
+
290
+ ref_text = ref_text.strip()
291
+ ps = phonemizer([ref_text], lang='en_us')
292
+ ps = word_tokenize(ps[0])
293
+ ps = ' '.join(ps)
294
+
295
+ ref_tokens = textclenaer(ps)
296
+ ref_tokens.insert(0, 0)
297
+ ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
298
+
299
+
300
+ with torch.no_grad():
301
+ input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
302
+ text_mask = length_to_mask(input_lengths).to(device)
303
+
304
+ t_en = model.text_encoder(tokens, input_lengths, text_mask)
305
+ bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
306
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
307
+
308
+ ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
309
+ ref_text_mask = length_to_mask(ref_input_lengths).to(device)
310
+ ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
311
+ s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
312
+ embedding=bert_dur,
313
+ embedding_scale=embedding_scale,
314
+ features=ref_s, # reference from the same speaker as the embedding
315
+ num_steps=diffusion_steps).squeeze(1)
316
+
317
+
318
+ s = s_pred[:, 128:]
319
+ ref = s_pred[:, :128]
320
+
321
+ ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
322
+ s = beta * s + (1 - beta) * ref_s[:, 128:]
323
+
324
+ d = model.predictor.text_encoder(d_en,
325
+ s, input_lengths, text_mask)
326
+
327
+ x, _ = model.predictor.lstm(d)
328
+ duration = model.predictor.duration_proj(x)
329
+
330
+ duration = torch.sigmoid(duration).sum(axis=-1)
331
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
332
+
333
+
334
+ pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
335
+ c_frame = 0
336
+ for i in range(pred_aln_trg.size(0)):
337
+ pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
338
+ c_frame += int(pred_dur[i].data)
339
+
340
+ # encode prosody
341
+ en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
342
+ if model_params.decoder.type == "hifigan":
343
+ asr_new = torch.zeros_like(en)
344
+ asr_new[:, :, 0] = en[:, :, 0]
345
+ asr_new[:, :, 1:] = en[:, :, 0:-1]
346
+ en = asr_new
347
+
348
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
349
+
350
+ asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
351
+ if model_params.decoder.type == "hifigan":
352
+ asr_new = torch.zeros_like(asr)
353
+ asr_new[:, :, 0] = asr[:, :, 0]
354
+ asr_new[:, :, 1:] = asr[:, :, 0:-1]
355
+ asr = asr_new
356
+
357
+ out = model.decoder(asr,
358
+ F0_pred, N_pred, ref.squeeze().unsqueeze(0))
359
+
360
+
361
+ return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
text_utils.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IPA Phonemizer: https://github.com/bootphon/phonemizer
2
+
3
+ _pad = "$"
4
+ _punctuation = ';:,.!?¡¿—…"«»“” '
5
+ _letters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
6
+ _letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
7
+
8
+ # Export all symbols:
9
+ symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)
10
+
11
+ dicts = {}
12
+ for i in range(len((symbols))):
13
+ dicts[symbols[i]] = i
14
+
15
+
16
+ class TextCleaner:
17
+ def __init__(self, dummy=None):
18
+ self.word_index_dictionary = dicts
19
+ print(len(dicts))
20
+
21
+ def __call__(self, text):
22
+ indexes = []
23
+ for char in text:
24
+ try:
25
+ indexes.append(self.word_index_dictionary[char])
26
+ except KeyError:
27
+ print(text)
28
+ return indexes
train_finetune.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+
16
+ warnings.simplefilter("ignore")
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+
35
+ # simple fix for dataparallel that allows access to class attributes
36
+ class MyDataParallel(torch.nn.DataParallel):
37
+ def __getattr__(self, name):
38
+ try:
39
+ return super().__getattr__(name)
40
+ except AttributeError:
41
+ return getattr(self.module, name)
42
+
43
+
44
+ import logging
45
+ from logging import StreamHandler
46
+
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.DEBUG)
49
+ handler = StreamHandler()
50
+ handler.setLevel(logging.DEBUG)
51
+ logger.addHandler(handler)
52
+
53
+
54
+ @click.command()
55
+ @click.option("-p", "--config_path", default="Configs/config_ft.yml", type=str)
56
+ def main(config_path):
57
+ config = yaml.safe_load(open(config_path))
58
+
59
+ log_dir = config["log_dir"]
60
+ if not osp.exists(log_dir):
61
+ os.makedirs(log_dir, exist_ok=True)
62
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
63
+ writer = SummaryWriter(log_dir + "/tensorboard")
64
+
65
+ # write logs
66
+ file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
67
+ file_handler.setLevel(logging.DEBUG)
68
+ file_handler.setFormatter(
69
+ logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
70
+ )
71
+ logger.addHandler(file_handler)
72
+
73
+ batch_size = config.get("batch_size", 10)
74
+
75
+ epochs = config.get("epochs", 200)
76
+ save_freq = config.get("save_freq", 2)
77
+ log_interval = config.get("log_interval", 10)
78
+ saving_epoch = config.get("save_freq", 2)
79
+
80
+ data_params = config.get("data_params", None)
81
+ sr = config["preprocess_params"].get("sr", 24000)
82
+ train_path = data_params["train_data"]
83
+ val_path = data_params["val_data"]
84
+ root_path = data_params["root_path"]
85
+ min_length = data_params["min_length"]
86
+ OOD_data = data_params["OOD_data"]
87
+
88
+ max_len = config.get("max_len", 200)
89
+
90
+ loss_params = Munch(config["loss_params"])
91
+ diff_epoch = loss_params.diff_epoch
92
+ joint_epoch = loss_params.joint_epoch
93
+
94
+ optimizer_params = Munch(config["optimizer_params"])
95
+
96
+ train_list, val_list = get_data_path_list(train_path, val_path)
97
+ device = "cuda"
98
+
99
+ train_dataloader = build_dataloader(
100
+ train_list,
101
+ root_path,
102
+ OOD_data=OOD_data,
103
+ min_length=min_length,
104
+ batch_size=batch_size,
105
+ num_workers=2,
106
+ dataset_config={},
107
+ device=device,
108
+ )
109
+
110
+ val_dataloader = build_dataloader(
111
+ val_list,
112
+ root_path,
113
+ OOD_data=OOD_data,
114
+ min_length=min_length,
115
+ batch_size=batch_size,
116
+ validation=True,
117
+ num_workers=0,
118
+ device=device,
119
+ dataset_config={},
120
+ )
121
+
122
+ # load pretrained ASR model
123
+ ASR_config = config.get("ASR_config", False)
124
+ ASR_path = config.get("ASR_path", False)
125
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
126
+
127
+ # load pretrained F0 model
128
+ F0_path = config.get("F0_path", False)
129
+ pitch_extractor = load_F0_models(F0_path)
130
+
131
+ # load PL-BERT model
132
+ BERT_path = config.get("PLBERT_dir", False)
133
+ plbert = load_plbert(BERT_path)
134
+
135
+ # build model
136
+ model_params = recursive_munch(config["model_params"])
137
+ multispeaker = model_params.multispeaker
138
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
139
+ _ = [model[key].to(device) for key in model]
140
+
141
+ # DP
142
+ for key in model:
143
+ if key != "mpd" and key != "msd" and key != "wd":
144
+ model[key] = MyDataParallel(model[key])
145
+
146
+ start_epoch = 0
147
+ iters = 0
148
+
149
+ load_pretrained = config.get("pretrained_model", "") != "" and config.get(
150
+ "second_stage_load_pretrained", False
151
+ )
152
+
153
+ if not load_pretrained:
154
+ if config.get("first_stage_path", "") != "":
155
+ first_stage_path = osp.join(
156
+ log_dir, config.get("first_stage_path", "first_stage.pth")
157
+ )
158
+ print("Loading the first stage model at %s ..." % first_stage_path)
159
+ model, _, start_epoch, iters = load_checkpoint(
160
+ model,
161
+ None,
162
+ first_stage_path,
163
+ load_only_params=True,
164
+ ignore_modules=[
165
+ "bert",
166
+ "bert_encoder",
167
+ "predictor",
168
+ "predictor_encoder",
169
+ "msd",
170
+ "mpd",
171
+ "wd",
172
+ "diffusion",
173
+ ],
174
+ ) # keep starting epoch for tensorboard log
175
+
176
+ # these epochs should be counted from the start epoch
177
+ diff_epoch += start_epoch
178
+ joint_epoch += start_epoch
179
+ epochs += start_epoch
180
+
181
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
182
+ else:
183
+ raise ValueError("You need to specify the path to the first stage model.")
184
+
185
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
186
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
187
+ wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
188
+
189
+ gl = MyDataParallel(gl)
190
+ dl = MyDataParallel(dl)
191
+ wl = MyDataParallel(wl)
192
+
193
+ sampler = DiffusionSampler(
194
+ model.diffusion.diffusion,
195
+ sampler=ADPM2Sampler(),
196
+ sigma_schedule=KarrasSchedule(
197
+ sigma_min=0.0001, sigma_max=3.0, rho=9.0
198
+ ), # empirical parameters
199
+ clamp=False,
200
+ )
201
+
202
+ scheduler_params = {
203
+ "max_lr": optimizer_params.lr,
204
+ "pct_start": float(0),
205
+ "epochs": epochs,
206
+ "steps_per_epoch": len(train_dataloader),
207
+ }
208
+ scheduler_params_dict = {key: scheduler_params.copy() for key in model}
209
+ scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2
210
+ scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2
211
+ scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2
212
+
213
+ optimizer = build_optimizer(
214
+ {key: model[key].parameters() for key in model},
215
+ scheduler_params_dict=scheduler_params_dict,
216
+ lr=optimizer_params.lr,
217
+ )
218
+
219
+ # adjust BERT learning rate
220
+ for g in optimizer.optimizers["bert"].param_groups:
221
+ g["betas"] = (0.9, 0.99)
222
+ g["lr"] = optimizer_params.bert_lr
223
+ g["initial_lr"] = optimizer_params.bert_lr
224
+ g["min_lr"] = 0
225
+ g["weight_decay"] = 0.01
226
+
227
+ # adjust acoustic module learning rate
228
+ for module in ["decoder", "style_encoder"]:
229
+ for g in optimizer.optimizers[module].param_groups:
230
+ g["betas"] = (0.0, 0.99)
231
+ g["lr"] = optimizer_params.ft_lr
232
+ g["initial_lr"] = optimizer_params.ft_lr
233
+ g["min_lr"] = 0
234
+ g["weight_decay"] = 1e-4
235
+
236
+ # load models if there is a model
237
+ if load_pretrained:
238
+ model, optimizer, start_epoch, iters = load_checkpoint(
239
+ model,
240
+ optimizer,
241
+ config["pretrained_model"],
242
+ load_only_params=config.get("load_only_params", True),
243
+ )
244
+
245
+ n_down = model.text_aligner.n_down
246
+
247
+ best_loss = float("inf") # best test loss
248
+ loss_train_record = list([])
249
+ loss_test_record = list([])
250
+ iters = 0
251
+
252
+ criterion = nn.L1Loss() # F0 loss (regression)
253
+ torch.cuda.empty_cache()
254
+
255
+ stft_loss = MultiResolutionSTFTLoss().to(device)
256
+
257
+ print("BERT", optimizer.optimizers["bert"])
258
+ print("decoder", optimizer.optimizers["decoder"])
259
+
260
+ start_ds = False
261
+
262
+ running_std = []
263
+
264
+ slmadv_params = Munch(config["slmadv_params"])
265
+ slmadv = SLMAdversarialLoss(
266
+ model,
267
+ wl,
268
+ sampler,
269
+ slmadv_params.min_len,
270
+ slmadv_params.max_len,
271
+ batch_percentage=slmadv_params.batch_percentage,
272
+ skip_update=slmadv_params.iter,
273
+ sig=slmadv_params.sig,
274
+ )
275
+
276
+ for epoch in range(start_epoch, epochs):
277
+ running_loss = 0
278
+ start_time = time.time()
279
+
280
+ _ = [model[key].eval() for key in model]
281
+
282
+ model.text_aligner.train()
283
+ model.text_encoder.train()
284
+
285
+ model.predictor.train()
286
+ model.bert_encoder.train()
287
+ model.bert.train()
288
+ model.msd.train()
289
+ model.mpd.train()
290
+
291
+ for i, batch in enumerate(train_dataloader):
292
+ waves = batch[0]
293
+ batch = [b.to(device) for b in batch[1:]]
294
+ (
295
+ texts,
296
+ input_lengths,
297
+ ref_texts,
298
+ ref_lengths,
299
+ mels,
300
+ mel_input_length,
301
+ ref_mels,
302
+ ) = batch
303
+ with torch.no_grad():
304
+ mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
305
+ mel_mask = length_to_mask(mel_input_length).to(device)
306
+ text_mask = length_to_mask(input_lengths).to(texts.device)
307
+
308
+ # compute reference styles
309
+ if multispeaker and epoch >= diff_epoch:
310
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
311
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
312
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
313
+
314
+ try:
315
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
316
+ s2s_attn = s2s_attn.transpose(-1, -2)
317
+ s2s_attn = s2s_attn[..., 1:]
318
+ s2s_attn = s2s_attn.transpose(-1, -2)
319
+ except:
320
+ continue
321
+
322
+ mask_ST = mask_from_lens(
323
+ s2s_attn, input_lengths, mel_input_length // (2**n_down)
324
+ )
325
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
326
+
327
+ # encode
328
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
329
+
330
+ # 50% of chance of using monotonic version
331
+ if bool(random.getrandbits(1)):
332
+ asr = t_en @ s2s_attn
333
+ else:
334
+ asr = t_en @ s2s_attn_mono
335
+
336
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
337
+
338
+ # compute the style of the entire utterance
339
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
340
+ ss = []
341
+ gs = []
342
+ for bib in range(len(mel_input_length)):
343
+ mel_length = int(mel_input_length[bib].item())
344
+ mel = mels[bib, :, : mel_input_length[bib]]
345
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
346
+ ss.append(s)
347
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
348
+ gs.append(s)
349
+
350
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
351
+ gs = torch.stack(gs).squeeze() # global acoustic styles
352
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
353
+
354
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
355
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
356
+
357
+ # denoiser training
358
+ if epoch >= diff_epoch:
359
+ num_steps = np.random.randint(3, 5)
360
+
361
+ if model_params.diffusion.dist.estimate_sigma_data:
362
+ model.diffusion.module.diffusion.sigma_data = (
363
+ s_trg.std(axis=-1).mean().item()
364
+ ) # batch-wise std estimation
365
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
366
+
367
+ if multispeaker:
368
+ s_preds = sampler(
369
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
370
+ embedding=bert_dur,
371
+ embedding_scale=1,
372
+ features=ref, # reference from the same speaker as the embedding
373
+ embedding_mask_proba=0.1,
374
+ num_steps=num_steps,
375
+ ).squeeze(1)
376
+ loss_diff = model.diffusion(
377
+ s_trg.unsqueeze(1), embedding=bert_dur, features=ref
378
+ ).mean() # EDM loss
379
+ loss_sty = F.l1_loss(
380
+ s_preds, s_trg.detach()
381
+ ) # style reconstruction loss
382
+ else:
383
+ s_preds = sampler(
384
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
385
+ embedding=bert_dur,
386
+ embedding_scale=1,
387
+ embedding_mask_proba=0.1,
388
+ num_steps=num_steps,
389
+ ).squeeze(1)
390
+ loss_diff = model.diffusion.module.diffusion(
391
+ s_trg.unsqueeze(1), embedding=bert_dur
392
+ ).mean() # EDM loss
393
+ loss_sty = F.l1_loss(
394
+ s_preds, s_trg.detach()
395
+ ) # style reconstruction loss
396
+ else:
397
+ loss_sty = 0
398
+ loss_diff = 0
399
+
400
+ s_loss = 0
401
+
402
+ d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask)
403
+
404
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
405
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
406
+ en = []
407
+ gt = []
408
+ p_en = []
409
+ wav = []
410
+ st = []
411
+
412
+ for bib in range(len(mel_input_length)):
413
+ mel_length = int(mel_input_length[bib].item() / 2)
414
+
415
+ random_start = np.random.randint(0, mel_length - mel_len)
416
+ en.append(asr[bib, :, random_start : random_start + mel_len])
417
+ p_en.append(p[bib, :, random_start : random_start + mel_len])
418
+ gt.append(
419
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
420
+ )
421
+
422
+ y = waves[bib][
423
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
424
+ ]
425
+ wav.append(torch.from_numpy(y).to(device))
426
+
427
+ # style reference (better to be different from the GT)
428
+ random_start = np.random.randint(0, mel_length - mel_len_st)
429
+ st.append(
430
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
431
+ )
432
+
433
+ wav = torch.stack(wav).float().detach()
434
+
435
+ en = torch.stack(en)
436
+ p_en = torch.stack(p_en)
437
+ gt = torch.stack(gt).detach()
438
+ st = torch.stack(st).detach()
439
+
440
+ if gt.size(-1) < 80:
441
+ continue
442
+
443
+ s = model.style_encoder(gt.unsqueeze(1))
444
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
445
+
446
+ with torch.no_grad():
447
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
448
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
449
+
450
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
451
+
452
+ y_rec_gt = wav.unsqueeze(1)
453
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
454
+
455
+ wav = y_rec_gt
456
+
457
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
458
+
459
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
460
+
461
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
462
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
463
+
464
+ optimizer.zero_grad()
465
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
466
+ d_loss.backward()
467
+ optimizer.step("msd")
468
+ optimizer.step("mpd")
469
+
470
+ # generator loss
471
+ optimizer.zero_grad()
472
+
473
+ loss_mel = stft_loss(y_rec, wav)
474
+ loss_gen_all = gl(wav, y_rec).mean()
475
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
476
+
477
+ loss_ce = 0
478
+ loss_dur = 0
479
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
480
+ _s2s_pred = _s2s_pred[:_text_length, :]
481
+ _text_input = _text_input[:_text_length].long()
482
+ _s2s_trg = torch.zeros_like(_s2s_pred)
483
+ for p in range(_s2s_trg.shape[0]):
484
+ _s2s_trg[p, : _text_input[p]] = 1
485
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
486
+
487
+ loss_dur += F.l1_loss(
488
+ _dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1]
489
+ )
490
+ loss_ce += F.binary_cross_entropy_with_logits(
491
+ _s2s_pred.flatten(), _s2s_trg.flatten()
492
+ )
493
+
494
+ loss_ce /= texts.size(0)
495
+ loss_dur /= texts.size(0)
496
+
497
+ loss_s2s = 0
498
+ for _s2s_pred, _text_input, _text_length in zip(
499
+ s2s_pred, texts, input_lengths
500
+ ):
501
+ loss_s2s += F.cross_entropy(
502
+ _s2s_pred[:_text_length], _text_input[:_text_length]
503
+ )
504
+ loss_s2s /= texts.size(0)
505
+
506
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
507
+
508
+ g_loss = (
509
+ loss_params.lambda_mel * loss_mel
510
+ + loss_params.lambda_F0 * loss_F0_rec
511
+ + loss_params.lambda_ce * loss_ce
512
+ + loss_params.lambda_norm * loss_norm_rec
513
+ + loss_params.lambda_dur * loss_dur
514
+ + loss_params.lambda_gen * loss_gen_all
515
+ + loss_params.lambda_slm * loss_lm
516
+ + loss_params.lambda_sty * loss_sty
517
+ + loss_params.lambda_diff * loss_diff
518
+ + loss_params.lambda_mono * loss_mono
519
+ + loss_params.lambda_s2s * loss_s2s
520
+ )
521
+
522
+ running_loss += loss_mel.item()
523
+ g_loss.backward()
524
+ if torch.isnan(g_loss):
525
+ from IPython.core.debugger import set_trace
526
+
527
+ set_trace()
528
+
529
+ optimizer.step("bert_encoder")
530
+ optimizer.step("bert")
531
+ optimizer.step("predictor")
532
+ optimizer.step("predictor_encoder")
533
+ optimizer.step("style_encoder")
534
+ optimizer.step("decoder")
535
+
536
+ optimizer.step("text_encoder")
537
+ optimizer.step("text_aligner")
538
+
539
+ if epoch >= diff_epoch:
540
+ optimizer.step("diffusion")
541
+
542
+ if epoch >= joint_epoch:
543
+ # randomly pick whether to use in-distribution text
544
+ if np.random.rand() < 0.5:
545
+ use_ind = True
546
+ else:
547
+ use_ind = False
548
+
549
+ if use_ind:
550
+ ref_lengths = input_lengths
551
+ ref_texts = texts
552
+
553
+ slm_out = slmadv(
554
+ i,
555
+ y_rec_gt,
556
+ y_rec_gt_pred,
557
+ waves,
558
+ mel_input_length,
559
+ ref_texts,
560
+ ref_lengths,
561
+ use_ind,
562
+ s_trg.detach(),
563
+ ref if multispeaker else None,
564
+ )
565
+
566
+ if slm_out is None:
567
+ continue
568
+
569
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
570
+
571
+ # SLM discriminator loss
572
+ if d_loss_slm != 0:
573
+ optimizer.zero_grad()
574
+ d_loss_slm.backward()
575
+ optimizer.step("wd")
576
+
577
+ # SLM generator loss
578
+ optimizer.zero_grad()
579
+ loss_gen_lm.backward()
580
+
581
+ # compute the gradient norm
582
+ total_norm = {}
583
+ for key in model.keys():
584
+ total_norm[key] = 0
585
+ parameters = [
586
+ p
587
+ for p in model[key].parameters()
588
+ if p.grad is not None and p.requires_grad
589
+ ]
590
+ for p in parameters:
591
+ param_norm = p.grad.detach().data.norm(2)
592
+ total_norm[key] += param_norm.item() ** 2
593
+ total_norm[key] = total_norm[key] ** 0.5
594
+
595
+ # gradient scaling
596
+ if total_norm["predictor"] > slmadv_params.thresh:
597
+ for key in model.keys():
598
+ for p in model[key].parameters():
599
+ if p.grad is not None:
600
+ p.grad *= 1 / total_norm["predictor"]
601
+
602
+ for p in model.predictor.duration_proj.parameters():
603
+ if p.grad is not None:
604
+ p.grad *= slmadv_params.scale
605
+
606
+ for p in model.predictor.lstm.parameters():
607
+ if p.grad is not None:
608
+ p.grad *= slmadv_params.scale
609
+
610
+ for p in model.diffusion.parameters():
611
+ if p.grad is not None:
612
+ p.grad *= slmadv_params.scale
613
+
614
+ optimizer.step("bert_encoder")
615
+ optimizer.step("bert")
616
+ optimizer.step("predictor")
617
+ optimizer.step("diffusion")
618
+
619
+ else:
620
+ d_loss_slm, loss_gen_lm = 0, 0
621
+
622
+ iters = iters + 1
623
+
624
+ if (i + 1) % log_interval == 0:
625
+ logger.info(
626
+ "Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f, SLoss: %.5f, S2S Loss: %.5f, Mono Loss: %.5f"
627
+ % (
628
+ epoch + 1,
629
+ epochs,
630
+ i + 1,
631
+ len(train_list) // batch_size,
632
+ running_loss / log_interval,
633
+ d_loss,
634
+ loss_dur,
635
+ loss_ce,
636
+ loss_norm_rec,
637
+ loss_F0_rec,
638
+ loss_lm,
639
+ loss_gen_all,
640
+ loss_sty,
641
+ loss_diff,
642
+ d_loss_slm,
643
+ loss_gen_lm,
644
+ s_loss,
645
+ loss_s2s,
646
+ loss_mono,
647
+ )
648
+ )
649
+
650
+ writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
651
+ writer.add_scalar("train/gen_loss", loss_gen_all, iters)
652
+ writer.add_scalar("train/d_loss", d_loss, iters)
653
+ writer.add_scalar("train/ce_loss", loss_ce, iters)
654
+ writer.add_scalar("train/dur_loss", loss_dur, iters)
655
+ writer.add_scalar("train/slm_loss", loss_lm, iters)
656
+ writer.add_scalar("train/norm_loss", loss_norm_rec, iters)
657
+ writer.add_scalar("train/F0_loss", loss_F0_rec, iters)
658
+ writer.add_scalar("train/sty_loss", loss_sty, iters)
659
+ writer.add_scalar("train/diff_loss", loss_diff, iters)
660
+ writer.add_scalar("train/d_loss_slm", d_loss_slm, iters)
661
+ writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters)
662
+
663
+ running_loss = 0
664
+
665
+ print("Time elasped:", time.time() - start_time)
666
+
667
+ loss_test = 0
668
+ loss_align = 0
669
+ loss_f = 0
670
+ _ = [model[key].eval() for key in model]
671
+
672
+ with torch.no_grad():
673
+ iters_test = 0
674
+ for batch_idx, batch in enumerate(val_dataloader):
675
+ optimizer.zero_grad()
676
+
677
+ try:
678
+ waves = batch[0]
679
+ batch = [b.to(device) for b in batch[1:]]
680
+ (
681
+ texts,
682
+ input_lengths,
683
+ ref_texts,
684
+ ref_lengths,
685
+ mels,
686
+ mel_input_length,
687
+ ref_mels,
688
+ ) = batch
689
+ with torch.no_grad():
690
+ mask = length_to_mask(mel_input_length // (2**n_down)).to(
691
+ "cuda"
692
+ )
693
+ text_mask = length_to_mask(input_lengths).to(texts.device)
694
+
695
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
696
+ s2s_attn = s2s_attn.transpose(-1, -2)
697
+ s2s_attn = s2s_attn[..., 1:]
698
+ s2s_attn = s2s_attn.transpose(-1, -2)
699
+
700
+ mask_ST = mask_from_lens(
701
+ s2s_attn, input_lengths, mel_input_length // (2**n_down)
702
+ )
703
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
704
+
705
+ # encode
706
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
707
+ asr = t_en @ s2s_attn_mono
708
+
709
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
710
+
711
+ ss = []
712
+ gs = []
713
+
714
+ for bib in range(len(mel_input_length)):
715
+ mel_length = int(mel_input_length[bib].item())
716
+ mel = mels[bib, :, : mel_input_length[bib]]
717
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
718
+ ss.append(s)
719
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
720
+ gs.append(s)
721
+
722
+ s = torch.stack(ss).squeeze()
723
+ gs = torch.stack(gs).squeeze()
724
+ s_trg = torch.cat([s, gs], dim=-1).detach()
725
+
726
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
727
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
728
+ d, p = model.predictor(
729
+ d_en, s, input_lengths, s2s_attn_mono, text_mask
730
+ )
731
+ # get clips
732
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
733
+ en = []
734
+ gt = []
735
+
736
+ p_en = []
737
+ wav = []
738
+
739
+ for bib in range(len(mel_input_length)):
740
+ mel_length = int(mel_input_length[bib].item() / 2)
741
+
742
+ random_start = np.random.randint(0, mel_length - mel_len)
743
+ en.append(asr[bib, :, random_start : random_start + mel_len])
744
+ p_en.append(p[bib, :, random_start : random_start + mel_len])
745
+
746
+ gt.append(
747
+ mels[
748
+ bib,
749
+ :,
750
+ (random_start * 2) : ((random_start + mel_len) * 2),
751
+ ]
752
+ )
753
+ y = waves[bib][
754
+ (random_start * 2)
755
+ * 300 : ((random_start + mel_len) * 2)
756
+ * 300
757
+ ]
758
+ wav.append(torch.from_numpy(y).to(device))
759
+
760
+ wav = torch.stack(wav).float().detach()
761
+
762
+ en = torch.stack(en)
763
+ p_en = torch.stack(p_en)
764
+ gt = torch.stack(gt).detach()
765
+ s = model.predictor_encoder(gt.unsqueeze(1))
766
+
767
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
768
+
769
+ loss_dur = 0
770
+ for _s2s_pred, _text_input, _text_length in zip(
771
+ d, (d_gt), input_lengths
772
+ ):
773
+ _s2s_pred = _s2s_pred[:_text_length, :]
774
+ _text_input = _text_input[:_text_length].long()
775
+ _s2s_trg = torch.zeros_like(_s2s_pred)
776
+ for bib in range(_s2s_trg.shape[0]):
777
+ _s2s_trg[bib, : _text_input[bib]] = 1
778
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
779
+ loss_dur += F.l1_loss(
780
+ _dur_pred[1 : _text_length - 1],
781
+ _text_input[1 : _text_length - 1],
782
+ )
783
+
784
+ loss_dur /= texts.size(0)
785
+
786
+ s = model.style_encoder(gt.unsqueeze(1))
787
+
788
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
789
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
790
+
791
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
792
+
793
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
794
+
795
+ loss_test += (loss_mel).mean()
796
+ loss_align += (loss_dur).mean()
797
+ loss_f += (loss_F0).mean()
798
+
799
+ iters_test += 1
800
+ except:
801
+ continue
802
+
803
+ print("Epochs:", epoch + 1)
804
+ logger.info(
805
+ "Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f"
806
+ % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test)
807
+ + "\n\n\n"
808
+ )
809
+ print("\n\n\n")
810
+ writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
811
+ writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1)
812
+ writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1)
813
+
814
+ if (epoch + 1) % save_freq == 0:
815
+ if (loss_test / iters_test) < best_loss:
816
+ best_loss = loss_test / iters_test
817
+ print("Saving..")
818
+ state = {
819
+ "net": {key: model[key].state_dict() for key in model},
820
+ "optimizer": optimizer.state_dict(),
821
+ "iters": iters,
822
+ "val_loss": loss_test / iters_test,
823
+ "epoch": epoch,
824
+ }
825
+ save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch)
826
+ torch.save(state, save_path)
827
+
828
+ # if estimate sigma, save the estimated simga
829
+ if model_params.diffusion.dist.estimate_sigma_data:
830
+ config["model_params"]["diffusion"]["dist"]["sigma_data"] = float(
831
+ np.mean(running_std)
832
+ )
833
+
834
+ with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile:
835
+ yaml.dump(config, outfile, default_flow_style=True)
836
+
837
+
838
+ if __name__ == "__main__":
839
+ main()
train_first.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import os.path as osp
3
+ import re
4
+ import sys
5
+ import yaml
6
+ import shutil
7
+ import numpy as np
8
+ import torch
9
+ import click
10
+ import warnings
11
+
12
+ warnings.simplefilter("ignore")
13
+
14
+ # load packages
15
+ import random
16
+ import yaml
17
+ from munch import Munch
18
+ import numpy as np
19
+ import torch
20
+ from torch import nn
21
+ import torch.nn.functional as F
22
+ import torchaudio
23
+ import librosa
24
+
25
+ from models import *
26
+ from meldataset import build_dataloader
27
+ from utils import *
28
+ from losses import *
29
+ from optimizers import build_optimizer
30
+ import time
31
+
32
+ from accelerate import Accelerator
33
+ from accelerate.utils import LoggerType
34
+ from accelerate import DistributedDataParallelKwargs
35
+
36
+ from torch.utils.tensorboard import SummaryWriter
37
+
38
+ import logging
39
+ from accelerate.logging import get_logger
40
+
41
+ logger = get_logger(__name__, log_level="DEBUG")
42
+
43
+
44
+ @click.command()
45
+ @click.option("-p", "--config_path", default="Configs/config.yml", type=str)
46
+ def main(config_path):
47
+ config = yaml.safe_load(open(config_path))
48
+
49
+ log_dir = config["log_dir"]
50
+ if not osp.exists(log_dir):
51
+ os.makedirs(log_dir, exist_ok=True)
52
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
53
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
54
+ accelerator = Accelerator(
55
+ project_dir=log_dir, split_batches=True, kwargs_handlers=[ddp_kwargs]
56
+ )
57
+ if accelerator.is_main_process:
58
+ writer = SummaryWriter(log_dir + "/tensorboard")
59
+
60
+ # write logs
61
+ file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
62
+ file_handler.setLevel(logging.DEBUG)
63
+ file_handler.setFormatter(
64
+ logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
65
+ )
66
+ logger.logger.addHandler(file_handler)
67
+
68
+ batch_size = config.get("batch_size", 10)
69
+ device = accelerator.device
70
+
71
+ epochs = config.get("epochs_1st", 200)
72
+ save_freq = config.get("save_freq", 2)
73
+ log_interval = config.get("log_interval", 10)
74
+ saving_epoch = config.get("save_freq", 2)
75
+
76
+ data_params = config.get("data_params", None)
77
+ sr = config["preprocess_params"].get("sr", 24000)
78
+ train_path = data_params["train_data"]
79
+ val_path = data_params["val_data"]
80
+ root_path = data_params["root_path"]
81
+ min_length = data_params["min_length"]
82
+ OOD_data = data_params["OOD_data"]
83
+
84
+ max_len = config.get("max_len", 200)
85
+
86
+ # load data
87
+ train_list, val_list = get_data_path_list(train_path, val_path)
88
+
89
+ train_dataloader = build_dataloader(
90
+ train_list,
91
+ root_path,
92
+ OOD_data=OOD_data,
93
+ min_length=min_length,
94
+ batch_size=batch_size,
95
+ num_workers=2,
96
+ dataset_config={},
97
+ device=device,
98
+ )
99
+
100
+ val_dataloader = build_dataloader(
101
+ val_list,
102
+ root_path,
103
+ OOD_data=OOD_data,
104
+ min_length=min_length,
105
+ batch_size=batch_size,
106
+ validation=True,
107
+ num_workers=0,
108
+ device=device,
109
+ dataset_config={},
110
+ )
111
+
112
+ with accelerator.main_process_first():
113
+ # load pretrained ASR model
114
+ ASR_config = config.get("ASR_config", False)
115
+ ASR_path = config.get("ASR_path", False)
116
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
117
+
118
+ # load pretrained F0 model
119
+ F0_path = config.get("F0_path", False)
120
+ pitch_extractor = load_F0_models(F0_path)
121
+
122
+ # load BERT model
123
+ from Utils.PLBERT.util import load_plbert
124
+
125
+ BERT_path = config.get("PLBERT_dir", False)
126
+ plbert = load_plbert(BERT_path)
127
+
128
+ scheduler_params = {
129
+ "max_lr": float(config["optimizer_params"].get("lr", 1e-4)),
130
+ "pct_start": float(config["optimizer_params"].get("pct_start", 0.0)),
131
+ "epochs": epochs,
132
+ "steps_per_epoch": len(train_dataloader),
133
+ }
134
+
135
+ model_params = recursive_munch(config["model_params"])
136
+ multispeaker = model_params.multispeaker
137
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
138
+
139
+ best_loss = float("inf") # best test loss
140
+ loss_train_record = list([])
141
+ loss_test_record = list([])
142
+
143
+ loss_params = Munch(config["loss_params"])
144
+ TMA_epoch = loss_params.TMA_epoch
145
+
146
+ for k in model:
147
+ model[k] = accelerator.prepare(model[k])
148
+
149
+ train_dataloader, val_dataloader = accelerator.prepare(
150
+ train_dataloader, val_dataloader
151
+ )
152
+
153
+ _ = [model[key].to(device) for key in model]
154
+
155
+ # initialize optimizers after preparing models for compatibility with FSDP
156
+ optimizer = build_optimizer(
157
+ {key: model[key].parameters() for key in model},
158
+ scheduler_params_dict={key: scheduler_params.copy() for key in model},
159
+ lr=float(config["optimizer_params"].get("lr", 1e-4)),
160
+ )
161
+
162
+ for k, v in optimizer.optimizers.items():
163
+ optimizer.optimizers[k] = accelerator.prepare(optimizer.optimizers[k])
164
+ optimizer.schedulers[k] = accelerator.prepare(optimizer.schedulers[k])
165
+
166
+ with accelerator.main_process_first():
167
+ if config.get("pretrained_model", "") != "":
168
+ model, optimizer, start_epoch, iters = load_checkpoint(
169
+ model,
170
+ optimizer,
171
+ config["pretrained_model"],
172
+ load_only_params=config.get("load_only_params", True),
173
+ )
174
+ else:
175
+ start_epoch = 0
176
+ iters = 0
177
+
178
+ # in case not distributed
179
+ try:
180
+ n_down = model.text_aligner.module.n_down
181
+ except:
182
+ n_down = model.text_aligner.n_down
183
+
184
+ # wrapped losses for compatibility with mixed precision
185
+ stft_loss = MultiResolutionSTFTLoss().to(device)
186
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
187
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
188
+ wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
189
+
190
+ for epoch in range(start_epoch, epochs):
191
+ running_loss = 0
192
+ start_time = time.time()
193
+
194
+ _ = [model[key].train() for key in model]
195
+
196
+ for i, batch in enumerate(train_dataloader):
197
+ waves = batch[0]
198
+ batch = [b.to(device) for b in batch[1:]]
199
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
200
+
201
+ with torch.no_grad():
202
+ mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda")
203
+ text_mask = length_to_mask(input_lengths).to(texts.device)
204
+
205
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
206
+
207
+ s2s_attn = s2s_attn.transpose(-1, -2)
208
+ s2s_attn = s2s_attn[..., 1:]
209
+ s2s_attn = s2s_attn.transpose(-1, -2)
210
+
211
+ with torch.no_grad():
212
+ attn_mask = (
213
+ (~mask)
214
+ .unsqueeze(-1)
215
+ .expand(mask.shape[0], mask.shape[1], text_mask.shape[-1])
216
+ .float()
217
+ .transpose(-1, -2)
218
+ )
219
+ attn_mask = (
220
+ attn_mask.float()
221
+ * (~text_mask)
222
+ .unsqueeze(-1)
223
+ .expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1])
224
+ .float()
225
+ )
226
+ attn_mask = attn_mask < 1
227
+
228
+ s2s_attn.masked_fill_(attn_mask, 0.0)
229
+
230
+ with torch.no_grad():
231
+ mask_ST = mask_from_lens(
232
+ s2s_attn, input_lengths, mel_input_length // (2**n_down)
233
+ )
234
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
235
+
236
+ # encode
237
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
238
+
239
+ # 50% of chance of using monotonic version
240
+ if bool(random.getrandbits(1)):
241
+ asr = t_en @ s2s_attn
242
+ else:
243
+ asr = t_en @ s2s_attn_mono
244
+
245
+ # get clips
246
+ mel_input_length_all = accelerator.gather(
247
+ mel_input_length
248
+ ) # for balanced load
249
+ mel_len = min(
250
+ [int(mel_input_length_all.min().item() / 2 - 1), max_len // 2]
251
+ )
252
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
253
+
254
+ en = []
255
+ gt = []
256
+ wav = []
257
+ st = []
258
+
259
+ for bib in range(len(mel_input_length)):
260
+ mel_length = int(mel_input_length[bib].item() / 2)
261
+
262
+ random_start = np.random.randint(0, mel_length - mel_len)
263
+ en.append(asr[bib, :, random_start : random_start + mel_len])
264
+ gt.append(
265
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
266
+ )
267
+
268
+ y = waves[bib][
269
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
270
+ ]
271
+ wav.append(torch.from_numpy(y).to(device))
272
+
273
+ # style reference (better to be different from the GT)
274
+ random_start = np.random.randint(0, mel_length - mel_len_st)
275
+ st.append(
276
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
277
+ )
278
+
279
+ en = torch.stack(en)
280
+ gt = torch.stack(gt).detach()
281
+ st = torch.stack(st).detach()
282
+
283
+ wav = torch.stack(wav).float().detach()
284
+
285
+ # clip too short to be used by the style encoder
286
+ if gt.shape[-1] < 80:
287
+ continue
288
+
289
+ with torch.no_grad():
290
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1).detach()
291
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
292
+
293
+ s = model.style_encoder(
294
+ st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
295
+ )
296
+
297
+ y_rec = model.decoder(en, F0_real, real_norm, s)
298
+
299
+ # discriminator loss
300
+
301
+ if epoch >= TMA_epoch:
302
+ optimizer.zero_grad()
303
+ d_loss = dl(wav.detach().unsqueeze(1).float(), y_rec.detach()).mean()
304
+ accelerator.backward(d_loss)
305
+ optimizer.step("msd")
306
+ optimizer.step("mpd")
307
+ else:
308
+ d_loss = 0
309
+
310
+ # generator loss
311
+ optimizer.zero_grad()
312
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
313
+
314
+ if epoch >= TMA_epoch: # start TMA training
315
+ loss_s2s = 0
316
+ for _s2s_pred, _text_input, _text_length in zip(
317
+ s2s_pred, texts, input_lengths
318
+ ):
319
+ loss_s2s += F.cross_entropy(
320
+ _s2s_pred[:_text_length], _text_input[:_text_length]
321
+ )
322
+ loss_s2s /= texts.size(0)
323
+
324
+ loss_mono = F.l1_loss(s2s_attn, s2s_attn_mono) * 10
325
+
326
+ loss_gen_all = gl(wav.detach().unsqueeze(1).float(), y_rec).mean()
327
+ loss_slm = wl(wav.detach(), y_rec).mean()
328
+
329
+ g_loss = (
330
+ loss_params.lambda_mel * loss_mel
331
+ + loss_params.lambda_mono * loss_mono
332
+ + loss_params.lambda_s2s * loss_s2s
333
+ + loss_params.lambda_gen * loss_gen_all
334
+ + loss_params.lambda_slm * loss_slm
335
+ )
336
+
337
+ else:
338
+ loss_s2s = 0
339
+ loss_mono = 0
340
+ loss_gen_all = 0
341
+ loss_slm = 0
342
+ g_loss = loss_mel
343
+
344
+ running_loss += accelerator.gather(loss_mel).mean().item()
345
+
346
+ accelerator.backward(g_loss)
347
+
348
+ optimizer.step("text_encoder")
349
+ optimizer.step("style_encoder")
350
+ optimizer.step("decoder")
351
+
352
+ if epoch >= TMA_epoch:
353
+ optimizer.step("text_aligner")
354
+ optimizer.step("pitch_extractor")
355
+
356
+ iters = iters + 1
357
+
358
+ if (i + 1) % log_interval == 0 and accelerator.is_main_process:
359
+ log_print(
360
+ "Epoch [%d/%d], Step [%d/%d], Mel Loss: %.5f, Gen Loss: %.5f, Disc Loss: %.5f, Mono Loss: %.5f, S2S Loss: %.5f, SLM Loss: %.5f"
361
+ % (
362
+ epoch + 1,
363
+ epochs,
364
+ i + 1,
365
+ len(train_list) // batch_size,
366
+ running_loss / log_interval,
367
+ loss_gen_all,
368
+ d_loss,
369
+ loss_mono,
370
+ loss_s2s,
371
+ loss_slm,
372
+ ),
373
+ logger,
374
+ )
375
+
376
+ writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
377
+ writer.add_scalar("train/gen_loss", loss_gen_all, iters)
378
+ writer.add_scalar("train/d_loss", d_loss, iters)
379
+ writer.add_scalar("train/mono_loss", loss_mono, iters)
380
+ writer.add_scalar("train/s2s_loss", loss_s2s, iters)
381
+ writer.add_scalar("train/slm_loss", loss_slm, iters)
382
+
383
+ running_loss = 0
384
+
385
+ print("Time elasped:", time.time() - start_time)
386
+
387
+ loss_test = 0
388
+
389
+ _ = [model[key].eval() for key in model]
390
+
391
+ with torch.no_grad():
392
+ iters_test = 0
393
+ for batch_idx, batch in enumerate(val_dataloader):
394
+ optimizer.zero_grad()
395
+
396
+ waves = batch[0]
397
+ batch = [b.to(device) for b in batch[1:]]
398
+ texts, input_lengths, _, _, mels, mel_input_length, _ = batch
399
+
400
+ with torch.no_grad():
401
+ mask = length_to_mask(mel_input_length // (2**n_down)).to("cuda")
402
+ ppgs, s2s_pred, s2s_attn = model.text_aligner(mels, mask, texts)
403
+
404
+ s2s_attn = s2s_attn.transpose(-1, -2)
405
+ s2s_attn = s2s_attn[..., 1:]
406
+ s2s_attn = s2s_attn.transpose(-1, -2)
407
+
408
+ text_mask = length_to_mask(input_lengths).to(texts.device)
409
+ attn_mask = (
410
+ (~mask)
411
+ .unsqueeze(-1)
412
+ .expand(mask.shape[0], mask.shape[1], text_mask.shape[-1])
413
+ .float()
414
+ .transpose(-1, -2)
415
+ )
416
+ attn_mask = (
417
+ attn_mask.float()
418
+ * (~text_mask)
419
+ .unsqueeze(-1)
420
+ .expand(text_mask.shape[0], text_mask.shape[1], mask.shape[-1])
421
+ .float()
422
+ )
423
+ attn_mask = attn_mask < 1
424
+ s2s_attn.masked_fill_(attn_mask, 0.0)
425
+
426
+ # encode
427
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
428
+
429
+ asr = t_en @ s2s_attn
430
+
431
+ # get clips
432
+ mel_input_length_all = accelerator.gather(
433
+ mel_input_length
434
+ ) # for balanced load
435
+ mel_len = min(
436
+ [int(mel_input_length.min().item() / 2 - 1), max_len // 2]
437
+ )
438
+
439
+ en = []
440
+ gt = []
441
+ wav = []
442
+ for bib in range(len(mel_input_length)):
443
+ mel_length = int(mel_input_length[bib].item() / 2)
444
+
445
+ random_start = np.random.randint(0, mel_length - mel_len)
446
+ en.append(asr[bib, :, random_start : random_start + mel_len])
447
+ gt.append(
448
+ mels[
449
+ bib, :, (random_start * 2) : ((random_start + mel_len) * 2)
450
+ ]
451
+ )
452
+ y = waves[bib][
453
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
454
+ ]
455
+ wav.append(torch.from_numpy(y).to("cuda"))
456
+
457
+ wav = torch.stack(wav).float().detach()
458
+
459
+ en = torch.stack(en)
460
+ gt = torch.stack(gt).detach()
461
+
462
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
463
+ s = model.style_encoder(gt.unsqueeze(1))
464
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
465
+ y_rec = model.decoder(en, F0_real, real_norm, s)
466
+
467
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
468
+
469
+ loss_test += accelerator.gather(loss_mel).mean().item()
470
+ iters_test += 1
471
+
472
+ if accelerator.is_main_process:
473
+ print("Epochs:", epoch + 1)
474
+ log_print(
475
+ "Validation loss: %.3f" % (loss_test / iters_test) + "\n\n\n\n", logger
476
+ )
477
+ print("\n\n\n")
478
+ writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
479
+ attn_image = get_image(s2s_attn[0].cpu().numpy().squeeze())
480
+ writer.add_figure("eval/attn", attn_image, epoch)
481
+
482
+ with torch.no_grad():
483
+ for bib in range(len(asr)):
484
+ mel_length = int(mel_input_length[bib].item())
485
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
486
+ en = asr[bib, :, : mel_length // 2].unsqueeze(0)
487
+
488
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
489
+ F0_real = F0_real.unsqueeze(0)
490
+ s = model.style_encoder(gt.unsqueeze(1))
491
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
492
+
493
+ y_rec = model.decoder(en, F0_real, real_norm, s)
494
+
495
+ writer.add_audio(
496
+ "eval/y" + str(bib),
497
+ y_rec.cpu().numpy().squeeze(),
498
+ epoch,
499
+ sample_rate=sr,
500
+ )
501
+ if epoch == 0:
502
+ writer.add_audio(
503
+ "gt/y" + str(bib),
504
+ waves[bib].squeeze(),
505
+ epoch,
506
+ sample_rate=sr,
507
+ )
508
+
509
+ if bib >= 6:
510
+ break
511
+
512
+ if epoch % saving_epoch == 0:
513
+ if (loss_test / iters_test) < best_loss:
514
+ best_loss = loss_test / iters_test
515
+ print("Saving..")
516
+ state = {
517
+ "net": {key: model[key].state_dict() for key in model},
518
+ "optimizer": optimizer.state_dict(),
519
+ "iters": iters,
520
+ "val_loss": loss_test / iters_test,
521
+ "epoch": epoch,
522
+ }
523
+ save_path = osp.join(log_dir, "epoch_1st_%05d.pth" % epoch)
524
+ torch.save(state, save_path)
525
+
526
+ if accelerator.is_main_process:
527
+ print("Saving..")
528
+ state = {
529
+ "net": {key: model[key].state_dict() for key in model},
530
+ "optimizer": optimizer.state_dict(),
531
+ "iters": iters,
532
+ "val_loss": loss_test / iters_test,
533
+ "epoch": epoch,
534
+ }
535
+ save_path = osp.join(log_dir, config.get("first_stage_path", "first_stage.pth"))
536
+ torch.save(state, save_path)
537
+
538
+
539
+ if __name__ == "__main__":
540
+ main()
train_second.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # load packages
2
+ import random
3
+ import yaml
4
+ import time
5
+ from munch import Munch
6
+ import numpy as np
7
+ import torch
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+ import torchaudio
11
+ import librosa
12
+ import click
13
+ import shutil
14
+ import warnings
15
+
16
+ warnings.simplefilter("ignore")
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ from meldataset import build_dataloader
20
+
21
+ from Utils.ASR.models import ASRCNN
22
+ from Utils.JDC.model import JDCNet
23
+ from Utils.PLBERT.util import load_plbert
24
+
25
+ from models import *
26
+ from losses import *
27
+ from utils import *
28
+
29
+ from Modules.slmadv import SLMAdversarialLoss
30
+ from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
31
+
32
+ from optimizers import build_optimizer
33
+
34
+
35
+ # simple fix for dataparallel that allows access to class attributes
36
+ class MyDataParallel(torch.nn.DataParallel):
37
+ def __getattr__(self, name):
38
+ try:
39
+ return super().__getattr__(name)
40
+ except AttributeError:
41
+ return getattr(self.module, name)
42
+
43
+
44
+ import logging
45
+ from logging import StreamHandler
46
+
47
+ logger = logging.getLogger(__name__)
48
+ logger.setLevel(logging.DEBUG)
49
+ handler = StreamHandler()
50
+ handler.setLevel(logging.DEBUG)
51
+ logger.addHandler(handler)
52
+
53
+
54
+ @click.command()
55
+ @click.option("-p", "--config_path", default="Configs/config.yml", type=str)
56
+ def main(config_path):
57
+ config = yaml.safe_load(open(config_path))
58
+
59
+ log_dir = config["log_dir"]
60
+ if not osp.exists(log_dir):
61
+ os.makedirs(log_dir, exist_ok=True)
62
+ shutil.copy(config_path, osp.join(log_dir, osp.basename(config_path)))
63
+ writer = SummaryWriter(log_dir + "/tensorboard")
64
+
65
+ # write logs
66
+ file_handler = logging.FileHandler(osp.join(log_dir, "train.log"))
67
+ file_handler.setLevel(logging.DEBUG)
68
+ file_handler.setFormatter(
69
+ logging.Formatter("%(levelname)s:%(asctime)s: %(message)s")
70
+ )
71
+ logger.addHandler(file_handler)
72
+
73
+ batch_size = config.get("batch_size", 10)
74
+
75
+ epochs = config.get("epochs_2nd", 200)
76
+ save_freq = config.get("save_freq", 2)
77
+ log_interval = config.get("log_interval", 10)
78
+ saving_epoch = config.get("save_freq", 2)
79
+
80
+ data_params = config.get("data_params", None)
81
+ sr = config["preprocess_params"].get("sr", 24000)
82
+ train_path = data_params["train_data"]
83
+ val_path = data_params["val_data"]
84
+ root_path = data_params["root_path"]
85
+ min_length = data_params["min_length"]
86
+ OOD_data = data_params["OOD_data"]
87
+
88
+ max_len = config.get("max_len", 200)
89
+
90
+ loss_params = Munch(config["loss_params"])
91
+ diff_epoch = loss_params.diff_epoch
92
+ joint_epoch = loss_params.joint_epoch
93
+
94
+ optimizer_params = Munch(config["optimizer_params"])
95
+
96
+ train_list, val_list = get_data_path_list(train_path, val_path)
97
+ device = "cuda"
98
+
99
+ train_dataloader = build_dataloader(
100
+ train_list,
101
+ root_path,
102
+ OOD_data=OOD_data,
103
+ min_length=min_length,
104
+ batch_size=batch_size,
105
+ num_workers=2,
106
+ dataset_config={},
107
+ device=device,
108
+ )
109
+
110
+ val_dataloader = build_dataloader(
111
+ val_list,
112
+ root_path,
113
+ OOD_data=OOD_data,
114
+ min_length=min_length,
115
+ batch_size=batch_size,
116
+ validation=True,
117
+ num_workers=0,
118
+ device=device,
119
+ dataset_config={},
120
+ )
121
+
122
+ # load pretrained ASR model
123
+ ASR_config = config.get("ASR_config", False)
124
+ ASR_path = config.get("ASR_path", False)
125
+ text_aligner = load_ASR_models(ASR_path, ASR_config)
126
+
127
+ # load pretrained F0 model
128
+ F0_path = config.get("F0_path", False)
129
+ pitch_extractor = load_F0_models(F0_path)
130
+
131
+ # load PL-BERT model
132
+ BERT_path = config.get("PLBERT_dir", False)
133
+ plbert = load_plbert(BERT_path)
134
+
135
+ # build model
136
+ model_params = recursive_munch(config["model_params"])
137
+ multispeaker = model_params.multispeaker
138
+ model = build_model(model_params, text_aligner, pitch_extractor, plbert)
139
+ _ = [model[key].to(device) for key in model]
140
+
141
+ # DP
142
+ for key in model:
143
+ if key != "mpd" and key != "msd" and key != "wd":
144
+ model[key] = MyDataParallel(model[key])
145
+
146
+ start_epoch = 0
147
+ iters = 0
148
+
149
+ load_pretrained = config.get("pretrained_model", "") != "" and config.get(
150
+ "second_stage_load_pretrained", False
151
+ )
152
+
153
+ if not load_pretrained:
154
+ if config.get("first_stage_path", "") != "":
155
+ first_stage_path = osp.join(
156
+ log_dir, config.get("first_stage_path", "first_stage.pth")
157
+ )
158
+ print("Loading the first stage model at %s ..." % first_stage_path)
159
+ model, _, start_epoch, iters = load_checkpoint(
160
+ model,
161
+ None,
162
+ first_stage_path,
163
+ load_only_params=True,
164
+ ignore_modules=[
165
+ "bert",
166
+ "bert_encoder",
167
+ "predictor",
168
+ "predictor_encoder",
169
+ "msd",
170
+ "mpd",
171
+ "wd",
172
+ "diffusion",
173
+ ],
174
+ ) # keep starting epoch for tensorboard log
175
+
176
+ # these epochs should be counted from the start epoch
177
+ diff_epoch += start_epoch
178
+ joint_epoch += start_epoch
179
+ epochs += start_epoch
180
+
181
+ model.predictor_encoder = copy.deepcopy(model.style_encoder)
182
+ else:
183
+ raise ValueError("You need to specify the path to the first stage model.")
184
+
185
+ gl = GeneratorLoss(model.mpd, model.msd).to(device)
186
+ dl = DiscriminatorLoss(model.mpd, model.msd).to(device)
187
+ wl = WavLMLoss(model_params.slm.model, model.wd, sr, model_params.slm.sr).to(device)
188
+
189
+ gl = MyDataParallel(gl)
190
+ dl = MyDataParallel(dl)
191
+ wl = MyDataParallel(wl)
192
+
193
+ sampler = DiffusionSampler(
194
+ model.diffusion.diffusion,
195
+ sampler=ADPM2Sampler(),
196
+ sigma_schedule=KarrasSchedule(
197
+ sigma_min=0.0001, sigma_max=3.0, rho=9.0
198
+ ), # empirical parameters
199
+ clamp=False,
200
+ )
201
+
202
+ scheduler_params = {
203
+ "max_lr": optimizer_params.lr,
204
+ "pct_start": float(0),
205
+ "epochs": epochs,
206
+ "steps_per_epoch": len(train_dataloader),
207
+ }
208
+ scheduler_params_dict = {key: scheduler_params.copy() for key in model}
209
+ scheduler_params_dict["bert"]["max_lr"] = optimizer_params.bert_lr * 2
210
+ scheduler_params_dict["decoder"]["max_lr"] = optimizer_params.ft_lr * 2
211
+ scheduler_params_dict["style_encoder"]["max_lr"] = optimizer_params.ft_lr * 2
212
+
213
+ optimizer = build_optimizer(
214
+ {key: model[key].parameters() for key in model},
215
+ scheduler_params_dict=scheduler_params_dict,
216
+ lr=optimizer_params.lr,
217
+ )
218
+
219
+ # adjust BERT learning rate
220
+ for g in optimizer.optimizers["bert"].param_groups:
221
+ g["betas"] = (0.9, 0.99)
222
+ g["lr"] = optimizer_params.bert_lr
223
+ g["initial_lr"] = optimizer_params.bert_lr
224
+ g["min_lr"] = 0
225
+ g["weight_decay"] = 0.01
226
+
227
+ # adjust acoustic module learning rate
228
+ for module in ["decoder", "style_encoder"]:
229
+ for g in optimizer.optimizers[module].param_groups:
230
+ g["betas"] = (0.0, 0.99)
231
+ g["lr"] = optimizer_params.ft_lr
232
+ g["initial_lr"] = optimizer_params.ft_lr
233
+ g["min_lr"] = 0
234
+ g["weight_decay"] = 1e-4
235
+
236
+ # load models if there is a model
237
+ if load_pretrained:
238
+ model, optimizer, start_epoch, iters = load_checkpoint(
239
+ model,
240
+ optimizer,
241
+ config["pretrained_model"],
242
+ load_only_params=config.get("load_only_params", True),
243
+ )
244
+
245
+ n_down = model.text_aligner.n_down
246
+
247
+ best_loss = float("inf") # best test loss
248
+ loss_train_record = list([])
249
+ loss_test_record = list([])
250
+ iters = 0
251
+
252
+ criterion = nn.L1Loss() # F0 loss (regression)
253
+ torch.cuda.empty_cache()
254
+
255
+ stft_loss = MultiResolutionSTFTLoss().to(device)
256
+
257
+ print("BERT", optimizer.optimizers["bert"])
258
+ print("decoder", optimizer.optimizers["decoder"])
259
+
260
+ start_ds = False
261
+
262
+ running_std = []
263
+
264
+ slmadv_params = Munch(config["slmadv_params"])
265
+ slmadv = SLMAdversarialLoss(
266
+ model,
267
+ wl,
268
+ sampler,
269
+ slmadv_params.min_len,
270
+ slmadv_params.max_len,
271
+ batch_percentage=slmadv_params.batch_percentage,
272
+ skip_update=slmadv_params.iter,
273
+ sig=slmadv_params.sig,
274
+ )
275
+
276
+ for epoch in range(start_epoch, epochs):
277
+ running_loss = 0
278
+ start_time = time.time()
279
+
280
+ _ = [model[key].eval() for key in model]
281
+
282
+ model.predictor.train()
283
+ model.bert_encoder.train()
284
+ model.bert.train()
285
+ model.msd.train()
286
+ model.mpd.train()
287
+
288
+ if epoch >= diff_epoch:
289
+ start_ds = True
290
+
291
+ for i, batch in enumerate(train_dataloader):
292
+ waves = batch[0]
293
+ batch = [b.to(device) for b in batch[1:]]
294
+ (
295
+ texts,
296
+ input_lengths,
297
+ ref_texts,
298
+ ref_lengths,
299
+ mels,
300
+ mel_input_length,
301
+ ref_mels,
302
+ ) = batch
303
+
304
+ with torch.no_grad():
305
+ mask = length_to_mask(mel_input_length // (2**n_down)).to(device)
306
+ mel_mask = length_to_mask(mel_input_length).to(device)
307
+ text_mask = length_to_mask(input_lengths).to(texts.device)
308
+
309
+ try:
310
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
311
+ s2s_attn = s2s_attn.transpose(-1, -2)
312
+ s2s_attn = s2s_attn[..., 1:]
313
+ s2s_attn = s2s_attn.transpose(-1, -2)
314
+ except:
315
+ continue
316
+
317
+ mask_ST = mask_from_lens(
318
+ s2s_attn, input_lengths, mel_input_length // (2**n_down)
319
+ )
320
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
321
+
322
+ # encode
323
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
324
+ asr = t_en @ s2s_attn_mono
325
+
326
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
327
+
328
+ # compute reference styles
329
+ if multispeaker and epoch >= diff_epoch:
330
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
331
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
332
+ ref = torch.cat([ref_ss, ref_sp], dim=1)
333
+
334
+ # compute the style of the entire utterance
335
+ # this operation cannot be done in batch because of the avgpool layer (may need to work on masked avgpool)
336
+ ss = []
337
+ gs = []
338
+ for bib in range(len(mel_input_length)):
339
+ mel_length = int(mel_input_length[bib].item())
340
+ mel = mels[bib, :, : mel_input_length[bib]]
341
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
342
+ ss.append(s)
343
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
344
+ gs.append(s)
345
+
346
+ s_dur = torch.stack(ss).squeeze() # global prosodic styles
347
+ gs = torch.stack(gs).squeeze() # global acoustic styles
348
+ s_trg = torch.cat([gs, s_dur], dim=-1).detach() # ground truth for denoiser
349
+
350
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
351
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
352
+
353
+ # denoiser training
354
+ if epoch >= diff_epoch:
355
+ num_steps = np.random.randint(3, 5)
356
+
357
+ if model_params.diffusion.dist.estimate_sigma_data:
358
+ model.diffusion.module.diffusion.sigma_data = (
359
+ s_trg.std(axis=-1).mean().item()
360
+ ) # batch-wise std estimation
361
+ running_std.append(model.diffusion.module.diffusion.sigma_data)
362
+
363
+ if multispeaker:
364
+ s_preds = sampler(
365
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
366
+ embedding=bert_dur,
367
+ embedding_scale=1,
368
+ features=ref, # reference from the same speaker as the embedding
369
+ embedding_mask_proba=0.1,
370
+ num_steps=num_steps,
371
+ ).squeeze(1)
372
+ loss_diff = model.diffusion(
373
+ s_trg.unsqueeze(1), embedding=bert_dur, features=ref
374
+ ).mean() # EDM loss
375
+ loss_sty = F.l1_loss(
376
+ s_preds, s_trg.detach()
377
+ ) # style reconstruction loss
378
+ else:
379
+ s_preds = sampler(
380
+ noise=torch.randn_like(s_trg).unsqueeze(1).to(device),
381
+ embedding=bert_dur,
382
+ embedding_scale=1,
383
+ embedding_mask_proba=0.1,
384
+ num_steps=num_steps,
385
+ ).squeeze(1)
386
+ loss_diff = model.diffusion.module.diffusion(
387
+ s_trg.unsqueeze(1), embedding=bert_dur
388
+ ).mean() # EDM loss
389
+ loss_sty = F.l1_loss(
390
+ s_preds, s_trg.detach()
391
+ ) # style reconstruction loss
392
+ else:
393
+ loss_sty = 0
394
+ loss_diff = 0
395
+
396
+ d, p = model.predictor(d_en, s_dur, input_lengths, s2s_attn_mono, text_mask)
397
+
398
+ mel_len = min(int(mel_input_length.min().item() / 2 - 1), max_len // 2)
399
+ mel_len_st = int(mel_input_length.min().item() / 2 - 1)
400
+ en = []
401
+ gt = []
402
+ st = []
403
+ p_en = []
404
+ wav = []
405
+
406
+ for bib in range(len(mel_input_length)):
407
+ mel_length = int(mel_input_length[bib].item() / 2)
408
+
409
+ random_start = np.random.randint(0, mel_length - mel_len)
410
+ en.append(asr[bib, :, random_start : random_start + mel_len])
411
+ p_en.append(p[bib, :, random_start : random_start + mel_len])
412
+ gt.append(
413
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len) * 2)]
414
+ )
415
+
416
+ y = waves[bib][
417
+ (random_start * 2) * 300 : ((random_start + mel_len) * 2) * 300
418
+ ]
419
+ wav.append(torch.from_numpy(y).to(device))
420
+
421
+ # style reference (better to be different from the GT)
422
+ random_start = np.random.randint(0, mel_length - mel_len_st)
423
+ st.append(
424
+ mels[bib, :, (random_start * 2) : ((random_start + mel_len_st) * 2)]
425
+ )
426
+
427
+ wav = torch.stack(wav).float().detach()
428
+
429
+ en = torch.stack(en)
430
+ p_en = torch.stack(p_en)
431
+ gt = torch.stack(gt).detach()
432
+ st = torch.stack(st).detach()
433
+
434
+ if gt.size(-1) < 80:
435
+ continue
436
+
437
+ s_dur = model.predictor_encoder(
438
+ st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
439
+ )
440
+ s = model.style_encoder(
441
+ st.unsqueeze(1) if multispeaker else gt.unsqueeze(1)
442
+ )
443
+
444
+ with torch.no_grad():
445
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
446
+ F0 = F0.reshape(F0.shape[0], F0.shape[1] * 2, F0.shape[2], 1).squeeze()
447
+
448
+ asr_real = model.text_aligner.get_feature(gt)
449
+
450
+ N_real = log_norm(gt.unsqueeze(1)).squeeze(1)
451
+
452
+ y_rec_gt = wav.unsqueeze(1)
453
+ y_rec_gt_pred = model.decoder(en, F0_real, N_real, s)
454
+
455
+ if epoch >= joint_epoch:
456
+ # ground truth from recording
457
+ wav = y_rec_gt # use recording since decoder is tuned
458
+ else:
459
+ # ground truth from reconstruction
460
+ wav = y_rec_gt_pred # use reconstruction since decoder is fixed
461
+
462
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
463
+
464
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
465
+
466
+ loss_F0_rec = (F.smooth_l1_loss(F0_real, F0_fake)) / 10
467
+ loss_norm_rec = F.smooth_l1_loss(N_real, N_fake)
468
+
469
+ if start_ds:
470
+ optimizer.zero_grad()
471
+ d_loss = dl(wav.detach(), y_rec.detach()).mean()
472
+ d_loss.backward()
473
+ optimizer.step("msd")
474
+ optimizer.step("mpd")
475
+ else:
476
+ d_loss = 0
477
+
478
+ # generator loss
479
+ optimizer.zero_grad()
480
+
481
+ loss_mel = stft_loss(y_rec, wav)
482
+ if start_ds:
483
+ loss_gen_all = gl(wav, y_rec).mean()
484
+ else:
485
+ loss_gen_all = 0
486
+ loss_lm = wl(wav.detach().squeeze(), y_rec.squeeze()).mean()
487
+
488
+ loss_ce = 0
489
+ loss_dur = 0
490
+ for _s2s_pred, _text_input, _text_length in zip(d, (d_gt), input_lengths):
491
+ _s2s_pred = _s2s_pred[:_text_length, :]
492
+ _text_input = _text_input[:_text_length].long()
493
+ _s2s_trg = torch.zeros_like(_s2s_pred)
494
+ for p in range(_s2s_trg.shape[0]):
495
+ _s2s_trg[p, : _text_input[p]] = 1
496
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
497
+
498
+ loss_dur += F.l1_loss(
499
+ _dur_pred[1 : _text_length - 1], _text_input[1 : _text_length - 1]
500
+ )
501
+ loss_ce += F.binary_cross_entropy_with_logits(
502
+ _s2s_pred.flatten(), _s2s_trg.flatten()
503
+ )
504
+
505
+ loss_ce /= texts.size(0)
506
+ loss_dur /= texts.size(0)
507
+
508
+ g_loss = (
509
+ loss_params.lambda_mel * loss_mel
510
+ + loss_params.lambda_F0 * loss_F0_rec
511
+ + loss_params.lambda_ce * loss_ce
512
+ + loss_params.lambda_norm * loss_norm_rec
513
+ + loss_params.lambda_dur * loss_dur
514
+ + loss_params.lambda_gen * loss_gen_all
515
+ + loss_params.lambda_slm * loss_lm
516
+ + loss_params.lambda_sty * loss_sty
517
+ + loss_params.lambda_diff * loss_diff
518
+ )
519
+
520
+ running_loss += loss_mel.item()
521
+ g_loss.backward()
522
+ if torch.isnan(g_loss):
523
+ from IPython.core.debugger import set_trace
524
+
525
+ set_trace()
526
+
527
+ optimizer.step("bert_encoder")
528
+ optimizer.step("bert")
529
+ optimizer.step("predictor")
530
+ optimizer.step("predictor_encoder")
531
+
532
+ if epoch >= diff_epoch:
533
+ optimizer.step("diffusion")
534
+
535
+ if epoch >= joint_epoch:
536
+ optimizer.step("style_encoder")
537
+ optimizer.step("decoder")
538
+
539
+ # randomly pick whether to use in-distribution text
540
+ if np.random.rand() < 0.5:
541
+ use_ind = True
542
+ else:
543
+ use_ind = False
544
+
545
+ if use_ind:
546
+ ref_lengths = input_lengths
547
+ ref_texts = texts
548
+
549
+ slm_out = slmadv(
550
+ i,
551
+ y_rec_gt,
552
+ y_rec_gt_pred,
553
+ waves,
554
+ mel_input_length,
555
+ ref_texts,
556
+ ref_lengths,
557
+ use_ind,
558
+ s_trg.detach(),
559
+ ref if multispeaker else None,
560
+ )
561
+
562
+ if slm_out is None:
563
+ continue
564
+
565
+ d_loss_slm, loss_gen_lm, y_pred = slm_out
566
+
567
+ # SLM generator loss
568
+ optimizer.zero_grad()
569
+ loss_gen_lm.backward()
570
+
571
+ # SLM discriminator loss
572
+ if d_loss_slm != 0:
573
+ optimizer.zero_grad()
574
+ d_loss_slm.backward(retain_graph=True)
575
+ optimizer.step("wd")
576
+
577
+ # compute the gradient norm
578
+ total_norm = {}
579
+ for key in model.keys():
580
+ total_norm[key] = 0
581
+ parameters = [
582
+ p
583
+ for p in model[key].parameters()
584
+ if p.grad is not None and p.requires_grad
585
+ ]
586
+ for p in parameters:
587
+ param_norm = p.grad.detach().data.norm(2)
588
+ total_norm[key] += param_norm.item() ** 2
589
+ total_norm[key] = total_norm[key] ** 0.5
590
+
591
+ # gradient scaling
592
+ if total_norm["predictor"] > slmadv_params.thresh:
593
+ for key in model.keys():
594
+ for p in model[key].parameters():
595
+ if p.grad is not None:
596
+ p.grad *= 1 / total_norm["predictor"]
597
+
598
+ for p in model.predictor.duration_proj.parameters():
599
+ if p.grad is not None:
600
+ p.grad *= slmadv_params.scale
601
+
602
+ for p in model.predictor.lstm.parameters():
603
+ if p.grad is not None:
604
+ p.grad *= slmadv_params.scale
605
+
606
+ for p in model.diffusion.parameters():
607
+ if p.grad is not None:
608
+ p.grad *= slmadv_params.scale
609
+
610
+ optimizer.step("bert_encoder")
611
+ optimizer.step("bert")
612
+ optimizer.step("predictor")
613
+ optimizer.step("diffusion")
614
+ else:
615
+ d_loss_slm, loss_gen_lm = 0, 0
616
+
617
+ iters = iters + 1
618
+
619
+ if (i + 1) % log_interval == 0:
620
+ logger.info(
621
+ "Epoch [%d/%d], Step [%d/%d], Loss: %.5f, Disc Loss: %.5f, Dur Loss: %.5f, CE Loss: %.5f, Norm Loss: %.5f, F0 Loss: %.5f, LM Loss: %.5f, Gen Loss: %.5f, Sty Loss: %.5f, Diff Loss: %.5f, DiscLM Loss: %.5f, GenLM Loss: %.5f"
622
+ % (
623
+ epoch + 1,
624
+ epochs,
625
+ i + 1,
626
+ len(train_list) // batch_size,
627
+ running_loss / log_interval,
628
+ d_loss,
629
+ loss_dur,
630
+ loss_ce,
631
+ loss_norm_rec,
632
+ loss_F0_rec,
633
+ loss_lm,
634
+ loss_gen_all,
635
+ loss_sty,
636
+ loss_diff,
637
+ d_loss_slm,
638
+ loss_gen_lm,
639
+ )
640
+ )
641
+
642
+ writer.add_scalar("train/mel_loss", running_loss / log_interval, iters)
643
+ writer.add_scalar("train/gen_loss", loss_gen_all, iters)
644
+ writer.add_scalar("train/d_loss", d_loss, iters)
645
+ writer.add_scalar("train/ce_loss", loss_ce, iters)
646
+ writer.add_scalar("train/dur_loss", loss_dur, iters)
647
+ writer.add_scalar("train/slm_loss", loss_lm, iters)
648
+ writer.add_scalar("train/norm_loss", loss_norm_rec, iters)
649
+ writer.add_scalar("train/F0_loss", loss_F0_rec, iters)
650
+ writer.add_scalar("train/sty_loss", loss_sty, iters)
651
+ writer.add_scalar("train/diff_loss", loss_diff, iters)
652
+ writer.add_scalar("train/d_loss_slm", d_loss_slm, iters)
653
+ writer.add_scalar("train/gen_loss_slm", loss_gen_lm, iters)
654
+
655
+ running_loss = 0
656
+
657
+ print("Time elasped:", time.time() - start_time)
658
+
659
+ loss_test = 0
660
+ loss_align = 0
661
+ loss_f = 0
662
+ _ = [model[key].eval() for key in model]
663
+
664
+ with torch.no_grad():
665
+ iters_test = 0
666
+ for batch_idx, batch in enumerate(val_dataloader):
667
+ optimizer.zero_grad()
668
+
669
+ try:
670
+ waves = batch[0]
671
+ batch = [b.to(device) for b in batch[1:]]
672
+ (
673
+ texts,
674
+ input_lengths,
675
+ ref_texts,
676
+ ref_lengths,
677
+ mels,
678
+ mel_input_length,
679
+ ref_mels,
680
+ ) = batch
681
+ with torch.no_grad():
682
+ mask = length_to_mask(mel_input_length // (2**n_down)).to(
683
+ "cuda"
684
+ )
685
+ text_mask = length_to_mask(input_lengths).to(texts.device)
686
+
687
+ _, _, s2s_attn = model.text_aligner(mels, mask, texts)
688
+ s2s_attn = s2s_attn.transpose(-1, -2)
689
+ s2s_attn = s2s_attn[..., 1:]
690
+ s2s_attn = s2s_attn.transpose(-1, -2)
691
+
692
+ mask_ST = mask_from_lens(
693
+ s2s_attn, input_lengths, mel_input_length // (2**n_down)
694
+ )
695
+ s2s_attn_mono = maximum_path(s2s_attn, mask_ST)
696
+
697
+ # encode
698
+ t_en = model.text_encoder(texts, input_lengths, text_mask)
699
+ asr = t_en @ s2s_attn_mono
700
+
701
+ d_gt = s2s_attn_mono.sum(axis=-1).detach()
702
+
703
+ ss = []
704
+ gs = []
705
+
706
+ for bib in range(len(mel_input_length)):
707
+ mel_length = int(mel_input_length[bib].item())
708
+ mel = mels[bib, :, : mel_input_length[bib]]
709
+ s = model.predictor_encoder(mel.unsqueeze(0).unsqueeze(1))
710
+ ss.append(s)
711
+ s = model.style_encoder(mel.unsqueeze(0).unsqueeze(1))
712
+ gs.append(s)
713
+
714
+ s = torch.stack(ss).squeeze()
715
+ gs = torch.stack(gs).squeeze()
716
+ s_trg = torch.cat([s, gs], dim=-1).detach()
717
+
718
+ bert_dur = model.bert(texts, attention_mask=(~text_mask).int())
719
+ d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
720
+ d, p = model.predictor(
721
+ d_en, s, input_lengths, s2s_attn_mono, text_mask
722
+ )
723
+ # get clips
724
+ mel_len = int(mel_input_length.min().item() / 2 - 1)
725
+ en = []
726
+ gt = []
727
+ p_en = []
728
+ wav = []
729
+
730
+ for bib in range(len(mel_input_length)):
731
+ mel_length = int(mel_input_length[bib].item() / 2)
732
+
733
+ random_start = np.random.randint(0, mel_length - mel_len)
734
+ en.append(asr[bib, :, random_start : random_start + mel_len])
735
+ p_en.append(p[bib, :, random_start : random_start + mel_len])
736
+
737
+ gt.append(
738
+ mels[
739
+ bib,
740
+ :,
741
+ (random_start * 2) : ((random_start + mel_len) * 2),
742
+ ]
743
+ )
744
+
745
+ y = waves[bib][
746
+ (random_start * 2)
747
+ * 300 : ((random_start + mel_len) * 2)
748
+ * 300
749
+ ]
750
+ wav.append(torch.from_numpy(y).to(device))
751
+
752
+ wav = torch.stack(wav).float().detach()
753
+
754
+ en = torch.stack(en)
755
+ p_en = torch.stack(p_en)
756
+ gt = torch.stack(gt).detach()
757
+
758
+ s = model.predictor_encoder(gt.unsqueeze(1))
759
+
760
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s)
761
+
762
+ loss_dur = 0
763
+ for _s2s_pred, _text_input, _text_length in zip(
764
+ d, (d_gt), input_lengths
765
+ ):
766
+ _s2s_pred = _s2s_pred[:_text_length, :]
767
+ _text_input = _text_input[:_text_length].long()
768
+ _s2s_trg = torch.zeros_like(_s2s_pred)
769
+ for bib in range(_s2s_trg.shape[0]):
770
+ _s2s_trg[bib, : _text_input[bib]] = 1
771
+ _dur_pred = torch.sigmoid(_s2s_pred).sum(axis=1)
772
+ loss_dur += F.l1_loss(
773
+ _dur_pred[1 : _text_length - 1],
774
+ _text_input[1 : _text_length - 1],
775
+ )
776
+
777
+ loss_dur /= texts.size(0)
778
+
779
+ s = model.style_encoder(gt.unsqueeze(1))
780
+
781
+ y_rec = model.decoder(en, F0_fake, N_fake, s)
782
+ loss_mel = stft_loss(y_rec.squeeze(), wav.detach())
783
+
784
+ F0_real, _, F0 = model.pitch_extractor(gt.unsqueeze(1))
785
+
786
+ loss_F0 = F.l1_loss(F0_real, F0_fake) / 10
787
+
788
+ loss_test += (loss_mel).mean()
789
+ loss_align += (loss_dur).mean()
790
+ loss_f += (loss_F0).mean()
791
+
792
+ iters_test += 1
793
+ except:
794
+ continue
795
+
796
+ print("Epochs:", epoch + 1)
797
+ logger.info(
798
+ "Validation loss: %.3f, Dur loss: %.3f, F0 loss: %.3f"
799
+ % (loss_test / iters_test, loss_align / iters_test, loss_f / iters_test)
800
+ + "\n\n\n"
801
+ )
802
+ print("\n\n\n")
803
+ writer.add_scalar("eval/mel_loss", loss_test / iters_test, epoch + 1)
804
+ writer.add_scalar("eval/dur_loss", loss_test / iters_test, epoch + 1)
805
+ writer.add_scalar("eval/F0_loss", loss_f / iters_test, epoch + 1)
806
+
807
+ if epoch < joint_epoch:
808
+ # generating reconstruction examples with GT duration
809
+
810
+ with torch.no_grad():
811
+ for bib in range(len(asr)):
812
+ mel_length = int(mel_input_length[bib].item())
813
+ gt = mels[bib, :, :mel_length].unsqueeze(0)
814
+ en = asr[bib, :, : mel_length // 2].unsqueeze(0)
815
+
816
+ F0_real, _, _ = model.pitch_extractor(gt.unsqueeze(1))
817
+ F0_real = F0_real.unsqueeze(0)
818
+ s = model.style_encoder(gt.unsqueeze(1))
819
+ real_norm = log_norm(gt.unsqueeze(1)).squeeze(1)
820
+
821
+ y_rec = model.decoder(en, F0_real, real_norm, s)
822
+
823
+ writer.add_audio(
824
+ "eval/y" + str(bib),
825
+ y_rec.cpu().numpy().squeeze(),
826
+ epoch,
827
+ sample_rate=sr,
828
+ )
829
+
830
+ s_dur = model.predictor_encoder(gt.unsqueeze(1))
831
+ p_en = p[bib, :, : mel_length // 2].unsqueeze(0)
832
+
833
+ F0_fake, N_fake = model.predictor.F0Ntrain(p_en, s_dur)
834
+
835
+ y_pred = model.decoder(en, F0_fake, N_fake, s)
836
+
837
+ writer.add_audio(
838
+ "pred/y" + str(bib),
839
+ y_pred.cpu().numpy().squeeze(),
840
+ epoch,
841
+ sample_rate=sr,
842
+ )
843
+
844
+ if epoch == 0:
845
+ writer.add_audio(
846
+ "gt/y" + str(bib),
847
+ waves[bib].squeeze(),
848
+ epoch,
849
+ sample_rate=sr,
850
+ )
851
+
852
+ if bib >= 5:
853
+ break
854
+ else:
855
+ # generating sampled speech from text directly
856
+ with torch.no_grad():
857
+ # compute reference styles
858
+ if multispeaker and epoch >= diff_epoch:
859
+ ref_ss = model.style_encoder(ref_mels.unsqueeze(1))
860
+ ref_sp = model.predictor_encoder(ref_mels.unsqueeze(1))
861
+ ref_s = torch.cat([ref_ss, ref_sp], dim=1)
862
+
863
+ for bib in range(len(d_en)):
864
+ if multispeaker:
865
+ s_pred = sampler(
866
+ noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
867
+ embedding=bert_dur[bib].unsqueeze(0),
868
+ embedding_scale=1,
869
+ features=ref_s[bib].unsqueeze(
870
+ 0
871
+ ), # reference from the same speaker as the embedding
872
+ num_steps=5,
873
+ ).squeeze(1)
874
+ else:
875
+ s_pred = sampler(
876
+ noise=torch.randn((1, 256)).unsqueeze(1).to(texts.device),
877
+ embedding=bert_dur[bib].unsqueeze(0),
878
+ embedding_scale=1,
879
+ num_steps=5,
880
+ ).squeeze(1)
881
+
882
+ s = s_pred[:, 128:]
883
+ ref = s_pred[:, :128]
884
+
885
+ d = model.predictor.text_encoder(
886
+ d_en[bib, :, : input_lengths[bib]].unsqueeze(0),
887
+ s,
888
+ input_lengths[bib, ...].unsqueeze(0),
889
+ text_mask[bib, : input_lengths[bib]].unsqueeze(0),
890
+ )
891
+
892
+ x, _ = model.predictor.lstm(d)
893
+ duration = model.predictor.duration_proj(x)
894
+
895
+ duration = torch.sigmoid(duration).sum(axis=-1)
896
+ pred_dur = torch.round(duration.squeeze()).clamp(min=1)
897
+
898
+ pred_dur[-1] += 5
899
+
900
+ pred_aln_trg = torch.zeros(
901
+ input_lengths[bib], int(pred_dur.sum().data)
902
+ )
903
+ c_frame = 0
904
+ for i in range(pred_aln_trg.size(0)):
905
+ pred_aln_trg[i, c_frame : c_frame + int(pred_dur[i].data)] = 1
906
+ c_frame += int(pred_dur[i].data)
907
+
908
+ # encode prosody
909
+ en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(
910
+ texts.device
911
+ )
912
+ F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
913
+ out = model.decoder(
914
+ (
915
+ t_en[bib, :, : input_lengths[bib]].unsqueeze(0)
916
+ @ pred_aln_trg.unsqueeze(0).to(texts.device)
917
+ ),
918
+ F0_pred,
919
+ N_pred,
920
+ ref.squeeze().unsqueeze(0),
921
+ )
922
+
923
+ writer.add_audio(
924
+ "pred/y" + str(bib),
925
+ out.cpu().numpy().squeeze(),
926
+ epoch,
927
+ sample_rate=sr,
928
+ )
929
+
930
+ if bib >= 5:
931
+ break
932
+
933
+ if epoch % saving_epoch == 0:
934
+ if (loss_test / iters_test) < best_loss:
935
+ best_loss = loss_test / iters_test
936
+ print("Saving..")
937
+ state = {
938
+ "net": {key: model[key].state_dict() for key in model},
939
+ "optimizer": optimizer.state_dict(),
940
+ "iters": iters,
941
+ "val_loss": loss_test / iters_test,
942
+ "epoch": epoch,
943
+ }
944
+ save_path = osp.join(log_dir, "epoch_2nd_%05d.pth" % epoch)
945
+ torch.save(state, save_path)
946
+
947
+ # if estimate sigma, save the estimated simga
948
+ if model_params.diffusion.dist.estimate_sigma_data:
949
+ config["model_params"]["diffusion"]["dist"]["sigma_data"] = float(
950
+ np.mean(running_std)
951
+ )
952
+
953
+ with open(osp.join(log_dir, osp.basename(config_path)), "w") as outfile:
954
+ yaml.dump(config, outfile, default_flow_style=True)
955
+
956
+
957
+ if __name__ == "__main__":
958
+ main()
utils.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from monotonic_align import maximum_path
2
+ from monotonic_align import mask_from_lens
3
+ from monotonic_align.core import maximum_path_c
4
+ import numpy as np
5
+ import torch
6
+ import copy
7
+ from torch import nn
8
+ import torch.nn.functional as F
9
+ import torchaudio
10
+ import librosa
11
+ import matplotlib.pyplot as plt
12
+ from munch import Munch
13
+
14
+
15
+ def maximum_path(neg_cent, mask):
16
+ """Cython optimized version.
17
+ neg_cent: [b, t_t, t_s]
18
+ mask: [b, t_t, t_s]
19
+ """
20
+ device = neg_cent.device
21
+ dtype = neg_cent.dtype
22
+ neg_cent = np.ascontiguousarray(neg_cent.data.cpu().numpy().astype(np.float32))
23
+ path = np.ascontiguousarray(np.zeros(neg_cent.shape, dtype=np.int32))
24
+
25
+ t_t_max = np.ascontiguousarray(
26
+ mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
27
+ )
28
+ t_s_max = np.ascontiguousarray(
29
+ mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
30
+ )
31
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
32
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
33
+
34
+
35
+ def get_data_path_list(train_path=None, val_path=None):
36
+ if train_path is None:
37
+ train_path = "Data/train_list.txt"
38
+ if val_path is None:
39
+ val_path = "Data/val_list.txt"
40
+
41
+ with open(train_path, "r", encoding="utf-8", errors="ignore") as f:
42
+ train_list = f.readlines()
43
+ with open(val_path, "r", encoding="utf-8", errors="ignore") as f:
44
+ val_list = f.readlines()
45
+
46
+ return train_list, val_list
47
+
48
+
49
+ def length_to_mask(lengths):
50
+ mask = (
51
+ torch.arange(lengths.max())
52
+ .unsqueeze(0)
53
+ .expand(lengths.shape[0], -1)
54
+ .type_as(lengths)
55
+ )
56
+ mask = torch.gt(mask + 1, lengths.unsqueeze(1))
57
+ return mask
58
+
59
+
60
+ # for norm consistency loss
61
+ def log_norm(x, mean=-4, std=4, dim=2):
62
+ """
63
+ normalized log mel -> mel -> norm -> log(norm)
64
+ """
65
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
66
+ return x
67
+
68
+
69
+ def get_image(arrs):
70
+ plt.switch_backend("agg")
71
+ fig = plt.figure()
72
+ ax = plt.gca()
73
+ ax.imshow(arrs)
74
+
75
+ return fig
76
+
77
+
78
+ def recursive_munch(d):
79
+ if isinstance(d, dict):
80
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
81
+ elif isinstance(d, list):
82
+ return [recursive_munch(v) for v in d]
83
+ else:
84
+ return d
85
+
86
+
87
+ def log_print(message, logger):
88
+ logger.info(message)
89
+ print(message)