Spaces:
Runtime error
Runtime error
ddd
commited on
Commit
•
40e984c
1
Parent(s):
c4e83e4
pndm codes
Browse files- .gitattributes +1 -0
- docs/README-SVS-opencpop-cascade.md +3 -3
- docs/README-SVS-opencpop-e2e.md +2 -1
- docs/README-SVS-popcs.md +1 -1
- docs/README-SVS.md +41 -9
- docs/README-TTS.md +7 -1
- inference/svs/base_svs_infer.py +1 -1
- inference/svs/ds_cascade.py +2 -0
- inference/svs/ds_e2e.py +2 -2
- inference/svs/gradio/infer.py +1 -1
- modules/diffsinger_midi/fs2.py +110 -0
- modules/hifigan/hifigan.py +365 -365
- modules/hifigan/mel_utils.py +80 -80
- modules/parallel_wavegan/models/parallel_wavegan.py +434 -434
- usr/configs/midi/cascade/opencs/ds60_rel.yaml +2 -1
- usr/diff/shallow_diffusion_tts.py +324 -273
- utils/hparams.py +36 -44
.gitattributes
CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
30 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
31 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
32 |
model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
|
33 |
+
checkpoints/0831_opencpop_ds1000 filter=lfs diff=lfs merge=lfs -text
|
docs/README-SVS-opencpop-cascade.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
|
6 |
-
## DiffSinger (MIDI version
|
7 |
### 0. Data Acquirement
|
8 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
9 |
|
@@ -67,7 +67,7 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/ope
|
|
67 |
|
68 |
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
|
69 |
|
70 |
-
### 3. Inference
|
71 |
```sh
|
72 |
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
73 |
```
|
@@ -82,7 +82,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
|
|
82 |
|
83 |
### 4. Inference from raw inputs
|
84 |
```sh
|
85 |
-
python inference/svs/
|
86 |
```
|
87 |
Raw inputs:
|
88 |
```
|
|
|
3 |
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
|
6 |
+
## DiffSinger (MIDI SVS | A version)
|
7 |
### 0. Data Acquirement
|
8 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
9 |
|
|
|
67 |
|
68 |
Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
|
69 |
|
70 |
+
### 3. Inference from packed test set
|
71 |
```sh
|
72 |
CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
|
73 |
```
|
|
|
82 |
|
83 |
### 4. Inference from raw inputs
|
84 |
```sh
|
85 |
+
python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
|
86 |
```
|
87 |
Raw inputs:
|
88 |
```
|
docs/README-SVS-opencpop-e2e.md
CHANGED
@@ -2,13 +2,14 @@
|
|
2 |
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
|
|
5 |
|
6 |
Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
|
7 |
**By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
|
8 |
|
9 |
简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
|
10 |
|
11 |
-
## DiffSinger (MIDI version
|
12 |
### 0. Data Acquirement
|
13 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
14 |
|
|
|
2 |
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
| [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
|
6 |
|
7 |
Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
|
8 |
**By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
|
9 |
|
10 |
简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
|
11 |
|
12 |
+
## DiffSinger (MIDI SVS | B version)
|
13 |
### 0. Data Acquirement
|
14 |
For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
|
15 |
|
docs/README-SVS-popcs.md
CHANGED
@@ -54,7 +54,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
|
|
54 |
*Note that:*
|
55 |
|
56 |
- *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
|
57 |
-
- *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-
|
58 |
|
59 |
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
60 |
|
|
|
54 |
*Note that:*
|
55 |
|
56 |
- *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
|
57 |
+
- *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-A-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-B-version](README-SVS-opencpop-e2e.md)).*
|
58 |
|
59 |
[1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
|
60 |
|
docs/README-SVS.md
CHANGED
@@ -1,7 +1,13 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
|
4 |
-
In
|
5 |
|
6 |
Thus, the pipeline of this part can be summarized as:
|
7 |
|
@@ -18,13 +24,16 @@ Thus, the pipeline of this part can be summarized as:
|
|
18 |
|
19 |
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
20 |
|
|
|
|
|
|
|
21 |
### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
|
22 |
-
Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022
|
23 |
|
24 |
Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
|
25 |
|
26 |
-
#### 2.
|
27 |
-
Thus, the pipeline of [
|
28 |
|
29 |
```
|
30 |
[lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
|
@@ -32,13 +41,36 @@ Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summari
|
|
32 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
33 |
```
|
34 |
|
35 |
-
|
36 |
-
|
|
|
|
|
37 |
|
38 |
-
Thus, the pipeline of [
|
39 |
```
|
40 |
[lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
|
41 |
[linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
42 |
[mel-spectrogram] -> [predicted F0] (Pitch extractor)
|
43 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
44 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
| [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
|
6 |
+
|
7 |
+
## DiffSinger (SVS)
|
8 |
|
9 |
### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
|
10 |
+
In PART1, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3]. If you want to conduct experiments with F0 prediction, please move to PART2.
|
11 |
|
12 |
Thus, the pipeline of this part can be summarized as:
|
13 |
|
|
|
24 |
|
25 |
[3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
|
26 |
|
27 |
+
Click here for detailed instructions: [link](README-SVS-popcs.md).
|
28 |
+
|
29 |
+
|
30 |
### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
|
31 |
+
Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022** (after we published our paper).
|
32 |
|
33 |
Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
|
34 |
|
35 |
+
#### 2.A
|
36 |
+
Thus, the pipeline of [2.A](README-SVS-opencpop-cascade.md) can be summarized as:
|
37 |
|
38 |
```
|
39 |
[lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
|
|
|
41 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
42 |
```
|
43 |
|
44 |
+
Click here for detailed instructions: [link](README-SVS-opencpop-cascade.md).
|
45 |
+
|
46 |
+
#### 2.B
|
47 |
+
In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend and make a joint prediction with spectrograms.
|
48 |
|
49 |
+
Thus, the pipeline of [2.B](README-SVS-opencpop-e2e.md) can be summarized as:
|
50 |
```
|
51 |
[lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
|
52 |
[linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
|
53 |
[mel-spectrogram] -> [predicted F0] (Pitch extractor)
|
54 |
[mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
|
55 |
+
```
|
56 |
+
|
57 |
+
Click here for detailed instructions: [link](README-SVS-opencpop-e2e.md).
|
58 |
+
|
59 |
+
### FAQ
|
60 |
+
Q1: Why do I need F0 in Vocoders?
|
61 |
+
|
62 |
+
A1: See vocoder parts in HiFiSinger, DiffSinger or SingGAN. This is a common practice now.
|
63 |
+
|
64 |
+
Q2: Why not run MIDI version SVS on PopCS dataset? or Why not release MIDI labels for PopCS dataset?
|
65 |
+
|
66 |
+
A2: Our laboratory has no funds to label PopCS dataset. But there are funds for labeling other singing dataset, which is coming soon.
|
67 |
+
|
68 |
+
Q3: Why " 'HifiGAN' object has no attribute 'model' "?
|
69 |
+
|
70 |
+
A3: Please put the pretrained vocoders in your `checkpoints` dictionary.
|
71 |
+
|
72 |
+
Q4: How to check whether I use GT information or predicted information during inference from packed test set?
|
73 |
+
|
74 |
+
A4: Please see codes [here](https://github.com/MoonInTheRiver/DiffSinger/blob/55e2f46068af6e69940a9f8f02d306c24a940cab/tasks/tts/fs2.py#L343).
|
75 |
+
|
76 |
+
...
|
docs/README-TTS.md
CHANGED
@@ -1,4 +1,10 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
### 1. Preparation
|
3 |
|
4 |
#### Data Preparation
|
|
|
1 |
+
# DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
|
2 |
+
[![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
|
3 |
+
[![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
|
4 |
+
[![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
|
5 |
+
| [Interactive🤗 TTS](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
|
6 |
+
|
7 |
+
## DiffSpeech (TTS)
|
8 |
### 1. Preparation
|
9 |
|
10 |
#### Data Preparation
|
inference/svs/base_svs_infer.py
CHANGED
@@ -142,7 +142,7 @@ class BaseSVSInfer:
|
|
142 |
ph_seq = inp['ph_seq']
|
143 |
note_lst = inp['note_seq'].split()
|
144 |
midi_dur_lst = inp['note_dur_seq'].split()
|
145 |
-
is_slur = inp['is_slur_seq'].split()
|
146 |
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
147 |
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
148 |
print('Pass word-notes check.')
|
|
|
142 |
ph_seq = inp['ph_seq']
|
143 |
note_lst = inp['note_seq'].split()
|
144 |
midi_dur_lst = inp['note_dur_seq'].split()
|
145 |
+
is_slur = [float(x) for x in inp['is_slur_seq'].split()]
|
146 |
print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
|
147 |
if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
|
148 |
print('Pass word-notes check.')
|
inference/svs/ds_cascade.py
CHANGED
@@ -52,3 +52,5 @@ if __name__ == '__main__':
|
|
52 |
'input_type': 'phoneme'
|
53 |
} # input like Opencpop dataset.
|
54 |
DiffSingerCascadeInfer.example_run(inp)
|
|
|
|
|
|
52 |
'input_type': 'phoneme'
|
53 |
} # input like Opencpop dataset.
|
54 |
DiffSingerCascadeInfer.example_run(inp)
|
55 |
+
|
56 |
+
# # CUDA_VISIBLE_DEVICES=1 python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
inference/svs/ds_e2e.py
CHANGED
@@ -53,7 +53,7 @@ if __name__ == '__main__':
|
|
53 |
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
54 |
'input_type': 'word'
|
55 |
} # user input: Chinese characters
|
56 |
-
|
57 |
'text': '小酒窝长睫毛AP是你最美的记号',
|
58 |
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
59 |
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
@@ -64,4 +64,4 @@ if __name__ == '__main__':
|
|
64 |
DiffSingerE2EInfer.example_run(inp)
|
65 |
|
66 |
|
67 |
-
# python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
|
|
53 |
'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
|
54 |
'input_type': 'word'
|
55 |
} # user input: Chinese characters
|
56 |
+
inp = {
|
57 |
'text': '小酒窝长睫毛AP是你最美的记号',
|
58 |
'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
|
59 |
'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
|
|
|
64 |
DiffSingerE2EInfer.example_run(inp)
|
65 |
|
66 |
|
67 |
+
# CUDA_VISIBLE_DEVICES=3 python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
inference/svs/gradio/infer.py
CHANGED
@@ -88,4 +88,4 @@ if __name__ == '__main__':
|
|
88 |
|
89 |
# python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
90 |
# python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
91 |
-
# CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
|
|
88 |
|
89 |
# python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
90 |
# python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
|
91 |
+
# CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
|
modules/diffsinger_midi/fs2.py
CHANGED
@@ -116,3 +116,113 @@ class FastSpeech2MIDI(FastSpeech2):
|
|
116 |
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
117 |
|
118 |
return ret
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
117 |
|
118 |
return ret
|
119 |
+
|
120 |
+
def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
|
121 |
+
decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
|
122 |
+
pitch_padding = mel2ph == 0
|
123 |
+
if hparams['pitch_ar']:
|
124 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
|
125 |
+
if f0 is None:
|
126 |
+
f0 = pitch_pred[:, :, 0]
|
127 |
+
else:
|
128 |
+
ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
|
129 |
+
if f0 is None:
|
130 |
+
f0 = pitch_pred[:, :, 0]
|
131 |
+
if hparams['use_uv'] and uv is None:
|
132 |
+
uv = pitch_pred[:, :, 1] > 0
|
133 |
+
|
134 |
+
# here f0_denorm for pitch prediction
|
135 |
+
ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
|
136 |
+
|
137 |
+
# here f0_denorm for mel prediction
|
138 |
+
if self.training:
|
139 |
+
mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
|
140 |
+
masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
|
141 |
+
uv_masked = uv.bool() | masked_uv
|
142 |
+
# print((uv.float()-uv_masked.float()).mean(dim=1))
|
143 |
+
f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
|
144 |
+
else:
|
145 |
+
f0_denorm = ret['f0_denorm']
|
146 |
+
|
147 |
+
if pitch_padding is not None:
|
148 |
+
f0[pitch_padding] = 0
|
149 |
+
|
150 |
+
pitch = f0_to_coarse(f0_denorm) # start from 0
|
151 |
+
pitch_embed = self.pitch_embed(pitch)
|
152 |
+
return pitch_embed
|
153 |
+
|
154 |
+
|
155 |
+
class FastSpeech2MIDIMasked(FastSpeech2MIDI):
|
156 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
157 |
+
ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
|
158 |
+
spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
|
159 |
+
ret = {}
|
160 |
+
|
161 |
+
midi_dur_embedding, slur_embedding = 0, 0
|
162 |
+
if kwargs.get('midi_dur') is not None:
|
163 |
+
midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
|
164 |
+
if kwargs.get('is_slur') is not None:
|
165 |
+
slur_embedding = self.is_slur_embed(kwargs['is_slur'])
|
166 |
+
encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
|
167 |
+
src_nonpadding = (txt_tokens > 0).float()[:, :, None]
|
168 |
+
|
169 |
+
# add ref style embed
|
170 |
+
# Not implemented
|
171 |
+
# variance encoder
|
172 |
+
var_embed = 0
|
173 |
+
|
174 |
+
# encoder_out_dur denotes encoder outputs for duration predictor
|
175 |
+
# in speech adaptation, duration predictor use old speaker embedding
|
176 |
+
if hparams['use_spk_embed']:
|
177 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
|
178 |
+
elif hparams['use_spk_id']:
|
179 |
+
spk_embed_id = spk_embed
|
180 |
+
if spk_embed_dur_id is None:
|
181 |
+
spk_embed_dur_id = spk_embed_id
|
182 |
+
if spk_embed_f0_id is None:
|
183 |
+
spk_embed_f0_id = spk_embed_id
|
184 |
+
spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
|
185 |
+
spk_embed_dur = spk_embed_f0 = spk_embed
|
186 |
+
if hparams['use_split_spk_id']:
|
187 |
+
spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
|
188 |
+
spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
|
189 |
+
else:
|
190 |
+
spk_embed_dur = spk_embed_f0 = spk_embed = 0
|
191 |
+
|
192 |
+
# add dur
|
193 |
+
dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
|
194 |
+
|
195 |
+
mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
|
196 |
+
|
197 |
+
decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
|
198 |
+
|
199 |
+
mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
|
200 |
+
decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
|
201 |
+
|
202 |
+
# expanded midi
|
203 |
+
midi_embedding = self.midi_embed(kwargs['pitch_midi'])
|
204 |
+
midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
|
205 |
+
midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
|
206 |
+
print(midi_embedding.shape, decoder_inp.shape)
|
207 |
+
midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
|
208 |
+
midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
|
209 |
+
|
210 |
+
tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
|
211 |
+
|
212 |
+
decoder_inp += midi_embedding
|
213 |
+
decoder_inp_origin = decoder_inp
|
214 |
+
# add pitch and energy embed
|
215 |
+
pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
|
216 |
+
if hparams['use_pitch_embed']:
|
217 |
+
pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
|
218 |
+
decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
|
219 |
+
if hparams['use_energy_embed']:
|
220 |
+
decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
|
221 |
+
|
222 |
+
ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
|
223 |
+
|
224 |
+
if skip_decoder:
|
225 |
+
return ret
|
226 |
+
ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
|
227 |
+
|
228 |
+
return ret
|
modules/hifigan/hifigan.py
CHANGED
@@ -1,365 +1,365 @@
|
|
1 |
-
import torch
|
2 |
-
import torch.nn.functional as F
|
3 |
-
import torch.nn as nn
|
4 |
-
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
-
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
-
|
7 |
-
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
8 |
-
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
9 |
-
import numpy as np
|
10 |
-
|
11 |
-
LRELU_SLOPE = 0.1
|
12 |
-
|
13 |
-
|
14 |
-
def init_weights(m, mean=0.0, std=0.01):
|
15 |
-
classname = m.__class__.__name__
|
16 |
-
if classname.find("Conv") != -1:
|
17 |
-
m.weight.data.normal_(mean, std)
|
18 |
-
|
19 |
-
|
20 |
-
def apply_weight_norm(m):
|
21 |
-
classname = m.__class__.__name__
|
22 |
-
if classname.find("Conv") != -1:
|
23 |
-
weight_norm(m)
|
24 |
-
|
25 |
-
|
26 |
-
def get_padding(kernel_size, dilation=1):
|
27 |
-
return int((kernel_size * dilation - dilation) / 2)
|
28 |
-
|
29 |
-
|
30 |
-
class ResBlock1(torch.nn.Module):
|
31 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
32 |
-
super(ResBlock1, self).__init__()
|
33 |
-
self.h = h
|
34 |
-
self.convs1 = nn.ModuleList([
|
35 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
36 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
37 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
38 |
-
padding=get_padding(kernel_size, dilation[1]))),
|
39 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
40 |
-
padding=get_padding(kernel_size, dilation[2])))
|
41 |
-
])
|
42 |
-
self.convs1.apply(init_weights)
|
43 |
-
|
44 |
-
self.convs2 = nn.ModuleList([
|
45 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
46 |
-
padding=get_padding(kernel_size, 1))),
|
47 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
-
padding=get_padding(kernel_size, 1))),
|
49 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
-
padding=get_padding(kernel_size, 1)))
|
51 |
-
])
|
52 |
-
self.convs2.apply(init_weights)
|
53 |
-
|
54 |
-
def forward(self, x):
|
55 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
56 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
57 |
-
xt = c1(xt)
|
58 |
-
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
59 |
-
xt = c2(xt)
|
60 |
-
x = xt + x
|
61 |
-
return x
|
62 |
-
|
63 |
-
def remove_weight_norm(self):
|
64 |
-
for l in self.convs1:
|
65 |
-
remove_weight_norm(l)
|
66 |
-
for l in self.convs2:
|
67 |
-
remove_weight_norm(l)
|
68 |
-
|
69 |
-
|
70 |
-
class ResBlock2(torch.nn.Module):
|
71 |
-
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
72 |
-
super(ResBlock2, self).__init__()
|
73 |
-
self.h = h
|
74 |
-
self.convs = nn.ModuleList([
|
75 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
76 |
-
padding=get_padding(kernel_size, dilation[0]))),
|
77 |
-
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
78 |
-
padding=get_padding(kernel_size, dilation[1])))
|
79 |
-
])
|
80 |
-
self.convs.apply(init_weights)
|
81 |
-
|
82 |
-
def forward(self, x):
|
83 |
-
for c in self.convs:
|
84 |
-
xt = F.leaky_relu(x, LRELU_SLOPE)
|
85 |
-
xt = c(xt)
|
86 |
-
x = xt + x
|
87 |
-
return x
|
88 |
-
|
89 |
-
def remove_weight_norm(self):
|
90 |
-
for l in self.convs:
|
91 |
-
remove_weight_norm(l)
|
92 |
-
|
93 |
-
|
94 |
-
class Conv1d1x1(Conv1d):
|
95 |
-
"""1x1 Conv1d with customized initialization."""
|
96 |
-
|
97 |
-
def __init__(self, in_channels, out_channels, bias):
|
98 |
-
"""Initialize 1x1 Conv1d module."""
|
99 |
-
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
100 |
-
kernel_size=1, padding=0,
|
101 |
-
dilation=1, bias=bias)
|
102 |
-
|
103 |
-
|
104 |
-
class HifiGanGenerator(torch.nn.Module):
|
105 |
-
def __init__(self, h, c_out=1):
|
106 |
-
super(HifiGanGenerator, self).__init__()
|
107 |
-
self.h = h
|
108 |
-
self.num_kernels = len(h['resblock_kernel_sizes'])
|
109 |
-
self.num_upsamples = len(h['upsample_rates'])
|
110 |
-
|
111 |
-
if h['use_pitch_embed']:
|
112 |
-
self.harmonic_num = 8
|
113 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
114 |
-
self.m_source = SourceModuleHnNSF(
|
115 |
-
sampling_rate=h['audio_sample_rate'],
|
116 |
-
harmonic_num=self.harmonic_num)
|
117 |
-
self.noise_convs = nn.ModuleList()
|
118 |
-
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
119 |
-
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
120 |
-
|
121 |
-
self.ups = nn.ModuleList()
|
122 |
-
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
123 |
-
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
124 |
-
self.ups.append(weight_norm(
|
125 |
-
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
126 |
-
if h['use_pitch_embed']:
|
127 |
-
if i + 1 < len(h['upsample_rates']):
|
128 |
-
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
129 |
-
self.noise_convs.append(Conv1d(
|
130 |
-
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
131 |
-
else:
|
132 |
-
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
133 |
-
|
134 |
-
self.resblocks = nn.ModuleList()
|
135 |
-
for i in range(len(self.ups)):
|
136 |
-
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
137 |
-
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
138 |
-
self.resblocks.append(resblock(h, ch, k, d))
|
139 |
-
|
140 |
-
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
141 |
-
self.ups.apply(init_weights)
|
142 |
-
self.conv_post.apply(init_weights)
|
143 |
-
|
144 |
-
def forward(self, x, f0=None):
|
145 |
-
if f0 is not None:
|
146 |
-
# harmonic-source signal, noise-source signal, uv flag
|
147 |
-
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
148 |
-
har_source, noi_source, uv = self.m_source(f0)
|
149 |
-
har_source = har_source.transpose(1, 2)
|
150 |
-
|
151 |
-
x = self.conv_pre(x)
|
152 |
-
for i in range(self.num_upsamples):
|
153 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
154 |
-
x = self.ups[i](x)
|
155 |
-
if f0 is not None:
|
156 |
-
x_source = self.noise_convs[i](har_source)
|
157 |
-
x = x + x_source
|
158 |
-
xs = None
|
159 |
-
for j in range(self.num_kernels):
|
160 |
-
if xs is None:
|
161 |
-
xs = self.resblocks[i * self.num_kernels + j](x)
|
162 |
-
else:
|
163 |
-
xs += self.resblocks[i * self.num_kernels + j](x)
|
164 |
-
x = xs / self.num_kernels
|
165 |
-
x = F.leaky_relu(x)
|
166 |
-
x = self.conv_post(x)
|
167 |
-
x = torch.tanh(x)
|
168 |
-
|
169 |
-
return x
|
170 |
-
|
171 |
-
def remove_weight_norm(self):
|
172 |
-
print('Removing weight norm...')
|
173 |
-
for l in self.ups:
|
174 |
-
remove_weight_norm(l)
|
175 |
-
for l in self.resblocks:
|
176 |
-
l.remove_weight_norm()
|
177 |
-
remove_weight_norm(self.conv_pre)
|
178 |
-
remove_weight_norm(self.conv_post)
|
179 |
-
|
180 |
-
|
181 |
-
class DiscriminatorP(torch.nn.Module):
|
182 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
183 |
-
super(DiscriminatorP, self).__init__()
|
184 |
-
self.use_cond = use_cond
|
185 |
-
if use_cond:
|
186 |
-
from utils.hparams import hparams
|
187 |
-
t = hparams['hop_size']
|
188 |
-
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
189 |
-
c_in = 2
|
190 |
-
|
191 |
-
self.period = period
|
192 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
193 |
-
self.convs = nn.ModuleList([
|
194 |
-
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
195 |
-
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
196 |
-
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
197 |
-
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
198 |
-
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
199 |
-
])
|
200 |
-
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
201 |
-
|
202 |
-
def forward(self, x, mel):
|
203 |
-
fmap = []
|
204 |
-
if self.use_cond:
|
205 |
-
x_mel = self.cond_net(mel)
|
206 |
-
x = torch.cat([x_mel, x], 1)
|
207 |
-
# 1d to 2d
|
208 |
-
b, c, t = x.shape
|
209 |
-
if t % self.period != 0: # pad first
|
210 |
-
n_pad = self.period - (t % self.period)
|
211 |
-
x = F.pad(x, (0, n_pad), "reflect")
|
212 |
-
t = t + n_pad
|
213 |
-
x = x.view(b, c, t // self.period, self.period)
|
214 |
-
|
215 |
-
for l in self.convs:
|
216 |
-
x = l(x)
|
217 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
218 |
-
fmap.append(x)
|
219 |
-
x = self.conv_post(x)
|
220 |
-
fmap.append(x)
|
221 |
-
x = torch.flatten(x, 1, -1)
|
222 |
-
|
223 |
-
return x, fmap
|
224 |
-
|
225 |
-
|
226 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
227 |
-
def __init__(self, use_cond=False, c_in=1):
|
228 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
229 |
-
self.discriminators = nn.ModuleList([
|
230 |
-
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
231 |
-
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
232 |
-
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
233 |
-
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
234 |
-
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
235 |
-
])
|
236 |
-
|
237 |
-
def forward(self, y, y_hat, mel=None):
|
238 |
-
y_d_rs = []
|
239 |
-
y_d_gs = []
|
240 |
-
fmap_rs = []
|
241 |
-
fmap_gs = []
|
242 |
-
for i, d in enumerate(self.discriminators):
|
243 |
-
y_d_r, fmap_r = d(y, mel)
|
244 |
-
y_d_g, fmap_g = d(y_hat, mel)
|
245 |
-
y_d_rs.append(y_d_r)
|
246 |
-
fmap_rs.append(fmap_r)
|
247 |
-
y_d_gs.append(y_d_g)
|
248 |
-
fmap_gs.append(fmap_g)
|
249 |
-
|
250 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
251 |
-
|
252 |
-
|
253 |
-
class DiscriminatorS(torch.nn.Module):
|
254 |
-
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
255 |
-
super(DiscriminatorS, self).__init__()
|
256 |
-
self.use_cond = use_cond
|
257 |
-
if use_cond:
|
258 |
-
t = np.prod(upsample_rates)
|
259 |
-
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
260 |
-
c_in = 2
|
261 |
-
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
262 |
-
self.convs = nn.ModuleList([
|
263 |
-
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
264 |
-
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
265 |
-
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
266 |
-
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
267 |
-
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
268 |
-
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
269 |
-
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
270 |
-
])
|
271 |
-
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
272 |
-
|
273 |
-
def forward(self, x, mel):
|
274 |
-
if self.use_cond:
|
275 |
-
x_mel = self.cond_net(mel)
|
276 |
-
x = torch.cat([x_mel, x], 1)
|
277 |
-
fmap = []
|
278 |
-
for l in self.convs:
|
279 |
-
x = l(x)
|
280 |
-
x = F.leaky_relu(x, LRELU_SLOPE)
|
281 |
-
fmap.append(x)
|
282 |
-
x = self.conv_post(x)
|
283 |
-
fmap.append(x)
|
284 |
-
x = torch.flatten(x, 1, -1)
|
285 |
-
|
286 |
-
return x, fmap
|
287 |
-
|
288 |
-
|
289 |
-
class MultiScaleDiscriminator(torch.nn.Module):
|
290 |
-
def __init__(self, use_cond=False, c_in=1):
|
291 |
-
super(MultiScaleDiscriminator, self).__init__()
|
292 |
-
from utils.hparams import hparams
|
293 |
-
self.discriminators = nn.ModuleList([
|
294 |
-
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
295 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
296 |
-
c_in=c_in),
|
297 |
-
DiscriminatorS(use_cond=use_cond,
|
298 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
299 |
-
c_in=c_in),
|
300 |
-
DiscriminatorS(use_cond=use_cond,
|
301 |
-
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
302 |
-
c_in=c_in),
|
303 |
-
])
|
304 |
-
self.meanpools = nn.ModuleList([
|
305 |
-
AvgPool1d(4, 2, padding=1),
|
306 |
-
AvgPool1d(4, 2, padding=1)
|
307 |
-
])
|
308 |
-
|
309 |
-
def forward(self, y, y_hat, mel=None):
|
310 |
-
y_d_rs = []
|
311 |
-
y_d_gs = []
|
312 |
-
fmap_rs = []
|
313 |
-
fmap_gs = []
|
314 |
-
for i, d in enumerate(self.discriminators):
|
315 |
-
if i != 0:
|
316 |
-
y = self.meanpools[i - 1](y)
|
317 |
-
y_hat = self.meanpools[i - 1](y_hat)
|
318 |
-
y_d_r, fmap_r = d(y, mel)
|
319 |
-
y_d_g, fmap_g = d(y_hat, mel)
|
320 |
-
y_d_rs.append(y_d_r)
|
321 |
-
fmap_rs.append(fmap_r)
|
322 |
-
y_d_gs.append(y_d_g)
|
323 |
-
fmap_gs.append(fmap_g)
|
324 |
-
|
325 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
326 |
-
|
327 |
-
|
328 |
-
def feature_loss(fmap_r, fmap_g):
|
329 |
-
loss = 0
|
330 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
331 |
-
for rl, gl in zip(dr, dg):
|
332 |
-
loss += torch.mean(torch.abs(rl - gl))
|
333 |
-
|
334 |
-
return loss * 2
|
335 |
-
|
336 |
-
|
337 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
338 |
-
r_losses = 0
|
339 |
-
g_losses = 0
|
340 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
341 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
342 |
-
g_loss = torch.mean(dg ** 2)
|
343 |
-
r_losses += r_loss
|
344 |
-
g_losses += g_loss
|
345 |
-
r_losses = r_losses / len(disc_real_outputs)
|
346 |
-
g_losses = g_losses / len(disc_real_outputs)
|
347 |
-
return r_losses, g_losses
|
348 |
-
|
349 |
-
|
350 |
-
def cond_discriminator_loss(outputs):
|
351 |
-
loss = 0
|
352 |
-
for dg in outputs:
|
353 |
-
g_loss = torch.mean(dg ** 2)
|
354 |
-
loss += g_loss
|
355 |
-
loss = loss / len(outputs)
|
356 |
-
return loss
|
357 |
-
|
358 |
-
|
359 |
-
def generator_loss(disc_outputs):
|
360 |
-
loss = 0
|
361 |
-
for dg in disc_outputs:
|
362 |
-
l = torch.mean((1 - dg) ** 2)
|
363 |
-
loss += l
|
364 |
-
loss = loss / len(disc_outputs)
|
365 |
-
return loss
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
6 |
+
|
7 |
+
from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
|
8 |
+
from modules.parallel_wavegan.models.source import SourceModuleHnNSF
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
LRELU_SLOPE = 0.1
|
12 |
+
|
13 |
+
|
14 |
+
def init_weights(m, mean=0.0, std=0.01):
|
15 |
+
classname = m.__class__.__name__
|
16 |
+
if classname.find("Conv") != -1:
|
17 |
+
m.weight.data.normal_(mean, std)
|
18 |
+
|
19 |
+
|
20 |
+
def apply_weight_norm(m):
|
21 |
+
classname = m.__class__.__name__
|
22 |
+
if classname.find("Conv") != -1:
|
23 |
+
weight_norm(m)
|
24 |
+
|
25 |
+
|
26 |
+
def get_padding(kernel_size, dilation=1):
|
27 |
+
return int((kernel_size * dilation - dilation) / 2)
|
28 |
+
|
29 |
+
|
30 |
+
class ResBlock1(torch.nn.Module):
|
31 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
32 |
+
super(ResBlock1, self).__init__()
|
33 |
+
self.h = h
|
34 |
+
self.convs1 = nn.ModuleList([
|
35 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
36 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
37 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
38 |
+
padding=get_padding(kernel_size, dilation[1]))),
|
39 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
|
40 |
+
padding=get_padding(kernel_size, dilation[2])))
|
41 |
+
])
|
42 |
+
self.convs1.apply(init_weights)
|
43 |
+
|
44 |
+
self.convs2 = nn.ModuleList([
|
45 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
46 |
+
padding=get_padding(kernel_size, 1))),
|
47 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
48 |
+
padding=get_padding(kernel_size, 1))),
|
49 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
|
50 |
+
padding=get_padding(kernel_size, 1)))
|
51 |
+
])
|
52 |
+
self.convs2.apply(init_weights)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
56 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
57 |
+
xt = c1(xt)
|
58 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
59 |
+
xt = c2(xt)
|
60 |
+
x = xt + x
|
61 |
+
return x
|
62 |
+
|
63 |
+
def remove_weight_norm(self):
|
64 |
+
for l in self.convs1:
|
65 |
+
remove_weight_norm(l)
|
66 |
+
for l in self.convs2:
|
67 |
+
remove_weight_norm(l)
|
68 |
+
|
69 |
+
|
70 |
+
class ResBlock2(torch.nn.Module):
|
71 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
72 |
+
super(ResBlock2, self).__init__()
|
73 |
+
self.h = h
|
74 |
+
self.convs = nn.ModuleList([
|
75 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
|
76 |
+
padding=get_padding(kernel_size, dilation[0]))),
|
77 |
+
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
|
78 |
+
padding=get_padding(kernel_size, dilation[1])))
|
79 |
+
])
|
80 |
+
self.convs.apply(init_weights)
|
81 |
+
|
82 |
+
def forward(self, x):
|
83 |
+
for c in self.convs:
|
84 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
85 |
+
xt = c(xt)
|
86 |
+
x = xt + x
|
87 |
+
return x
|
88 |
+
|
89 |
+
def remove_weight_norm(self):
|
90 |
+
for l in self.convs:
|
91 |
+
remove_weight_norm(l)
|
92 |
+
|
93 |
+
|
94 |
+
class Conv1d1x1(Conv1d):
|
95 |
+
"""1x1 Conv1d with customized initialization."""
|
96 |
+
|
97 |
+
def __init__(self, in_channels, out_channels, bias):
|
98 |
+
"""Initialize 1x1 Conv1d module."""
|
99 |
+
super(Conv1d1x1, self).__init__(in_channels, out_channels,
|
100 |
+
kernel_size=1, padding=0,
|
101 |
+
dilation=1, bias=bias)
|
102 |
+
|
103 |
+
|
104 |
+
class HifiGanGenerator(torch.nn.Module):
|
105 |
+
def __init__(self, h, c_out=1):
|
106 |
+
super(HifiGanGenerator, self).__init__()
|
107 |
+
self.h = h
|
108 |
+
self.num_kernels = len(h['resblock_kernel_sizes'])
|
109 |
+
self.num_upsamples = len(h['upsample_rates'])
|
110 |
+
|
111 |
+
if h['use_pitch_embed']:
|
112 |
+
self.harmonic_num = 8
|
113 |
+
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
|
114 |
+
self.m_source = SourceModuleHnNSF(
|
115 |
+
sampling_rate=h['audio_sample_rate'],
|
116 |
+
harmonic_num=self.harmonic_num)
|
117 |
+
self.noise_convs = nn.ModuleList()
|
118 |
+
self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
|
119 |
+
resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
|
120 |
+
|
121 |
+
self.ups = nn.ModuleList()
|
122 |
+
for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
|
123 |
+
c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
|
124 |
+
self.ups.append(weight_norm(
|
125 |
+
ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
|
126 |
+
if h['use_pitch_embed']:
|
127 |
+
if i + 1 < len(h['upsample_rates']):
|
128 |
+
stride_f0 = np.prod(h['upsample_rates'][i + 1:])
|
129 |
+
self.noise_convs.append(Conv1d(
|
130 |
+
1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
|
131 |
+
else:
|
132 |
+
self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
|
133 |
+
|
134 |
+
self.resblocks = nn.ModuleList()
|
135 |
+
for i in range(len(self.ups)):
|
136 |
+
ch = h['upsample_initial_channel'] // (2 ** (i + 1))
|
137 |
+
for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
|
138 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
139 |
+
|
140 |
+
self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
|
141 |
+
self.ups.apply(init_weights)
|
142 |
+
self.conv_post.apply(init_weights)
|
143 |
+
|
144 |
+
def forward(self, x, f0=None):
|
145 |
+
if f0 is not None:
|
146 |
+
# harmonic-source signal, noise-source signal, uv flag
|
147 |
+
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
|
148 |
+
har_source, noi_source, uv = self.m_source(f0)
|
149 |
+
har_source = har_source.transpose(1, 2)
|
150 |
+
|
151 |
+
x = self.conv_pre(x)
|
152 |
+
for i in range(self.num_upsamples):
|
153 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
154 |
+
x = self.ups[i](x)
|
155 |
+
if f0 is not None:
|
156 |
+
x_source = self.noise_convs[i](har_source)
|
157 |
+
x = x + x_source
|
158 |
+
xs = None
|
159 |
+
for j in range(self.num_kernels):
|
160 |
+
if xs is None:
|
161 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
162 |
+
else:
|
163 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
164 |
+
x = xs / self.num_kernels
|
165 |
+
x = F.leaky_relu(x)
|
166 |
+
x = self.conv_post(x)
|
167 |
+
x = torch.tanh(x)
|
168 |
+
|
169 |
+
return x
|
170 |
+
|
171 |
+
def remove_weight_norm(self):
|
172 |
+
print('Removing weight norm...')
|
173 |
+
for l in self.ups:
|
174 |
+
remove_weight_norm(l)
|
175 |
+
for l in self.resblocks:
|
176 |
+
l.remove_weight_norm()
|
177 |
+
remove_weight_norm(self.conv_pre)
|
178 |
+
remove_weight_norm(self.conv_post)
|
179 |
+
|
180 |
+
|
181 |
+
class DiscriminatorP(torch.nn.Module):
|
182 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
|
183 |
+
super(DiscriminatorP, self).__init__()
|
184 |
+
self.use_cond = use_cond
|
185 |
+
if use_cond:
|
186 |
+
from utils.hparams import hparams
|
187 |
+
t = hparams['hop_size']
|
188 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
189 |
+
c_in = 2
|
190 |
+
|
191 |
+
self.period = period
|
192 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
193 |
+
self.convs = nn.ModuleList([
|
194 |
+
norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
195 |
+
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
196 |
+
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
197 |
+
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
|
198 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
199 |
+
])
|
200 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
201 |
+
|
202 |
+
def forward(self, x, mel):
|
203 |
+
fmap = []
|
204 |
+
if self.use_cond:
|
205 |
+
x_mel = self.cond_net(mel)
|
206 |
+
x = torch.cat([x_mel, x], 1)
|
207 |
+
# 1d to 2d
|
208 |
+
b, c, t = x.shape
|
209 |
+
if t % self.period != 0: # pad first
|
210 |
+
n_pad = self.period - (t % self.period)
|
211 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
212 |
+
t = t + n_pad
|
213 |
+
x = x.view(b, c, t // self.period, self.period)
|
214 |
+
|
215 |
+
for l in self.convs:
|
216 |
+
x = l(x)
|
217 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
218 |
+
fmap.append(x)
|
219 |
+
x = self.conv_post(x)
|
220 |
+
fmap.append(x)
|
221 |
+
x = torch.flatten(x, 1, -1)
|
222 |
+
|
223 |
+
return x, fmap
|
224 |
+
|
225 |
+
|
226 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
227 |
+
def __init__(self, use_cond=False, c_in=1):
|
228 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
229 |
+
self.discriminators = nn.ModuleList([
|
230 |
+
DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
|
231 |
+
DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
|
232 |
+
DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
|
233 |
+
DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
|
234 |
+
DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
|
235 |
+
])
|
236 |
+
|
237 |
+
def forward(self, y, y_hat, mel=None):
|
238 |
+
y_d_rs = []
|
239 |
+
y_d_gs = []
|
240 |
+
fmap_rs = []
|
241 |
+
fmap_gs = []
|
242 |
+
for i, d in enumerate(self.discriminators):
|
243 |
+
y_d_r, fmap_r = d(y, mel)
|
244 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
245 |
+
y_d_rs.append(y_d_r)
|
246 |
+
fmap_rs.append(fmap_r)
|
247 |
+
y_d_gs.append(y_d_g)
|
248 |
+
fmap_gs.append(fmap_g)
|
249 |
+
|
250 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
251 |
+
|
252 |
+
|
253 |
+
class DiscriminatorS(torch.nn.Module):
|
254 |
+
def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
|
255 |
+
super(DiscriminatorS, self).__init__()
|
256 |
+
self.use_cond = use_cond
|
257 |
+
if use_cond:
|
258 |
+
t = np.prod(upsample_rates)
|
259 |
+
self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
|
260 |
+
c_in = 2
|
261 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
262 |
+
self.convs = nn.ModuleList([
|
263 |
+
norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
|
264 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
265 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
266 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
267 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
268 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
269 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
270 |
+
])
|
271 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
272 |
+
|
273 |
+
def forward(self, x, mel):
|
274 |
+
if self.use_cond:
|
275 |
+
x_mel = self.cond_net(mel)
|
276 |
+
x = torch.cat([x_mel, x], 1)
|
277 |
+
fmap = []
|
278 |
+
for l in self.convs:
|
279 |
+
x = l(x)
|
280 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
281 |
+
fmap.append(x)
|
282 |
+
x = self.conv_post(x)
|
283 |
+
fmap.append(x)
|
284 |
+
x = torch.flatten(x, 1, -1)
|
285 |
+
|
286 |
+
return x, fmap
|
287 |
+
|
288 |
+
|
289 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
290 |
+
def __init__(self, use_cond=False, c_in=1):
|
291 |
+
super(MultiScaleDiscriminator, self).__init__()
|
292 |
+
from utils.hparams import hparams
|
293 |
+
self.discriminators = nn.ModuleList([
|
294 |
+
DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
|
295 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 16],
|
296 |
+
c_in=c_in),
|
297 |
+
DiscriminatorS(use_cond=use_cond,
|
298 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 32],
|
299 |
+
c_in=c_in),
|
300 |
+
DiscriminatorS(use_cond=use_cond,
|
301 |
+
upsample_rates=[4, 4, hparams['hop_size'] // 64],
|
302 |
+
c_in=c_in),
|
303 |
+
])
|
304 |
+
self.meanpools = nn.ModuleList([
|
305 |
+
AvgPool1d(4, 2, padding=1),
|
306 |
+
AvgPool1d(4, 2, padding=1)
|
307 |
+
])
|
308 |
+
|
309 |
+
def forward(self, y, y_hat, mel=None):
|
310 |
+
y_d_rs = []
|
311 |
+
y_d_gs = []
|
312 |
+
fmap_rs = []
|
313 |
+
fmap_gs = []
|
314 |
+
for i, d in enumerate(self.discriminators):
|
315 |
+
if i != 0:
|
316 |
+
y = self.meanpools[i - 1](y)
|
317 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
318 |
+
y_d_r, fmap_r = d(y, mel)
|
319 |
+
y_d_g, fmap_g = d(y_hat, mel)
|
320 |
+
y_d_rs.append(y_d_r)
|
321 |
+
fmap_rs.append(fmap_r)
|
322 |
+
y_d_gs.append(y_d_g)
|
323 |
+
fmap_gs.append(fmap_g)
|
324 |
+
|
325 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
326 |
+
|
327 |
+
|
328 |
+
def feature_loss(fmap_r, fmap_g):
|
329 |
+
loss = 0
|
330 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
331 |
+
for rl, gl in zip(dr, dg):
|
332 |
+
loss += torch.mean(torch.abs(rl - gl))
|
333 |
+
|
334 |
+
return loss * 2
|
335 |
+
|
336 |
+
|
337 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
338 |
+
r_losses = 0
|
339 |
+
g_losses = 0
|
340 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
341 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
342 |
+
g_loss = torch.mean(dg ** 2)
|
343 |
+
r_losses += r_loss
|
344 |
+
g_losses += g_loss
|
345 |
+
r_losses = r_losses / len(disc_real_outputs)
|
346 |
+
g_losses = g_losses / len(disc_real_outputs)
|
347 |
+
return r_losses, g_losses
|
348 |
+
|
349 |
+
|
350 |
+
def cond_discriminator_loss(outputs):
|
351 |
+
loss = 0
|
352 |
+
for dg in outputs:
|
353 |
+
g_loss = torch.mean(dg ** 2)
|
354 |
+
loss += g_loss
|
355 |
+
loss = loss / len(outputs)
|
356 |
+
return loss
|
357 |
+
|
358 |
+
|
359 |
+
def generator_loss(disc_outputs):
|
360 |
+
loss = 0
|
361 |
+
for dg in disc_outputs:
|
362 |
+
l = torch.mean((1 - dg) ** 2)
|
363 |
+
loss += l
|
364 |
+
loss = loss / len(disc_outputs)
|
365 |
+
return loss
|
modules/hifigan/mel_utils.py
CHANGED
@@ -1,80 +1,80 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import torch
|
3 |
-
import torch.utils.data
|
4 |
-
from librosa.filters import mel as librosa_mel_fn
|
5 |
-
from scipy.io.wavfile import read
|
6 |
-
|
7 |
-
MAX_WAV_VALUE = 32768.0
|
8 |
-
|
9 |
-
|
10 |
-
def load_wav(full_path):
|
11 |
-
sampling_rate, data = read(full_path)
|
12 |
-
return data, sampling_rate
|
13 |
-
|
14 |
-
|
15 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
-
|
18 |
-
|
19 |
-
def dynamic_range_decompression(x, C=1):
|
20 |
-
return np.exp(x) / C
|
21 |
-
|
22 |
-
|
23 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
-
|
26 |
-
|
27 |
-
def dynamic_range_decompression_torch(x, C=1):
|
28 |
-
return torch.exp(x) / C
|
29 |
-
|
30 |
-
|
31 |
-
def spectral_normalize_torch(magnitudes):
|
32 |
-
output = dynamic_range_compression_torch(magnitudes)
|
33 |
-
return output
|
34 |
-
|
35 |
-
|
36 |
-
def spectral_de_normalize_torch(magnitudes):
|
37 |
-
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
-
return output
|
39 |
-
|
40 |
-
|
41 |
-
mel_basis = {}
|
42 |
-
hann_window = {}
|
43 |
-
|
44 |
-
|
45 |
-
def mel_spectrogram(y, hparams, center=False, complex=False):
|
46 |
-
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
47 |
-
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
48 |
-
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
49 |
-
# fmax: 10000 # To be increased/reduced depending on data.
|
50 |
-
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
51 |
-
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
52 |
-
n_fft = hparams['fft_size']
|
53 |
-
num_mels = hparams['audio_num_mel_bins']
|
54 |
-
sampling_rate = hparams['audio_sample_rate']
|
55 |
-
hop_size = hparams['hop_size']
|
56 |
-
win_size = hparams['win_size']
|
57 |
-
fmin = hparams['fmin']
|
58 |
-
fmax = hparams['fmax']
|
59 |
-
y = y.clamp(min=-1., max=1.)
|
60 |
-
global mel_basis, hann_window
|
61 |
-
if fmax not in mel_basis:
|
62 |
-
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
63 |
-
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
64 |
-
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
65 |
-
|
66 |
-
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
67 |
-
mode='reflect')
|
68 |
-
y = y.squeeze(1)
|
69 |
-
|
70 |
-
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
71 |
-
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
72 |
-
|
73 |
-
if not complex:
|
74 |
-
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
75 |
-
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
76 |
-
spec = spectral_normalize_torch(spec)
|
77 |
-
else:
|
78 |
-
B, C, T, _ = spec.shape
|
79 |
-
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
80 |
-
return spec
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data
|
4 |
+
from librosa.filters import mel as librosa_mel_fn
|
5 |
+
from scipy.io.wavfile import read
|
6 |
+
|
7 |
+
MAX_WAV_VALUE = 32768.0
|
8 |
+
|
9 |
+
|
10 |
+
def load_wav(full_path):
|
11 |
+
sampling_rate, data = read(full_path)
|
12 |
+
return data, sampling_rate
|
13 |
+
|
14 |
+
|
15 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
16 |
+
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
17 |
+
|
18 |
+
|
19 |
+
def dynamic_range_decompression(x, C=1):
|
20 |
+
return np.exp(x) / C
|
21 |
+
|
22 |
+
|
23 |
+
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
24 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
25 |
+
|
26 |
+
|
27 |
+
def dynamic_range_decompression_torch(x, C=1):
|
28 |
+
return torch.exp(x) / C
|
29 |
+
|
30 |
+
|
31 |
+
def spectral_normalize_torch(magnitudes):
|
32 |
+
output = dynamic_range_compression_torch(magnitudes)
|
33 |
+
return output
|
34 |
+
|
35 |
+
|
36 |
+
def spectral_de_normalize_torch(magnitudes):
|
37 |
+
output = dynamic_range_decompression_torch(magnitudes)
|
38 |
+
return output
|
39 |
+
|
40 |
+
|
41 |
+
mel_basis = {}
|
42 |
+
hann_window = {}
|
43 |
+
|
44 |
+
|
45 |
+
def mel_spectrogram(y, hparams, center=False, complex=False):
|
46 |
+
# hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
|
47 |
+
# win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
|
48 |
+
# fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
|
49 |
+
# fmax: 10000 # To be increased/reduced depending on data.
|
50 |
+
# fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
|
51 |
+
# n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
|
52 |
+
n_fft = hparams['fft_size']
|
53 |
+
num_mels = hparams['audio_num_mel_bins']
|
54 |
+
sampling_rate = hparams['audio_sample_rate']
|
55 |
+
hop_size = hparams['hop_size']
|
56 |
+
win_size = hparams['win_size']
|
57 |
+
fmin = hparams['fmin']
|
58 |
+
fmax = hparams['fmax']
|
59 |
+
y = y.clamp(min=-1., max=1.)
|
60 |
+
global mel_basis, hann_window
|
61 |
+
if fmax not in mel_basis:
|
62 |
+
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
|
63 |
+
mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
64 |
+
hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
|
65 |
+
|
66 |
+
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
67 |
+
mode='reflect')
|
68 |
+
y = y.squeeze(1)
|
69 |
+
|
70 |
+
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
|
71 |
+
center=center, pad_mode='reflect', normalized=False, onesided=True)
|
72 |
+
|
73 |
+
if not complex:
|
74 |
+
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
|
75 |
+
spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
|
76 |
+
spec = spectral_normalize_torch(spec)
|
77 |
+
else:
|
78 |
+
B, C, T, _ = spec.shape
|
79 |
+
spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
|
80 |
+
return spec
|
modules/parallel_wavegan/models/parallel_wavegan.py
CHANGED
@@ -1,434 +1,434 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
|
3 |
-
# Copyright 2019 Tomoki Hayashi
|
4 |
-
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
-
|
6 |
-
"""Parallel WaveGAN Modules."""
|
7 |
-
|
8 |
-
import logging
|
9 |
-
import math
|
10 |
-
|
11 |
-
import torch
|
12 |
-
from torch import nn
|
13 |
-
|
14 |
-
from modules.parallel_wavegan.layers import Conv1d
|
15 |
-
from modules.parallel_wavegan.layers import Conv1d1x1
|
16 |
-
from modules.parallel_wavegan.layers import ResidualBlock
|
17 |
-
from modules.parallel_wavegan.layers import upsample
|
18 |
-
from modules.parallel_wavegan import models
|
19 |
-
|
20 |
-
|
21 |
-
class ParallelWaveGANGenerator(torch.nn.Module):
|
22 |
-
"""Parallel WaveGAN Generator module."""
|
23 |
-
|
24 |
-
def __init__(self,
|
25 |
-
in_channels=1,
|
26 |
-
out_channels=1,
|
27 |
-
kernel_size=3,
|
28 |
-
layers=30,
|
29 |
-
stacks=3,
|
30 |
-
residual_channels=64,
|
31 |
-
gate_channels=128,
|
32 |
-
skip_channels=64,
|
33 |
-
aux_channels=80,
|
34 |
-
aux_context_window=2,
|
35 |
-
dropout=0.0,
|
36 |
-
bias=True,
|
37 |
-
use_weight_norm=True,
|
38 |
-
use_causal_conv=False,
|
39 |
-
upsample_conditional_features=True,
|
40 |
-
upsample_net="ConvInUpsampleNetwork",
|
41 |
-
upsample_params={"upsample_scales": [4, 4, 4, 4]},
|
42 |
-
use_pitch_embed=False,
|
43 |
-
):
|
44 |
-
"""Initialize Parallel WaveGAN Generator module.
|
45 |
-
|
46 |
-
Args:
|
47 |
-
in_channels (int): Number of input channels.
|
48 |
-
out_channels (int): Number of output channels.
|
49 |
-
kernel_size (int): Kernel size of dilated convolution.
|
50 |
-
layers (int): Number of residual block layers.
|
51 |
-
stacks (int): Number of stacks i.e., dilation cycles.
|
52 |
-
residual_channels (int): Number of channels in residual conv.
|
53 |
-
gate_channels (int): Number of channels in gated conv.
|
54 |
-
skip_channels (int): Number of channels in skip conv.
|
55 |
-
aux_channels (int): Number of channels for auxiliary feature conv.
|
56 |
-
aux_context_window (int): Context window size for auxiliary feature.
|
57 |
-
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
58 |
-
bias (bool): Whether to use bias parameter in conv layer.
|
59 |
-
use_weight_norm (bool): Whether to use weight norm.
|
60 |
-
If set to true, it will be applied to all of the conv layers.
|
61 |
-
use_causal_conv (bool): Whether to use causal structure.
|
62 |
-
upsample_conditional_features (bool): Whether to use upsampling network.
|
63 |
-
upsample_net (str): Upsampling network architecture.
|
64 |
-
upsample_params (dict): Upsampling network parameters.
|
65 |
-
|
66 |
-
"""
|
67 |
-
super(ParallelWaveGANGenerator, self).__init__()
|
68 |
-
self.in_channels = in_channels
|
69 |
-
self.out_channels = out_channels
|
70 |
-
self.aux_channels = aux_channels
|
71 |
-
self.layers = layers
|
72 |
-
self.stacks = stacks
|
73 |
-
self.kernel_size = kernel_size
|
74 |
-
|
75 |
-
# check the number of layers and stacks
|
76 |
-
assert layers % stacks == 0
|
77 |
-
layers_per_stack = layers // stacks
|
78 |
-
|
79 |
-
# define first convolution
|
80 |
-
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
81 |
-
|
82 |
-
# define conv + upsampling network
|
83 |
-
if upsample_conditional_features:
|
84 |
-
upsample_params.update({
|
85 |
-
"use_causal_conv": use_causal_conv,
|
86 |
-
})
|
87 |
-
if upsample_net == "MelGANGenerator":
|
88 |
-
assert aux_context_window == 0
|
89 |
-
upsample_params.update({
|
90 |
-
"use_weight_norm": False, # not to apply twice
|
91 |
-
"use_final_nonlinear_activation": False,
|
92 |
-
})
|
93 |
-
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
|
94 |
-
else:
|
95 |
-
if upsample_net == "ConvInUpsampleNetwork":
|
96 |
-
upsample_params.update({
|
97 |
-
"aux_channels": aux_channels,
|
98 |
-
"aux_context_window": aux_context_window,
|
99 |
-
})
|
100 |
-
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
101 |
-
else:
|
102 |
-
self.upsample_net = None
|
103 |
-
|
104 |
-
# define residual blocks
|
105 |
-
self.conv_layers = torch.nn.ModuleList()
|
106 |
-
for layer in range(layers):
|
107 |
-
dilation = 2 ** (layer % layers_per_stack)
|
108 |
-
conv = ResidualBlock(
|
109 |
-
kernel_size=kernel_size,
|
110 |
-
residual_channels=residual_channels,
|
111 |
-
gate_channels=gate_channels,
|
112 |
-
skip_channels=skip_channels,
|
113 |
-
aux_channels=aux_channels,
|
114 |
-
dilation=dilation,
|
115 |
-
dropout=dropout,
|
116 |
-
bias=bias,
|
117 |
-
use_causal_conv=use_causal_conv,
|
118 |
-
)
|
119 |
-
self.conv_layers += [conv]
|
120 |
-
|
121 |
-
# define output layers
|
122 |
-
self.last_conv_layers = torch.nn.ModuleList([
|
123 |
-
torch.nn.ReLU(inplace=True),
|
124 |
-
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
125 |
-
torch.nn.ReLU(inplace=True),
|
126 |
-
Conv1d1x1(skip_channels, out_channels, bias=True),
|
127 |
-
])
|
128 |
-
|
129 |
-
self.use_pitch_embed = use_pitch_embed
|
130 |
-
if use_pitch_embed:
|
131 |
-
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
|
132 |
-
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
|
133 |
-
|
134 |
-
# apply weight norm
|
135 |
-
if use_weight_norm:
|
136 |
-
self.apply_weight_norm()
|
137 |
-
|
138 |
-
def forward(self, x, c=None, pitch=None, **kwargs):
|
139 |
-
"""Calculate forward propagation.
|
140 |
-
|
141 |
-
Args:
|
142 |
-
x (Tensor): Input noise signal (B, C_in, T).
|
143 |
-
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
144 |
-
pitch (Tensor): Local conditioning pitch (B, T').
|
145 |
-
|
146 |
-
Returns:
|
147 |
-
Tensor: Output tensor (B, C_out, T)
|
148 |
-
|
149 |
-
"""
|
150 |
-
# perform upsampling
|
151 |
-
if c is not None and self.upsample_net is not None:
|
152 |
-
if self.use_pitch_embed:
|
153 |
-
p = self.pitch_embed(pitch)
|
154 |
-
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
|
155 |
-
c = self.upsample_net(c)
|
156 |
-
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
|
157 |
-
|
158 |
-
# encode to hidden representation
|
159 |
-
x = self.first_conv(x)
|
160 |
-
skips = 0
|
161 |
-
for f in self.conv_layers:
|
162 |
-
x, h = f(x, c)
|
163 |
-
skips += h
|
164 |
-
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
165 |
-
|
166 |
-
# apply final layers
|
167 |
-
x = skips
|
168 |
-
for f in self.last_conv_layers:
|
169 |
-
x = f(x)
|
170 |
-
|
171 |
-
return x
|
172 |
-
|
173 |
-
def remove_weight_norm(self):
|
174 |
-
"""Remove weight normalization module from all of the layers."""
|
175 |
-
def _remove_weight_norm(m):
|
176 |
-
try:
|
177 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
178 |
-
torch.nn.utils.remove_weight_norm(m)
|
179 |
-
except ValueError: # this module didn't have weight norm
|
180 |
-
return
|
181 |
-
|
182 |
-
self.apply(_remove_weight_norm)
|
183 |
-
|
184 |
-
def apply_weight_norm(self):
|
185 |
-
"""Apply weight normalization module from all of the layers."""
|
186 |
-
def _apply_weight_norm(m):
|
187 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
188 |
-
torch.nn.utils.weight_norm(m)
|
189 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
190 |
-
|
191 |
-
self.apply(_apply_weight_norm)
|
192 |
-
|
193 |
-
@staticmethod
|
194 |
-
def _get_receptive_field_size(layers, stacks, kernel_size,
|
195 |
-
dilation=lambda x: 2 ** x):
|
196 |
-
assert layers % stacks == 0
|
197 |
-
layers_per_cycle = layers // stacks
|
198 |
-
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
199 |
-
return (kernel_size - 1) * sum(dilations) + 1
|
200 |
-
|
201 |
-
@property
|
202 |
-
def receptive_field_size(self):
|
203 |
-
"""Return receptive field size."""
|
204 |
-
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
205 |
-
|
206 |
-
|
207 |
-
class ParallelWaveGANDiscriminator(torch.nn.Module):
|
208 |
-
"""Parallel WaveGAN Discriminator module."""
|
209 |
-
|
210 |
-
def __init__(self,
|
211 |
-
in_channels=1,
|
212 |
-
out_channels=1,
|
213 |
-
kernel_size=3,
|
214 |
-
layers=10,
|
215 |
-
conv_channels=64,
|
216 |
-
dilation_factor=1,
|
217 |
-
nonlinear_activation="LeakyReLU",
|
218 |
-
nonlinear_activation_params={"negative_slope": 0.2},
|
219 |
-
bias=True,
|
220 |
-
use_weight_norm=True,
|
221 |
-
):
|
222 |
-
"""Initialize Parallel WaveGAN Discriminator module.
|
223 |
-
|
224 |
-
Args:
|
225 |
-
in_channels (int): Number of input channels.
|
226 |
-
out_channels (int): Number of output channels.
|
227 |
-
kernel_size (int): Number of output channels.
|
228 |
-
layers (int): Number of conv layers.
|
229 |
-
conv_channels (int): Number of chnn layers.
|
230 |
-
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
|
231 |
-
the dilation will be 2, 4, 8, ..., and so on.
|
232 |
-
nonlinear_activation (str): Nonlinear function after each conv.
|
233 |
-
nonlinear_activation_params (dict): Nonlinear function parameters
|
234 |
-
bias (bool): Whether to use bias parameter in conv.
|
235 |
-
use_weight_norm (bool) Whether to use weight norm.
|
236 |
-
If set to true, it will be applied to all of the conv layers.
|
237 |
-
|
238 |
-
"""
|
239 |
-
super(ParallelWaveGANDiscriminator, self).__init__()
|
240 |
-
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
241 |
-
assert dilation_factor > 0, "Dilation factor must be > 0."
|
242 |
-
self.conv_layers = torch.nn.ModuleList()
|
243 |
-
conv_in_channels = in_channels
|
244 |
-
for i in range(layers - 1):
|
245 |
-
if i == 0:
|
246 |
-
dilation = 1
|
247 |
-
else:
|
248 |
-
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
249 |
-
conv_in_channels = conv_channels
|
250 |
-
padding = (kernel_size - 1) // 2 * dilation
|
251 |
-
conv_layer = [
|
252 |
-
Conv1d(conv_in_channels, conv_channels,
|
253 |
-
kernel_size=kernel_size, padding=padding,
|
254 |
-
dilation=dilation, bias=bias),
|
255 |
-
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
256 |
-
]
|
257 |
-
self.conv_layers += conv_layer
|
258 |
-
padding = (kernel_size - 1) // 2
|
259 |
-
last_conv_layer = Conv1d(
|
260 |
-
conv_in_channels, out_channels,
|
261 |
-
kernel_size=kernel_size, padding=padding, bias=bias)
|
262 |
-
self.conv_layers += [last_conv_layer]
|
263 |
-
|
264 |
-
# apply weight norm
|
265 |
-
if use_weight_norm:
|
266 |
-
self.apply_weight_norm()
|
267 |
-
|
268 |
-
def forward(self, x):
|
269 |
-
"""Calculate forward propagation.
|
270 |
-
|
271 |
-
Args:
|
272 |
-
x (Tensor): Input noise signal (B, 1, T).
|
273 |
-
|
274 |
-
Returns:
|
275 |
-
Tensor: Output tensor (B, 1, T)
|
276 |
-
|
277 |
-
"""
|
278 |
-
for f in self.conv_layers:
|
279 |
-
x = f(x)
|
280 |
-
return x
|
281 |
-
|
282 |
-
def apply_weight_norm(self):
|
283 |
-
"""Apply weight normalization module from all of the layers."""
|
284 |
-
def _apply_weight_norm(m):
|
285 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
286 |
-
torch.nn.utils.weight_norm(m)
|
287 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
288 |
-
|
289 |
-
self.apply(_apply_weight_norm)
|
290 |
-
|
291 |
-
def remove_weight_norm(self):
|
292 |
-
"""Remove weight normalization module from all of the layers."""
|
293 |
-
def _remove_weight_norm(m):
|
294 |
-
try:
|
295 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
296 |
-
torch.nn.utils.remove_weight_norm(m)
|
297 |
-
except ValueError: # this module didn't have weight norm
|
298 |
-
return
|
299 |
-
|
300 |
-
self.apply(_remove_weight_norm)
|
301 |
-
|
302 |
-
|
303 |
-
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
|
304 |
-
"""Parallel WaveGAN Discriminator module."""
|
305 |
-
|
306 |
-
def __init__(self,
|
307 |
-
in_channels=1,
|
308 |
-
out_channels=1,
|
309 |
-
kernel_size=3,
|
310 |
-
layers=30,
|
311 |
-
stacks=3,
|
312 |
-
residual_channels=64,
|
313 |
-
gate_channels=128,
|
314 |
-
skip_channels=64,
|
315 |
-
dropout=0.0,
|
316 |
-
bias=True,
|
317 |
-
use_weight_norm=True,
|
318 |
-
use_causal_conv=False,
|
319 |
-
nonlinear_activation="LeakyReLU",
|
320 |
-
nonlinear_activation_params={"negative_slope": 0.2},
|
321 |
-
):
|
322 |
-
"""Initialize Parallel WaveGAN Discriminator module.
|
323 |
-
|
324 |
-
Args:
|
325 |
-
in_channels (int): Number of input channels.
|
326 |
-
out_channels (int): Number of output channels.
|
327 |
-
kernel_size (int): Kernel size of dilated convolution.
|
328 |
-
layers (int): Number of residual block layers.
|
329 |
-
stacks (int): Number of stacks i.e., dilation cycles.
|
330 |
-
residual_channels (int): Number of channels in residual conv.
|
331 |
-
gate_channels (int): Number of channels in gated conv.
|
332 |
-
skip_channels (int): Number of channels in skip conv.
|
333 |
-
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
334 |
-
bias (bool): Whether to use bias parameter in conv.
|
335 |
-
use_weight_norm (bool): Whether to use weight norm.
|
336 |
-
If set to true, it will be applied to all of the conv layers.
|
337 |
-
use_causal_conv (bool): Whether to use causal structure.
|
338 |
-
nonlinear_activation_params (dict): Nonlinear function parameters
|
339 |
-
|
340 |
-
"""
|
341 |
-
super(ResidualParallelWaveGANDiscriminator, self).__init__()
|
342 |
-
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
343 |
-
|
344 |
-
self.in_channels = in_channels
|
345 |
-
self.out_channels = out_channels
|
346 |
-
self.layers = layers
|
347 |
-
self.stacks = stacks
|
348 |
-
self.kernel_size = kernel_size
|
349 |
-
|
350 |
-
# check the number of layers and stacks
|
351 |
-
assert layers % stacks == 0
|
352 |
-
layers_per_stack = layers // stacks
|
353 |
-
|
354 |
-
# define first convolution
|
355 |
-
self.first_conv = torch.nn.Sequential(
|
356 |
-
Conv1d1x1(in_channels, residual_channels, bias=True),
|
357 |
-
getattr(torch.nn, nonlinear_activation)(
|
358 |
-
inplace=True, **nonlinear_activation_params),
|
359 |
-
)
|
360 |
-
|
361 |
-
# define residual blocks
|
362 |
-
self.conv_layers = torch.nn.ModuleList()
|
363 |
-
for layer in range(layers):
|
364 |
-
dilation = 2 ** (layer % layers_per_stack)
|
365 |
-
conv = ResidualBlock(
|
366 |
-
kernel_size=kernel_size,
|
367 |
-
residual_channels=residual_channels,
|
368 |
-
gate_channels=gate_channels,
|
369 |
-
skip_channels=skip_channels,
|
370 |
-
aux_channels=-1,
|
371 |
-
dilation=dilation,
|
372 |
-
dropout=dropout,
|
373 |
-
bias=bias,
|
374 |
-
use_causal_conv=use_causal_conv,
|
375 |
-
)
|
376 |
-
self.conv_layers += [conv]
|
377 |
-
|
378 |
-
# define output layers
|
379 |
-
self.last_conv_layers = torch.nn.ModuleList([
|
380 |
-
getattr(torch.nn, nonlinear_activation)(
|
381 |
-
inplace=True, **nonlinear_activation_params),
|
382 |
-
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
383 |
-
getattr(torch.nn, nonlinear_activation)(
|
384 |
-
inplace=True, **nonlinear_activation_params),
|
385 |
-
Conv1d1x1(skip_channels, out_channels, bias=True),
|
386 |
-
])
|
387 |
-
|
388 |
-
# apply weight norm
|
389 |
-
if use_weight_norm:
|
390 |
-
self.apply_weight_norm()
|
391 |
-
|
392 |
-
def forward(self, x):
|
393 |
-
"""Calculate forward propagation.
|
394 |
-
|
395 |
-
Args:
|
396 |
-
x (Tensor): Input noise signal (B, 1, T).
|
397 |
-
|
398 |
-
Returns:
|
399 |
-
Tensor: Output tensor (B, 1, T)
|
400 |
-
|
401 |
-
"""
|
402 |
-
x = self.first_conv(x)
|
403 |
-
|
404 |
-
skips = 0
|
405 |
-
for f in self.conv_layers:
|
406 |
-
x, h = f(x, None)
|
407 |
-
skips += h
|
408 |
-
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
409 |
-
|
410 |
-
# apply final layers
|
411 |
-
x = skips
|
412 |
-
for f in self.last_conv_layers:
|
413 |
-
x = f(x)
|
414 |
-
return x
|
415 |
-
|
416 |
-
def apply_weight_norm(self):
|
417 |
-
"""Apply weight normalization module from all of the layers."""
|
418 |
-
def _apply_weight_norm(m):
|
419 |
-
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
420 |
-
torch.nn.utils.weight_norm(m)
|
421 |
-
logging.debug(f"Weight norm is applied to {m}.")
|
422 |
-
|
423 |
-
self.apply(_apply_weight_norm)
|
424 |
-
|
425 |
-
def remove_weight_norm(self):
|
426 |
-
"""Remove weight normalization module from all of the layers."""
|
427 |
-
def _remove_weight_norm(m):
|
428 |
-
try:
|
429 |
-
logging.debug(f"Weight norm is removed from {m}.")
|
430 |
-
torch.nn.utils.remove_weight_norm(m)
|
431 |
-
except ValueError: # this module didn't have weight norm
|
432 |
-
return
|
433 |
-
|
434 |
-
self.apply(_remove_weight_norm)
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Copyright 2019 Tomoki Hayashi
|
4 |
+
# MIT License (https://opensource.org/licenses/MIT)
|
5 |
+
|
6 |
+
"""Parallel WaveGAN Modules."""
|
7 |
+
|
8 |
+
import logging
|
9 |
+
import math
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
from modules.parallel_wavegan.layers import Conv1d
|
15 |
+
from modules.parallel_wavegan.layers import Conv1d1x1
|
16 |
+
from modules.parallel_wavegan.layers import ResidualBlock
|
17 |
+
from modules.parallel_wavegan.layers import upsample
|
18 |
+
from modules.parallel_wavegan import models
|
19 |
+
|
20 |
+
|
21 |
+
class ParallelWaveGANGenerator(torch.nn.Module):
|
22 |
+
"""Parallel WaveGAN Generator module."""
|
23 |
+
|
24 |
+
def __init__(self,
|
25 |
+
in_channels=1,
|
26 |
+
out_channels=1,
|
27 |
+
kernel_size=3,
|
28 |
+
layers=30,
|
29 |
+
stacks=3,
|
30 |
+
residual_channels=64,
|
31 |
+
gate_channels=128,
|
32 |
+
skip_channels=64,
|
33 |
+
aux_channels=80,
|
34 |
+
aux_context_window=2,
|
35 |
+
dropout=0.0,
|
36 |
+
bias=True,
|
37 |
+
use_weight_norm=True,
|
38 |
+
use_causal_conv=False,
|
39 |
+
upsample_conditional_features=True,
|
40 |
+
upsample_net="ConvInUpsampleNetwork",
|
41 |
+
upsample_params={"upsample_scales": [4, 4, 4, 4]},
|
42 |
+
use_pitch_embed=False,
|
43 |
+
):
|
44 |
+
"""Initialize Parallel WaveGAN Generator module.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
in_channels (int): Number of input channels.
|
48 |
+
out_channels (int): Number of output channels.
|
49 |
+
kernel_size (int): Kernel size of dilated convolution.
|
50 |
+
layers (int): Number of residual block layers.
|
51 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
52 |
+
residual_channels (int): Number of channels in residual conv.
|
53 |
+
gate_channels (int): Number of channels in gated conv.
|
54 |
+
skip_channels (int): Number of channels in skip conv.
|
55 |
+
aux_channels (int): Number of channels for auxiliary feature conv.
|
56 |
+
aux_context_window (int): Context window size for auxiliary feature.
|
57 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
58 |
+
bias (bool): Whether to use bias parameter in conv layer.
|
59 |
+
use_weight_norm (bool): Whether to use weight norm.
|
60 |
+
If set to true, it will be applied to all of the conv layers.
|
61 |
+
use_causal_conv (bool): Whether to use causal structure.
|
62 |
+
upsample_conditional_features (bool): Whether to use upsampling network.
|
63 |
+
upsample_net (str): Upsampling network architecture.
|
64 |
+
upsample_params (dict): Upsampling network parameters.
|
65 |
+
|
66 |
+
"""
|
67 |
+
super(ParallelWaveGANGenerator, self).__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
self.out_channels = out_channels
|
70 |
+
self.aux_channels = aux_channels
|
71 |
+
self.layers = layers
|
72 |
+
self.stacks = stacks
|
73 |
+
self.kernel_size = kernel_size
|
74 |
+
|
75 |
+
# check the number of layers and stacks
|
76 |
+
assert layers % stacks == 0
|
77 |
+
layers_per_stack = layers // stacks
|
78 |
+
|
79 |
+
# define first convolution
|
80 |
+
self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
|
81 |
+
|
82 |
+
# define conv + upsampling network
|
83 |
+
if upsample_conditional_features:
|
84 |
+
upsample_params.update({
|
85 |
+
"use_causal_conv": use_causal_conv,
|
86 |
+
})
|
87 |
+
if upsample_net == "MelGANGenerator":
|
88 |
+
assert aux_context_window == 0
|
89 |
+
upsample_params.update({
|
90 |
+
"use_weight_norm": False, # not to apply twice
|
91 |
+
"use_final_nonlinear_activation": False,
|
92 |
+
})
|
93 |
+
self.upsample_net = getattr(models, upsample_net)(**upsample_params)
|
94 |
+
else:
|
95 |
+
if upsample_net == "ConvInUpsampleNetwork":
|
96 |
+
upsample_params.update({
|
97 |
+
"aux_channels": aux_channels,
|
98 |
+
"aux_context_window": aux_context_window,
|
99 |
+
})
|
100 |
+
self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
|
101 |
+
else:
|
102 |
+
self.upsample_net = None
|
103 |
+
|
104 |
+
# define residual blocks
|
105 |
+
self.conv_layers = torch.nn.ModuleList()
|
106 |
+
for layer in range(layers):
|
107 |
+
dilation = 2 ** (layer % layers_per_stack)
|
108 |
+
conv = ResidualBlock(
|
109 |
+
kernel_size=kernel_size,
|
110 |
+
residual_channels=residual_channels,
|
111 |
+
gate_channels=gate_channels,
|
112 |
+
skip_channels=skip_channels,
|
113 |
+
aux_channels=aux_channels,
|
114 |
+
dilation=dilation,
|
115 |
+
dropout=dropout,
|
116 |
+
bias=bias,
|
117 |
+
use_causal_conv=use_causal_conv,
|
118 |
+
)
|
119 |
+
self.conv_layers += [conv]
|
120 |
+
|
121 |
+
# define output layers
|
122 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
123 |
+
torch.nn.ReLU(inplace=True),
|
124 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
125 |
+
torch.nn.ReLU(inplace=True),
|
126 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
127 |
+
])
|
128 |
+
|
129 |
+
self.use_pitch_embed = use_pitch_embed
|
130 |
+
if use_pitch_embed:
|
131 |
+
self.pitch_embed = nn.Embedding(300, aux_channels, 0)
|
132 |
+
self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
|
133 |
+
|
134 |
+
# apply weight norm
|
135 |
+
if use_weight_norm:
|
136 |
+
self.apply_weight_norm()
|
137 |
+
|
138 |
+
def forward(self, x, c=None, pitch=None, **kwargs):
|
139 |
+
"""Calculate forward propagation.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
x (Tensor): Input noise signal (B, C_in, T).
|
143 |
+
c (Tensor): Local conditioning auxiliary features (B, C ,T').
|
144 |
+
pitch (Tensor): Local conditioning pitch (B, T').
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Tensor: Output tensor (B, C_out, T)
|
148 |
+
|
149 |
+
"""
|
150 |
+
# perform upsampling
|
151 |
+
if c is not None and self.upsample_net is not None:
|
152 |
+
if self.use_pitch_embed:
|
153 |
+
p = self.pitch_embed(pitch)
|
154 |
+
c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
|
155 |
+
c = self.upsample_net(c)
|
156 |
+
assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
|
157 |
+
|
158 |
+
# encode to hidden representation
|
159 |
+
x = self.first_conv(x)
|
160 |
+
skips = 0
|
161 |
+
for f in self.conv_layers:
|
162 |
+
x, h = f(x, c)
|
163 |
+
skips += h
|
164 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
165 |
+
|
166 |
+
# apply final layers
|
167 |
+
x = skips
|
168 |
+
for f in self.last_conv_layers:
|
169 |
+
x = f(x)
|
170 |
+
|
171 |
+
return x
|
172 |
+
|
173 |
+
def remove_weight_norm(self):
|
174 |
+
"""Remove weight normalization module from all of the layers."""
|
175 |
+
def _remove_weight_norm(m):
|
176 |
+
try:
|
177 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
178 |
+
torch.nn.utils.remove_weight_norm(m)
|
179 |
+
except ValueError: # this module didn't have weight norm
|
180 |
+
return
|
181 |
+
|
182 |
+
self.apply(_remove_weight_norm)
|
183 |
+
|
184 |
+
def apply_weight_norm(self):
|
185 |
+
"""Apply weight normalization module from all of the layers."""
|
186 |
+
def _apply_weight_norm(m):
|
187 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
188 |
+
torch.nn.utils.weight_norm(m)
|
189 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
190 |
+
|
191 |
+
self.apply(_apply_weight_norm)
|
192 |
+
|
193 |
+
@staticmethod
|
194 |
+
def _get_receptive_field_size(layers, stacks, kernel_size,
|
195 |
+
dilation=lambda x: 2 ** x):
|
196 |
+
assert layers % stacks == 0
|
197 |
+
layers_per_cycle = layers // stacks
|
198 |
+
dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
|
199 |
+
return (kernel_size - 1) * sum(dilations) + 1
|
200 |
+
|
201 |
+
@property
|
202 |
+
def receptive_field_size(self):
|
203 |
+
"""Return receptive field size."""
|
204 |
+
return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
|
205 |
+
|
206 |
+
|
207 |
+
class ParallelWaveGANDiscriminator(torch.nn.Module):
|
208 |
+
"""Parallel WaveGAN Discriminator module."""
|
209 |
+
|
210 |
+
def __init__(self,
|
211 |
+
in_channels=1,
|
212 |
+
out_channels=1,
|
213 |
+
kernel_size=3,
|
214 |
+
layers=10,
|
215 |
+
conv_channels=64,
|
216 |
+
dilation_factor=1,
|
217 |
+
nonlinear_activation="LeakyReLU",
|
218 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
219 |
+
bias=True,
|
220 |
+
use_weight_norm=True,
|
221 |
+
):
|
222 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
in_channels (int): Number of input channels.
|
226 |
+
out_channels (int): Number of output channels.
|
227 |
+
kernel_size (int): Number of output channels.
|
228 |
+
layers (int): Number of conv layers.
|
229 |
+
conv_channels (int): Number of chnn layers.
|
230 |
+
dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
|
231 |
+
the dilation will be 2, 4, 8, ..., and so on.
|
232 |
+
nonlinear_activation (str): Nonlinear function after each conv.
|
233 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
234 |
+
bias (bool): Whether to use bias parameter in conv.
|
235 |
+
use_weight_norm (bool) Whether to use weight norm.
|
236 |
+
If set to true, it will be applied to all of the conv layers.
|
237 |
+
|
238 |
+
"""
|
239 |
+
super(ParallelWaveGANDiscriminator, self).__init__()
|
240 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
241 |
+
assert dilation_factor > 0, "Dilation factor must be > 0."
|
242 |
+
self.conv_layers = torch.nn.ModuleList()
|
243 |
+
conv_in_channels = in_channels
|
244 |
+
for i in range(layers - 1):
|
245 |
+
if i == 0:
|
246 |
+
dilation = 1
|
247 |
+
else:
|
248 |
+
dilation = i if dilation_factor == 1 else dilation_factor ** i
|
249 |
+
conv_in_channels = conv_channels
|
250 |
+
padding = (kernel_size - 1) // 2 * dilation
|
251 |
+
conv_layer = [
|
252 |
+
Conv1d(conv_in_channels, conv_channels,
|
253 |
+
kernel_size=kernel_size, padding=padding,
|
254 |
+
dilation=dilation, bias=bias),
|
255 |
+
getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
|
256 |
+
]
|
257 |
+
self.conv_layers += conv_layer
|
258 |
+
padding = (kernel_size - 1) // 2
|
259 |
+
last_conv_layer = Conv1d(
|
260 |
+
conv_in_channels, out_channels,
|
261 |
+
kernel_size=kernel_size, padding=padding, bias=bias)
|
262 |
+
self.conv_layers += [last_conv_layer]
|
263 |
+
|
264 |
+
# apply weight norm
|
265 |
+
if use_weight_norm:
|
266 |
+
self.apply_weight_norm()
|
267 |
+
|
268 |
+
def forward(self, x):
|
269 |
+
"""Calculate forward propagation.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
x (Tensor): Input noise signal (B, 1, T).
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Tensor: Output tensor (B, 1, T)
|
276 |
+
|
277 |
+
"""
|
278 |
+
for f in self.conv_layers:
|
279 |
+
x = f(x)
|
280 |
+
return x
|
281 |
+
|
282 |
+
def apply_weight_norm(self):
|
283 |
+
"""Apply weight normalization module from all of the layers."""
|
284 |
+
def _apply_weight_norm(m):
|
285 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
286 |
+
torch.nn.utils.weight_norm(m)
|
287 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
288 |
+
|
289 |
+
self.apply(_apply_weight_norm)
|
290 |
+
|
291 |
+
def remove_weight_norm(self):
|
292 |
+
"""Remove weight normalization module from all of the layers."""
|
293 |
+
def _remove_weight_norm(m):
|
294 |
+
try:
|
295 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
296 |
+
torch.nn.utils.remove_weight_norm(m)
|
297 |
+
except ValueError: # this module didn't have weight norm
|
298 |
+
return
|
299 |
+
|
300 |
+
self.apply(_remove_weight_norm)
|
301 |
+
|
302 |
+
|
303 |
+
class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
|
304 |
+
"""Parallel WaveGAN Discriminator module."""
|
305 |
+
|
306 |
+
def __init__(self,
|
307 |
+
in_channels=1,
|
308 |
+
out_channels=1,
|
309 |
+
kernel_size=3,
|
310 |
+
layers=30,
|
311 |
+
stacks=3,
|
312 |
+
residual_channels=64,
|
313 |
+
gate_channels=128,
|
314 |
+
skip_channels=64,
|
315 |
+
dropout=0.0,
|
316 |
+
bias=True,
|
317 |
+
use_weight_norm=True,
|
318 |
+
use_causal_conv=False,
|
319 |
+
nonlinear_activation="LeakyReLU",
|
320 |
+
nonlinear_activation_params={"negative_slope": 0.2},
|
321 |
+
):
|
322 |
+
"""Initialize Parallel WaveGAN Discriminator module.
|
323 |
+
|
324 |
+
Args:
|
325 |
+
in_channels (int): Number of input channels.
|
326 |
+
out_channels (int): Number of output channels.
|
327 |
+
kernel_size (int): Kernel size of dilated convolution.
|
328 |
+
layers (int): Number of residual block layers.
|
329 |
+
stacks (int): Number of stacks i.e., dilation cycles.
|
330 |
+
residual_channels (int): Number of channels in residual conv.
|
331 |
+
gate_channels (int): Number of channels in gated conv.
|
332 |
+
skip_channels (int): Number of channels in skip conv.
|
333 |
+
dropout (float): Dropout rate. 0.0 means no dropout applied.
|
334 |
+
bias (bool): Whether to use bias parameter in conv.
|
335 |
+
use_weight_norm (bool): Whether to use weight norm.
|
336 |
+
If set to true, it will be applied to all of the conv layers.
|
337 |
+
use_causal_conv (bool): Whether to use causal structure.
|
338 |
+
nonlinear_activation_params (dict): Nonlinear function parameters
|
339 |
+
|
340 |
+
"""
|
341 |
+
super(ResidualParallelWaveGANDiscriminator, self).__init__()
|
342 |
+
assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
|
343 |
+
|
344 |
+
self.in_channels = in_channels
|
345 |
+
self.out_channels = out_channels
|
346 |
+
self.layers = layers
|
347 |
+
self.stacks = stacks
|
348 |
+
self.kernel_size = kernel_size
|
349 |
+
|
350 |
+
# check the number of layers and stacks
|
351 |
+
assert layers % stacks == 0
|
352 |
+
layers_per_stack = layers // stacks
|
353 |
+
|
354 |
+
# define first convolution
|
355 |
+
self.first_conv = torch.nn.Sequential(
|
356 |
+
Conv1d1x1(in_channels, residual_channels, bias=True),
|
357 |
+
getattr(torch.nn, nonlinear_activation)(
|
358 |
+
inplace=True, **nonlinear_activation_params),
|
359 |
+
)
|
360 |
+
|
361 |
+
# define residual blocks
|
362 |
+
self.conv_layers = torch.nn.ModuleList()
|
363 |
+
for layer in range(layers):
|
364 |
+
dilation = 2 ** (layer % layers_per_stack)
|
365 |
+
conv = ResidualBlock(
|
366 |
+
kernel_size=kernel_size,
|
367 |
+
residual_channels=residual_channels,
|
368 |
+
gate_channels=gate_channels,
|
369 |
+
skip_channels=skip_channels,
|
370 |
+
aux_channels=-1,
|
371 |
+
dilation=dilation,
|
372 |
+
dropout=dropout,
|
373 |
+
bias=bias,
|
374 |
+
use_causal_conv=use_causal_conv,
|
375 |
+
)
|
376 |
+
self.conv_layers += [conv]
|
377 |
+
|
378 |
+
# define output layers
|
379 |
+
self.last_conv_layers = torch.nn.ModuleList([
|
380 |
+
getattr(torch.nn, nonlinear_activation)(
|
381 |
+
inplace=True, **nonlinear_activation_params),
|
382 |
+
Conv1d1x1(skip_channels, skip_channels, bias=True),
|
383 |
+
getattr(torch.nn, nonlinear_activation)(
|
384 |
+
inplace=True, **nonlinear_activation_params),
|
385 |
+
Conv1d1x1(skip_channels, out_channels, bias=True),
|
386 |
+
])
|
387 |
+
|
388 |
+
# apply weight norm
|
389 |
+
if use_weight_norm:
|
390 |
+
self.apply_weight_norm()
|
391 |
+
|
392 |
+
def forward(self, x):
|
393 |
+
"""Calculate forward propagation.
|
394 |
+
|
395 |
+
Args:
|
396 |
+
x (Tensor): Input noise signal (B, 1, T).
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
Tensor: Output tensor (B, 1, T)
|
400 |
+
|
401 |
+
"""
|
402 |
+
x = self.first_conv(x)
|
403 |
+
|
404 |
+
skips = 0
|
405 |
+
for f in self.conv_layers:
|
406 |
+
x, h = f(x, None)
|
407 |
+
skips += h
|
408 |
+
skips *= math.sqrt(1.0 / len(self.conv_layers))
|
409 |
+
|
410 |
+
# apply final layers
|
411 |
+
x = skips
|
412 |
+
for f in self.last_conv_layers:
|
413 |
+
x = f(x)
|
414 |
+
return x
|
415 |
+
|
416 |
+
def apply_weight_norm(self):
|
417 |
+
"""Apply weight normalization module from all of the layers."""
|
418 |
+
def _apply_weight_norm(m):
|
419 |
+
if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
|
420 |
+
torch.nn.utils.weight_norm(m)
|
421 |
+
logging.debug(f"Weight norm is applied to {m}.")
|
422 |
+
|
423 |
+
self.apply(_apply_weight_norm)
|
424 |
+
|
425 |
+
def remove_weight_norm(self):
|
426 |
+
"""Remove weight normalization module from all of the layers."""
|
427 |
+
def _remove_weight_norm(m):
|
428 |
+
try:
|
429 |
+
logging.debug(f"Weight norm is removed from {m}.")
|
430 |
+
torch.nn.utils.remove_weight_norm(m)
|
431 |
+
except ValueError: # this module didn't have weight norm
|
432 |
+
return
|
433 |
+
|
434 |
+
self.apply(_remove_weight_norm)
|
usr/configs/midi/cascade/opencs/ds60_rel.yaml
CHANGED
@@ -24,10 +24,11 @@ fs2_ckpt: 'checkpoints/0302_opencpop_fs_midi/model_ckpt_steps_160000.ckpt' #
|
|
24 |
task_cls: usr.diffsinger_task.DiffSingerMIDITask
|
25 |
|
26 |
K_step: 60
|
27 |
-
max_tokens:
|
28 |
predictor_layers: 5
|
29 |
dilation_cycle_length: 4 # *
|
30 |
rel_pos: true
|
31 |
dur_predictor_layers: 5 # *
|
32 |
max_updates: 160000
|
33 |
gaussian_start: false
|
|
|
|
24 |
task_cls: usr.diffsinger_task.DiffSingerMIDITask
|
25 |
|
26 |
K_step: 60
|
27 |
+
max_tokens: 36000
|
28 |
predictor_layers: 5
|
29 |
dilation_cycle_length: 4 # *
|
30 |
rel_pos: true
|
31 |
dur_predictor_layers: 5 # *
|
32 |
max_updates: 160000
|
33 |
gaussian_start: false
|
34 |
+
mask_uv_prob: 0.15
|
usr/diff/shallow_diffusion_tts.py
CHANGED
@@ -1,273 +1,324 @@
|
|
1 |
-
import math
|
2 |
-
import random
|
3 |
-
from
|
4 |
-
from
|
5 |
-
from
|
6 |
-
|
7 |
-
import
|
8 |
-
import torch
|
9 |
-
|
10 |
-
from
|
11 |
-
from
|
12 |
-
|
13 |
-
|
14 |
-
from modules.
|
15 |
-
from
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
alphas_cumprod =
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
"
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
self.
|
96 |
-
self.
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
self.register_buffer('
|
106 |
-
|
107 |
-
|
108 |
-
self.register_buffer('
|
109 |
-
self.register_buffer('
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
self.register_buffer('
|
118 |
-
|
119 |
-
self.register_buffer('
|
120 |
-
|
121 |
-
|
122 |
-
self.register_buffer('
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
return
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
x = self.
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
265 |
-
|
266 |
-
|
267 |
-
|
268 |
-
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
273 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
from collections import deque
|
4 |
+
from functools import partial
|
5 |
+
from inspect import isfunction
|
6 |
+
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch import nn
|
11 |
+
from tqdm import tqdm
|
12 |
+
from einops import rearrange
|
13 |
+
|
14 |
+
from modules.fastspeech.fs2 import FastSpeech2
|
15 |
+
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
16 |
+
from utils.hparams import hparams
|
17 |
+
|
18 |
+
|
19 |
+
|
20 |
+
def exists(x):
|
21 |
+
return x is not None
|
22 |
+
|
23 |
+
|
24 |
+
def default(val, d):
|
25 |
+
if exists(val):
|
26 |
+
return val
|
27 |
+
return d() if isfunction(d) else d
|
28 |
+
|
29 |
+
|
30 |
+
# gaussian diffusion trainer class
|
31 |
+
|
32 |
+
def extract(a, t, x_shape):
|
33 |
+
b, *_ = t.shape
|
34 |
+
out = a.gather(-1, t)
|
35 |
+
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
36 |
+
|
37 |
+
|
38 |
+
def noise_like(shape, device, repeat=False):
|
39 |
+
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
40 |
+
noise = lambda: torch.randn(shape, device=device)
|
41 |
+
return repeat_noise() if repeat else noise()
|
42 |
+
|
43 |
+
|
44 |
+
def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
|
45 |
+
"""
|
46 |
+
linear schedule
|
47 |
+
"""
|
48 |
+
betas = np.linspace(1e-4, max_beta, timesteps)
|
49 |
+
return betas
|
50 |
+
|
51 |
+
|
52 |
+
def cosine_beta_schedule(timesteps, s=0.008):
|
53 |
+
"""
|
54 |
+
cosine schedule
|
55 |
+
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
56 |
+
"""
|
57 |
+
steps = timesteps + 1
|
58 |
+
x = np.linspace(0, steps, steps)
|
59 |
+
alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
|
60 |
+
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
61 |
+
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
62 |
+
return np.clip(betas, a_min=0, a_max=0.999)
|
63 |
+
|
64 |
+
|
65 |
+
beta_schedule = {
|
66 |
+
"cosine": cosine_beta_schedule,
|
67 |
+
"linear": linear_beta_schedule,
|
68 |
+
}
|
69 |
+
|
70 |
+
|
71 |
+
class GaussianDiffusion(nn.Module):
|
72 |
+
def __init__(self, phone_encoder, out_dims, denoise_fn,
|
73 |
+
timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
|
74 |
+
super().__init__()
|
75 |
+
self.denoise_fn = denoise_fn
|
76 |
+
if hparams.get('use_midi') is not None and hparams['use_midi']:
|
77 |
+
self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
|
78 |
+
else:
|
79 |
+
self.fs2 = FastSpeech2(phone_encoder, out_dims)
|
80 |
+
self.mel_bins = out_dims
|
81 |
+
|
82 |
+
if exists(betas):
|
83 |
+
betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
|
84 |
+
else:
|
85 |
+
if 'schedule_type' in hparams.keys():
|
86 |
+
betas = beta_schedule[hparams['schedule_type']](timesteps)
|
87 |
+
else:
|
88 |
+
betas = cosine_beta_schedule(timesteps)
|
89 |
+
|
90 |
+
alphas = 1. - betas
|
91 |
+
alphas_cumprod = np.cumprod(alphas, axis=0)
|
92 |
+
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
93 |
+
|
94 |
+
timesteps, = betas.shape
|
95 |
+
self.num_timesteps = int(timesteps)
|
96 |
+
self.K_step = K_step
|
97 |
+
self.loss_type = loss_type
|
98 |
+
|
99 |
+
self.noise_list = deque(maxlen=4)
|
100 |
+
|
101 |
+
to_torch = partial(torch.tensor, dtype=torch.float32)
|
102 |
+
|
103 |
+
self.register_buffer('betas', to_torch(betas))
|
104 |
+
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
105 |
+
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
106 |
+
|
107 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
108 |
+
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
109 |
+
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
110 |
+
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
111 |
+
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
112 |
+
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
113 |
+
|
114 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
115 |
+
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
|
116 |
+
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
117 |
+
self.register_buffer('posterior_variance', to_torch(posterior_variance))
|
118 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
119 |
+
self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
|
120 |
+
self.register_buffer('posterior_mean_coef1', to_torch(
|
121 |
+
betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
|
122 |
+
self.register_buffer('posterior_mean_coef2', to_torch(
|
123 |
+
(1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
|
124 |
+
|
125 |
+
self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
|
126 |
+
self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
|
127 |
+
|
128 |
+
def q_mean_variance(self, x_start, t):
|
129 |
+
mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
130 |
+
variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
|
131 |
+
log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
132 |
+
return mean, variance, log_variance
|
133 |
+
|
134 |
+
def predict_start_from_noise(self, x_t, t, noise):
|
135 |
+
return (
|
136 |
+
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
|
137 |
+
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
138 |
+
)
|
139 |
+
|
140 |
+
def q_posterior(self, x_start, x_t, t):
|
141 |
+
posterior_mean = (
|
142 |
+
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
|
143 |
+
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
144 |
+
)
|
145 |
+
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
146 |
+
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
147 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
148 |
+
|
149 |
+
def p_mean_variance(self, x, t, cond, clip_denoised: bool):
|
150 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
151 |
+
x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
|
152 |
+
|
153 |
+
if clip_denoised:
|
154 |
+
x_recon.clamp_(-1., 1.)
|
155 |
+
|
156 |
+
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
|
157 |
+
return model_mean, posterior_variance, posterior_log_variance
|
158 |
+
|
159 |
+
@torch.no_grad()
|
160 |
+
def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
|
161 |
+
b, *_, device = *x.shape, x.device
|
162 |
+
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
|
163 |
+
noise = noise_like(x.shape, device, repeat_noise)
|
164 |
+
# no noise when t == 0
|
165 |
+
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
166 |
+
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
167 |
+
|
168 |
+
@torch.no_grad()
|
169 |
+
def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
|
170 |
+
"""
|
171 |
+
Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
|
172 |
+
"""
|
173 |
+
|
174 |
+
def get_x_pred(x, noise_t, t):
|
175 |
+
a_t = extract(self.alphas_cumprod, t, x.shape)
|
176 |
+
if t[0] < interval:
|
177 |
+
a_prev = torch.ones_like(a_t)
|
178 |
+
else:
|
179 |
+
a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
|
180 |
+
a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
|
181 |
+
|
182 |
+
x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
|
183 |
+
x_pred = x + x_delta
|
184 |
+
|
185 |
+
return x_pred
|
186 |
+
|
187 |
+
noise_list = self.noise_list
|
188 |
+
noise_pred = self.denoise_fn(x, t, cond=cond)
|
189 |
+
|
190 |
+
if len(noise_list) == 0:
|
191 |
+
x_pred = get_x_pred(x, noise_pred, t)
|
192 |
+
noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
|
193 |
+
noise_pred_prime = (noise_pred + noise_pred_prev) / 2
|
194 |
+
elif len(noise_list) == 1:
|
195 |
+
noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
|
196 |
+
elif len(noise_list) == 2:
|
197 |
+
noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
|
198 |
+
elif len(noise_list) >= 3:
|
199 |
+
noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
|
200 |
+
|
201 |
+
x_prev = get_x_pred(x, noise_pred_prime, t)
|
202 |
+
noise_list.append(noise_pred)
|
203 |
+
|
204 |
+
return x_prev
|
205 |
+
|
206 |
+
def q_sample(self, x_start, t, noise=None):
|
207 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
208 |
+
return (
|
209 |
+
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
210 |
+
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
211 |
+
)
|
212 |
+
|
213 |
+
def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
|
214 |
+
noise = default(noise, lambda: torch.randn_like(x_start))
|
215 |
+
|
216 |
+
x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
|
217 |
+
x_recon = self.denoise_fn(x_noisy, t, cond)
|
218 |
+
|
219 |
+
if self.loss_type == 'l1':
|
220 |
+
if nonpadding is not None:
|
221 |
+
loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
|
222 |
+
else:
|
223 |
+
# print('are you sure w/o nonpadding?')
|
224 |
+
loss = (noise - x_recon).abs().mean()
|
225 |
+
|
226 |
+
elif self.loss_type == 'l2':
|
227 |
+
loss = F.mse_loss(noise, x_recon)
|
228 |
+
else:
|
229 |
+
raise NotImplementedError()
|
230 |
+
|
231 |
+
return loss
|
232 |
+
|
233 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
234 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
235 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
236 |
+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
237 |
+
skip_decoder=(not infer), infer=infer, **kwargs)
|
238 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
239 |
+
|
240 |
+
if not infer:
|
241 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
242 |
+
x = ref_mels
|
243 |
+
x = self.norm_spec(x)
|
244 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
245 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
246 |
+
# nonpadding = (mel2ph != 0).float()
|
247 |
+
# ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
|
248 |
+
else:
|
249 |
+
ret['fs2_mel'] = ret['mel_out']
|
250 |
+
fs2_mels = ret['mel_out']
|
251 |
+
t = self.K_step
|
252 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
253 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
254 |
+
|
255 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
256 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
257 |
+
print('===> gaussion start.')
|
258 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
259 |
+
x = torch.randn(shape, device=device)
|
260 |
+
|
261 |
+
if hparams.get('pndm_speedup'):
|
262 |
+
print('===> pndm speed:', hparams['pndm_speedup'])
|
263 |
+
self.noise_list = deque(maxlen=4)
|
264 |
+
iteration_interval = hparams['pndm_speedup']
|
265 |
+
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|
266 |
+
total=t // iteration_interval):
|
267 |
+
x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
|
268 |
+
cond)
|
269 |
+
else:
|
270 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
271 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
272 |
+
x = x[:, 0].transpose(1, 2)
|
273 |
+
if mel2ph is not None: # for singing
|
274 |
+
ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
|
275 |
+
else:
|
276 |
+
ret['mel_out'] = self.denorm_spec(x)
|
277 |
+
return ret
|
278 |
+
|
279 |
+
def norm_spec(self, x):
|
280 |
+
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
281 |
+
|
282 |
+
def denorm_spec(self, x):
|
283 |
+
return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
|
284 |
+
|
285 |
+
def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
|
286 |
+
return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
|
287 |
+
|
288 |
+
def out2mel(self, x):
|
289 |
+
return x
|
290 |
+
|
291 |
+
|
292 |
+
class OfflineGaussianDiffusion(GaussianDiffusion):
|
293 |
+
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
294 |
+
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
295 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
296 |
+
|
297 |
+
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
298 |
+
skip_decoder=True, infer=True, **kwargs)
|
299 |
+
cond = ret['decoder_inp'].transpose(1, 2)
|
300 |
+
fs2_mels = ref_mels[1]
|
301 |
+
ref_mels = ref_mels[0]
|
302 |
+
|
303 |
+
if not infer:
|
304 |
+
t = torch.randint(0, self.K_step, (b,), device=device).long()
|
305 |
+
x = ref_mels
|
306 |
+
x = self.norm_spec(x)
|
307 |
+
x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
|
308 |
+
ret['diff_loss'] = self.p_losses(x, t, cond)
|
309 |
+
else:
|
310 |
+
t = self.K_step
|
311 |
+
fs2_mels = self.norm_spec(fs2_mels)
|
312 |
+
fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
|
313 |
+
|
314 |
+
x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
|
315 |
+
|
316 |
+
if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
|
317 |
+
print('===> gaussion start.')
|
318 |
+
shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
|
319 |
+
x = torch.randn(shape, device=device)
|
320 |
+
for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
|
321 |
+
x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
|
322 |
+
x = x[:, 0].transpose(1, 2)
|
323 |
+
ret['mel_out'] = self.denorm_spec(x)
|
324 |
+
return ret
|
utils/hparams.py
CHANGED
@@ -21,35 +21,30 @@ def override_config(old_config: dict, new_config: dict):
|
|
21 |
|
22 |
|
23 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
24 |
-
if config == ''
|
25 |
-
parser = argparse.ArgumentParser(description='')
|
26 |
parser.add_argument('--config', type=str, default='',
|
27 |
help='location of the data corpus')
|
28 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
29 |
-
parser.add_argument('
|
30 |
help='location of the data corpus')
|
31 |
parser.add_argument('--infer', action='store_true', help='infer')
|
32 |
parser.add_argument('--validate', action='store_true', help='validate')
|
33 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
34 |
-
parser.add_argument('--remove', action='store_true', help='remove old ckpt')
|
35 |
parser.add_argument('--debug', action='store_true', help='debug')
|
36 |
args, unknown = parser.parse_known_args()
|
37 |
-
print("| Unknow hparams: ", unknown)
|
38 |
else:
|
39 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
40 |
-
infer=False, validate=False, reset=False, debug=False
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
|
46 |
config_chains = []
|
47 |
loaded_config = set()
|
48 |
|
49 |
-
def load_config(config_fn):
|
50 |
-
# deep first inheritance and avoid the second visit of one node
|
51 |
-
if not os.path.exists(config_fn):
|
52 |
-
return {}
|
53 |
with open(config_fn) as f:
|
54 |
hparams_ = yaml.safe_load(f)
|
55 |
loaded_config.add(config_fn)
|
@@ -58,10 +53,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
58 |
if not isinstance(hparams_['base_config'], list):
|
59 |
hparams_['base_config'] = [hparams_['base_config']]
|
60 |
for c in hparams_['base_config']:
|
61 |
-
if c.startswith('.'):
|
62 |
-
c = f'{os.path.dirname(config_fn)}/{c}'
|
63 |
-
c = os.path.normpath(c)
|
64 |
if c not in loaded_config:
|
|
|
|
|
|
|
65 |
override_config(ret_hparams, load_config(c))
|
66 |
override_config(ret_hparams, hparams_)
|
67 |
else:
|
@@ -69,43 +64,36 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
69 |
config_chains.append(config_fn)
|
70 |
return ret_hparams
|
71 |
|
|
|
|
|
72 |
saved_hparams = {}
|
73 |
-
args_work_dir
|
74 |
-
if args.exp_name != '':
|
75 |
-
args_work_dir = f'checkpoints/{args.exp_name}'
|
76 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
77 |
if os.path.exists(ckpt_config_path):
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
82 |
hparams_ = {}
|
83 |
-
|
84 |
-
|
|
|
85 |
if not args.reset:
|
86 |
hparams_.update(saved_hparams)
|
87 |
hparams_['work_dir'] = args_work_dir
|
88 |
|
89 |
-
# Support config overriding in command line. Support list type config overriding.
|
90 |
-
# Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
|
91 |
if args.hparams != "":
|
92 |
for new_hparam in args.hparams.split(","):
|
93 |
k, v = new_hparam.split("=")
|
94 |
-
v
|
95 |
-
|
96 |
-
for k_ in k.split(".")[:-1]:
|
97 |
-
config_node = config_node[k_]
|
98 |
-
k = k.split(".")[-1]
|
99 |
-
if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
|
100 |
-
if type(config_node[k]) == list:
|
101 |
-
v = v.replace(" ", ",")
|
102 |
-
config_node[k] = eval(v)
|
103 |
else:
|
104 |
-
|
105 |
-
|
106 |
-
answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
|
107 |
-
if answer.lower() == "y":
|
108 |
-
remove_file(args_work_dir)
|
109 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
110 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
111 |
with open(ckpt_config_path, 'w') as f:
|
@@ -114,11 +102,11 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
114 |
hparams_['infer'] = args.infer
|
115 |
hparams_['debug'] = args.debug
|
116 |
hparams_['validate'] = args.validate
|
117 |
-
hparams_['exp_name'] = args.exp_name
|
118 |
global global_print_hparams
|
119 |
if global_hparams:
|
120 |
hparams.clear()
|
121 |
hparams.update(hparams_)
|
|
|
122 |
if print_hparams and global_print_hparams and global_hparams:
|
123 |
print('| Hparams chains: ', config_chains)
|
124 |
print('| Hparams: ')
|
@@ -126,5 +114,9 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
|
|
126 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
127 |
print("")
|
128 |
global_print_hparams = False
|
|
|
|
|
|
|
|
|
|
|
129 |
return hparams_
|
130 |
-
|
|
|
21 |
|
22 |
|
23 |
def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
|
24 |
+
if config == '':
|
25 |
+
parser = argparse.ArgumentParser(description='neural music')
|
26 |
parser.add_argument('--config', type=str, default='',
|
27 |
help='location of the data corpus')
|
28 |
parser.add_argument('--exp_name', type=str, default='', help='exp_name')
|
29 |
+
parser.add_argument('--hparams', type=str, default='',
|
30 |
help='location of the data corpus')
|
31 |
parser.add_argument('--infer', action='store_true', help='infer')
|
32 |
parser.add_argument('--validate', action='store_true', help='validate')
|
33 |
parser.add_argument('--reset', action='store_true', help='reset hparams')
|
|
|
34 |
parser.add_argument('--debug', action='store_true', help='debug')
|
35 |
args, unknown = parser.parse_known_args()
|
|
|
36 |
else:
|
37 |
args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
|
38 |
+
infer=False, validate=False, reset=False, debug=False)
|
39 |
+
args_work_dir = ''
|
40 |
+
if args.exp_name != '':
|
41 |
+
args.work_dir = args.exp_name
|
42 |
+
args_work_dir = f'checkpoints/{args.work_dir}'
|
43 |
|
44 |
config_chains = []
|
45 |
loaded_config = set()
|
46 |
|
47 |
+
def load_config(config_fn): # deep first
|
|
|
|
|
|
|
48 |
with open(config_fn) as f:
|
49 |
hparams_ = yaml.safe_load(f)
|
50 |
loaded_config.add(config_fn)
|
|
|
53 |
if not isinstance(hparams_['base_config'], list):
|
54 |
hparams_['base_config'] = [hparams_['base_config']]
|
55 |
for c in hparams_['base_config']:
|
|
|
|
|
|
|
56 |
if c not in loaded_config:
|
57 |
+
if c.startswith('.'):
|
58 |
+
c = f'{os.path.dirname(config_fn)}/{c}'
|
59 |
+
c = os.path.normpath(c)
|
60 |
override_config(ret_hparams, load_config(c))
|
61 |
override_config(ret_hparams, hparams_)
|
62 |
else:
|
|
|
64 |
config_chains.append(config_fn)
|
65 |
return ret_hparams
|
66 |
|
67 |
+
global hparams
|
68 |
+
assert args.config != '' or args_work_dir != ''
|
69 |
saved_hparams = {}
|
70 |
+
if args_work_dir != 'checkpoints/':
|
|
|
|
|
71 |
ckpt_config_path = f'{args_work_dir}/config.yaml'
|
72 |
if os.path.exists(ckpt_config_path):
|
73 |
+
try:
|
74 |
+
with open(ckpt_config_path) as f:
|
75 |
+
saved_hparams.update(yaml.safe_load(f))
|
76 |
+
except:
|
77 |
+
pass
|
78 |
+
if args.config == '':
|
79 |
+
args.config = ckpt_config_path
|
80 |
+
|
81 |
hparams_ = {}
|
82 |
+
|
83 |
+
hparams_.update(load_config(args.config))
|
84 |
+
|
85 |
if not args.reset:
|
86 |
hparams_.update(saved_hparams)
|
87 |
hparams_['work_dir'] = args_work_dir
|
88 |
|
|
|
|
|
89 |
if args.hparams != "":
|
90 |
for new_hparam in args.hparams.split(","):
|
91 |
k, v = new_hparam.split("=")
|
92 |
+
if v in ['True', 'False'] or type(hparams_[k]) == bool:
|
93 |
+
hparams_[k] = eval(v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
else:
|
95 |
+
hparams_[k] = type(hparams_[k])(v)
|
96 |
+
|
|
|
|
|
|
|
97 |
if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
|
98 |
os.makedirs(hparams_['work_dir'], exist_ok=True)
|
99 |
with open(ckpt_config_path, 'w') as f:
|
|
|
102 |
hparams_['infer'] = args.infer
|
103 |
hparams_['debug'] = args.debug
|
104 |
hparams_['validate'] = args.validate
|
|
|
105 |
global global_print_hparams
|
106 |
if global_hparams:
|
107 |
hparams.clear()
|
108 |
hparams.update(hparams_)
|
109 |
+
|
110 |
if print_hparams and global_print_hparams and global_hparams:
|
111 |
print('| Hparams chains: ', config_chains)
|
112 |
print('| Hparams: ')
|
|
|
114 |
print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
|
115 |
print("")
|
116 |
global_print_hparams = False
|
117 |
+
# print(hparams_.keys())
|
118 |
+
if hparams.get('exp_name') is None:
|
119 |
+
hparams['exp_name'] = args.exp_name
|
120 |
+
if hparams_.get('exp_name') is None:
|
121 |
+
hparams_['exp_name'] = args.exp_name
|
122 |
return hparams_
|
|