Serhiy Stetskovych commited on
Commit
e9fe7cc
1 Parent(s): 0792429

Remove unused files

Browse files
config.yaml DELETED
@@ -1,33 +0,0 @@
1
- # pytorch_lightning==1.8.6
2
-
3
- feature_extractor:
4
- class_path: vocos.feature_extractors.MelSpectrogramFeatures
5
- init_args:
6
- sample_rate: 22050
7
- n_fft: 1024
8
- hop_length: 256
9
- n_mels: 80
10
- padding: same
11
- f_min: 0
12
- f_max: 8000
13
- norm: "slaney"
14
- mel_scale: "slaney"
15
-
16
-
17
- backbone:
18
- class_path: vocos.models.VocosBackbone
19
- init_args:
20
- input_channels: 80
21
- dim: 512
22
- intermediate_dim: 1536
23
- num_layers: 8
24
-
25
- head:
26
- class_path: vocos.heads.ISTFTHead
27
- init_args:
28
- dim: 512
29
- n_fft: 1024
30
- hop_length: 256
31
- padding: same
32
-
33
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Jungil Kong
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/README.md DELETED
@@ -1,101 +0,0 @@
1
- # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2
-
3
- ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4
-
5
- In our [paper](https://arxiv.org/abs/2010.05646),
6
- we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
7
- We provide our implementation and pretrained models as open source in this repository.
8
-
9
- **Abstract :**
10
- Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11
- Although such methods improve the sampling efficiency and memory usage,
12
- their sample quality has not yet reached that of autoregressive and flow-based generative models.
13
- In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14
- As speech audio consists of sinusoidal signals with various periods,
15
- we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16
- A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17
- demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18
- real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19
- speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20
- faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21
-
22
- Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23
-
24
- ## Pre-requisites
25
-
26
- 1. Python >= 3.6
27
- 2. Clone this repository.
28
- 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29
- 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30
- And move all wav files to `LJSpeech-1.1/wavs`
31
-
32
- ## Training
33
-
34
- ```
35
- python train.py --config config_v1.json
36
- ```
37
-
38
- To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
39
- Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
40
- You can change the path by adding `--checkpoint_path` option.
41
-
42
- Validation loss during training with V1 generator.<br>
43
- ![validation loss](./validation_loss.png)
44
-
45
- ## Pretrained Model
46
-
47
- You can also use pretrained models we provide.<br/>
48
- [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
49
- Details of each folder are as in follows:
50
-
51
- | Folder Name | Generator | Dataset | Fine-Tuned |
52
- | ------------ | --------- | --------- | ------------------------------------------------------ |
53
- | LJ_V1 | V1 | LJSpeech | No |
54
- | LJ_V2 | V2 | LJSpeech | No |
55
- | LJ_V3 | V3 | LJSpeech | No |
56
- | LJ_FT_T2_V1 | V1 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
57
- | LJ_FT_T2_V2 | V2 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
58
- | LJ_FT_T2_V3 | V3 | LJSpeech | Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2)) |
59
- | VCTK_V1 | V1 | VCTK | No |
60
- | VCTK_V2 | V2 | VCTK | No |
61
- | VCTK_V3 | V3 | VCTK | No |
62
- | UNIVERSAL_V1 | V1 | Universal | No |
63
-
64
- We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
65
-
66
- ## Fine-Tuning
67
-
68
- 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
69
- The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
70
- Example:
71
- ` Audio File : LJ001-0001.wav
72
- Mel-Spectrogram File : LJ001-0001.npy`
73
- 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
74
- 3. Run the following command.
75
- ```
76
- python train.py --fine_tuning True --config config_v1.json
77
- ```
78
- For other command line options, please refer to the training section.
79
-
80
- ## Inference from wav file
81
-
82
- 1. Make `test_files` directory and copy wav files into the directory.
83
- 2. Run the following command.
84
- ` python inference.py --checkpoint_file [generator checkpoint file path]`
85
- Generated wav files are saved in `generated_files` by default.<br>
86
- You can change the path by adding `--output_dir` option.
87
-
88
- ## Inference for end-to-end speech synthesis
89
-
90
- 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
91
- You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
92
- [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
93
- 2. Run the following command.
94
- ` python inference_e2e.py --checkpoint_file [generator checkpoint file path]`
95
- Generated wav files are saved in `generated_files_from_mel` by default.<br>
96
- You can change the path by adding `--output_dir` option.
97
-
98
- ## Acknowledgements
99
-
100
- We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
101
- and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/__init__.py DELETED
File without changes
hifigan/config.py DELETED
@@ -1,28 +0,0 @@
1
- v1 = {
2
- "resblock": "1",
3
- "num_gpus": 0,
4
- "batch_size": 16,
5
- "learning_rate": 0.0004,
6
- "adam_b1": 0.8,
7
- "adam_b2": 0.99,
8
- "lr_decay": 0.999,
9
- "seed": 1234,
10
- "upsample_rates": [8, 8, 2, 2],
11
- "upsample_kernel_sizes": [16, 16, 4, 4],
12
- "upsample_initial_channel": 512,
13
- "resblock_kernel_sizes": [3, 7, 11],
14
- "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
15
- "resblock_initial_channel": 256,
16
- "segment_size": 8192,
17
- "num_mels": 80,
18
- "num_freq": 1025,
19
- "n_fft": 1024,
20
- "hop_size": 256,
21
- "win_size": 1024,
22
- "sampling_rate": 22050,
23
- "fmin": 0,
24
- "fmax": 8000,
25
- "fmax_loss": None,
26
- "num_workers": 4,
27
- "dist_config": {"dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1},
28
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/denoiser.py DELETED
@@ -1,64 +0,0 @@
1
- # Code modified from Rafael Valle's implementation https://github.com/NVIDIA/waveglow/blob/5bc2a53e20b3b533362f974cfa1ea0267ae1c2b1/denoiser.py
2
-
3
- """Waveglow style denoiser can be used to remove the artifacts from the HiFiGAN generated audio."""
4
- import torch
5
-
6
-
7
- class Denoiser(torch.nn.Module):
8
- """Removes model bias from audio produced with waveglow"""
9
-
10
- def __init__(self, vocoder, filter_length=1024, n_overlap=4, win_length=1024, mode="zeros"):
11
- super().__init__()
12
- self.filter_length = filter_length
13
- self.hop_length = int(filter_length / n_overlap)
14
- self.win_length = win_length
15
-
16
- dtype, device = next(vocoder.parameters()).dtype, next(vocoder.parameters()).device
17
- self.device = device
18
- if mode == "zeros":
19
- mel_input = torch.zeros((1, 80, 88), dtype=dtype, device=device)
20
- elif mode == "normal":
21
- mel_input = torch.randn((1, 80, 88), dtype=dtype, device=device)
22
- else:
23
- raise Exception(f"Mode {mode} if not supported")
24
-
25
- def stft_fn(audio, n_fft, hop_length, win_length, window):
26
- spec = torch.stft(
27
- audio,
28
- n_fft=n_fft,
29
- hop_length=hop_length,
30
- win_length=win_length,
31
- window=window,
32
- return_complex=True,
33
- )
34
- spec = torch.view_as_real(spec)
35
- return torch.sqrt(spec.pow(2).sum(-1)), torch.atan2(spec[..., -1], spec[..., 0])
36
-
37
- self.stft = lambda x: stft_fn(
38
- audio=x,
39
- n_fft=self.filter_length,
40
- hop_length=self.hop_length,
41
- win_length=self.win_length,
42
- window=torch.hann_window(self.win_length, device=device),
43
- )
44
- self.istft = lambda x, y: torch.istft(
45
- torch.complex(x * torch.cos(y), x * torch.sin(y)),
46
- n_fft=self.filter_length,
47
- hop_length=self.hop_length,
48
- win_length=self.win_length,
49
- window=torch.hann_window(self.win_length, device=device),
50
- )
51
-
52
- with torch.no_grad():
53
- bias_audio = vocoder(mel_input).float().squeeze(0)
54
- bias_spec, _ = self.stft(bias_audio)
55
-
56
- self.register_buffer("bias_spec", bias_spec[:, :, 0][:, :, None])
57
-
58
- @torch.inference_mode()
59
- def forward(self, audio, strength=0.0005):
60
- audio_spec, audio_angles = self.stft(audio)
61
- audio_spec_denoised = audio_spec - self.bias_spec.to(audio.device) * strength
62
- audio_spec_denoised = torch.clamp(audio_spec_denoised, 0.0)
63
- audio_denoised = self.istft(audio_spec_denoised, audio_angles)
64
- return audio_denoised
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/env.py DELETED
@@ -1,17 +0,0 @@
1
- """ from https://github.com/jik876/hifi-gan """
2
-
3
- import os
4
- import shutil
5
-
6
-
7
- class AttrDict(dict):
8
- def __init__(self, *args, **kwargs):
9
- super().__init__(*args, **kwargs)
10
- self.__dict__ = self
11
-
12
-
13
- def build_env(config, config_name, path):
14
- t_path = os.path.join(path, config_name)
15
- if config != t_path:
16
- os.makedirs(path, exist_ok=True)
17
- shutil.copyfile(config, os.path.join(path, config_name))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/meldataset.py DELETED
@@ -1,217 +0,0 @@
1
- """ from https://github.com/jik876/hifi-gan """
2
-
3
- import math
4
- import os
5
- import random
6
-
7
- import numpy as np
8
- import torch
9
- import torch.utils.data
10
- from librosa.filters import mel as librosa_mel_fn
11
- from librosa.util import normalize
12
- from scipy.io.wavfile import read
13
-
14
- MAX_WAV_VALUE = 32768.0
15
-
16
-
17
- def load_wav(full_path):
18
- sampling_rate, data = read(full_path)
19
- return data, sampling_rate
20
-
21
-
22
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
23
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
24
-
25
-
26
- def dynamic_range_decompression(x, C=1):
27
- return np.exp(x) / C
28
-
29
-
30
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
31
- return torch.log(torch.clamp(x, min=clip_val) * C)
32
-
33
-
34
- def dynamic_range_decompression_torch(x, C=1):
35
- return torch.exp(x) / C
36
-
37
-
38
- def spectral_normalize_torch(magnitudes):
39
- output = dynamic_range_compression_torch(magnitudes)
40
- return output
41
-
42
-
43
- def spectral_de_normalize_torch(magnitudes):
44
- output = dynamic_range_decompression_torch(magnitudes)
45
- return output
46
-
47
-
48
- mel_basis = {}
49
- hann_window = {}
50
-
51
-
52
- def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
53
- if torch.min(y) < -1.0:
54
- print("min value is ", torch.min(y))
55
- if torch.max(y) > 1.0:
56
- print("max value is ", torch.max(y))
57
-
58
- global mel_basis, hann_window # pylint: disable=global-statement
59
- if fmax not in mel_basis:
60
- mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
61
- mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
62
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
63
-
64
- y = torch.nn.functional.pad(
65
- y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
66
- )
67
- y = y.squeeze(1)
68
-
69
- spec = torch.view_as_real(
70
- torch.stft(
71
- y,
72
- n_fft,
73
- hop_length=hop_size,
74
- win_length=win_size,
75
- window=hann_window[str(y.device)],
76
- center=center,
77
- pad_mode="reflect",
78
- normalized=False,
79
- onesided=True,
80
- return_complex=True,
81
- )
82
- )
83
-
84
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
85
-
86
- spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
87
- spec = spectral_normalize_torch(spec)
88
-
89
- return spec
90
-
91
-
92
- def get_dataset_filelist(a):
93
- with open(a.input_training_file, encoding="utf-8") as fi:
94
- training_files = [
95
- os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
96
- ]
97
-
98
- with open(a.input_validation_file, encoding="utf-8") as fi:
99
- validation_files = [
100
- os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
101
- ]
102
- return training_files, validation_files
103
-
104
-
105
- class MelDataset(torch.utils.data.Dataset):
106
- def __init__(
107
- self,
108
- training_files,
109
- segment_size,
110
- n_fft,
111
- num_mels,
112
- hop_size,
113
- win_size,
114
- sampling_rate,
115
- fmin,
116
- fmax,
117
- split=True,
118
- shuffle=True,
119
- n_cache_reuse=1,
120
- device=None,
121
- fmax_loss=None,
122
- fine_tuning=False,
123
- base_mels_path=None,
124
- ):
125
- self.audio_files = training_files
126
- random.seed(1234)
127
- if shuffle:
128
- random.shuffle(self.audio_files)
129
- self.segment_size = segment_size
130
- self.sampling_rate = sampling_rate
131
- self.split = split
132
- self.n_fft = n_fft
133
- self.num_mels = num_mels
134
- self.hop_size = hop_size
135
- self.win_size = win_size
136
- self.fmin = fmin
137
- self.fmax = fmax
138
- self.fmax_loss = fmax_loss
139
- self.cached_wav = None
140
- self.n_cache_reuse = n_cache_reuse
141
- self._cache_ref_count = 0
142
- self.device = device
143
- self.fine_tuning = fine_tuning
144
- self.base_mels_path = base_mels_path
145
-
146
- def __getitem__(self, index):
147
- filename = self.audio_files[index]
148
- if self._cache_ref_count == 0:
149
- audio, sampling_rate = load_wav(filename)
150
- audio = audio / MAX_WAV_VALUE
151
- if not self.fine_tuning:
152
- audio = normalize(audio) * 0.95
153
- self.cached_wav = audio
154
- if sampling_rate != self.sampling_rate:
155
- raise ValueError(f"{sampling_rate} SR doesn't match target {self.sampling_rate} SR")
156
- self._cache_ref_count = self.n_cache_reuse
157
- else:
158
- audio = self.cached_wav
159
- self._cache_ref_count -= 1
160
-
161
- audio = torch.FloatTensor(audio)
162
- audio = audio.unsqueeze(0)
163
-
164
- if not self.fine_tuning:
165
- if self.split:
166
- if audio.size(1) >= self.segment_size:
167
- max_audio_start = audio.size(1) - self.segment_size
168
- audio_start = random.randint(0, max_audio_start)
169
- audio = audio[:, audio_start : audio_start + self.segment_size]
170
- else:
171
- audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
172
-
173
- mel = mel_spectrogram(
174
- audio,
175
- self.n_fft,
176
- self.num_mels,
177
- self.sampling_rate,
178
- self.hop_size,
179
- self.win_size,
180
- self.fmin,
181
- self.fmax,
182
- center=False,
183
- )
184
- else:
185
- mel = np.load(os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + ".npy"))
186
- mel = torch.from_numpy(mel)
187
-
188
- if len(mel.shape) < 3:
189
- mel = mel.unsqueeze(0)
190
-
191
- if self.split:
192
- frames_per_seg = math.ceil(self.segment_size / self.hop_size)
193
-
194
- if audio.size(1) >= self.segment_size:
195
- mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
196
- mel = mel[:, :, mel_start : mel_start + frames_per_seg]
197
- audio = audio[:, mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size]
198
- else:
199
- mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
200
- audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
201
-
202
- mel_loss = mel_spectrogram(
203
- audio,
204
- self.n_fft,
205
- self.num_mels,
206
- self.sampling_rate,
207
- self.hop_size,
208
- self.win_size,
209
- self.fmin,
210
- self.fmax_loss,
211
- center=False,
212
- )
213
-
214
- return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
215
-
216
- def __len__(self):
217
- return len(self.audio_files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/models.py DELETED
@@ -1,368 +0,0 @@
1
- """ from https://github.com/jik876/hifi-gan """
2
-
3
- import torch
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
- from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
- from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
8
-
9
- from .xutils import get_padding, init_weights
10
-
11
- LRELU_SLOPE = 0.1
12
-
13
-
14
- class ResBlock1(torch.nn.Module):
15
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
16
- super().__init__()
17
- self.h = h
18
- self.convs1 = nn.ModuleList(
19
- [
20
- weight_norm(
21
- Conv1d(
22
- channels,
23
- channels,
24
- kernel_size,
25
- 1,
26
- dilation=dilation[0],
27
- padding=get_padding(kernel_size, dilation[0]),
28
- )
29
- ),
30
- weight_norm(
31
- Conv1d(
32
- channels,
33
- channels,
34
- kernel_size,
35
- 1,
36
- dilation=dilation[1],
37
- padding=get_padding(kernel_size, dilation[1]),
38
- )
39
- ),
40
- weight_norm(
41
- Conv1d(
42
- channels,
43
- channels,
44
- kernel_size,
45
- 1,
46
- dilation=dilation[2],
47
- padding=get_padding(kernel_size, dilation[2]),
48
- )
49
- ),
50
- ]
51
- )
52
- self.convs1.apply(init_weights)
53
-
54
- self.convs2 = nn.ModuleList(
55
- [
56
- weight_norm(
57
- Conv1d(
58
- channels,
59
- channels,
60
- kernel_size,
61
- 1,
62
- dilation=1,
63
- padding=get_padding(kernel_size, 1),
64
- )
65
- ),
66
- weight_norm(
67
- Conv1d(
68
- channels,
69
- channels,
70
- kernel_size,
71
- 1,
72
- dilation=1,
73
- padding=get_padding(kernel_size, 1),
74
- )
75
- ),
76
- weight_norm(
77
- Conv1d(
78
- channels,
79
- channels,
80
- kernel_size,
81
- 1,
82
- dilation=1,
83
- padding=get_padding(kernel_size, 1),
84
- )
85
- ),
86
- ]
87
- )
88
- self.convs2.apply(init_weights)
89
-
90
- def forward(self, x):
91
- for c1, c2 in zip(self.convs1, self.convs2):
92
- xt = F.leaky_relu(x, LRELU_SLOPE)
93
- xt = c1(xt)
94
- xt = F.leaky_relu(xt, LRELU_SLOPE)
95
- xt = c2(xt)
96
- x = xt + x
97
- return x
98
-
99
- def remove_weight_norm(self):
100
- for l in self.convs1:
101
- remove_weight_norm(l)
102
- for l in self.convs2:
103
- remove_weight_norm(l)
104
-
105
-
106
- class ResBlock2(torch.nn.Module):
107
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
108
- super().__init__()
109
- self.h = h
110
- self.convs = nn.ModuleList(
111
- [
112
- weight_norm(
113
- Conv1d(
114
- channels,
115
- channels,
116
- kernel_size,
117
- 1,
118
- dilation=dilation[0],
119
- padding=get_padding(kernel_size, dilation[0]),
120
- )
121
- ),
122
- weight_norm(
123
- Conv1d(
124
- channels,
125
- channels,
126
- kernel_size,
127
- 1,
128
- dilation=dilation[1],
129
- padding=get_padding(kernel_size, dilation[1]),
130
- )
131
- ),
132
- ]
133
- )
134
- self.convs.apply(init_weights)
135
-
136
- def forward(self, x):
137
- for c in self.convs:
138
- xt = F.leaky_relu(x, LRELU_SLOPE)
139
- xt = c(xt)
140
- x = xt + x
141
- return x
142
-
143
- def remove_weight_norm(self):
144
- for l in self.convs:
145
- remove_weight_norm(l)
146
-
147
-
148
- class Generator(torch.nn.Module):
149
- def __init__(self, h):
150
- super().__init__()
151
- self.h = h
152
- self.num_kernels = len(h.resblock_kernel_sizes)
153
- self.num_upsamples = len(h.upsample_rates)
154
- self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
155
- resblock = ResBlock1 if h.resblock == "1" else ResBlock2
156
-
157
- self.ups = nn.ModuleList()
158
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
159
- self.ups.append(
160
- weight_norm(
161
- ConvTranspose1d(
162
- h.upsample_initial_channel // (2**i),
163
- h.upsample_initial_channel // (2 ** (i + 1)),
164
- k,
165
- u,
166
- padding=(k - u) // 2,
167
- )
168
- )
169
- )
170
-
171
- self.resblocks = nn.ModuleList()
172
- for i in range(len(self.ups)):
173
- ch = h.upsample_initial_channel // (2 ** (i + 1))
174
- for _, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
175
- self.resblocks.append(resblock(h, ch, k, d))
176
-
177
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
178
- self.ups.apply(init_weights)
179
- self.conv_post.apply(init_weights)
180
-
181
- def forward(self, x):
182
- x = self.conv_pre(x)
183
- for i in range(self.num_upsamples):
184
- x = F.leaky_relu(x, LRELU_SLOPE)
185
- x = self.ups[i](x)
186
- xs = None
187
- for j in range(self.num_kernels):
188
- if xs is None:
189
- xs = self.resblocks[i * self.num_kernels + j](x)
190
- else:
191
- xs += self.resblocks[i * self.num_kernels + j](x)
192
- x = xs / self.num_kernels
193
- x = F.leaky_relu(x)
194
- x = self.conv_post(x)
195
- x = torch.tanh(x)
196
-
197
- return x
198
-
199
- def remove_weight_norm(self):
200
- print("Removing weight norm...")
201
- for l in self.ups:
202
- remove_weight_norm(l)
203
- for l in self.resblocks:
204
- l.remove_weight_norm()
205
- remove_weight_norm(self.conv_pre)
206
- remove_weight_norm(self.conv_post)
207
-
208
-
209
- class DiscriminatorP(torch.nn.Module):
210
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
211
- super().__init__()
212
- self.period = period
213
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
214
- self.convs = nn.ModuleList(
215
- [
216
- norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
217
- norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
218
- norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
219
- norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
220
- norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
221
- ]
222
- )
223
- self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
224
-
225
- def forward(self, x):
226
- fmap = []
227
-
228
- # 1d to 2d
229
- b, c, t = x.shape
230
- if t % self.period != 0: # pad first
231
- n_pad = self.period - (t % self.period)
232
- x = F.pad(x, (0, n_pad), "reflect")
233
- t = t + n_pad
234
- x = x.view(b, c, t // self.period, self.period)
235
-
236
- for l in self.convs:
237
- x = l(x)
238
- x = F.leaky_relu(x, LRELU_SLOPE)
239
- fmap.append(x)
240
- x = self.conv_post(x)
241
- fmap.append(x)
242
- x = torch.flatten(x, 1, -1)
243
-
244
- return x, fmap
245
-
246
-
247
- class MultiPeriodDiscriminator(torch.nn.Module):
248
- def __init__(self):
249
- super().__init__()
250
- self.discriminators = nn.ModuleList(
251
- [
252
- DiscriminatorP(2),
253
- DiscriminatorP(3),
254
- DiscriminatorP(5),
255
- DiscriminatorP(7),
256
- DiscriminatorP(11),
257
- ]
258
- )
259
-
260
- def forward(self, y, y_hat):
261
- y_d_rs = []
262
- y_d_gs = []
263
- fmap_rs = []
264
- fmap_gs = []
265
- for _, d in enumerate(self.discriminators):
266
- y_d_r, fmap_r = d(y)
267
- y_d_g, fmap_g = d(y_hat)
268
- y_d_rs.append(y_d_r)
269
- fmap_rs.append(fmap_r)
270
- y_d_gs.append(y_d_g)
271
- fmap_gs.append(fmap_g)
272
-
273
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
274
-
275
-
276
- class DiscriminatorS(torch.nn.Module):
277
- def __init__(self, use_spectral_norm=False):
278
- super().__init__()
279
- norm_f = weight_norm if use_spectral_norm is False else spectral_norm
280
- self.convs = nn.ModuleList(
281
- [
282
- norm_f(Conv1d(1, 128, 15, 1, padding=7)),
283
- norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
284
- norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
285
- norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
286
- norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
287
- norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
288
- norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
289
- ]
290
- )
291
- self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
292
-
293
- def forward(self, x):
294
- fmap = []
295
- for l in self.convs:
296
- x = l(x)
297
- x = F.leaky_relu(x, LRELU_SLOPE)
298
- fmap.append(x)
299
- x = self.conv_post(x)
300
- fmap.append(x)
301
- x = torch.flatten(x, 1, -1)
302
-
303
- return x, fmap
304
-
305
-
306
- class MultiScaleDiscriminator(torch.nn.Module):
307
- def __init__(self):
308
- super().__init__()
309
- self.discriminators = nn.ModuleList(
310
- [
311
- DiscriminatorS(use_spectral_norm=True),
312
- DiscriminatorS(),
313
- DiscriminatorS(),
314
- ]
315
- )
316
- self.meanpools = nn.ModuleList([AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)])
317
-
318
- def forward(self, y, y_hat):
319
- y_d_rs = []
320
- y_d_gs = []
321
- fmap_rs = []
322
- fmap_gs = []
323
- for i, d in enumerate(self.discriminators):
324
- if i != 0:
325
- y = self.meanpools[i - 1](y)
326
- y_hat = self.meanpools[i - 1](y_hat)
327
- y_d_r, fmap_r = d(y)
328
- y_d_g, fmap_g = d(y_hat)
329
- y_d_rs.append(y_d_r)
330
- fmap_rs.append(fmap_r)
331
- y_d_gs.append(y_d_g)
332
- fmap_gs.append(fmap_g)
333
-
334
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
335
-
336
-
337
- def feature_loss(fmap_r, fmap_g):
338
- loss = 0
339
- for dr, dg in zip(fmap_r, fmap_g):
340
- for rl, gl in zip(dr, dg):
341
- loss += torch.mean(torch.abs(rl - gl))
342
-
343
- return loss * 2
344
-
345
-
346
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
347
- loss = 0
348
- r_losses = []
349
- g_losses = []
350
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
351
- r_loss = torch.mean((1 - dr) ** 2)
352
- g_loss = torch.mean(dg**2)
353
- loss += r_loss + g_loss
354
- r_losses.append(r_loss.item())
355
- g_losses.append(g_loss.item())
356
-
357
- return loss, r_losses, g_losses
358
-
359
-
360
- def generator_loss(disc_outputs):
361
- loss = 0
362
- gen_losses = []
363
- for dg in disc_outputs:
364
- l = torch.mean((1 - dg) ** 2)
365
- gen_losses.append(l)
366
- loss += l
367
-
368
- return loss, gen_losses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hifigan/xutils.py DELETED
@@ -1,60 +0,0 @@
1
- """ from https://github.com/jik876/hifi-gan """
2
-
3
- import glob
4
- import os
5
-
6
- import matplotlib
7
- import torch
8
- from torch.nn.utils import weight_norm
9
-
10
- matplotlib.use("Agg")
11
- import matplotlib.pylab as plt
12
-
13
-
14
- def plot_spectrogram(spectrogram):
15
- fig, ax = plt.subplots(figsize=(10, 2))
16
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
17
- plt.colorbar(im, ax=ax)
18
-
19
- fig.canvas.draw()
20
- plt.close()
21
-
22
- return fig
23
-
24
-
25
- def init_weights(m, mean=0.0, std=0.01):
26
- classname = m.__class__.__name__
27
- if classname.find("Conv") != -1:
28
- m.weight.data.normal_(mean, std)
29
-
30
-
31
- def apply_weight_norm(m):
32
- classname = m.__class__.__name__
33
- if classname.find("Conv") != -1:
34
- weight_norm(m)
35
-
36
-
37
- def get_padding(kernel_size, dilation=1):
38
- return int((kernel_size * dilation - dilation) / 2)
39
-
40
-
41
- def load_checkpoint(filepath, device):
42
- assert os.path.isfile(filepath)
43
- print(f"Loading '{filepath}'")
44
- checkpoint_dict = torch.load(filepath, map_location=device)
45
- print("Complete.")
46
- return checkpoint_dict
47
-
48
-
49
- def save_checkpoint(filepath, obj):
50
- print(f"Saving checkpoint to {filepath}")
51
- torch.save(obj, filepath)
52
- print("Complete.")
53
-
54
-
55
- def scan_checkpoint(cp_dir, prefix):
56
- pattern = os.path.join(cp_dir, prefix + "????????")
57
- cp_list = glob.glob(pattern)
58
- if len(cp_list) == 0:
59
- return None
60
- return sorted(cp_list)[-1]