ddd commited on
Commit
40e984c
1 Parent(s): c4e83e4

pndm codes

Browse files
.gitattributes CHANGED
@@ -30,3 +30,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
  model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
 
 
30
  *.zstandard filter=lfs diff=lfs merge=lfs -text
31
  *tfevents* filter=lfs diff=lfs merge=lfs -text
32
  model_ckpt_steps* filter=lfs diff=lfs merge=lfs -text
33
+ checkpoints/0831_opencpop_ds1000 filter=lfs diff=lfs merge=lfs -text
docs/README-SVS-opencpop-cascade.md CHANGED
@@ -3,7 +3,7 @@
3
  [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
  [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
 
6
- ## DiffSinger (MIDI version SVS)
7
  ### 0. Data Acquirement
8
  For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
9
 
@@ -67,7 +67,7 @@ CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/ope
67
 
68
  Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
69
 
70
- ### 3. Inference Example
71
  ```sh
72
  CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
73
  ```
@@ -82,7 +82,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
82
 
83
  ### 4. Inference from raw inputs
84
  ```sh
85
- python inference/svs/ds_e2e.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
86
  ```
87
  Raw inputs:
88
  ```
 
3
  [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
  [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
 
6
+ ## DiffSinger (MIDI SVS | A version)
7
  ### 0. Data Acquirement
8
  For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
9
 
 
67
 
68
  Remember to adjust the "fs2_ckpt" parameter in `usr/configs/midi/cascade/opencs/ds60_rel.yaml` to fit your path.
69
 
70
+ ### 3. Inference from packed test set
71
  ```sh
72
  CUDA_VISIBLE_DEVICES=0 python tasks/run.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME --reset --infer
73
  ```
 
82
 
83
  ### 4. Inference from raw inputs
84
  ```sh
85
+ python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name $MY_DS_EXP_NAME
86
  ```
87
  Raw inputs:
88
  ```
docs/README-SVS-opencpop-e2e.md CHANGED
@@ -2,13 +2,14 @@
2
  [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
  [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
  [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
 
5
 
6
  Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
7
  **By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
8
 
9
  简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
10
 
11
- ## DiffSinger (MIDI version SVS)
12
  ### 0. Data Acquirement
13
  For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
14
 
 
2
  [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
  [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
  [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+ | [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
6
 
7
  Substantial update: We 1) **abandon** the explicit prediction of the F0 curve; 2) increase the receptive field of the denoiser; 3) make the linguistic encoder more robust.
8
  **By doing so, 1) the synthesized recordings are more natural in terms of pitch; 2) the pipeline is simpler.**
9
 
10
  简而言之,把F0曲线的动态性交给生成式模型去捕捉,而不再是以前那样用MSE约束对数域F0。
11
 
12
+ ## DiffSinger (MIDI SVS | B version)
13
  ### 0. Data Acquirement
14
  For Opencpop dataset: Please strictly follow the instructions of [Opencpop](https://wenet.org.cn/opencpop/). We have no right to give you the access to Opencpop.
15
 
docs/README-SVS-popcs.md CHANGED
@@ -54,7 +54,7 @@ Remember to put the pre-trained models in `checkpoints` directory.
54
  *Note that:*
55
 
56
  - *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
57
- - *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-old-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-new-version](README-SVS-opencpop-e2e.md)).*
58
 
59
  [1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
60
 
 
54
  *Note that:*
55
 
56
  - *the original PWG version vocoder in the paper we used has been put into commercial use, so we provide this HifiGAN version vocoder as a substitute.*
57
+ - *we assume the ground-truth F0 to be given as the pitch information following [1][2][3]. If you want to conduct experiments on MIDI data, you need an external F0 predictor (like [MIDI-A-version](README-SVS-opencpop-cascade.md)) or a joint prediction with spectrograms(like [MIDI-B-version](README-SVS-opencpop-e2e.md)).*
58
 
59
  [1] Adversarially trained multi-singer sequence-to-sequence singing synthesizer. Interspeech 2020.
60
 
docs/README-SVS.md CHANGED
@@ -1,7 +1,13 @@
1
- ## DiffSinger (SVS version)
 
 
 
 
 
 
2
 
3
  ### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
4
- In this part, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3].
5
 
6
  Thus, the pipeline of this part can be summarized as:
7
 
@@ -18,13 +24,16 @@ Thus, the pipeline of this part can be summarized as:
18
 
19
  [3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
20
 
 
 
 
21
  ### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
22
- Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022**. (Also thanks to my co-author [Yi Ren](https://github.com/RayeRen), who applied for the dataset and did some preprocessing works for this part).
23
 
24
  Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
25
 
26
- #### 2.1
27
- Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summarized as:
28
 
29
  ```
30
  [lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
@@ -32,13 +41,36 @@ Thus, the pipeline of [this part](README-SVS-opencpop-cascade.md) can be summari
32
  [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
33
  ```
34
 
35
- #### 2.2
36
- In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend but make a joint prediction with spectrograms.
 
 
37
 
38
- Thus, the pipeline of [this part](README-SVS-opencpop-e2e.md) can be summarized as:
39
  ```
40
  [lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
41
  [linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
42
  [mel-spectrogram] -> [predicted F0] (Pitch extractor)
43
  [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
44
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+ | [Interactive🤗 SVS](https://huggingface.co/spaces/Silentlin/DiffSinger)
6
+
7
+ ## DiffSinger (SVS)
8
 
9
  ### PART1. [Run DiffSinger on PopCS](README-SVS-popcs.md)
10
+ In PART1, we only focus on spectrum modeling (acoustic model) and assume the ground-truth (GT) F0 to be given as the pitch information following these papers [1][2][3]. If you want to conduct experiments with F0 prediction, please move to PART2.
11
 
12
  Thus, the pipeline of this part can be summarized as:
13
 
 
24
 
25
  [3] DeepSinger : Singing Voice Synthesis with Data Mined From the Web. KDD 2020.
26
 
27
+ Click here for detailed instructions: [link](README-SVS-popcs.md).
28
+
29
+
30
  ### PART2. [Run DiffSinger on Opencpop](README-SVS-opencpop-cascade.md)
31
+ Thanks [Opencpop team](https://wenet.org.cn/opencpop/) for releasing their SVS dataset with MIDI label, **Jan.20, 2022** (after we published our paper).
32
 
33
  Since there are elaborately annotated MIDI labels, we are able to supplement the pipeline in PART 1 by adding a naive melody frontend.
34
 
35
+ #### 2.A
36
+ Thus, the pipeline of [2.A](README-SVS-opencpop-cascade.md) can be summarized as:
37
 
38
  ```
39
  [lyrics] + [MIDI] -> [linguistic representation (with MIDI information)] + [predicted F0] + [predicted phoneme duration] (Melody frontend)
 
41
  [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
42
  ```
43
 
44
+ Click here for detailed instructions: [link](README-SVS-opencpop-cascade.md).
45
+
46
+ #### 2.B
47
+ In 2.1, we find that if we predict F0 explicitly in the melody frontend, there will be many bad cases of uv/v prediction. Then, we abandon the explicit prediction of the F0 curve in the melody frontend and make a joint prediction with spectrograms.
48
 
49
+ Thus, the pipeline of [2.B](README-SVS-opencpop-e2e.md) can be summarized as:
50
  ```
51
  [lyrics] + [MIDI] -> [linguistic representation] + [predicted phoneme duration] (Melody frontend)
52
  [linguistic representation (with MIDI information)] + [predicted phoneme duration] -> [mel-spectrogram] (Acoustic model)
53
  [mel-spectrogram] -> [predicted F0] (Pitch extractor)
54
  [mel-spectrogram] + [predicted F0] -> [waveform] (Vocoder)
55
+ ```
56
+
57
+ Click here for detailed instructions: [link](README-SVS-opencpop-e2e.md).
58
+
59
+ ### FAQ
60
+ Q1: Why do I need F0 in Vocoders?
61
+
62
+ A1: See vocoder parts in HiFiSinger, DiffSinger or SingGAN. This is a common practice now.
63
+
64
+ Q2: Why not run MIDI version SVS on PopCS dataset? or Why not release MIDI labels for PopCS dataset?
65
+
66
+ A2: Our laboratory has no funds to label PopCS dataset. But there are funds for labeling other singing dataset, which is coming soon.
67
+
68
+ Q3: Why " 'HifiGAN' object has no attribute 'model' "?
69
+
70
+ A3: Please put the pretrained vocoders in your `checkpoints` dictionary.
71
+
72
+ Q4: How to check whether I use GT information or predicted information during inference from packed test set?
73
+
74
+ A4: Please see codes [here](https://github.com/MoonInTheRiver/DiffSinger/blob/55e2f46068af6e69940a9f8f02d306c24a940cab/tasks/tts/fs2.py#L343).
75
+
76
+ ...
docs/README-TTS.md CHANGED
@@ -1,4 +1,10 @@
1
- ## DiffSpeech (TTS version)
 
 
 
 
 
 
2
  ### 1. Preparation
3
 
4
  #### Data Preparation
 
1
+ # DiffSinger: Singing Voice Synthesis via Shallow Diffusion Mechanism
2
+ [![arXiv](https://img.shields.io/badge/arXiv-Paper-<COLOR>.svg)](https://arxiv.org/abs/2105.02446)
3
+ [![GitHub Stars](https://img.shields.io/github/stars/MoonInTheRiver/DiffSinger?style=social)](https://github.com/MoonInTheRiver/DiffSinger)
4
+ [![downloads](https://img.shields.io/github/downloads/MoonInTheRiver/DiffSinger/total.svg)](https://github.com/MoonInTheRiver/DiffSinger/releases)
5
+ | [Interactive🤗 TTS](https://huggingface.co/spaces/NATSpeech/DiffSpeech)
6
+
7
+ ## DiffSpeech (TTS)
8
  ### 1. Preparation
9
 
10
  #### Data Preparation
inference/svs/base_svs_infer.py CHANGED
@@ -142,7 +142,7 @@ class BaseSVSInfer:
142
  ph_seq = inp['ph_seq']
143
  note_lst = inp['note_seq'].split()
144
  midi_dur_lst = inp['note_dur_seq'].split()
145
- is_slur = inp['is_slur_seq'].split()
146
  print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
147
  if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
148
  print('Pass word-notes check.')
 
142
  ph_seq = inp['ph_seq']
143
  note_lst = inp['note_seq'].split()
144
  midi_dur_lst = inp['note_dur_seq'].split()
145
+ is_slur = [float(x) for x in inp['is_slur_seq'].split()]
146
  print(len(note_lst), len(ph_seq.split()), len(midi_dur_lst))
147
  if len(note_lst) == len(ph_seq.split()) == len(midi_dur_lst):
148
  print('Pass word-notes check.')
inference/svs/ds_cascade.py CHANGED
@@ -52,3 +52,5 @@ if __name__ == '__main__':
52
  'input_type': 'phoneme'
53
  } # input like Opencpop dataset.
54
  DiffSingerCascadeInfer.example_run(inp)
 
 
 
52
  'input_type': 'phoneme'
53
  } # input like Opencpop dataset.
54
  DiffSingerCascadeInfer.example_run(inp)
55
+
56
+ # # CUDA_VISIBLE_DEVICES=1 python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
inference/svs/ds_e2e.py CHANGED
@@ -53,7 +53,7 @@ if __name__ == '__main__':
53
  'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
54
  'input_type': 'word'
55
  } # user input: Chinese characters
56
- c = {
57
  'text': '小酒窝长睫毛AP是你最美的记号',
58
  'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
59
  'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
@@ -64,4 +64,4 @@ if __name__ == '__main__':
64
  DiffSingerE2EInfer.example_run(inp)
65
 
66
 
67
- # python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
 
53
  'notes_duration': '0.407140 | 0.376190 | 0.242180 | 0.509550 0.183420 | 0.315400 0.235020 | 0.361660 | 0.223070 | 0.377270 | 0.340550 | 0.299620 | 0.344510 | 0.283770 | 0.323390 | 0.360340',
54
  'input_type': 'word'
55
  } # user input: Chinese characters
56
+ inp = {
57
  'text': '小酒窝长睫毛AP是你最美的记号',
58
  'ph_seq': 'x iao j iu w o ch ang ang j ie ie m ao AP sh i n i z ui m ei d e j i h ao',
59
  'note_seq': 'C#4/Db4 C#4/Db4 F#4/Gb4 F#4/Gb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 F#4/Gb4 F#4/Gb4 F#4/Gb4 C#4/Db4 C#4/Db4 C#4/Db4 rest C#4/Db4 C#4/Db4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 A#4/Bb4 A#4/Bb4 G#4/Ab4 G#4/Ab4 F4 F4 C#4/Db4 C#4/Db4',
 
64
  DiffSingerE2EInfer.example_run(inp)
65
 
66
 
67
+ # CUDA_VISIBLE_DEVICES=3 python inference/svs/ds_e2e.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
inference/svs/gradio/infer.py CHANGED
@@ -88,4 +88,4 @@ if __name__ == '__main__':
88
 
89
  # python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
90
  # python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
91
- # CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
 
88
 
89
  # python inference/svs/gradio/infer.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
90
  # python inference/svs/ds_cascade.py --config usr/configs/midi/cascade/opencs/ds60_rel.yaml --exp_name 0303_opencpop_ds58_midi
91
+ # CUDA_VISIBLE_DEVICES=3 python inference/svs/gradio/infer.py --config usr/configs/midi/e2e/opencpop/ds100_adj_rel.yaml --exp_name 0228_opencpop_ds100_rel
modules/diffsinger_midi/fs2.py CHANGED
@@ -116,3 +116,113 @@ class FastSpeech2MIDI(FastSpeech2):
116
  ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
117
 
118
  return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
117
 
118
  return ret
119
+
120
+ def add_pitch(self, decoder_inp, f0, uv, mel2ph, ret, encoder_out=None):
121
+ decoder_inp = decoder_inp.detach() + hparams['predictor_grad'] * (decoder_inp - decoder_inp.detach())
122
+ pitch_padding = mel2ph == 0
123
+ if hparams['pitch_ar']:
124
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp, f0 if self.training else None)
125
+ if f0 is None:
126
+ f0 = pitch_pred[:, :, 0]
127
+ else:
128
+ ret['pitch_pred'] = pitch_pred = self.pitch_predictor(decoder_inp)
129
+ if f0 is None:
130
+ f0 = pitch_pred[:, :, 0]
131
+ if hparams['use_uv'] and uv is None:
132
+ uv = pitch_pred[:, :, 1] > 0
133
+
134
+ # here f0_denorm for pitch prediction
135
+ ret['f0_denorm'] = denorm_f0(f0, uv, hparams, pitch_padding=pitch_padding)
136
+
137
+ # here f0_denorm for mel prediction
138
+ if self.training:
139
+ mask = torch.full(uv.shape, hparams.get('mask_uv_prob', 0.)).to(f0.device)
140
+ masked_uv = torch.bernoulli(mask).bool().to(f0.device) # prob 的概率吐出一个随机uv.
141
+ uv_masked = uv.bool() | masked_uv
142
+ # print((uv.float()-uv_masked.float()).mean(dim=1))
143
+ f0_denorm = denorm_f0(f0, uv_masked, hparams, pitch_padding=pitch_padding)
144
+ else:
145
+ f0_denorm = ret['f0_denorm']
146
+
147
+ if pitch_padding is not None:
148
+ f0[pitch_padding] = 0
149
+
150
+ pitch = f0_to_coarse(f0_denorm) # start from 0
151
+ pitch_embed = self.pitch_embed(pitch)
152
+ return pitch_embed
153
+
154
+
155
+ class FastSpeech2MIDIMasked(FastSpeech2MIDI):
156
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
157
+ ref_mels=None, f0=None, uv=None, energy=None, skip_decoder=False,
158
+ spk_embed_dur_id=None, spk_embed_f0_id=None, infer=False, **kwargs):
159
+ ret = {}
160
+
161
+ midi_dur_embedding, slur_embedding = 0, 0
162
+ if kwargs.get('midi_dur') is not None:
163
+ midi_dur_embedding = self.midi_dur_layer(kwargs['midi_dur'][:, :, None]) # [B, T, 1] -> [B, T, H]
164
+ if kwargs.get('is_slur') is not None:
165
+ slur_embedding = self.is_slur_embed(kwargs['is_slur'])
166
+ encoder_out = self.encoder(txt_tokens, 0, midi_dur_embedding, slur_embedding) # [B, T, C]
167
+ src_nonpadding = (txt_tokens > 0).float()[:, :, None]
168
+
169
+ # add ref style embed
170
+ # Not implemented
171
+ # variance encoder
172
+ var_embed = 0
173
+
174
+ # encoder_out_dur denotes encoder outputs for duration predictor
175
+ # in speech adaptation, duration predictor use old speaker embedding
176
+ if hparams['use_spk_embed']:
177
+ spk_embed_dur = spk_embed_f0 = spk_embed = self.spk_embed_proj(spk_embed)[:, None, :]
178
+ elif hparams['use_spk_id']:
179
+ spk_embed_id = spk_embed
180
+ if spk_embed_dur_id is None:
181
+ spk_embed_dur_id = spk_embed_id
182
+ if spk_embed_f0_id is None:
183
+ spk_embed_f0_id = spk_embed_id
184
+ spk_embed = self.spk_embed_proj(spk_embed_id)[:, None, :]
185
+ spk_embed_dur = spk_embed_f0 = spk_embed
186
+ if hparams['use_split_spk_id']:
187
+ spk_embed_dur = self.spk_embed_dur(spk_embed_dur_id)[:, None, :]
188
+ spk_embed_f0 = self.spk_embed_f0(spk_embed_f0_id)[:, None, :]
189
+ else:
190
+ spk_embed_dur = spk_embed_f0 = spk_embed = 0
191
+
192
+ # add dur
193
+ dur_inp = (encoder_out + var_embed + spk_embed_dur) * src_nonpadding
194
+
195
+ mel2ph = self.add_dur(dur_inp, mel2ph, txt_tokens, ret)
196
+
197
+ decoder_inp = F.pad(encoder_out, [0, 0, 1, 0])
198
+
199
+ mel2ph_ = mel2ph[..., None].repeat([1, 1, encoder_out.shape[-1]])
200
+ decoder_inp = torch.gather(decoder_inp, 1, mel2ph_) # [B, T, H]
201
+
202
+ # expanded midi
203
+ midi_embedding = self.midi_embed(kwargs['pitch_midi'])
204
+ midi_embedding = F.pad(midi_embedding, [0, 0, 1, 0])
205
+ midi_embedding = torch.gather(midi_embedding, 1, mel2ph_)
206
+ print(midi_embedding.shape, decoder_inp.shape)
207
+ midi_mask = torch.full(midi_embedding.shape, hparams.get('mask_uv_prob', 0.)).to(midi_embedding.device)
208
+ midi_mask = 1 - torch.bernoulli(midi_mask).bool().to(midi_embedding.device) # prob 的概率吐出一个随机uv.
209
+
210
+ tgt_nonpadding = (mel2ph > 0).float()[:, :, None]
211
+
212
+ decoder_inp += midi_embedding
213
+ decoder_inp_origin = decoder_inp
214
+ # add pitch and energy embed
215
+ pitch_inp = (decoder_inp_origin + var_embed + spk_embed_f0) * tgt_nonpadding
216
+ if hparams['use_pitch_embed']:
217
+ pitch_inp_ph = (encoder_out + var_embed + spk_embed_f0) * src_nonpadding
218
+ decoder_inp = decoder_inp + self.add_pitch(pitch_inp, f0, uv, mel2ph, ret, encoder_out=pitch_inp_ph)
219
+ if hparams['use_energy_embed']:
220
+ decoder_inp = decoder_inp + self.add_energy(pitch_inp, energy, ret)
221
+
222
+ ret['decoder_inp'] = decoder_inp = (decoder_inp + spk_embed) * tgt_nonpadding
223
+
224
+ if skip_decoder:
225
+ return ret
226
+ ret['mel_out'] = self.run_decoder(decoder_inp, tgt_nonpadding, ret, infer=infer, **kwargs)
227
+
228
+ return ret
modules/hifigan/hifigan.py CHANGED
@@ -1,365 +1,365 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import torch.nn as nn
4
- from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
- from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
-
7
- from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
8
- from modules.parallel_wavegan.models.source import SourceModuleHnNSF
9
- import numpy as np
10
-
11
- LRELU_SLOPE = 0.1
12
-
13
-
14
- def init_weights(m, mean=0.0, std=0.01):
15
- classname = m.__class__.__name__
16
- if classname.find("Conv") != -1:
17
- m.weight.data.normal_(mean, std)
18
-
19
-
20
- def apply_weight_norm(m):
21
- classname = m.__class__.__name__
22
- if classname.find("Conv") != -1:
23
- weight_norm(m)
24
-
25
-
26
- def get_padding(kernel_size, dilation=1):
27
- return int((kernel_size * dilation - dilation) / 2)
28
-
29
-
30
- class ResBlock1(torch.nn.Module):
31
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32
- super(ResBlock1, self).__init__()
33
- self.h = h
34
- self.convs1 = nn.ModuleList([
35
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36
- padding=get_padding(kernel_size, dilation[0]))),
37
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38
- padding=get_padding(kernel_size, dilation[1]))),
39
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40
- padding=get_padding(kernel_size, dilation[2])))
41
- ])
42
- self.convs1.apply(init_weights)
43
-
44
- self.convs2 = nn.ModuleList([
45
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
- padding=get_padding(kernel_size, 1))),
47
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
- padding=get_padding(kernel_size, 1))),
49
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
- padding=get_padding(kernel_size, 1)))
51
- ])
52
- self.convs2.apply(init_weights)
53
-
54
- def forward(self, x):
55
- for c1, c2 in zip(self.convs1, self.convs2):
56
- xt = F.leaky_relu(x, LRELU_SLOPE)
57
- xt = c1(xt)
58
- xt = F.leaky_relu(xt, LRELU_SLOPE)
59
- xt = c2(xt)
60
- x = xt + x
61
- return x
62
-
63
- def remove_weight_norm(self):
64
- for l in self.convs1:
65
- remove_weight_norm(l)
66
- for l in self.convs2:
67
- remove_weight_norm(l)
68
-
69
-
70
- class ResBlock2(torch.nn.Module):
71
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72
- super(ResBlock2, self).__init__()
73
- self.h = h
74
- self.convs = nn.ModuleList([
75
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76
- padding=get_padding(kernel_size, dilation[0]))),
77
- weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78
- padding=get_padding(kernel_size, dilation[1])))
79
- ])
80
- self.convs.apply(init_weights)
81
-
82
- def forward(self, x):
83
- for c in self.convs:
84
- xt = F.leaky_relu(x, LRELU_SLOPE)
85
- xt = c(xt)
86
- x = xt + x
87
- return x
88
-
89
- def remove_weight_norm(self):
90
- for l in self.convs:
91
- remove_weight_norm(l)
92
-
93
-
94
- class Conv1d1x1(Conv1d):
95
- """1x1 Conv1d with customized initialization."""
96
-
97
- def __init__(self, in_channels, out_channels, bias):
98
- """Initialize 1x1 Conv1d module."""
99
- super(Conv1d1x1, self).__init__(in_channels, out_channels,
100
- kernel_size=1, padding=0,
101
- dilation=1, bias=bias)
102
-
103
-
104
- class HifiGanGenerator(torch.nn.Module):
105
- def __init__(self, h, c_out=1):
106
- super(HifiGanGenerator, self).__init__()
107
- self.h = h
108
- self.num_kernels = len(h['resblock_kernel_sizes'])
109
- self.num_upsamples = len(h['upsample_rates'])
110
-
111
- if h['use_pitch_embed']:
112
- self.harmonic_num = 8
113
- self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114
- self.m_source = SourceModuleHnNSF(
115
- sampling_rate=h['audio_sample_rate'],
116
- harmonic_num=self.harmonic_num)
117
- self.noise_convs = nn.ModuleList()
118
- self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119
- resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120
-
121
- self.ups = nn.ModuleList()
122
- for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123
- c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124
- self.ups.append(weight_norm(
125
- ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126
- if h['use_pitch_embed']:
127
- if i + 1 < len(h['upsample_rates']):
128
- stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129
- self.noise_convs.append(Conv1d(
130
- 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131
- else:
132
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133
-
134
- self.resblocks = nn.ModuleList()
135
- for i in range(len(self.ups)):
136
- ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137
- for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138
- self.resblocks.append(resblock(h, ch, k, d))
139
-
140
- self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141
- self.ups.apply(init_weights)
142
- self.conv_post.apply(init_weights)
143
-
144
- def forward(self, x, f0=None):
145
- if f0 is not None:
146
- # harmonic-source signal, noise-source signal, uv flag
147
- f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148
- har_source, noi_source, uv = self.m_source(f0)
149
- har_source = har_source.transpose(1, 2)
150
-
151
- x = self.conv_pre(x)
152
- for i in range(self.num_upsamples):
153
- x = F.leaky_relu(x, LRELU_SLOPE)
154
- x = self.ups[i](x)
155
- if f0 is not None:
156
- x_source = self.noise_convs[i](har_source)
157
- x = x + x_source
158
- xs = None
159
- for j in range(self.num_kernels):
160
- if xs is None:
161
- xs = self.resblocks[i * self.num_kernels + j](x)
162
- else:
163
- xs += self.resblocks[i * self.num_kernels + j](x)
164
- x = xs / self.num_kernels
165
- x = F.leaky_relu(x)
166
- x = self.conv_post(x)
167
- x = torch.tanh(x)
168
-
169
- return x
170
-
171
- def remove_weight_norm(self):
172
- print('Removing weight norm...')
173
- for l in self.ups:
174
- remove_weight_norm(l)
175
- for l in self.resblocks:
176
- l.remove_weight_norm()
177
- remove_weight_norm(self.conv_pre)
178
- remove_weight_norm(self.conv_post)
179
-
180
-
181
- class DiscriminatorP(torch.nn.Module):
182
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
183
- super(DiscriminatorP, self).__init__()
184
- self.use_cond = use_cond
185
- if use_cond:
186
- from utils.hparams import hparams
187
- t = hparams['hop_size']
188
- self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
189
- c_in = 2
190
-
191
- self.period = period
192
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
193
- self.convs = nn.ModuleList([
194
- norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
195
- norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
196
- norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
197
- norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
198
- norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
199
- ])
200
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
201
-
202
- def forward(self, x, mel):
203
- fmap = []
204
- if self.use_cond:
205
- x_mel = self.cond_net(mel)
206
- x = torch.cat([x_mel, x], 1)
207
- # 1d to 2d
208
- b, c, t = x.shape
209
- if t % self.period != 0: # pad first
210
- n_pad = self.period - (t % self.period)
211
- x = F.pad(x, (0, n_pad), "reflect")
212
- t = t + n_pad
213
- x = x.view(b, c, t // self.period, self.period)
214
-
215
- for l in self.convs:
216
- x = l(x)
217
- x = F.leaky_relu(x, LRELU_SLOPE)
218
- fmap.append(x)
219
- x = self.conv_post(x)
220
- fmap.append(x)
221
- x = torch.flatten(x, 1, -1)
222
-
223
- return x, fmap
224
-
225
-
226
- class MultiPeriodDiscriminator(torch.nn.Module):
227
- def __init__(self, use_cond=False, c_in=1):
228
- super(MultiPeriodDiscriminator, self).__init__()
229
- self.discriminators = nn.ModuleList([
230
- DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
231
- DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
232
- DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
233
- DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
234
- DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
235
- ])
236
-
237
- def forward(self, y, y_hat, mel=None):
238
- y_d_rs = []
239
- y_d_gs = []
240
- fmap_rs = []
241
- fmap_gs = []
242
- for i, d in enumerate(self.discriminators):
243
- y_d_r, fmap_r = d(y, mel)
244
- y_d_g, fmap_g = d(y_hat, mel)
245
- y_d_rs.append(y_d_r)
246
- fmap_rs.append(fmap_r)
247
- y_d_gs.append(y_d_g)
248
- fmap_gs.append(fmap_g)
249
-
250
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
251
-
252
-
253
- class DiscriminatorS(torch.nn.Module):
254
- def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
255
- super(DiscriminatorS, self).__init__()
256
- self.use_cond = use_cond
257
- if use_cond:
258
- t = np.prod(upsample_rates)
259
- self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
260
- c_in = 2
261
- norm_f = weight_norm if use_spectral_norm == False else spectral_norm
262
- self.convs = nn.ModuleList([
263
- norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
264
- norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
265
- norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
266
- norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
267
- norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
268
- norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
269
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
270
- ])
271
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
272
-
273
- def forward(self, x, mel):
274
- if self.use_cond:
275
- x_mel = self.cond_net(mel)
276
- x = torch.cat([x_mel, x], 1)
277
- fmap = []
278
- for l in self.convs:
279
- x = l(x)
280
- x = F.leaky_relu(x, LRELU_SLOPE)
281
- fmap.append(x)
282
- x = self.conv_post(x)
283
- fmap.append(x)
284
- x = torch.flatten(x, 1, -1)
285
-
286
- return x, fmap
287
-
288
-
289
- class MultiScaleDiscriminator(torch.nn.Module):
290
- def __init__(self, use_cond=False, c_in=1):
291
- super(MultiScaleDiscriminator, self).__init__()
292
- from utils.hparams import hparams
293
- self.discriminators = nn.ModuleList([
294
- DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
295
- upsample_rates=[4, 4, hparams['hop_size'] // 16],
296
- c_in=c_in),
297
- DiscriminatorS(use_cond=use_cond,
298
- upsample_rates=[4, 4, hparams['hop_size'] // 32],
299
- c_in=c_in),
300
- DiscriminatorS(use_cond=use_cond,
301
- upsample_rates=[4, 4, hparams['hop_size'] // 64],
302
- c_in=c_in),
303
- ])
304
- self.meanpools = nn.ModuleList([
305
- AvgPool1d(4, 2, padding=1),
306
- AvgPool1d(4, 2, padding=1)
307
- ])
308
-
309
- def forward(self, y, y_hat, mel=None):
310
- y_d_rs = []
311
- y_d_gs = []
312
- fmap_rs = []
313
- fmap_gs = []
314
- for i, d in enumerate(self.discriminators):
315
- if i != 0:
316
- y = self.meanpools[i - 1](y)
317
- y_hat = self.meanpools[i - 1](y_hat)
318
- y_d_r, fmap_r = d(y, mel)
319
- y_d_g, fmap_g = d(y_hat, mel)
320
- y_d_rs.append(y_d_r)
321
- fmap_rs.append(fmap_r)
322
- y_d_gs.append(y_d_g)
323
- fmap_gs.append(fmap_g)
324
-
325
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
326
-
327
-
328
- def feature_loss(fmap_r, fmap_g):
329
- loss = 0
330
- for dr, dg in zip(fmap_r, fmap_g):
331
- for rl, gl in zip(dr, dg):
332
- loss += torch.mean(torch.abs(rl - gl))
333
-
334
- return loss * 2
335
-
336
-
337
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
338
- r_losses = 0
339
- g_losses = 0
340
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
341
- r_loss = torch.mean((1 - dr) ** 2)
342
- g_loss = torch.mean(dg ** 2)
343
- r_losses += r_loss
344
- g_losses += g_loss
345
- r_losses = r_losses / len(disc_real_outputs)
346
- g_losses = g_losses / len(disc_real_outputs)
347
- return r_losses, g_losses
348
-
349
-
350
- def cond_discriminator_loss(outputs):
351
- loss = 0
352
- for dg in outputs:
353
- g_loss = torch.mean(dg ** 2)
354
- loss += g_loss
355
- loss = loss / len(outputs)
356
- return loss
357
-
358
-
359
- def generator_loss(disc_outputs):
360
- loss = 0
361
- for dg in disc_outputs:
362
- l = torch.mean((1 - dg) ** 2)
363
- loss += l
364
- loss = loss / len(disc_outputs)
365
- return loss
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+
7
+ from modules.parallel_wavegan.layers import UpsampleNetwork, ConvInUpsampleNetwork
8
+ from modules.parallel_wavegan.models.source import SourceModuleHnNSF
9
+ import numpy as np
10
+
11
+ LRELU_SLOPE = 0.1
12
+
13
+
14
+ def init_weights(m, mean=0.0, std=0.01):
15
+ classname = m.__class__.__name__
16
+ if classname.find("Conv") != -1:
17
+ m.weight.data.normal_(mean, std)
18
+
19
+
20
+ def apply_weight_norm(m):
21
+ classname = m.__class__.__name__
22
+ if classname.find("Conv") != -1:
23
+ weight_norm(m)
24
+
25
+
26
+ def get_padding(kernel_size, dilation=1):
27
+ return int((kernel_size * dilation - dilation) / 2)
28
+
29
+
30
+ class ResBlock1(torch.nn.Module):
31
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
32
+ super(ResBlock1, self).__init__()
33
+ self.h = h
34
+ self.convs1 = nn.ModuleList([
35
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
36
+ padding=get_padding(kernel_size, dilation[0]))),
37
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
38
+ padding=get_padding(kernel_size, dilation[1]))),
39
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
40
+ padding=get_padding(kernel_size, dilation[2])))
41
+ ])
42
+ self.convs1.apply(init_weights)
43
+
44
+ self.convs2 = nn.ModuleList([
45
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
46
+ padding=get_padding(kernel_size, 1))),
47
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1)))
51
+ ])
52
+ self.convs2.apply(init_weights)
53
+
54
+ def forward(self, x):
55
+ for c1, c2 in zip(self.convs1, self.convs2):
56
+ xt = F.leaky_relu(x, LRELU_SLOPE)
57
+ xt = c1(xt)
58
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
59
+ xt = c2(xt)
60
+ x = xt + x
61
+ return x
62
+
63
+ def remove_weight_norm(self):
64
+ for l in self.convs1:
65
+ remove_weight_norm(l)
66
+ for l in self.convs2:
67
+ remove_weight_norm(l)
68
+
69
+
70
+ class ResBlock2(torch.nn.Module):
71
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
72
+ super(ResBlock2, self).__init__()
73
+ self.h = h
74
+ self.convs = nn.ModuleList([
75
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
76
+ padding=get_padding(kernel_size, dilation[0]))),
77
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
78
+ padding=get_padding(kernel_size, dilation[1])))
79
+ ])
80
+ self.convs.apply(init_weights)
81
+
82
+ def forward(self, x):
83
+ for c in self.convs:
84
+ xt = F.leaky_relu(x, LRELU_SLOPE)
85
+ xt = c(xt)
86
+ x = xt + x
87
+ return x
88
+
89
+ def remove_weight_norm(self):
90
+ for l in self.convs:
91
+ remove_weight_norm(l)
92
+
93
+
94
+ class Conv1d1x1(Conv1d):
95
+ """1x1 Conv1d with customized initialization."""
96
+
97
+ def __init__(self, in_channels, out_channels, bias):
98
+ """Initialize 1x1 Conv1d module."""
99
+ super(Conv1d1x1, self).__init__(in_channels, out_channels,
100
+ kernel_size=1, padding=0,
101
+ dilation=1, bias=bias)
102
+
103
+
104
+ class HifiGanGenerator(torch.nn.Module):
105
+ def __init__(self, h, c_out=1):
106
+ super(HifiGanGenerator, self).__init__()
107
+ self.h = h
108
+ self.num_kernels = len(h['resblock_kernel_sizes'])
109
+ self.num_upsamples = len(h['upsample_rates'])
110
+
111
+ if h['use_pitch_embed']:
112
+ self.harmonic_num = 8
113
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h['upsample_rates']))
114
+ self.m_source = SourceModuleHnNSF(
115
+ sampling_rate=h['audio_sample_rate'],
116
+ harmonic_num=self.harmonic_num)
117
+ self.noise_convs = nn.ModuleList()
118
+ self.conv_pre = weight_norm(Conv1d(80, h['upsample_initial_channel'], 7, 1, padding=3))
119
+ resblock = ResBlock1 if h['resblock'] == '1' else ResBlock2
120
+
121
+ self.ups = nn.ModuleList()
122
+ for i, (u, k) in enumerate(zip(h['upsample_rates'], h['upsample_kernel_sizes'])):
123
+ c_cur = h['upsample_initial_channel'] // (2 ** (i + 1))
124
+ self.ups.append(weight_norm(
125
+ ConvTranspose1d(c_cur * 2, c_cur, k, u, padding=(k - u) // 2)))
126
+ if h['use_pitch_embed']:
127
+ if i + 1 < len(h['upsample_rates']):
128
+ stride_f0 = np.prod(h['upsample_rates'][i + 1:])
129
+ self.noise_convs.append(Conv1d(
130
+ 1, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=stride_f0 // 2))
131
+ else:
132
+ self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
133
+
134
+ self.resblocks = nn.ModuleList()
135
+ for i in range(len(self.ups)):
136
+ ch = h['upsample_initial_channel'] // (2 ** (i + 1))
137
+ for j, (k, d) in enumerate(zip(h['resblock_kernel_sizes'], h['resblock_dilation_sizes'])):
138
+ self.resblocks.append(resblock(h, ch, k, d))
139
+
140
+ self.conv_post = weight_norm(Conv1d(ch, c_out, 7, 1, padding=3))
141
+ self.ups.apply(init_weights)
142
+ self.conv_post.apply(init_weights)
143
+
144
+ def forward(self, x, f0=None):
145
+ if f0 is not None:
146
+ # harmonic-source signal, noise-source signal, uv flag
147
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2)
148
+ har_source, noi_source, uv = self.m_source(f0)
149
+ har_source = har_source.transpose(1, 2)
150
+
151
+ x = self.conv_pre(x)
152
+ for i in range(self.num_upsamples):
153
+ x = F.leaky_relu(x, LRELU_SLOPE)
154
+ x = self.ups[i](x)
155
+ if f0 is not None:
156
+ x_source = self.noise_convs[i](har_source)
157
+ x = x + x_source
158
+ xs = None
159
+ for j in range(self.num_kernels):
160
+ if xs is None:
161
+ xs = self.resblocks[i * self.num_kernels + j](x)
162
+ else:
163
+ xs += self.resblocks[i * self.num_kernels + j](x)
164
+ x = xs / self.num_kernels
165
+ x = F.leaky_relu(x)
166
+ x = self.conv_post(x)
167
+ x = torch.tanh(x)
168
+
169
+ return x
170
+
171
+ def remove_weight_norm(self):
172
+ print('Removing weight norm...')
173
+ for l in self.ups:
174
+ remove_weight_norm(l)
175
+ for l in self.resblocks:
176
+ l.remove_weight_norm()
177
+ remove_weight_norm(self.conv_pre)
178
+ remove_weight_norm(self.conv_post)
179
+
180
+
181
+ class DiscriminatorP(torch.nn.Module):
182
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False, use_cond=False, c_in=1):
183
+ super(DiscriminatorP, self).__init__()
184
+ self.use_cond = use_cond
185
+ if use_cond:
186
+ from utils.hparams import hparams
187
+ t = hparams['hop_size']
188
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
189
+ c_in = 2
190
+
191
+ self.period = period
192
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
193
+ self.convs = nn.ModuleList([
194
+ norm_f(Conv2d(c_in, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
195
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
196
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
197
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
198
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
199
+ ])
200
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
201
+
202
+ def forward(self, x, mel):
203
+ fmap = []
204
+ if self.use_cond:
205
+ x_mel = self.cond_net(mel)
206
+ x = torch.cat([x_mel, x], 1)
207
+ # 1d to 2d
208
+ b, c, t = x.shape
209
+ if t % self.period != 0: # pad first
210
+ n_pad = self.period - (t % self.period)
211
+ x = F.pad(x, (0, n_pad), "reflect")
212
+ t = t + n_pad
213
+ x = x.view(b, c, t // self.period, self.period)
214
+
215
+ for l in self.convs:
216
+ x = l(x)
217
+ x = F.leaky_relu(x, LRELU_SLOPE)
218
+ fmap.append(x)
219
+ x = self.conv_post(x)
220
+ fmap.append(x)
221
+ x = torch.flatten(x, 1, -1)
222
+
223
+ return x, fmap
224
+
225
+
226
+ class MultiPeriodDiscriminator(torch.nn.Module):
227
+ def __init__(self, use_cond=False, c_in=1):
228
+ super(MultiPeriodDiscriminator, self).__init__()
229
+ self.discriminators = nn.ModuleList([
230
+ DiscriminatorP(2, use_cond=use_cond, c_in=c_in),
231
+ DiscriminatorP(3, use_cond=use_cond, c_in=c_in),
232
+ DiscriminatorP(5, use_cond=use_cond, c_in=c_in),
233
+ DiscriminatorP(7, use_cond=use_cond, c_in=c_in),
234
+ DiscriminatorP(11, use_cond=use_cond, c_in=c_in),
235
+ ])
236
+
237
+ def forward(self, y, y_hat, mel=None):
238
+ y_d_rs = []
239
+ y_d_gs = []
240
+ fmap_rs = []
241
+ fmap_gs = []
242
+ for i, d in enumerate(self.discriminators):
243
+ y_d_r, fmap_r = d(y, mel)
244
+ y_d_g, fmap_g = d(y_hat, mel)
245
+ y_d_rs.append(y_d_r)
246
+ fmap_rs.append(fmap_r)
247
+ y_d_gs.append(y_d_g)
248
+ fmap_gs.append(fmap_g)
249
+
250
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
251
+
252
+
253
+ class DiscriminatorS(torch.nn.Module):
254
+ def __init__(self, use_spectral_norm=False, use_cond=False, upsample_rates=None, c_in=1):
255
+ super(DiscriminatorS, self).__init__()
256
+ self.use_cond = use_cond
257
+ if use_cond:
258
+ t = np.prod(upsample_rates)
259
+ self.cond_net = torch.nn.ConvTranspose1d(80, 1, t * 2, stride=t, padding=t // 2)
260
+ c_in = 2
261
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
262
+ self.convs = nn.ModuleList([
263
+ norm_f(Conv1d(c_in, 128, 15, 1, padding=7)),
264
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
265
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
266
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
267
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
268
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
269
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
270
+ ])
271
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
272
+
273
+ def forward(self, x, mel):
274
+ if self.use_cond:
275
+ x_mel = self.cond_net(mel)
276
+ x = torch.cat([x_mel, x], 1)
277
+ fmap = []
278
+ for l in self.convs:
279
+ x = l(x)
280
+ x = F.leaky_relu(x, LRELU_SLOPE)
281
+ fmap.append(x)
282
+ x = self.conv_post(x)
283
+ fmap.append(x)
284
+ x = torch.flatten(x, 1, -1)
285
+
286
+ return x, fmap
287
+
288
+
289
+ class MultiScaleDiscriminator(torch.nn.Module):
290
+ def __init__(self, use_cond=False, c_in=1):
291
+ super(MultiScaleDiscriminator, self).__init__()
292
+ from utils.hparams import hparams
293
+ self.discriminators = nn.ModuleList([
294
+ DiscriminatorS(use_spectral_norm=True, use_cond=use_cond,
295
+ upsample_rates=[4, 4, hparams['hop_size'] // 16],
296
+ c_in=c_in),
297
+ DiscriminatorS(use_cond=use_cond,
298
+ upsample_rates=[4, 4, hparams['hop_size'] // 32],
299
+ c_in=c_in),
300
+ DiscriminatorS(use_cond=use_cond,
301
+ upsample_rates=[4, 4, hparams['hop_size'] // 64],
302
+ c_in=c_in),
303
+ ])
304
+ self.meanpools = nn.ModuleList([
305
+ AvgPool1d(4, 2, padding=1),
306
+ AvgPool1d(4, 2, padding=1)
307
+ ])
308
+
309
+ def forward(self, y, y_hat, mel=None):
310
+ y_d_rs = []
311
+ y_d_gs = []
312
+ fmap_rs = []
313
+ fmap_gs = []
314
+ for i, d in enumerate(self.discriminators):
315
+ if i != 0:
316
+ y = self.meanpools[i - 1](y)
317
+ y_hat = self.meanpools[i - 1](y_hat)
318
+ y_d_r, fmap_r = d(y, mel)
319
+ y_d_g, fmap_g = d(y_hat, mel)
320
+ y_d_rs.append(y_d_r)
321
+ fmap_rs.append(fmap_r)
322
+ y_d_gs.append(y_d_g)
323
+ fmap_gs.append(fmap_g)
324
+
325
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
326
+
327
+
328
+ def feature_loss(fmap_r, fmap_g):
329
+ loss = 0
330
+ for dr, dg in zip(fmap_r, fmap_g):
331
+ for rl, gl in zip(dr, dg):
332
+ loss += torch.mean(torch.abs(rl - gl))
333
+
334
+ return loss * 2
335
+
336
+
337
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
338
+ r_losses = 0
339
+ g_losses = 0
340
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
341
+ r_loss = torch.mean((1 - dr) ** 2)
342
+ g_loss = torch.mean(dg ** 2)
343
+ r_losses += r_loss
344
+ g_losses += g_loss
345
+ r_losses = r_losses / len(disc_real_outputs)
346
+ g_losses = g_losses / len(disc_real_outputs)
347
+ return r_losses, g_losses
348
+
349
+
350
+ def cond_discriminator_loss(outputs):
351
+ loss = 0
352
+ for dg in outputs:
353
+ g_loss = torch.mean(dg ** 2)
354
+ loss += g_loss
355
+ loss = loss / len(outputs)
356
+ return loss
357
+
358
+
359
+ def generator_loss(disc_outputs):
360
+ loss = 0
361
+ for dg in disc_outputs:
362
+ l = torch.mean((1 - dg) ** 2)
363
+ loss += l
364
+ loss = loss / len(disc_outputs)
365
+ return loss
modules/hifigan/mel_utils.py CHANGED
@@ -1,80 +1,80 @@
1
- import numpy as np
2
- import torch
3
- import torch.utils.data
4
- from librosa.filters import mel as librosa_mel_fn
5
- from scipy.io.wavfile import read
6
-
7
- MAX_WAV_VALUE = 32768.0
8
-
9
-
10
- def load_wav(full_path):
11
- sampling_rate, data = read(full_path)
12
- return data, sampling_rate
13
-
14
-
15
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
-
18
-
19
- def dynamic_range_decompression(x, C=1):
20
- return np.exp(x) / C
21
-
22
-
23
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
- return torch.log(torch.clamp(x, min=clip_val) * C)
25
-
26
-
27
- def dynamic_range_decompression_torch(x, C=1):
28
- return torch.exp(x) / C
29
-
30
-
31
- def spectral_normalize_torch(magnitudes):
32
- output = dynamic_range_compression_torch(magnitudes)
33
- return output
34
-
35
-
36
- def spectral_de_normalize_torch(magnitudes):
37
- output = dynamic_range_decompression_torch(magnitudes)
38
- return output
39
-
40
-
41
- mel_basis = {}
42
- hann_window = {}
43
-
44
-
45
- def mel_spectrogram(y, hparams, center=False, complex=False):
46
- # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
47
- # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
48
- # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
49
- # fmax: 10000 # To be increased/reduced depending on data.
50
- # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
51
- # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
52
- n_fft = hparams['fft_size']
53
- num_mels = hparams['audio_num_mel_bins']
54
- sampling_rate = hparams['audio_sample_rate']
55
- hop_size = hparams['hop_size']
56
- win_size = hparams['win_size']
57
- fmin = hparams['fmin']
58
- fmax = hparams['fmax']
59
- y = y.clamp(min=-1., max=1.)
60
- global mel_basis, hann_window
61
- if fmax not in mel_basis:
62
- mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
63
- mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
64
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
65
-
66
- y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
67
- mode='reflect')
68
- y = y.squeeze(1)
69
-
70
- spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
71
- center=center, pad_mode='reflect', normalized=False, onesided=True)
72
-
73
- if not complex:
74
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
75
- spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
76
- spec = spectral_normalize_torch(spec)
77
- else:
78
- B, C, T, _ = spec.shape
79
- spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
80
- return spec
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, hparams, center=False, complex=False):
46
+ # hop_size: 512 # For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate)
47
+ # win_size: 2048 # For 22050Hz, 1100 ~= 50 ms (If None, win_size: fft_size) (0.05 * sample_rate)
48
+ # fmin: 55 # Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
49
+ # fmax: 10000 # To be increased/reduced depending on data.
50
+ # fft_size: 2048 # Extra window size is filled with 0 paddings to match this parameter
51
+ # n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax,
52
+ n_fft = hparams['fft_size']
53
+ num_mels = hparams['audio_num_mel_bins']
54
+ sampling_rate = hparams['audio_sample_rate']
55
+ hop_size = hparams['hop_size']
56
+ win_size = hparams['win_size']
57
+ fmin = hparams['fmin']
58
+ fmax = hparams['fmax']
59
+ y = y.clamp(min=-1., max=1.)
60
+ global mel_basis, hann_window
61
+ if fmax not in mel_basis:
62
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
63
+ mel_basis[str(fmax) + '_' + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
64
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
65
+
66
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
67
+ mode='reflect')
68
+ y = y.squeeze(1)
69
+
70
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
71
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
72
+
73
+ if not complex:
74
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
75
+ spec = torch.matmul(mel_basis[str(fmax) + '_' + str(y.device)], spec)
76
+ spec = spectral_normalize_torch(spec)
77
+ else:
78
+ B, C, T, _ = spec.shape
79
+ spec = spec.transpose(1, 2) # [B, T, n_fft, 2]
80
+ return spec
modules/parallel_wavegan/models/parallel_wavegan.py CHANGED
@@ -1,434 +1,434 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # Copyright 2019 Tomoki Hayashi
4
- # MIT License (https://opensource.org/licenses/MIT)
5
-
6
- """Parallel WaveGAN Modules."""
7
-
8
- import logging
9
- import math
10
-
11
- import torch
12
- from torch import nn
13
-
14
- from modules.parallel_wavegan.layers import Conv1d
15
- from modules.parallel_wavegan.layers import Conv1d1x1
16
- from modules.parallel_wavegan.layers import ResidualBlock
17
- from modules.parallel_wavegan.layers import upsample
18
- from modules.parallel_wavegan import models
19
-
20
-
21
- class ParallelWaveGANGenerator(torch.nn.Module):
22
- """Parallel WaveGAN Generator module."""
23
-
24
- def __init__(self,
25
- in_channels=1,
26
- out_channels=1,
27
- kernel_size=3,
28
- layers=30,
29
- stacks=3,
30
- residual_channels=64,
31
- gate_channels=128,
32
- skip_channels=64,
33
- aux_channels=80,
34
- aux_context_window=2,
35
- dropout=0.0,
36
- bias=True,
37
- use_weight_norm=True,
38
- use_causal_conv=False,
39
- upsample_conditional_features=True,
40
- upsample_net="ConvInUpsampleNetwork",
41
- upsample_params={"upsample_scales": [4, 4, 4, 4]},
42
- use_pitch_embed=False,
43
- ):
44
- """Initialize Parallel WaveGAN Generator module.
45
-
46
- Args:
47
- in_channels (int): Number of input channels.
48
- out_channels (int): Number of output channels.
49
- kernel_size (int): Kernel size of dilated convolution.
50
- layers (int): Number of residual block layers.
51
- stacks (int): Number of stacks i.e., dilation cycles.
52
- residual_channels (int): Number of channels in residual conv.
53
- gate_channels (int): Number of channels in gated conv.
54
- skip_channels (int): Number of channels in skip conv.
55
- aux_channels (int): Number of channels for auxiliary feature conv.
56
- aux_context_window (int): Context window size for auxiliary feature.
57
- dropout (float): Dropout rate. 0.0 means no dropout applied.
58
- bias (bool): Whether to use bias parameter in conv layer.
59
- use_weight_norm (bool): Whether to use weight norm.
60
- If set to true, it will be applied to all of the conv layers.
61
- use_causal_conv (bool): Whether to use causal structure.
62
- upsample_conditional_features (bool): Whether to use upsampling network.
63
- upsample_net (str): Upsampling network architecture.
64
- upsample_params (dict): Upsampling network parameters.
65
-
66
- """
67
- super(ParallelWaveGANGenerator, self).__init__()
68
- self.in_channels = in_channels
69
- self.out_channels = out_channels
70
- self.aux_channels = aux_channels
71
- self.layers = layers
72
- self.stacks = stacks
73
- self.kernel_size = kernel_size
74
-
75
- # check the number of layers and stacks
76
- assert layers % stacks == 0
77
- layers_per_stack = layers // stacks
78
-
79
- # define first convolution
80
- self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
81
-
82
- # define conv + upsampling network
83
- if upsample_conditional_features:
84
- upsample_params.update({
85
- "use_causal_conv": use_causal_conv,
86
- })
87
- if upsample_net == "MelGANGenerator":
88
- assert aux_context_window == 0
89
- upsample_params.update({
90
- "use_weight_norm": False, # not to apply twice
91
- "use_final_nonlinear_activation": False,
92
- })
93
- self.upsample_net = getattr(models, upsample_net)(**upsample_params)
94
- else:
95
- if upsample_net == "ConvInUpsampleNetwork":
96
- upsample_params.update({
97
- "aux_channels": aux_channels,
98
- "aux_context_window": aux_context_window,
99
- })
100
- self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
101
- else:
102
- self.upsample_net = None
103
-
104
- # define residual blocks
105
- self.conv_layers = torch.nn.ModuleList()
106
- for layer in range(layers):
107
- dilation = 2 ** (layer % layers_per_stack)
108
- conv = ResidualBlock(
109
- kernel_size=kernel_size,
110
- residual_channels=residual_channels,
111
- gate_channels=gate_channels,
112
- skip_channels=skip_channels,
113
- aux_channels=aux_channels,
114
- dilation=dilation,
115
- dropout=dropout,
116
- bias=bias,
117
- use_causal_conv=use_causal_conv,
118
- )
119
- self.conv_layers += [conv]
120
-
121
- # define output layers
122
- self.last_conv_layers = torch.nn.ModuleList([
123
- torch.nn.ReLU(inplace=True),
124
- Conv1d1x1(skip_channels, skip_channels, bias=True),
125
- torch.nn.ReLU(inplace=True),
126
- Conv1d1x1(skip_channels, out_channels, bias=True),
127
- ])
128
-
129
- self.use_pitch_embed = use_pitch_embed
130
- if use_pitch_embed:
131
- self.pitch_embed = nn.Embedding(300, aux_channels, 0)
132
- self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
133
-
134
- # apply weight norm
135
- if use_weight_norm:
136
- self.apply_weight_norm()
137
-
138
- def forward(self, x, c=None, pitch=None, **kwargs):
139
- """Calculate forward propagation.
140
-
141
- Args:
142
- x (Tensor): Input noise signal (B, C_in, T).
143
- c (Tensor): Local conditioning auxiliary features (B, C ,T').
144
- pitch (Tensor): Local conditioning pitch (B, T').
145
-
146
- Returns:
147
- Tensor: Output tensor (B, C_out, T)
148
-
149
- """
150
- # perform upsampling
151
- if c is not None and self.upsample_net is not None:
152
- if self.use_pitch_embed:
153
- p = self.pitch_embed(pitch)
154
- c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
155
- c = self.upsample_net(c)
156
- assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
157
-
158
- # encode to hidden representation
159
- x = self.first_conv(x)
160
- skips = 0
161
- for f in self.conv_layers:
162
- x, h = f(x, c)
163
- skips += h
164
- skips *= math.sqrt(1.0 / len(self.conv_layers))
165
-
166
- # apply final layers
167
- x = skips
168
- for f in self.last_conv_layers:
169
- x = f(x)
170
-
171
- return x
172
-
173
- def remove_weight_norm(self):
174
- """Remove weight normalization module from all of the layers."""
175
- def _remove_weight_norm(m):
176
- try:
177
- logging.debug(f"Weight norm is removed from {m}.")
178
- torch.nn.utils.remove_weight_norm(m)
179
- except ValueError: # this module didn't have weight norm
180
- return
181
-
182
- self.apply(_remove_weight_norm)
183
-
184
- def apply_weight_norm(self):
185
- """Apply weight normalization module from all of the layers."""
186
- def _apply_weight_norm(m):
187
- if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
188
- torch.nn.utils.weight_norm(m)
189
- logging.debug(f"Weight norm is applied to {m}.")
190
-
191
- self.apply(_apply_weight_norm)
192
-
193
- @staticmethod
194
- def _get_receptive_field_size(layers, stacks, kernel_size,
195
- dilation=lambda x: 2 ** x):
196
- assert layers % stacks == 0
197
- layers_per_cycle = layers // stacks
198
- dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
199
- return (kernel_size - 1) * sum(dilations) + 1
200
-
201
- @property
202
- def receptive_field_size(self):
203
- """Return receptive field size."""
204
- return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
205
-
206
-
207
- class ParallelWaveGANDiscriminator(torch.nn.Module):
208
- """Parallel WaveGAN Discriminator module."""
209
-
210
- def __init__(self,
211
- in_channels=1,
212
- out_channels=1,
213
- kernel_size=3,
214
- layers=10,
215
- conv_channels=64,
216
- dilation_factor=1,
217
- nonlinear_activation="LeakyReLU",
218
- nonlinear_activation_params={"negative_slope": 0.2},
219
- bias=True,
220
- use_weight_norm=True,
221
- ):
222
- """Initialize Parallel WaveGAN Discriminator module.
223
-
224
- Args:
225
- in_channels (int): Number of input channels.
226
- out_channels (int): Number of output channels.
227
- kernel_size (int): Number of output channels.
228
- layers (int): Number of conv layers.
229
- conv_channels (int): Number of chnn layers.
230
- dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
231
- the dilation will be 2, 4, 8, ..., and so on.
232
- nonlinear_activation (str): Nonlinear function after each conv.
233
- nonlinear_activation_params (dict): Nonlinear function parameters
234
- bias (bool): Whether to use bias parameter in conv.
235
- use_weight_norm (bool) Whether to use weight norm.
236
- If set to true, it will be applied to all of the conv layers.
237
-
238
- """
239
- super(ParallelWaveGANDiscriminator, self).__init__()
240
- assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
241
- assert dilation_factor > 0, "Dilation factor must be > 0."
242
- self.conv_layers = torch.nn.ModuleList()
243
- conv_in_channels = in_channels
244
- for i in range(layers - 1):
245
- if i == 0:
246
- dilation = 1
247
- else:
248
- dilation = i if dilation_factor == 1 else dilation_factor ** i
249
- conv_in_channels = conv_channels
250
- padding = (kernel_size - 1) // 2 * dilation
251
- conv_layer = [
252
- Conv1d(conv_in_channels, conv_channels,
253
- kernel_size=kernel_size, padding=padding,
254
- dilation=dilation, bias=bias),
255
- getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
256
- ]
257
- self.conv_layers += conv_layer
258
- padding = (kernel_size - 1) // 2
259
- last_conv_layer = Conv1d(
260
- conv_in_channels, out_channels,
261
- kernel_size=kernel_size, padding=padding, bias=bias)
262
- self.conv_layers += [last_conv_layer]
263
-
264
- # apply weight norm
265
- if use_weight_norm:
266
- self.apply_weight_norm()
267
-
268
- def forward(self, x):
269
- """Calculate forward propagation.
270
-
271
- Args:
272
- x (Tensor): Input noise signal (B, 1, T).
273
-
274
- Returns:
275
- Tensor: Output tensor (B, 1, T)
276
-
277
- """
278
- for f in self.conv_layers:
279
- x = f(x)
280
- return x
281
-
282
- def apply_weight_norm(self):
283
- """Apply weight normalization module from all of the layers."""
284
- def _apply_weight_norm(m):
285
- if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
286
- torch.nn.utils.weight_norm(m)
287
- logging.debug(f"Weight norm is applied to {m}.")
288
-
289
- self.apply(_apply_weight_norm)
290
-
291
- def remove_weight_norm(self):
292
- """Remove weight normalization module from all of the layers."""
293
- def _remove_weight_norm(m):
294
- try:
295
- logging.debug(f"Weight norm is removed from {m}.")
296
- torch.nn.utils.remove_weight_norm(m)
297
- except ValueError: # this module didn't have weight norm
298
- return
299
-
300
- self.apply(_remove_weight_norm)
301
-
302
-
303
- class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
304
- """Parallel WaveGAN Discriminator module."""
305
-
306
- def __init__(self,
307
- in_channels=1,
308
- out_channels=1,
309
- kernel_size=3,
310
- layers=30,
311
- stacks=3,
312
- residual_channels=64,
313
- gate_channels=128,
314
- skip_channels=64,
315
- dropout=0.0,
316
- bias=True,
317
- use_weight_norm=True,
318
- use_causal_conv=False,
319
- nonlinear_activation="LeakyReLU",
320
- nonlinear_activation_params={"negative_slope": 0.2},
321
- ):
322
- """Initialize Parallel WaveGAN Discriminator module.
323
-
324
- Args:
325
- in_channels (int): Number of input channels.
326
- out_channels (int): Number of output channels.
327
- kernel_size (int): Kernel size of dilated convolution.
328
- layers (int): Number of residual block layers.
329
- stacks (int): Number of stacks i.e., dilation cycles.
330
- residual_channels (int): Number of channels in residual conv.
331
- gate_channels (int): Number of channels in gated conv.
332
- skip_channels (int): Number of channels in skip conv.
333
- dropout (float): Dropout rate. 0.0 means no dropout applied.
334
- bias (bool): Whether to use bias parameter in conv.
335
- use_weight_norm (bool): Whether to use weight norm.
336
- If set to true, it will be applied to all of the conv layers.
337
- use_causal_conv (bool): Whether to use causal structure.
338
- nonlinear_activation_params (dict): Nonlinear function parameters
339
-
340
- """
341
- super(ResidualParallelWaveGANDiscriminator, self).__init__()
342
- assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
343
-
344
- self.in_channels = in_channels
345
- self.out_channels = out_channels
346
- self.layers = layers
347
- self.stacks = stacks
348
- self.kernel_size = kernel_size
349
-
350
- # check the number of layers and stacks
351
- assert layers % stacks == 0
352
- layers_per_stack = layers // stacks
353
-
354
- # define first convolution
355
- self.first_conv = torch.nn.Sequential(
356
- Conv1d1x1(in_channels, residual_channels, bias=True),
357
- getattr(torch.nn, nonlinear_activation)(
358
- inplace=True, **nonlinear_activation_params),
359
- )
360
-
361
- # define residual blocks
362
- self.conv_layers = torch.nn.ModuleList()
363
- for layer in range(layers):
364
- dilation = 2 ** (layer % layers_per_stack)
365
- conv = ResidualBlock(
366
- kernel_size=kernel_size,
367
- residual_channels=residual_channels,
368
- gate_channels=gate_channels,
369
- skip_channels=skip_channels,
370
- aux_channels=-1,
371
- dilation=dilation,
372
- dropout=dropout,
373
- bias=bias,
374
- use_causal_conv=use_causal_conv,
375
- )
376
- self.conv_layers += [conv]
377
-
378
- # define output layers
379
- self.last_conv_layers = torch.nn.ModuleList([
380
- getattr(torch.nn, nonlinear_activation)(
381
- inplace=True, **nonlinear_activation_params),
382
- Conv1d1x1(skip_channels, skip_channels, bias=True),
383
- getattr(torch.nn, nonlinear_activation)(
384
- inplace=True, **nonlinear_activation_params),
385
- Conv1d1x1(skip_channels, out_channels, bias=True),
386
- ])
387
-
388
- # apply weight norm
389
- if use_weight_norm:
390
- self.apply_weight_norm()
391
-
392
- def forward(self, x):
393
- """Calculate forward propagation.
394
-
395
- Args:
396
- x (Tensor): Input noise signal (B, 1, T).
397
-
398
- Returns:
399
- Tensor: Output tensor (B, 1, T)
400
-
401
- """
402
- x = self.first_conv(x)
403
-
404
- skips = 0
405
- for f in self.conv_layers:
406
- x, h = f(x, None)
407
- skips += h
408
- skips *= math.sqrt(1.0 / len(self.conv_layers))
409
-
410
- # apply final layers
411
- x = skips
412
- for f in self.last_conv_layers:
413
- x = f(x)
414
- return x
415
-
416
- def apply_weight_norm(self):
417
- """Apply weight normalization module from all of the layers."""
418
- def _apply_weight_norm(m):
419
- if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
420
- torch.nn.utils.weight_norm(m)
421
- logging.debug(f"Weight norm is applied to {m}.")
422
-
423
- self.apply(_apply_weight_norm)
424
-
425
- def remove_weight_norm(self):
426
- """Remove weight normalization module from all of the layers."""
427
- def _remove_weight_norm(m):
428
- try:
429
- logging.debug(f"Weight norm is removed from {m}.")
430
- torch.nn.utils.remove_weight_norm(m)
431
- except ValueError: # this module didn't have weight norm
432
- return
433
-
434
- self.apply(_remove_weight_norm)
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Copyright 2019 Tomoki Hayashi
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ """Parallel WaveGAN Modules."""
7
+
8
+ import logging
9
+ import math
10
+
11
+ import torch
12
+ from torch import nn
13
+
14
+ from modules.parallel_wavegan.layers import Conv1d
15
+ from modules.parallel_wavegan.layers import Conv1d1x1
16
+ from modules.parallel_wavegan.layers import ResidualBlock
17
+ from modules.parallel_wavegan.layers import upsample
18
+ from modules.parallel_wavegan import models
19
+
20
+
21
+ class ParallelWaveGANGenerator(torch.nn.Module):
22
+ """Parallel WaveGAN Generator module."""
23
+
24
+ def __init__(self,
25
+ in_channels=1,
26
+ out_channels=1,
27
+ kernel_size=3,
28
+ layers=30,
29
+ stacks=3,
30
+ residual_channels=64,
31
+ gate_channels=128,
32
+ skip_channels=64,
33
+ aux_channels=80,
34
+ aux_context_window=2,
35
+ dropout=0.0,
36
+ bias=True,
37
+ use_weight_norm=True,
38
+ use_causal_conv=False,
39
+ upsample_conditional_features=True,
40
+ upsample_net="ConvInUpsampleNetwork",
41
+ upsample_params={"upsample_scales": [4, 4, 4, 4]},
42
+ use_pitch_embed=False,
43
+ ):
44
+ """Initialize Parallel WaveGAN Generator module.
45
+
46
+ Args:
47
+ in_channels (int): Number of input channels.
48
+ out_channels (int): Number of output channels.
49
+ kernel_size (int): Kernel size of dilated convolution.
50
+ layers (int): Number of residual block layers.
51
+ stacks (int): Number of stacks i.e., dilation cycles.
52
+ residual_channels (int): Number of channels in residual conv.
53
+ gate_channels (int): Number of channels in gated conv.
54
+ skip_channels (int): Number of channels in skip conv.
55
+ aux_channels (int): Number of channels for auxiliary feature conv.
56
+ aux_context_window (int): Context window size for auxiliary feature.
57
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
58
+ bias (bool): Whether to use bias parameter in conv layer.
59
+ use_weight_norm (bool): Whether to use weight norm.
60
+ If set to true, it will be applied to all of the conv layers.
61
+ use_causal_conv (bool): Whether to use causal structure.
62
+ upsample_conditional_features (bool): Whether to use upsampling network.
63
+ upsample_net (str): Upsampling network architecture.
64
+ upsample_params (dict): Upsampling network parameters.
65
+
66
+ """
67
+ super(ParallelWaveGANGenerator, self).__init__()
68
+ self.in_channels = in_channels
69
+ self.out_channels = out_channels
70
+ self.aux_channels = aux_channels
71
+ self.layers = layers
72
+ self.stacks = stacks
73
+ self.kernel_size = kernel_size
74
+
75
+ # check the number of layers and stacks
76
+ assert layers % stacks == 0
77
+ layers_per_stack = layers // stacks
78
+
79
+ # define first convolution
80
+ self.first_conv = Conv1d1x1(in_channels, residual_channels, bias=True)
81
+
82
+ # define conv + upsampling network
83
+ if upsample_conditional_features:
84
+ upsample_params.update({
85
+ "use_causal_conv": use_causal_conv,
86
+ })
87
+ if upsample_net == "MelGANGenerator":
88
+ assert aux_context_window == 0
89
+ upsample_params.update({
90
+ "use_weight_norm": False, # not to apply twice
91
+ "use_final_nonlinear_activation": False,
92
+ })
93
+ self.upsample_net = getattr(models, upsample_net)(**upsample_params)
94
+ else:
95
+ if upsample_net == "ConvInUpsampleNetwork":
96
+ upsample_params.update({
97
+ "aux_channels": aux_channels,
98
+ "aux_context_window": aux_context_window,
99
+ })
100
+ self.upsample_net = getattr(upsample, upsample_net)(**upsample_params)
101
+ else:
102
+ self.upsample_net = None
103
+
104
+ # define residual blocks
105
+ self.conv_layers = torch.nn.ModuleList()
106
+ for layer in range(layers):
107
+ dilation = 2 ** (layer % layers_per_stack)
108
+ conv = ResidualBlock(
109
+ kernel_size=kernel_size,
110
+ residual_channels=residual_channels,
111
+ gate_channels=gate_channels,
112
+ skip_channels=skip_channels,
113
+ aux_channels=aux_channels,
114
+ dilation=dilation,
115
+ dropout=dropout,
116
+ bias=bias,
117
+ use_causal_conv=use_causal_conv,
118
+ )
119
+ self.conv_layers += [conv]
120
+
121
+ # define output layers
122
+ self.last_conv_layers = torch.nn.ModuleList([
123
+ torch.nn.ReLU(inplace=True),
124
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
125
+ torch.nn.ReLU(inplace=True),
126
+ Conv1d1x1(skip_channels, out_channels, bias=True),
127
+ ])
128
+
129
+ self.use_pitch_embed = use_pitch_embed
130
+ if use_pitch_embed:
131
+ self.pitch_embed = nn.Embedding(300, aux_channels, 0)
132
+ self.c_proj = nn.Linear(2 * aux_channels, aux_channels)
133
+
134
+ # apply weight norm
135
+ if use_weight_norm:
136
+ self.apply_weight_norm()
137
+
138
+ def forward(self, x, c=None, pitch=None, **kwargs):
139
+ """Calculate forward propagation.
140
+
141
+ Args:
142
+ x (Tensor): Input noise signal (B, C_in, T).
143
+ c (Tensor): Local conditioning auxiliary features (B, C ,T').
144
+ pitch (Tensor): Local conditioning pitch (B, T').
145
+
146
+ Returns:
147
+ Tensor: Output tensor (B, C_out, T)
148
+
149
+ """
150
+ # perform upsampling
151
+ if c is not None and self.upsample_net is not None:
152
+ if self.use_pitch_embed:
153
+ p = self.pitch_embed(pitch)
154
+ c = self.c_proj(torch.cat([c.transpose(1, 2), p], -1)).transpose(1, 2)
155
+ c = self.upsample_net(c)
156
+ assert c.size(-1) == x.size(-1), (c.size(-1), x.size(-1))
157
+
158
+ # encode to hidden representation
159
+ x = self.first_conv(x)
160
+ skips = 0
161
+ for f in self.conv_layers:
162
+ x, h = f(x, c)
163
+ skips += h
164
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
165
+
166
+ # apply final layers
167
+ x = skips
168
+ for f in self.last_conv_layers:
169
+ x = f(x)
170
+
171
+ return x
172
+
173
+ def remove_weight_norm(self):
174
+ """Remove weight normalization module from all of the layers."""
175
+ def _remove_weight_norm(m):
176
+ try:
177
+ logging.debug(f"Weight norm is removed from {m}.")
178
+ torch.nn.utils.remove_weight_norm(m)
179
+ except ValueError: # this module didn't have weight norm
180
+ return
181
+
182
+ self.apply(_remove_weight_norm)
183
+
184
+ def apply_weight_norm(self):
185
+ """Apply weight normalization module from all of the layers."""
186
+ def _apply_weight_norm(m):
187
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
188
+ torch.nn.utils.weight_norm(m)
189
+ logging.debug(f"Weight norm is applied to {m}.")
190
+
191
+ self.apply(_apply_weight_norm)
192
+
193
+ @staticmethod
194
+ def _get_receptive_field_size(layers, stacks, kernel_size,
195
+ dilation=lambda x: 2 ** x):
196
+ assert layers % stacks == 0
197
+ layers_per_cycle = layers // stacks
198
+ dilations = [dilation(i % layers_per_cycle) for i in range(layers)]
199
+ return (kernel_size - 1) * sum(dilations) + 1
200
+
201
+ @property
202
+ def receptive_field_size(self):
203
+ """Return receptive field size."""
204
+ return self._get_receptive_field_size(self.layers, self.stacks, self.kernel_size)
205
+
206
+
207
+ class ParallelWaveGANDiscriminator(torch.nn.Module):
208
+ """Parallel WaveGAN Discriminator module."""
209
+
210
+ def __init__(self,
211
+ in_channels=1,
212
+ out_channels=1,
213
+ kernel_size=3,
214
+ layers=10,
215
+ conv_channels=64,
216
+ dilation_factor=1,
217
+ nonlinear_activation="LeakyReLU",
218
+ nonlinear_activation_params={"negative_slope": 0.2},
219
+ bias=True,
220
+ use_weight_norm=True,
221
+ ):
222
+ """Initialize Parallel WaveGAN Discriminator module.
223
+
224
+ Args:
225
+ in_channels (int): Number of input channels.
226
+ out_channels (int): Number of output channels.
227
+ kernel_size (int): Number of output channels.
228
+ layers (int): Number of conv layers.
229
+ conv_channels (int): Number of chnn layers.
230
+ dilation_factor (int): Dilation factor. For example, if dilation_factor = 2,
231
+ the dilation will be 2, 4, 8, ..., and so on.
232
+ nonlinear_activation (str): Nonlinear function after each conv.
233
+ nonlinear_activation_params (dict): Nonlinear function parameters
234
+ bias (bool): Whether to use bias parameter in conv.
235
+ use_weight_norm (bool) Whether to use weight norm.
236
+ If set to true, it will be applied to all of the conv layers.
237
+
238
+ """
239
+ super(ParallelWaveGANDiscriminator, self).__init__()
240
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
241
+ assert dilation_factor > 0, "Dilation factor must be > 0."
242
+ self.conv_layers = torch.nn.ModuleList()
243
+ conv_in_channels = in_channels
244
+ for i in range(layers - 1):
245
+ if i == 0:
246
+ dilation = 1
247
+ else:
248
+ dilation = i if dilation_factor == 1 else dilation_factor ** i
249
+ conv_in_channels = conv_channels
250
+ padding = (kernel_size - 1) // 2 * dilation
251
+ conv_layer = [
252
+ Conv1d(conv_in_channels, conv_channels,
253
+ kernel_size=kernel_size, padding=padding,
254
+ dilation=dilation, bias=bias),
255
+ getattr(torch.nn, nonlinear_activation)(inplace=True, **nonlinear_activation_params)
256
+ ]
257
+ self.conv_layers += conv_layer
258
+ padding = (kernel_size - 1) // 2
259
+ last_conv_layer = Conv1d(
260
+ conv_in_channels, out_channels,
261
+ kernel_size=kernel_size, padding=padding, bias=bias)
262
+ self.conv_layers += [last_conv_layer]
263
+
264
+ # apply weight norm
265
+ if use_weight_norm:
266
+ self.apply_weight_norm()
267
+
268
+ def forward(self, x):
269
+ """Calculate forward propagation.
270
+
271
+ Args:
272
+ x (Tensor): Input noise signal (B, 1, T).
273
+
274
+ Returns:
275
+ Tensor: Output tensor (B, 1, T)
276
+
277
+ """
278
+ for f in self.conv_layers:
279
+ x = f(x)
280
+ return x
281
+
282
+ def apply_weight_norm(self):
283
+ """Apply weight normalization module from all of the layers."""
284
+ def _apply_weight_norm(m):
285
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
286
+ torch.nn.utils.weight_norm(m)
287
+ logging.debug(f"Weight norm is applied to {m}.")
288
+
289
+ self.apply(_apply_weight_norm)
290
+
291
+ def remove_weight_norm(self):
292
+ """Remove weight normalization module from all of the layers."""
293
+ def _remove_weight_norm(m):
294
+ try:
295
+ logging.debug(f"Weight norm is removed from {m}.")
296
+ torch.nn.utils.remove_weight_norm(m)
297
+ except ValueError: # this module didn't have weight norm
298
+ return
299
+
300
+ self.apply(_remove_weight_norm)
301
+
302
+
303
+ class ResidualParallelWaveGANDiscriminator(torch.nn.Module):
304
+ """Parallel WaveGAN Discriminator module."""
305
+
306
+ def __init__(self,
307
+ in_channels=1,
308
+ out_channels=1,
309
+ kernel_size=3,
310
+ layers=30,
311
+ stacks=3,
312
+ residual_channels=64,
313
+ gate_channels=128,
314
+ skip_channels=64,
315
+ dropout=0.0,
316
+ bias=True,
317
+ use_weight_norm=True,
318
+ use_causal_conv=False,
319
+ nonlinear_activation="LeakyReLU",
320
+ nonlinear_activation_params={"negative_slope": 0.2},
321
+ ):
322
+ """Initialize Parallel WaveGAN Discriminator module.
323
+
324
+ Args:
325
+ in_channels (int): Number of input channels.
326
+ out_channels (int): Number of output channels.
327
+ kernel_size (int): Kernel size of dilated convolution.
328
+ layers (int): Number of residual block layers.
329
+ stacks (int): Number of stacks i.e., dilation cycles.
330
+ residual_channels (int): Number of channels in residual conv.
331
+ gate_channels (int): Number of channels in gated conv.
332
+ skip_channels (int): Number of channels in skip conv.
333
+ dropout (float): Dropout rate. 0.0 means no dropout applied.
334
+ bias (bool): Whether to use bias parameter in conv.
335
+ use_weight_norm (bool): Whether to use weight norm.
336
+ If set to true, it will be applied to all of the conv layers.
337
+ use_causal_conv (bool): Whether to use causal structure.
338
+ nonlinear_activation_params (dict): Nonlinear function parameters
339
+
340
+ """
341
+ super(ResidualParallelWaveGANDiscriminator, self).__init__()
342
+ assert (kernel_size - 1) % 2 == 0, "Not support even number kernel size."
343
+
344
+ self.in_channels = in_channels
345
+ self.out_channels = out_channels
346
+ self.layers = layers
347
+ self.stacks = stacks
348
+ self.kernel_size = kernel_size
349
+
350
+ # check the number of layers and stacks
351
+ assert layers % stacks == 0
352
+ layers_per_stack = layers // stacks
353
+
354
+ # define first convolution
355
+ self.first_conv = torch.nn.Sequential(
356
+ Conv1d1x1(in_channels, residual_channels, bias=True),
357
+ getattr(torch.nn, nonlinear_activation)(
358
+ inplace=True, **nonlinear_activation_params),
359
+ )
360
+
361
+ # define residual blocks
362
+ self.conv_layers = torch.nn.ModuleList()
363
+ for layer in range(layers):
364
+ dilation = 2 ** (layer % layers_per_stack)
365
+ conv = ResidualBlock(
366
+ kernel_size=kernel_size,
367
+ residual_channels=residual_channels,
368
+ gate_channels=gate_channels,
369
+ skip_channels=skip_channels,
370
+ aux_channels=-1,
371
+ dilation=dilation,
372
+ dropout=dropout,
373
+ bias=bias,
374
+ use_causal_conv=use_causal_conv,
375
+ )
376
+ self.conv_layers += [conv]
377
+
378
+ # define output layers
379
+ self.last_conv_layers = torch.nn.ModuleList([
380
+ getattr(torch.nn, nonlinear_activation)(
381
+ inplace=True, **nonlinear_activation_params),
382
+ Conv1d1x1(skip_channels, skip_channels, bias=True),
383
+ getattr(torch.nn, nonlinear_activation)(
384
+ inplace=True, **nonlinear_activation_params),
385
+ Conv1d1x1(skip_channels, out_channels, bias=True),
386
+ ])
387
+
388
+ # apply weight norm
389
+ if use_weight_norm:
390
+ self.apply_weight_norm()
391
+
392
+ def forward(self, x):
393
+ """Calculate forward propagation.
394
+
395
+ Args:
396
+ x (Tensor): Input noise signal (B, 1, T).
397
+
398
+ Returns:
399
+ Tensor: Output tensor (B, 1, T)
400
+
401
+ """
402
+ x = self.first_conv(x)
403
+
404
+ skips = 0
405
+ for f in self.conv_layers:
406
+ x, h = f(x, None)
407
+ skips += h
408
+ skips *= math.sqrt(1.0 / len(self.conv_layers))
409
+
410
+ # apply final layers
411
+ x = skips
412
+ for f in self.last_conv_layers:
413
+ x = f(x)
414
+ return x
415
+
416
+ def apply_weight_norm(self):
417
+ """Apply weight normalization module from all of the layers."""
418
+ def _apply_weight_norm(m):
419
+ if isinstance(m, torch.nn.Conv1d) or isinstance(m, torch.nn.Conv2d):
420
+ torch.nn.utils.weight_norm(m)
421
+ logging.debug(f"Weight norm is applied to {m}.")
422
+
423
+ self.apply(_apply_weight_norm)
424
+
425
+ def remove_weight_norm(self):
426
+ """Remove weight normalization module from all of the layers."""
427
+ def _remove_weight_norm(m):
428
+ try:
429
+ logging.debug(f"Weight norm is removed from {m}.")
430
+ torch.nn.utils.remove_weight_norm(m)
431
+ except ValueError: # this module didn't have weight norm
432
+ return
433
+
434
+ self.apply(_remove_weight_norm)
usr/configs/midi/cascade/opencs/ds60_rel.yaml CHANGED
@@ -24,10 +24,11 @@ fs2_ckpt: 'checkpoints/0302_opencpop_fs_midi/model_ckpt_steps_160000.ckpt' #
24
  task_cls: usr.diffsinger_task.DiffSingerMIDITask
25
 
26
  K_step: 60
27
- max_tokens: 40000
28
  predictor_layers: 5
29
  dilation_cycle_length: 4 # *
30
  rel_pos: true
31
  dur_predictor_layers: 5 # *
32
  max_updates: 160000
33
  gaussian_start: false
 
 
24
  task_cls: usr.diffsinger_task.DiffSingerMIDITask
25
 
26
  K_step: 60
27
+ max_tokens: 36000
28
  predictor_layers: 5
29
  dilation_cycle_length: 4 # *
30
  rel_pos: true
31
  dur_predictor_layers: 5 # *
32
  max_updates: 160000
33
  gaussian_start: false
34
+ mask_uv_prob: 0.15
usr/diff/shallow_diffusion_tts.py CHANGED
@@ -1,273 +1,324 @@
1
- import math
2
- import random
3
- from functools import partial
4
- from inspect import isfunction
5
- from pathlib import Path
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from torch import nn
10
- from tqdm import tqdm
11
- from einops import rearrange
12
-
13
- from modules.fastspeech.fs2 import FastSpeech2
14
- from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
15
- from utils.hparams import hparams
16
-
17
-
18
-
19
- def exists(x):
20
- return x is not None
21
-
22
-
23
- def default(val, d):
24
- if exists(val):
25
- return val
26
- return d() if isfunction(d) else d
27
-
28
-
29
- # gaussian diffusion trainer class
30
-
31
- def extract(a, t, x_shape):
32
- b, *_ = t.shape
33
- out = a.gather(-1, t)
34
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
35
-
36
-
37
- def noise_like(shape, device, repeat=False):
38
- repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
39
- noise = lambda: torch.randn(shape, device=device)
40
- return repeat_noise() if repeat else noise()
41
-
42
-
43
- def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
44
- """
45
- linear schedule
46
- """
47
- betas = np.linspace(1e-4, max_beta, timesteps)
48
- return betas
49
-
50
-
51
- def cosine_beta_schedule(timesteps, s=0.008):
52
- """
53
- cosine schedule
54
- as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
55
- """
56
- steps = timesteps + 1
57
- x = np.linspace(0, steps, steps)
58
- alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
59
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
60
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
61
- return np.clip(betas, a_min=0, a_max=0.999)
62
-
63
-
64
- beta_schedule = {
65
- "cosine": cosine_beta_schedule,
66
- "linear": linear_beta_schedule,
67
- }
68
-
69
-
70
- class GaussianDiffusion(nn.Module):
71
- def __init__(self, phone_encoder, out_dims, denoise_fn,
72
- timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
73
- super().__init__()
74
- self.denoise_fn = denoise_fn
75
- if hparams.get('use_midi') is not None and hparams['use_midi']:
76
- self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
77
- else:
78
- self.fs2 = FastSpeech2(phone_encoder, out_dims)
79
- self.mel_bins = out_dims
80
-
81
- if exists(betas):
82
- betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
83
- else:
84
- if 'schedule_type' in hparams.keys():
85
- betas = beta_schedule[hparams['schedule_type']](timesteps)
86
- else:
87
- betas = cosine_beta_schedule(timesteps)
88
-
89
- alphas = 1. - betas
90
- alphas_cumprod = np.cumprod(alphas, axis=0)
91
- alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
92
-
93
- timesteps, = betas.shape
94
- self.num_timesteps = int(timesteps)
95
- self.K_step = K_step
96
- self.loss_type = loss_type
97
-
98
- to_torch = partial(torch.tensor, dtype=torch.float32)
99
-
100
- self.register_buffer('betas', to_torch(betas))
101
- self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
102
- self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
103
-
104
- # calculations for diffusion q(x_t | x_{t-1}) and others
105
- self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
106
- self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
107
- self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
108
- self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
109
- self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
110
-
111
- # calculations for posterior q(x_{t-1} | x_t, x_0)
112
- posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
113
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
114
- self.register_buffer('posterior_variance', to_torch(posterior_variance))
115
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
116
- self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
117
- self.register_buffer('posterior_mean_coef1', to_torch(
118
- betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
119
- self.register_buffer('posterior_mean_coef2', to_torch(
120
- (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
121
-
122
- self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
123
- self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
124
-
125
- def q_mean_variance(self, x_start, t):
126
- mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
127
- variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
128
- log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
129
- return mean, variance, log_variance
130
-
131
- def predict_start_from_noise(self, x_t, t, noise):
132
- return (
133
- extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
134
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
135
- )
136
-
137
- def q_posterior(self, x_start, x_t, t):
138
- posterior_mean = (
139
- extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
140
- extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
141
- )
142
- posterior_variance = extract(self.posterior_variance, t, x_t.shape)
143
- posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
144
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
145
-
146
- def p_mean_variance(self, x, t, cond, clip_denoised: bool):
147
- noise_pred = self.denoise_fn(x, t, cond=cond)
148
- x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
149
-
150
- if clip_denoised:
151
- x_recon.clamp_(-1., 1.)
152
-
153
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
154
- return model_mean, posterior_variance, posterior_log_variance
155
-
156
- @torch.no_grad()
157
- def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
158
- b, *_, device = *x.shape, x.device
159
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
160
- noise = noise_like(x.shape, device, repeat_noise)
161
- # no noise when t == 0
162
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
163
- return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
164
-
165
- def q_sample(self, x_start, t, noise=None):
166
- noise = default(noise, lambda: torch.randn_like(x_start))
167
- return (
168
- extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
169
- extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
170
- )
171
-
172
- def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
173
- noise = default(noise, lambda: torch.randn_like(x_start))
174
-
175
- x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
176
- x_recon = self.denoise_fn(x_noisy, t, cond)
177
-
178
- if self.loss_type == 'l1':
179
- if nonpadding is not None:
180
- loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
181
- else:
182
- # print('are you sure w/o nonpadding?')
183
- loss = (noise - x_recon).abs().mean()
184
-
185
- elif self.loss_type == 'l2':
186
- loss = F.mse_loss(noise, x_recon)
187
- else:
188
- raise NotImplementedError()
189
-
190
- return loss
191
-
192
- def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
193
- ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
194
- b, *_, device = *txt_tokens.shape, txt_tokens.device
195
- ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
196
- skip_decoder=(not infer), infer=infer, **kwargs)
197
- cond = ret['decoder_inp'].transpose(1, 2)
198
-
199
- if not infer:
200
- t = torch.randint(0, self.K_step, (b,), device=device).long()
201
- x = ref_mels
202
- x = self.norm_spec(x)
203
- x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
204
- ret['diff_loss'] = self.p_losses(x, t, cond)
205
- # nonpadding = (mel2ph != 0).float()
206
- # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
207
- else:
208
- ret['fs2_mel'] = ret['mel_out']
209
- fs2_mels = ret['mel_out']
210
- t = self.K_step
211
- fs2_mels = self.norm_spec(fs2_mels)
212
- fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
213
-
214
- x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
215
- if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
216
- print('===> gaussion start.')
217
- shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
218
- x = torch.randn(shape, device=device)
219
- for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
220
- x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
221
- x = x[:, 0].transpose(1, 2)
222
- if mel2ph is not None: # for singing
223
- ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
224
- else:
225
- ret['mel_out'] = self.denorm_spec(x)
226
- return ret
227
-
228
- def norm_spec(self, x):
229
- return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
230
-
231
- def denorm_spec(self, x):
232
- return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
233
-
234
- def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
235
- return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
236
-
237
- def out2mel(self, x):
238
- return x
239
-
240
-
241
- class OfflineGaussianDiffusion(GaussianDiffusion):
242
- def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
243
- ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
244
- b, *_, device = *txt_tokens.shape, txt_tokens.device
245
-
246
- ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
247
- skip_decoder=True, infer=True, **kwargs)
248
- cond = ret['decoder_inp'].transpose(1, 2)
249
- fs2_mels = ref_mels[1]
250
- ref_mels = ref_mels[0]
251
-
252
- if not infer:
253
- t = torch.randint(0, self.K_step, (b,), device=device).long()
254
- x = ref_mels
255
- x = self.norm_spec(x)
256
- x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
257
- ret['diff_loss'] = self.p_losses(x, t, cond)
258
- else:
259
- t = self.K_step
260
- fs2_mels = self.norm_spec(fs2_mels)
261
- fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
262
-
263
- x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
264
-
265
- if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
266
- print('===> gaussion start.')
267
- shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
268
- x = torch.randn(shape, device=device)
269
- for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
270
- x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
271
- x = x[:, 0].transpose(1, 2)
272
- ret['mel_out'] = self.denorm_spec(x)
273
- return ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from collections import deque
4
+ from functools import partial
5
+ from inspect import isfunction
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch import nn
11
+ from tqdm import tqdm
12
+ from einops import rearrange
13
+
14
+ from modules.fastspeech.fs2 import FastSpeech2
15
+ from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
16
+ from utils.hparams import hparams
17
+
18
+
19
+
20
+ def exists(x):
21
+ return x is not None
22
+
23
+
24
+ def default(val, d):
25
+ if exists(val):
26
+ return val
27
+ return d() if isfunction(d) else d
28
+
29
+
30
+ # gaussian diffusion trainer class
31
+
32
+ def extract(a, t, x_shape):
33
+ b, *_ = t.shape
34
+ out = a.gather(-1, t)
35
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
36
+
37
+
38
+ def noise_like(shape, device, repeat=False):
39
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
40
+ noise = lambda: torch.randn(shape, device=device)
41
+ return repeat_noise() if repeat else noise()
42
+
43
+
44
+ def linear_beta_schedule(timesteps, max_beta=hparams.get('max_beta', 0.01)):
45
+ """
46
+ linear schedule
47
+ """
48
+ betas = np.linspace(1e-4, max_beta, timesteps)
49
+ return betas
50
+
51
+
52
+ def cosine_beta_schedule(timesteps, s=0.008):
53
+ """
54
+ cosine schedule
55
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
56
+ """
57
+ steps = timesteps + 1
58
+ x = np.linspace(0, steps, steps)
59
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
60
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
61
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
62
+ return np.clip(betas, a_min=0, a_max=0.999)
63
+
64
+
65
+ beta_schedule = {
66
+ "cosine": cosine_beta_schedule,
67
+ "linear": linear_beta_schedule,
68
+ }
69
+
70
+
71
+ class GaussianDiffusion(nn.Module):
72
+ def __init__(self, phone_encoder, out_dims, denoise_fn,
73
+ timesteps=1000, K_step=1000, loss_type=hparams.get('diff_loss_type', 'l1'), betas=None, spec_min=None, spec_max=None):
74
+ super().__init__()
75
+ self.denoise_fn = denoise_fn
76
+ if hparams.get('use_midi') is not None and hparams['use_midi']:
77
+ self.fs2 = FastSpeech2MIDI(phone_encoder, out_dims)
78
+ else:
79
+ self.fs2 = FastSpeech2(phone_encoder, out_dims)
80
+ self.mel_bins = out_dims
81
+
82
+ if exists(betas):
83
+ betas = betas.detach().cpu().numpy() if isinstance(betas, torch.Tensor) else betas
84
+ else:
85
+ if 'schedule_type' in hparams.keys():
86
+ betas = beta_schedule[hparams['schedule_type']](timesteps)
87
+ else:
88
+ betas = cosine_beta_schedule(timesteps)
89
+
90
+ alphas = 1. - betas
91
+ alphas_cumprod = np.cumprod(alphas, axis=0)
92
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
93
+
94
+ timesteps, = betas.shape
95
+ self.num_timesteps = int(timesteps)
96
+ self.K_step = K_step
97
+ self.loss_type = loss_type
98
+
99
+ self.noise_list = deque(maxlen=4)
100
+
101
+ to_torch = partial(torch.tensor, dtype=torch.float32)
102
+
103
+ self.register_buffer('betas', to_torch(betas))
104
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
105
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
106
+
107
+ # calculations for diffusion q(x_t | x_{t-1}) and others
108
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
109
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
110
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
111
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
112
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
113
+
114
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
115
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
116
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
117
+ self.register_buffer('posterior_variance', to_torch(posterior_variance))
118
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
119
+ self.register_buffer('posterior_log_variance_clipped', to_torch(np.log(np.maximum(posterior_variance, 1e-20))))
120
+ self.register_buffer('posterior_mean_coef1', to_torch(
121
+ betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
122
+ self.register_buffer('posterior_mean_coef2', to_torch(
123
+ (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
124
+
125
+ self.register_buffer('spec_min', torch.FloatTensor(spec_min)[None, None, :hparams['keep_bins']])
126
+ self.register_buffer('spec_max', torch.FloatTensor(spec_max)[None, None, :hparams['keep_bins']])
127
+
128
+ def q_mean_variance(self, x_start, t):
129
+ mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
130
+ variance = extract(1. - self.alphas_cumprod, t, x_start.shape)
131
+ log_variance = extract(self.log_one_minus_alphas_cumprod, t, x_start.shape)
132
+ return mean, variance, log_variance
133
+
134
+ def predict_start_from_noise(self, x_t, t, noise):
135
+ return (
136
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
137
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
138
+ )
139
+
140
+ def q_posterior(self, x_start, x_t, t):
141
+ posterior_mean = (
142
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
143
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
144
+ )
145
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
146
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
147
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
148
+
149
+ def p_mean_variance(self, x, t, cond, clip_denoised: bool):
150
+ noise_pred = self.denoise_fn(x, t, cond=cond)
151
+ x_recon = self.predict_start_from_noise(x, t=t, noise=noise_pred)
152
+
153
+ if clip_denoised:
154
+ x_recon.clamp_(-1., 1.)
155
+
156
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
157
+ return model_mean, posterior_variance, posterior_log_variance
158
+
159
+ @torch.no_grad()
160
+ def p_sample(self, x, t, cond, clip_denoised=True, repeat_noise=False):
161
+ b, *_, device = *x.shape, x.device
162
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, cond=cond, clip_denoised=clip_denoised)
163
+ noise = noise_like(x.shape, device, repeat_noise)
164
+ # no noise when t == 0
165
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
166
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
167
+
168
+ @torch.no_grad()
169
+ def p_sample_plms(self, x, t, interval, cond, clip_denoised=True, repeat_noise=False):
170
+ """
171
+ Use the PLMS method from [Pseudo Numerical Methods for Diffusion Models on Manifolds](https://arxiv.org/abs/2202.09778).
172
+ """
173
+
174
+ def get_x_pred(x, noise_t, t):
175
+ a_t = extract(self.alphas_cumprod, t, x.shape)
176
+ if t[0] < interval:
177
+ a_prev = torch.ones_like(a_t)
178
+ else:
179
+ a_prev = extract(self.alphas_cumprod, torch.max(t-interval, torch.zeros_like(t)), x.shape)
180
+ a_t_sq, a_prev_sq = a_t.sqrt(), a_prev.sqrt()
181
+
182
+ x_delta = (a_prev - a_t) * ((1 / (a_t_sq * (a_t_sq + a_prev_sq))) * x - 1 / (a_t_sq * (((1 - a_prev) * a_t).sqrt() + ((1 - a_t) * a_prev).sqrt())) * noise_t)
183
+ x_pred = x + x_delta
184
+
185
+ return x_pred
186
+
187
+ noise_list = self.noise_list
188
+ noise_pred = self.denoise_fn(x, t, cond=cond)
189
+
190
+ if len(noise_list) == 0:
191
+ x_pred = get_x_pred(x, noise_pred, t)
192
+ noise_pred_prev = self.denoise_fn(x_pred, max(t-interval, 0), cond=cond)
193
+ noise_pred_prime = (noise_pred + noise_pred_prev) / 2
194
+ elif len(noise_list) == 1:
195
+ noise_pred_prime = (3 * noise_pred - noise_list[-1]) / 2
196
+ elif len(noise_list) == 2:
197
+ noise_pred_prime = (23 * noise_pred - 16 * noise_list[-1] + 5 * noise_list[-2]) / 12
198
+ elif len(noise_list) >= 3:
199
+ noise_pred_prime = (55 * noise_pred - 59 * noise_list[-1] + 37 * noise_list[-2] - 9 * noise_list[-3]) / 24
200
+
201
+ x_prev = get_x_pred(x, noise_pred_prime, t)
202
+ noise_list.append(noise_pred)
203
+
204
+ return x_prev
205
+
206
+ def q_sample(self, x_start, t, noise=None):
207
+ noise = default(noise, lambda: torch.randn_like(x_start))
208
+ return (
209
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
210
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
211
+ )
212
+
213
+ def p_losses(self, x_start, t, cond, noise=None, nonpadding=None):
214
+ noise = default(noise, lambda: torch.randn_like(x_start))
215
+
216
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
217
+ x_recon = self.denoise_fn(x_noisy, t, cond)
218
+
219
+ if self.loss_type == 'l1':
220
+ if nonpadding is not None:
221
+ loss = ((noise - x_recon).abs() * nonpadding.unsqueeze(1)).mean()
222
+ else:
223
+ # print('are you sure w/o nonpadding?')
224
+ loss = (noise - x_recon).abs().mean()
225
+
226
+ elif self.loss_type == 'l2':
227
+ loss = F.mse_loss(noise, x_recon)
228
+ else:
229
+ raise NotImplementedError()
230
+
231
+ return loss
232
+
233
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
234
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
235
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
236
+ ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
237
+ skip_decoder=(not infer), infer=infer, **kwargs)
238
+ cond = ret['decoder_inp'].transpose(1, 2)
239
+
240
+ if not infer:
241
+ t = torch.randint(0, self.K_step, (b,), device=device).long()
242
+ x = ref_mels
243
+ x = self.norm_spec(x)
244
+ x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
245
+ ret['diff_loss'] = self.p_losses(x, t, cond)
246
+ # nonpadding = (mel2ph != 0).float()
247
+ # ret['diff_loss'] = self.p_losses(x, t, cond, nonpadding=nonpadding)
248
+ else:
249
+ ret['fs2_mel'] = ret['mel_out']
250
+ fs2_mels = ret['mel_out']
251
+ t = self.K_step
252
+ fs2_mels = self.norm_spec(fs2_mels)
253
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
254
+
255
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
256
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
257
+ print('===> gaussion start.')
258
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
259
+ x = torch.randn(shape, device=device)
260
+
261
+ if hparams.get('pndm_speedup'):
262
+ print('===> pndm speed:', hparams['pndm_speedup'])
263
+ self.noise_list = deque(maxlen=4)
264
+ iteration_interval = hparams['pndm_speedup']
265
+ for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
266
+ total=t // iteration_interval):
267
+ x = self.p_sample_plms(x, torch.full((b,), i, device=device, dtype=torch.long), iteration_interval,
268
+ cond)
269
+ else:
270
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
271
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
272
+ x = x[:, 0].transpose(1, 2)
273
+ if mel2ph is not None: # for singing
274
+ ret['mel_out'] = self.denorm_spec(x) * ((mel2ph > 0).float()[:, :, None])
275
+ else:
276
+ ret['mel_out'] = self.denorm_spec(x)
277
+ return ret
278
+
279
+ def norm_spec(self, x):
280
+ return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
281
+
282
+ def denorm_spec(self, x):
283
+ return (x + 1) / 2 * (self.spec_max - self.spec_min) + self.spec_min
284
+
285
+ def cwt2f0_norm(self, cwt_spec, mean, std, mel2ph):
286
+ return self.fs2.cwt2f0_norm(cwt_spec, mean, std, mel2ph)
287
+
288
+ def out2mel(self, x):
289
+ return x
290
+
291
+
292
+ class OfflineGaussianDiffusion(GaussianDiffusion):
293
+ def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
294
+ ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
295
+ b, *_, device = *txt_tokens.shape, txt_tokens.device
296
+
297
+ ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
298
+ skip_decoder=True, infer=True, **kwargs)
299
+ cond = ret['decoder_inp'].transpose(1, 2)
300
+ fs2_mels = ref_mels[1]
301
+ ref_mels = ref_mels[0]
302
+
303
+ if not infer:
304
+ t = torch.randint(0, self.K_step, (b,), device=device).long()
305
+ x = ref_mels
306
+ x = self.norm_spec(x)
307
+ x = x.transpose(1, 2)[:, None, :, :] # [B, 1, M, T]
308
+ ret['diff_loss'] = self.p_losses(x, t, cond)
309
+ else:
310
+ t = self.K_step
311
+ fs2_mels = self.norm_spec(fs2_mels)
312
+ fs2_mels = fs2_mels.transpose(1, 2)[:, None, :, :]
313
+
314
+ x = self.q_sample(x_start=fs2_mels, t=torch.tensor([t - 1], device=device).long())
315
+
316
+ if hparams.get('gaussian_start') is not None and hparams['gaussian_start']:
317
+ print('===> gaussion start.')
318
+ shape = (cond.shape[0], 1, self.mel_bins, cond.shape[2])
319
+ x = torch.randn(shape, device=device)
320
+ for i in tqdm(reversed(range(0, t)), desc='sample time step', total=t):
321
+ x = self.p_sample(x, torch.full((b,), i, device=device, dtype=torch.long), cond)
322
+ x = x[:, 0].transpose(1, 2)
323
+ ret['mel_out'] = self.denorm_spec(x)
324
+ return ret
utils/hparams.py CHANGED
@@ -21,35 +21,30 @@ def override_config(old_config: dict, new_config: dict):
21
 
22
 
23
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
24
- if config == '' and exp_name == '':
25
- parser = argparse.ArgumentParser(description='')
26
  parser.add_argument('--config', type=str, default='',
27
  help='location of the data corpus')
28
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
29
- parser.add_argument('-hp', '--hparams', type=str, default='',
30
  help='location of the data corpus')
31
  parser.add_argument('--infer', action='store_true', help='infer')
32
  parser.add_argument('--validate', action='store_true', help='validate')
33
  parser.add_argument('--reset', action='store_true', help='reset hparams')
34
- parser.add_argument('--remove', action='store_true', help='remove old ckpt')
35
  parser.add_argument('--debug', action='store_true', help='debug')
36
  args, unknown = parser.parse_known_args()
37
- print("| Unknow hparams: ", unknown)
38
  else:
39
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
40
- infer=False, validate=False, reset=False, debug=False, remove=False)
41
- global hparams
42
- assert args.config != '' or args.exp_name != ''
43
- if args.config != '':
44
- assert os.path.exists(args.config)
45
 
46
  config_chains = []
47
  loaded_config = set()
48
 
49
- def load_config(config_fn):
50
- # deep first inheritance and avoid the second visit of one node
51
- if not os.path.exists(config_fn):
52
- return {}
53
  with open(config_fn) as f:
54
  hparams_ = yaml.safe_load(f)
55
  loaded_config.add(config_fn)
@@ -58,10 +53,10 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
58
  if not isinstance(hparams_['base_config'], list):
59
  hparams_['base_config'] = [hparams_['base_config']]
60
  for c in hparams_['base_config']:
61
- if c.startswith('.'):
62
- c = f'{os.path.dirname(config_fn)}/{c}'
63
- c = os.path.normpath(c)
64
  if c not in loaded_config:
 
 
 
65
  override_config(ret_hparams, load_config(c))
66
  override_config(ret_hparams, hparams_)
67
  else:
@@ -69,43 +64,36 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
69
  config_chains.append(config_fn)
70
  return ret_hparams
71
 
 
 
72
  saved_hparams = {}
73
- args_work_dir = ''
74
- if args.exp_name != '':
75
- args_work_dir = f'checkpoints/{args.exp_name}'
76
  ckpt_config_path = f'{args_work_dir}/config.yaml'
77
  if os.path.exists(ckpt_config_path):
78
- with open(ckpt_config_path) as f:
79
- saved_hparams_ = yaml.safe_load(f)
80
- if saved_hparams_ is not None:
81
- saved_hparams.update(saved_hparams_)
 
 
 
 
82
  hparams_ = {}
83
- if args.config != '':
84
- hparams_.update(load_config(args.config))
 
85
  if not args.reset:
86
  hparams_.update(saved_hparams)
87
  hparams_['work_dir'] = args_work_dir
88
 
89
- # Support config overriding in command line. Support list type config overriding.
90
- # Examples: --hparams="a=1,b.c=2,d=[1 1 1]"
91
  if args.hparams != "":
92
  for new_hparam in args.hparams.split(","):
93
  k, v = new_hparam.split("=")
94
- v = v.strip("\'\" ")
95
- config_node = hparams_
96
- for k_ in k.split(".")[:-1]:
97
- config_node = config_node[k_]
98
- k = k.split(".")[-1]
99
- if v in ['True', 'False'] or type(config_node[k]) in [bool, list, dict]:
100
- if type(config_node[k]) == list:
101
- v = v.replace(" ", ",")
102
- config_node[k] = eval(v)
103
  else:
104
- config_node[k] = type(config_node[k])(v)
105
- if args_work_dir != '' and args.remove:
106
- answer = input("REMOVE old checkpoint? Y/N [Default: N]: ")
107
- if answer.lower() == "y":
108
- remove_file(args_work_dir)
109
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
110
  os.makedirs(hparams_['work_dir'], exist_ok=True)
111
  with open(ckpt_config_path, 'w') as f:
@@ -114,11 +102,11 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
114
  hparams_['infer'] = args.infer
115
  hparams_['debug'] = args.debug
116
  hparams_['validate'] = args.validate
117
- hparams_['exp_name'] = args.exp_name
118
  global global_print_hparams
119
  if global_hparams:
120
  hparams.clear()
121
  hparams.update(hparams_)
 
122
  if print_hparams and global_print_hparams and global_hparams:
123
  print('| Hparams chains: ', config_chains)
124
  print('| Hparams: ')
@@ -126,5 +114,9 @@ def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, glob
126
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
127
  print("")
128
  global_print_hparams = False
 
 
 
 
 
129
  return hparams_
130
-
 
21
 
22
 
23
  def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
24
+ if config == '':
25
+ parser = argparse.ArgumentParser(description='neural music')
26
  parser.add_argument('--config', type=str, default='',
27
  help='location of the data corpus')
28
  parser.add_argument('--exp_name', type=str, default='', help='exp_name')
29
+ parser.add_argument('--hparams', type=str, default='',
30
  help='location of the data corpus')
31
  parser.add_argument('--infer', action='store_true', help='infer')
32
  parser.add_argument('--validate', action='store_true', help='validate')
33
  parser.add_argument('--reset', action='store_true', help='reset hparams')
 
34
  parser.add_argument('--debug', action='store_true', help='debug')
35
  args, unknown = parser.parse_known_args()
 
36
  else:
37
  args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
38
+ infer=False, validate=False, reset=False, debug=False)
39
+ args_work_dir = ''
40
+ if args.exp_name != '':
41
+ args.work_dir = args.exp_name
42
+ args_work_dir = f'checkpoints/{args.work_dir}'
43
 
44
  config_chains = []
45
  loaded_config = set()
46
 
47
+ def load_config(config_fn): # deep first
 
 
 
48
  with open(config_fn) as f:
49
  hparams_ = yaml.safe_load(f)
50
  loaded_config.add(config_fn)
 
53
  if not isinstance(hparams_['base_config'], list):
54
  hparams_['base_config'] = [hparams_['base_config']]
55
  for c in hparams_['base_config']:
 
 
 
56
  if c not in loaded_config:
57
+ if c.startswith('.'):
58
+ c = f'{os.path.dirname(config_fn)}/{c}'
59
+ c = os.path.normpath(c)
60
  override_config(ret_hparams, load_config(c))
61
  override_config(ret_hparams, hparams_)
62
  else:
 
64
  config_chains.append(config_fn)
65
  return ret_hparams
66
 
67
+ global hparams
68
+ assert args.config != '' or args_work_dir != ''
69
  saved_hparams = {}
70
+ if args_work_dir != 'checkpoints/':
 
 
71
  ckpt_config_path = f'{args_work_dir}/config.yaml'
72
  if os.path.exists(ckpt_config_path):
73
+ try:
74
+ with open(ckpt_config_path) as f:
75
+ saved_hparams.update(yaml.safe_load(f))
76
+ except:
77
+ pass
78
+ if args.config == '':
79
+ args.config = ckpt_config_path
80
+
81
  hparams_ = {}
82
+
83
+ hparams_.update(load_config(args.config))
84
+
85
  if not args.reset:
86
  hparams_.update(saved_hparams)
87
  hparams_['work_dir'] = args_work_dir
88
 
 
 
89
  if args.hparams != "":
90
  for new_hparam in args.hparams.split(","):
91
  k, v = new_hparam.split("=")
92
+ if v in ['True', 'False'] or type(hparams_[k]) == bool:
93
+ hparams_[k] = eval(v)
 
 
 
 
 
 
 
94
  else:
95
+ hparams_[k] = type(hparams_[k])(v)
96
+
 
 
 
97
  if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
98
  os.makedirs(hparams_['work_dir'], exist_ok=True)
99
  with open(ckpt_config_path, 'w') as f:
 
102
  hparams_['infer'] = args.infer
103
  hparams_['debug'] = args.debug
104
  hparams_['validate'] = args.validate
 
105
  global global_print_hparams
106
  if global_hparams:
107
  hparams.clear()
108
  hparams.update(hparams_)
109
+
110
  if print_hparams and global_print_hparams and global_hparams:
111
  print('| Hparams chains: ', config_chains)
112
  print('| Hparams: ')
 
114
  print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
115
  print("")
116
  global_print_hparams = False
117
+ # print(hparams_.keys())
118
+ if hparams.get('exp_name') is None:
119
+ hparams['exp_name'] = args.exp_name
120
+ if hparams_.get('exp_name') is None:
121
+ hparams_['exp_name'] = args.exp_name
122
  return hparams_