w11wo commited on
Commit
11120b4
1 Parent(s): 408d582

added demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Grad-TTS/LICENSE +19 -0
  2. Grad-TTS/README.md +74 -0
  3. Grad-TTS/THIRD_PARTY_NOTICE +75 -0
  4. Grad-TTS/checkpts/hifigan-config.json +38 -0
  5. Grad-TTS/data.py +186 -0
  6. Grad-TTS/finetune_multi_speaker.py +237 -0
  7. Grad-TTS/finetune_params.py +56 -0
  8. Grad-TTS/hifi-gan/LICENSE +21 -0
  9. Grad-TTS/hifi-gan/README.md +105 -0
  10. Grad-TTS/hifi-gan/env.py +17 -0
  11. Grad-TTS/hifi-gan/meldataset.py +170 -0
  12. Grad-TTS/hifi-gan/models.py +285 -0
  13. Grad-TTS/hifi-gan/xutils.py +60 -0
  14. Grad-TTS/inference.ipynb +199 -0
  15. Grad-TTS/inference.py +85 -0
  16. Grad-TTS/model/__init__.py +9 -0
  17. Grad-TTS/model/base.py +37 -0
  18. Grad-TTS/model/diffusion.py +294 -0
  19. Grad-TTS/model/monotonic_align/LICENCE +21 -0
  20. Grad-TTS/model/monotonic_align/__init__.py +23 -0
  21. Grad-TTS/model/monotonic_align/core.pyx +45 -0
  22. Grad-TTS/model/monotonic_align/setup.py +11 -0
  23. Grad-TTS/model/text_encoder.py +326 -0
  24. Grad-TTS/model/tts.py +181 -0
  25. Grad-TTS/model/utils.py +44 -0
  26. Grad-TTS/out/sample_0.wav +0 -0
  27. Grad-TTS/out/sample_1.wav +0 -0
  28. Grad-TTS/out/sample_2.wav +0 -0
  29. Grad-TTS/params.py +54 -0
  30. Grad-TTS/params_en.py +54 -0
  31. Grad-TTS/requirements.txt +12 -0
  32. Grad-TTS/resources/cmu_dictionary +0 -0
  33. Grad-TTS/resources/cmu_dictionary_id +0 -0
  34. Grad-TTS/resources/cmu_dictionary_id_en +0 -0
  35. Grad-TTS/resources/filelists/libri-tts/train.txt +0 -0
  36. Grad-TTS/resources/filelists/libri-tts/valid.txt +4 -0
  37. Grad-TTS/resources/filelists/ljspeech/test.txt +488 -0
  38. Grad-TTS/resources/filelists/ljspeech/train.txt +0 -0
  39. Grad-TTS/resources/filelists/ljspeech/valid.txt +95 -0
  40. Grad-TTS/resources/filelists/synthesis.txt +3 -0
  41. Grad-TTS/resources/ipa_dictionary_id +0 -0
  42. Grad-TTS/text/LICENSE +30 -0
  43. Grad-TTS/text/__init__.py +96 -0
  44. Grad-TTS/text/cleaners.py +73 -0
  45. Grad-TTS/text/cmudict.py +60 -0
  46. Grad-TTS/text/numbers.py +72 -0
  47. Grad-TTS/text/symbols.py +14 -0
  48. Grad-TTS/train.py +177 -0
  49. Grad-TTS/train_multi_speaker.py +182 -0
  50. Grad-TTS/utils.py +75 -0
Grad-TTS/LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2021 Huawei Technologies Co., Ltd.
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19
+ SOFTWARE.
Grad-TTS/README.md ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <img src="resources/reverse-diffusion.gif" alt="drawing" width="500"/>
3
+ </p>
4
+
5
+
6
+ # Grad-TTS
7
+
8
+ Official implementation of the Grad-TTS model based on Diffusion Probabilistic Modelling. For all details check out our paper accepted to ICML 2021 via [this](https://arxiv.org/abs/2105.06337) link.
9
+
10
+ **Authors**: Vadim Popov\*, Ivan Vovk\*, Vladimir Gogoryan, Tasnima Sadekova, Mikhail Kudinov.
11
+
12
+ <sup>\*Equal contribution.</sup>
13
+
14
+ ## Abstract
15
+
16
+ **Demo page** with voiced abstract: [link](https://grad-tts.github.io/).
17
+
18
+ Recently, denoising diffusion probabilistic models and generative score matching have shown high potential in modelling complex data distributions while stochastic calculus has provided a unified point of view on these techniques allowing for flexible inference schemes. In this paper we introduce Grad-TTS, a novel text-to-speech model with score-based decoder producing mel-spectrograms by gradually transforming noise predicted by encoder and aligned with text input by means of Monotonic Alignment Search. The framework of stochastic differential equations helps us to generalize conventional diffusion probabilistic models to the case of reconstructing data from noise with different parameters and allows to make this reconstruction flexible by explicitly controlling trade-off between sound quality and inference speed. Subjective human evaluation shows that Grad-TTS is competitive with state-of-the-art text-to-speech approaches in terms of Mean Opinion Score.
19
+
20
+ ## Installation
21
+
22
+ Firstly, install all Python package requirements:
23
+
24
+ ```bash
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ Secondly, build `monotonic_align` code (Cython):
29
+
30
+ ```bash
31
+ cd model/monotonic_align; python setup.py build_ext --inplace; cd ../..
32
+ ```
33
+
34
+ **Note**: code is tested on Python==3.6.9.
35
+
36
+ ## Inference
37
+
38
+ You can download Grad-TTS and HiFi-GAN checkpoints trained on LJSpeech* and Libri-TTS datasets (22kHz) from [here](https://drive.google.com/drive/folders/1grsfccJbmEuSBGQExQKr3cVxNV0xEOZ7?usp=sharing).
39
+
40
+ ***Note**: we open-source 2 checkpoints of Grad-TTS trained on LJSpeech. They are the same models but trained with different positional encoding scale: **x1** (`"grad-tts-old.pt"`, ICML 2021 sumbission model) and **x1000** (`"grad-tts.pt"`). To use the former set `params.pe_scale=1` and to use the latter set `params.pe_scale=1000`. Libri-TTS checkpoint was trained with scale **x1000**.
41
+
42
+ Put necessary Grad-TTS and HiFi-GAN checkpoints into `checkpts` folder in root Grad-TTS directory (note: in `inference.py` you can change default HiFi-GAN path).
43
+
44
+ 1. Create text file with sentences you want to synthesize like `resources/filelists/synthesis.txt`.
45
+ 2. For single speaker set `params.n_spks=1` and for multispeaker (Libri-TTS) inference set `params.n_spks=247`.
46
+ 3. Run script `inference.py` by providing path to the text file, path to the Grad-TTS checkpoint, number of iterations to be used for reverse diffusion (default: 10) and speaker id if you want to perform multispeaker inference:
47
+ ```bash
48
+ python inference.py -f <your-text-file> -c <grad-tts-checkpoint> -t <number-of-timesteps> -s <speaker-id-if-multispeaker>
49
+ ```
50
+ 4. Check out folder called `out` for generated audios.
51
+
52
+ You can also perform *interactive inference* by running Jupyter Notebook `inference.ipynb` or by using our [Google Colab Demo](https://colab.research.google.com/drive/1YNrXtkJQKcYDmIYJeyX8s5eXxB4zgpZI?usp=sharing).
53
+
54
+ ## Training
55
+
56
+ 1. Make filelists of your audio data like ones included into `resources/filelists` folder. For single speaker training refer to `jspeech` filelists and to `libri-tts` filelists for multispeaker.
57
+ 2. Set experiment configuration in `params.py` file.
58
+ 3. Specify your GPU device and run training script:
59
+ ```bash
60
+ export CUDA_VISIBLE_DEVICES=YOUR_GPU_ID
61
+ python train.py # if single speaker
62
+ python train_multi_speaker.py # if multispeaker
63
+ ```
64
+ 4. To track your training process run tensorboard server on any available port:
65
+ ```bash
66
+ tensorboard --logdir=YOUR_LOG_DIR --port=8888
67
+ ```
68
+ During training all logging information and checkpoints are stored in `YOUR_LOG_DIR`, which you can specify in `params.py` before training.
69
+
70
+ ## References
71
+
72
+ * HiFi-GAN model is used as vocoder, official github repository: [link](https://github.com/jik876/hifi-gan).
73
+ * Monotonic Alignment Search algorithm is used for unsupervised duration modelling, official github repository: [link](https://github.com/jaywalnut310/glow-tts).
74
+ * Phonemization utilizes CMUdict, official github repository: [link](https://github.com/cmusphinx/cmudict).
Grad-TTS/THIRD_PARTY_NOTICE ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Please note we provide an open source software notice for the third party
2
+ open source software along with this software and/or this software component
3
+ contributed by Huawei (in the following just “this SOFTWARE”). The open source
4
+ software licenses are granted by the respective right holders.
5
+
6
+ WARRANTY DISCLAIMER
7
+ THE OPEN SOURCE SOFTWARE IN THIS SOFTWARE IS DISTRIBUTED IN THE HOPE THAT IT WILL
8
+ BE USEFUL, BUT WITHOUT ANY WARRANTY, WITHOUT EVEN THE IMPLIED WARRANTY OF
9
+ MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. SEE THE APPLICABLE LICENSES
10
+ FOR MORE DETAILS.
11
+
12
+ COPYRIGHT NOTICE AND LICENSE TEXTS
13
+
14
+ SOFTWARE: HiFi-GAN
15
+ Copyright (c) 2020 Jungil Kong <henry.k@kakaoenterprise.com>
16
+ License: MIT
17
+ Permission is hereby granted, free of charge, to any person obtaining a copy
18
+ of this software and associated documentation files (the "Software"), to deal
19
+ in the Software without restriction, including without limitation the rights
20
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
21
+ copies of the Software, and to permit persons to whom the Software is
22
+ furnished to do so, subject to the following conditions:
23
+
24
+ The above copyright notice and this permission notice shall be included in all
25
+ copies or substantial portions of the Software.
26
+
27
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
28
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
29
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
30
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
31
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
32
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
33
+ SOFTWARE.
34
+
35
+ SOFTWARE: GLOW-TTS
36
+ Copyright (c) 2020 Jaehyeon Kim <jay.xyz@kakaoenterprise.com>
37
+ License: MIT
38
+ Text: See above
39
+
40
+ SOFTWARE: CMUDict
41
+ Copyright (C) 1993-2015 Carnegie Mellon University <air+cmudict@cs.cmu.edu>
42
+ License text
43
+ Copyright (C) 1993-2015 Carnegie Mellon University. All rights reserved.
44
+
45
+ Redistribution and use in source and binary forms, with or without
46
+ modification, are permitted provided that the following conditions
47
+ are met:
48
+
49
+ 1. Redistributions of source code must retain the above copyright
50
+ notice, this list of conditions and the following disclaimer.
51
+ The contents of this file are deemed to be source code.
52
+
53
+ 2. Redistributions in binary form must reproduce the above copyright
54
+ notice, this list of conditions and the following disclaimer in
55
+ the documentation and/or other materials provided with the
56
+ distribution.
57
+
58
+ This work was supported in part by funding from the Defense Advanced
59
+ Research Projects Agency, the Office of Naval Research and the National
60
+ Science Foundation of the United States of America, and by member
61
+ companies of the Carnegie Mellon Sphinx Speech Consortium. We acknowledge
62
+ the contributions of many volunteers to the expansion and improvement of
63
+ this dictionary.
64
+
65
+ THIS SOFTWARE IS PROVIDED BY CARNEGIE MELLON UNIVERSITY ``AS IS'' AND
66
+ ANY EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
67
+ THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
68
+ PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL CARNEGIE MELLON UNIVERSITY
69
+ NOR ITS EMPLOYEES BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
70
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
71
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
72
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
73
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
74
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
75
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
Grad-TTS/checkpts/hifigan-config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 0,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0004,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8,2,2],
12
+ "upsample_kernel_sizes": [16,16,4,4],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+ "resblock_initial_channel": 256,
17
+
18
+ "segment_size": 8192,
19
+ "num_mels": 80,
20
+ "num_freq": 1025,
21
+ "n_fft": 1024,
22
+ "hop_size": 256,
23
+ "win_size": 1024,
24
+
25
+ "sampling_rate": 22050,
26
+
27
+ "fmin": 0,
28
+ "fmax": 8000,
29
+ "fmax_loss": null,
30
+
31
+ "num_workers": 4,
32
+
33
+ "dist_config": {
34
+ "dist_backend": "nccl",
35
+ "dist_url": "tcp://localhost:54321",
36
+ "world_size": 1
37
+ }
38
+ }
Grad-TTS/data.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ import random
10
+ import numpy as np
11
+
12
+ import torch
13
+ import torchaudio as ta
14
+
15
+ from text import text_to_sequence, cmudict
16
+ from text.symbols import symbols
17
+ from utils import parse_filelist, intersperse
18
+ from model.utils import fix_len_compatibility
19
+ from params import seed as random_seed
20
+
21
+ import sys
22
+ sys.path.insert(0, 'hifi-gan')
23
+ from meldataset import mel_spectrogram
24
+
25
+
26
+ class TextMelDataset(torch.utils.data.Dataset):
27
+ def __init__(self, filelist_path, cmudict_path, add_blank=True,
28
+ n_fft=1024, n_mels=80, sample_rate=22050,
29
+ hop_length=256, win_length=1024, f_min=0., f_max=8000):
30
+ self.filepaths_and_text = parse_filelist(filelist_path)
31
+ self.cmudict = cmudict.CMUDict(cmudict_path)
32
+ self.add_blank = add_blank
33
+ self.n_fft = n_fft
34
+ self.n_mels = n_mels
35
+ self.sample_rate = sample_rate
36
+ self.hop_length = hop_length
37
+ self.win_length = win_length
38
+ self.f_min = f_min
39
+ self.f_max = f_max
40
+ random.seed(random_seed)
41
+ random.shuffle(self.filepaths_and_text)
42
+
43
+ def get_pair(self, filepath_and_text):
44
+ filepath, text = filepath_and_text[0], filepath_and_text[1]
45
+ text = self.get_text(text, add_blank=self.add_blank)
46
+ mel = self.get_mel(filepath)
47
+ return (text, mel)
48
+
49
+ def get_mel(self, filepath):
50
+ audio, sr = ta.load(filepath)
51
+ assert sr == self.sample_rate
52
+ mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length,
53
+ self.win_length, self.f_min, self.f_max, center=False).squeeze()
54
+ return mel
55
+
56
+ def get_text(self, text, add_blank=True):
57
+ text_norm = text_to_sequence(text, dictionary=self.cmudict)
58
+ if self.add_blank:
59
+ text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
60
+ text_norm = torch.IntTensor(text_norm)
61
+ return text_norm
62
+
63
+ def __getitem__(self, index):
64
+ text, mel = self.get_pair(self.filepaths_and_text[index])
65
+ item = {'y': mel, 'x': text}
66
+ return item
67
+
68
+ def __len__(self):
69
+ return len(self.filepaths_and_text)
70
+
71
+ def sample_test_batch(self, size):
72
+ idx = np.random.choice(range(len(self)), size=size, replace=False)
73
+ test_batch = []
74
+ for index in idx:
75
+ test_batch.append(self.__getitem__(index))
76
+ return test_batch
77
+
78
+
79
+ class TextMelBatchCollate(object):
80
+ def __call__(self, batch):
81
+ B = len(batch)
82
+ y_max_length = max([item['y'].shape[-1] for item in batch])
83
+ y_max_length = fix_len_compatibility(y_max_length)
84
+ x_max_length = max([item['x'].shape[-1] for item in batch])
85
+ n_feats = batch[0]['y'].shape[-2]
86
+
87
+ y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
88
+ x = torch.zeros((B, x_max_length), dtype=torch.long)
89
+ y_lengths, x_lengths = [], []
90
+
91
+ for i, item in enumerate(batch):
92
+ y_, x_ = item['y'], item['x']
93
+ y_lengths.append(y_.shape[-1])
94
+ x_lengths.append(x_.shape[-1])
95
+ y[i, :, :y_.shape[-1]] = y_
96
+ x[i, :x_.shape[-1]] = x_
97
+
98
+ y_lengths = torch.LongTensor(y_lengths)
99
+ x_lengths = torch.LongTensor(x_lengths)
100
+ return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths}
101
+
102
+
103
+ class TextMelSpeakerDataset(torch.utils.data.Dataset):
104
+ def __init__(self, filelist_path, cmudict_path, add_blank=True,
105
+ n_fft=1024, n_mels=80, sample_rate=22050,
106
+ hop_length=256, win_length=1024, f_min=0., f_max=8000):
107
+ super().__init__()
108
+ self.filelist = parse_filelist(filelist_path, split_char='|')
109
+ self.cmudict = cmudict.CMUDict(cmudict_path)
110
+ self.n_fft = n_fft
111
+ self.n_mels = n_mels
112
+ self.sample_rate = sample_rate
113
+ self.hop_length = hop_length
114
+ self.win_length = win_length
115
+ self.f_min = f_min
116
+ self.f_max = f_max
117
+ self.add_blank = add_blank
118
+ random.seed(random_seed)
119
+ random.shuffle(self.filelist)
120
+
121
+ def get_triplet(self, line):
122
+ filepath, text, speaker = line[0], line[1], line[2]
123
+ text = self.get_text(text, add_blank=self.add_blank)
124
+ mel = self.get_mel(filepath)
125
+ speaker = self.get_speaker(speaker)
126
+ return (text, mel, speaker)
127
+
128
+ def get_mel(self, filepath):
129
+ audio, sr = ta.load(filepath)
130
+ assert sr == self.sample_rate
131
+ mel = mel_spectrogram(audio, self.n_fft, self.n_mels, self.sample_rate, self.hop_length,
132
+ self.win_length, self.f_min, self.f_max, center=False).squeeze()
133
+ return mel
134
+
135
+ def get_text(self, text, add_blank=True):
136
+ text_norm = text_to_sequence(text, dictionary=self.cmudict)
137
+ if self.add_blank:
138
+ text_norm = intersperse(text_norm, len(symbols)) # add a blank token, whose id number is len(symbols)
139
+ text_norm = torch.LongTensor(text_norm)
140
+ return text_norm
141
+
142
+ def get_speaker(self, speaker):
143
+ speaker = torch.LongTensor([int(speaker)])
144
+ return speaker
145
+
146
+ def __getitem__(self, index):
147
+ text, mel, speaker = self.get_triplet(self.filelist[index])
148
+ item = {'y': mel, 'x': text, 'spk': speaker}
149
+ return item
150
+
151
+ def __len__(self):
152
+ return len(self.filelist)
153
+
154
+ def sample_test_batch(self, size):
155
+ idx = np.random.choice(range(len(self)), size=size, replace=False)
156
+ test_batch = []
157
+ for index in idx:
158
+ test_batch.append(self.__getitem__(index))
159
+ return test_batch
160
+
161
+
162
+ class TextMelSpeakerBatchCollate(object):
163
+ def __call__(self, batch):
164
+ B = len(batch)
165
+ y_max_length = max([item['y'].shape[-1] for item in batch])
166
+ y_max_length = fix_len_compatibility(y_max_length)
167
+ x_max_length = max([item['x'].shape[-1] for item in batch])
168
+ n_feats = batch[0]['y'].shape[-2]
169
+
170
+ y = torch.zeros((B, n_feats, y_max_length), dtype=torch.float32)
171
+ x = torch.zeros((B, x_max_length), dtype=torch.long)
172
+ y_lengths, x_lengths = [], []
173
+ spk = []
174
+
175
+ for i, item in enumerate(batch):
176
+ y_, x_, spk_ = item['y'], item['x'], item['spk']
177
+ y_lengths.append(y_.shape[-1])
178
+ x_lengths.append(x_.shape[-1])
179
+ y[i, :, :y_.shape[-1]] = y_
180
+ x[i, :x_.shape[-1]] = x_
181
+ spk.append(spk_)
182
+
183
+ y_lengths = torch.LongTensor(y_lengths)
184
+ x_lengths = torch.LongTensor(x_lengths)
185
+ spk = torch.cat(spk, dim=0)
186
+ return {'x': x, 'x_lengths': x_lengths, 'y': y, 'y_lengths': y_lengths, 'spk': spk}
Grad-TTS/finetune_multi_speaker.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ import numpy as np
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from torch.utils.data import DataLoader
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+ import finetune_params as params
17
+ from model import GradTTS
18
+ from data import TextMelSpeakerDataset, TextMelSpeakerBatchCollate
19
+ from utils import plot_tensor, save_plot
20
+ from text.symbols import symbols
21
+
22
+
23
+ train_filelist_path = params.train_filelist_path
24
+ valid_filelist_path = params.valid_filelist_path
25
+ cmudict_path = params.cmudict_path
26
+ add_blank = params.add_blank
27
+ n_spks = params.n_spks
28
+ spk_emb_dim = params.spk_emb_dim
29
+
30
+ log_dir = params.log_dir
31
+ n_epochs = params.n_epochs
32
+ batch_size = params.batch_size
33
+ out_size = params.out_size
34
+ learning_rate = params.learning_rate
35
+ random_seed = params.seed
36
+
37
+ nsymbols = len(symbols) + 1 if add_blank else len(symbols)
38
+ n_enc_channels = params.n_enc_channels
39
+ filter_channels = params.filter_channels
40
+ filter_channels_dp = params.filter_channels_dp
41
+ n_enc_layers = params.n_enc_layers
42
+ enc_kernel = params.enc_kernel
43
+ enc_dropout = params.enc_dropout
44
+ n_heads = params.n_heads
45
+ window_size = params.window_size
46
+
47
+ n_feats = params.n_feats
48
+ n_fft = params.n_fft
49
+ sample_rate = params.sample_rate
50
+ hop_length = params.hop_length
51
+ win_length = params.win_length
52
+ f_min = params.f_min
53
+ f_max = params.f_max
54
+
55
+ dec_dim = params.dec_dim
56
+ beta_min = params.beta_min
57
+ beta_max = params.beta_max
58
+ pe_scale = params.pe_scale
59
+
60
+ num_workers = params.num_workers
61
+ checkpoint = params.checkpoint
62
+
63
+ if __name__ == "__main__":
64
+ torch.manual_seed(random_seed)
65
+ np.random.seed(random_seed)
66
+
67
+ print("Initializing logger...")
68
+ logger = SummaryWriter(log_dir=log_dir)
69
+
70
+ print("Initializing data loaders...")
71
+ train_dataset = TextMelSpeakerDataset(
72
+ train_filelist_path,
73
+ cmudict_path,
74
+ add_blank,
75
+ n_fft,
76
+ n_feats,
77
+ sample_rate,
78
+ hop_length,
79
+ win_length,
80
+ f_min,
81
+ f_max,
82
+ )
83
+ batch_collate = TextMelSpeakerBatchCollate()
84
+ loader = DataLoader(
85
+ dataset=train_dataset,
86
+ batch_size=batch_size,
87
+ collate_fn=batch_collate,
88
+ drop_last=True,
89
+ num_workers=num_workers,
90
+ shuffle=True,
91
+ )
92
+ test_dataset = TextMelSpeakerDataset(
93
+ valid_filelist_path,
94
+ cmudict_path,
95
+ add_blank,
96
+ n_fft,
97
+ n_feats,
98
+ sample_rate,
99
+ hop_length,
100
+ win_length,
101
+ f_min,
102
+ f_max,
103
+ )
104
+
105
+ print("Initializing model...")
106
+ model = GradTTS(
107
+ nsymbols,
108
+ n_spks,
109
+ spk_emb_dim,
110
+ n_enc_channels,
111
+ filter_channels,
112
+ filter_channels_dp,
113
+ n_heads,
114
+ n_enc_layers,
115
+ enc_kernel,
116
+ enc_dropout,
117
+ window_size,
118
+ n_feats,
119
+ dec_dim,
120
+ beta_min,
121
+ beta_max,
122
+ pe_scale,
123
+ ).cuda()
124
+ model.load_state_dict(torch.load(checkpoint, map_location=torch.device("cuda")))
125
+ print("Number of encoder parameters = %.2fm" % (model.encoder.nparams / 1e6))
126
+ print("Number of decoder parameters = %.2fm" % (model.decoder.nparams / 1e6))
127
+
128
+ print("Initializing optimizer...")
129
+ optimizer = torch.optim.Adam(params=model.parameters(), lr=learning_rate)
130
+
131
+ print("Logging test batch...")
132
+ test_batch = test_dataset.sample_test_batch(size=params.test_size)
133
+ for item in test_batch:
134
+ mel, spk = item["y"], item["spk"]
135
+ i = int(spk.cpu())
136
+ logger.add_image(
137
+ f"image_{i}/ground_truth",
138
+ plot_tensor(mel.squeeze()),
139
+ global_step=0,
140
+ dataformats="HWC",
141
+ )
142
+ save_plot(mel.squeeze(), f"{log_dir}/original_{i}.png")
143
+
144
+ print("Start training...")
145
+ iteration = 0
146
+ for epoch in range(1, n_epochs + 1):
147
+ model.eval()
148
+ print("Synthesis...")
149
+ with torch.no_grad():
150
+ for item in test_batch:
151
+ x = item["x"].to(torch.long).unsqueeze(0).cuda()
152
+ x_lengths = torch.LongTensor([x.shape[-1]]).cuda()
153
+ spk = item["spk"].to(torch.long).cuda()
154
+ i = int(spk.cpu())
155
+
156
+ y_enc, y_dec, attn = model(x, x_lengths, n_timesteps=50, spk=spk)
157
+ logger.add_image(
158
+ f"image_{i}/generated_enc",
159
+ plot_tensor(y_enc.squeeze().cpu()),
160
+ global_step=iteration,
161
+ dataformats="HWC",
162
+ )
163
+ logger.add_image(
164
+ f"image_{i}/generated_dec",
165
+ plot_tensor(y_dec.squeeze().cpu()),
166
+ global_step=iteration,
167
+ dataformats="HWC",
168
+ )
169
+ logger.add_image(
170
+ f"image_{i}/alignment",
171
+ plot_tensor(attn.squeeze().cpu()),
172
+ global_step=iteration,
173
+ dataformats="HWC",
174
+ )
175
+ save_plot(y_enc.squeeze().cpu(), f"{log_dir}/generated_enc_{i}.png")
176
+ save_plot(y_dec.squeeze().cpu(), f"{log_dir}/generated_dec_{i}.png")
177
+ save_plot(attn.squeeze().cpu(), f"{log_dir}/alignment_{i}.png")
178
+
179
+ model.train()
180
+ dur_losses = []
181
+ prior_losses = []
182
+ diff_losses = []
183
+ with tqdm(loader, total=len(train_dataset) // batch_size) as progress_bar:
184
+ for batch in progress_bar:
185
+ model.zero_grad()
186
+ x, x_lengths = batch["x"].cuda(), batch["x_lengths"].cuda()
187
+ y, y_lengths = batch["y"].cuda(), batch["y_lengths"].cuda()
188
+ spk = batch["spk"].cuda()
189
+ dur_loss, prior_loss, diff_loss = model.compute_loss(
190
+ x, x_lengths, y, y_lengths, spk=spk, out_size=out_size
191
+ )
192
+ loss = sum([dur_loss, prior_loss, diff_loss])
193
+ loss.backward()
194
+
195
+ enc_grad_norm = torch.nn.utils.clip_grad_norm_(
196
+ model.encoder.parameters(), max_norm=1
197
+ )
198
+ dec_grad_norm = torch.nn.utils.clip_grad_norm_(
199
+ model.decoder.parameters(), max_norm=1
200
+ )
201
+ optimizer.step()
202
+
203
+ logger.add_scalar(
204
+ "training/duration_loss", dur_loss, global_step=iteration
205
+ )
206
+ logger.add_scalar(
207
+ "training/prior_loss", prior_loss, global_step=iteration
208
+ )
209
+ logger.add_scalar(
210
+ "training/diffusion_loss", diff_loss, global_step=iteration
211
+ )
212
+ logger.add_scalar(
213
+ "training/encoder_grad_norm", enc_grad_norm, global_step=iteration
214
+ )
215
+ logger.add_scalar(
216
+ "training/decoder_grad_norm", dec_grad_norm, global_step=iteration
217
+ )
218
+
219
+ msg = f"Epoch: {epoch}, iteration: {iteration} | dur_loss: {dur_loss.item()}, prior_loss: {prior_loss.item()}, diff_loss: {diff_loss.item()}"
220
+ progress_bar.set_description(msg)
221
+
222
+ dur_losses.append(dur_loss.item())
223
+ prior_losses.append(prior_loss.item())
224
+ diff_losses.append(diff_loss.item())
225
+ iteration += 1
226
+
227
+ msg = "Epoch %d: duration loss = %.3f " % (epoch, np.mean(dur_losses))
228
+ msg += "| prior loss = %.3f " % np.mean(prior_losses)
229
+ msg += "| diffusion loss = %.3f\n" % np.mean(diff_losses)
230
+ with open(f"{log_dir}/train.log", "a") as f:
231
+ f.write(msg)
232
+
233
+ if epoch % params.save_every > 0:
234
+ continue
235
+
236
+ ckpt = model.state_dict()
237
+ torch.save(ckpt, f=f"{log_dir}/grad_{epoch}.pt")
Grad-TTS/finetune_params.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
2
+ # This program is free software; you can redistribute it and/or modify
3
+ # it under the terms of the MIT License.
4
+ # This program is distributed in the hope that it will be useful,
5
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
6
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
7
+ # MIT License for more details.
8
+
9
+ from model.utils import fix_len_compatibility
10
+
11
+
12
+ # data parameters
13
+ train_filelist_path = "../../id/train.txt"
14
+ valid_filelist_path = "../../id/valid.txt"
15
+ # test_filelist_path = 'resources/filelists/ljspeech/test.txt'
16
+ cmudict_path = "resources/cmu_dictionary_id"
17
+ add_blank = True
18
+ n_feats = 80
19
+ n_spks = 247 # for Libri-TTS filelist and 1 for LJSpeech
20
+ spk_emb_dim = 64
21
+ n_feats = 80
22
+ n_fft = 1024
23
+ sample_rate = 22050
24
+ hop_length = 256
25
+ win_length = 1024
26
+ f_min = 0
27
+ f_max = 8000
28
+
29
+ # encoder parameters
30
+ n_enc_channels = 192
31
+ filter_channels = 768
32
+ filter_channels_dp = 256
33
+ n_enc_layers = 6
34
+ enc_kernel = 3
35
+ enc_dropout = 0.1
36
+ n_heads = 2
37
+ window_size = 4
38
+
39
+ # decoder parameters
40
+ dec_dim = 64
41
+ beta_min = 0.05
42
+ beta_max = 20.0
43
+ pe_scale = 1000 # 1 for `grad-tts-old.pt` checkpoint
44
+
45
+ # training parameters
46
+ log_dir = "logs/grad-tts-bookbot-ft-weildan"
47
+ test_size = 4
48
+ n_epochs = 24000
49
+ batch_size = 8
50
+ learning_rate = 1e-4
51
+ seed = 37
52
+ save_every = 1000
53
+ out_size = fix_len_compatibility(2 * 22050 // 256)
54
+ num_workers = 6
55
+
56
+ checkpoint = "checkpts/grad-tts-libri-tts.pt"
Grad-TTS/hifi-gan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Grad-TTS/hifi-gan/README.md ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFi-GAN: Generative Adversarial Networks for Efficient and High Fidelity Speech Synthesis
2
+
3
+ ### Jungil Kong, Jaehyeon Kim, Jaekyoung Bae
4
+
5
+ In our [paper](https://arxiv.org/abs/2010.05646),
6
+ we proposed HiFi-GAN: a GAN-based model capable of generating high fidelity speech efficiently.<br/>
7
+ We provide our implementation and pretrained models as open source in this repository.
8
+
9
+ **Abstract :**
10
+ Several recent work on speech synthesis have employed generative adversarial networks (GANs) to produce raw waveforms.
11
+ Although such methods improve the sampling efficiency and memory usage,
12
+ their sample quality has not yet reached that of autoregressive and flow-based generative models.
13
+ In this work, we propose HiFi-GAN, which achieves both efficient and high-fidelity speech synthesis.
14
+ As speech audio consists of sinusoidal signals with various periods,
15
+ we demonstrate that modeling periodic patterns of an audio is crucial for enhancing sample quality.
16
+ A subjective human evaluation (mean opinion score, MOS) of a single speaker dataset indicates that our proposed method
17
+ demonstrates similarity to human quality while generating 22.05 kHz high-fidelity audio 167.9 times faster than
18
+ real-time on a single V100 GPU. We further show the generality of HiFi-GAN to the mel-spectrogram inversion of unseen
19
+ speakers and end-to-end speech synthesis. Finally, a small footprint version of HiFi-GAN generates samples 13.4 times
20
+ faster than real-time on CPU with comparable quality to an autoregressive counterpart.
21
+
22
+ Visit our [demo website](https://jik876.github.io/hifi-gan-demo/) for audio samples.
23
+
24
+
25
+ ## Pre-requisites
26
+ 1. Python >= 3.6
27
+ 2. Clone this repository.
28
+ 3. Install python requirements. Please refer [requirements.txt](requirements.txt)
29
+ 4. Download and extract the [LJ Speech dataset](https://keithito.com/LJ-Speech-Dataset/).
30
+ And move all wav files to `LJSpeech-1.1/wavs`
31
+
32
+
33
+ ## Training
34
+ ```
35
+ python train.py --config config_v1.json
36
+ ```
37
+ To train V2 or V3 Generator, replace `config_v1.json` with `config_v2.json` or `config_v3.json`.<br>
38
+ Checkpoints and copy of the configuration file are saved in `cp_hifigan` directory by default.<br>
39
+ You can change the path by adding `--checkpoint_path` option.
40
+
41
+ Validation loss during training with V1 generator.<br>
42
+ ![validation loss](./validation_loss.png)
43
+
44
+ ## Pretrained Model
45
+ You can also use pretrained models we provide.<br/>
46
+ [Download pretrained models](https://drive.google.com/drive/folders/1-eEYTB5Av9jNql0WGBlRoi-WH2J7bp5Y?usp=sharing)<br/>
47
+ Details of each folder are as in follows:
48
+
49
+ |Folder Name|Generator|Dataset|Fine-Tuned|
50
+ |------|---|---|---|
51
+ |LJ_V1|V1|LJSpeech|No|
52
+ |LJ_V2|V2|LJSpeech|No|
53
+ |LJ_V3|V3|LJSpeech|No|
54
+ |LJ_FT_T2_V1|V1|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
55
+ |LJ_FT_T2_V2|V2|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
56
+ |LJ_FT_T2_V3|V3|LJSpeech|Yes ([Tacotron2](https://github.com/NVIDIA/tacotron2))|
57
+ |VCTK_V1|V1|VCTK|No|
58
+ |VCTK_V2|V2|VCTK|No|
59
+ |VCTK_V3|V3|VCTK|No|
60
+ |UNIVERSAL_V1|V1|Universal|No|
61
+
62
+ We provide the universal model with discriminator weights that can be used as a base for transfer learning to other datasets.
63
+
64
+ ## Fine-Tuning
65
+ 1. Generate mel-spectrograms in numpy format using [Tacotron2](https://github.com/NVIDIA/tacotron2) with teacher-forcing.<br/>
66
+ The file name of the generated mel-spectrogram should match the audio file and the extension should be `.npy`.<br/>
67
+ Example:
68
+ ```
69
+ Audio File : LJ001-0001.wav
70
+ Mel-Spectrogram File : LJ001-0001.npy
71
+ ```
72
+ 2. Create `ft_dataset` folder and copy the generated mel-spectrogram files into it.<br/>
73
+ 3. Run the following command.
74
+ ```
75
+ python train.py --fine_tuning True --config config_v1.json
76
+ ```
77
+ For other command line options, please refer to the training section.
78
+
79
+
80
+ ## Inference from wav file
81
+ 1. Make `test_files` directory and copy wav files into the directory.
82
+ 2. Run the following command.
83
+ ```
84
+ python inference.py --checkpoint_file [generator checkpoint file path]
85
+ ```
86
+ Generated wav files are saved in `generated_files` by default.<br>
87
+ You can change the path by adding `--output_dir` option.
88
+
89
+
90
+ ## Inference for end-to-end speech synthesis
91
+ 1. Make `test_mel_files` directory and copy generated mel-spectrogram files into the directory.<br>
92
+ You can generate mel-spectrograms using [Tacotron2](https://github.com/NVIDIA/tacotron2),
93
+ [Glow-TTS](https://github.com/jaywalnut310/glow-tts) and so forth.
94
+ 2. Run the following command.
95
+ ```
96
+ python inference_e2e.py --checkpoint_file [generator checkpoint file path]
97
+ ```
98
+ Generated wav files are saved in `generated_files_from_mel` by default.<br>
99
+ You can change the path by adding `--output_dir` option.
100
+
101
+
102
+ ## Acknowledgements
103
+ We referred to [WaveGlow](https://github.com/NVIDIA/waveglow), [MelGAN](https://github.com/descriptinc/melgan-neurips)
104
+ and [Tacotron2](https://github.com/NVIDIA/tacotron2) to implement this.
105
+
Grad-TTS/hifi-gan/env.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import os
4
+ import shutil
5
+
6
+
7
+ class AttrDict(dict):
8
+ def __init__(self, *args, **kwargs):
9
+ super(AttrDict, self).__init__(*args, **kwargs)
10
+ self.__dict__ = self
11
+
12
+
13
+ def build_env(config, config_name, path):
14
+ t_path = os.path.join(path, config_name)
15
+ if config != t_path:
16
+ os.makedirs(path, exist_ok=True)
17
+ shutil.copyfile(config, os.path.join(path, config_name))
Grad-TTS/hifi-gan/meldataset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import math
4
+ import os
5
+ import random
6
+ import torch
7
+ import torch.utils.data
8
+ import numpy as np
9
+ from librosa.util import normalize
10
+ from scipy.io.wavfile import read
11
+ from librosa.filters import mel as librosa_mel_fn
12
+
13
+ MAX_WAV_VALUE = 32768.0
14
+
15
+
16
+ def load_wav(full_path):
17
+ sampling_rate, data = read(full_path)
18
+ return data, sampling_rate
19
+
20
+
21
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
22
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
23
+
24
+
25
+ def dynamic_range_decompression(x, C=1):
26
+ return np.exp(x) / C
27
+
28
+
29
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
30
+ return torch.log(torch.clamp(x, min=clip_val) * C)
31
+
32
+
33
+ def dynamic_range_decompression_torch(x, C=1):
34
+ return torch.exp(x) / C
35
+
36
+
37
+ def spectral_normalize_torch(magnitudes):
38
+ output = dynamic_range_compression_torch(magnitudes)
39
+ return output
40
+
41
+
42
+ def spectral_de_normalize_torch(magnitudes):
43
+ output = dynamic_range_decompression_torch(magnitudes)
44
+ return output
45
+
46
+
47
+ mel_basis = {}
48
+ hann_window = {}
49
+
50
+
51
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
52
+ if torch.min(y) < -1.:
53
+ print('min value is ', torch.min(y))
54
+ if torch.max(y) > 1.:
55
+ print('max value is ', torch.max(y))
56
+
57
+ global mel_basis, hann_window
58
+ if fmax not in mel_basis:
59
+ mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
60
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
61
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
62
+
63
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
64
+ y = y.squeeze(1)
65
+
66
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
67
+ center=center, pad_mode='reflect', normalized=False, onesided=True)
68
+
69
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
70
+
71
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
72
+ spec = spectral_normalize_torch(spec)
73
+
74
+ return spec
75
+
76
+
77
+ def get_dataset_filelist(a):
78
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
79
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
80
+ for x in fi.read().split('\n') if len(x) > 0]
81
+
82
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
83
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + '.wav')
84
+ for x in fi.read().split('\n') if len(x) > 0]
85
+ return training_files, validation_files
86
+
87
+
88
+ class MelDataset(torch.utils.data.Dataset):
89
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
90
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
91
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
92
+ self.audio_files = training_files
93
+ random.seed(1234)
94
+ if shuffle:
95
+ random.shuffle(self.audio_files)
96
+ self.segment_size = segment_size
97
+ self.sampling_rate = sampling_rate
98
+ self.split = split
99
+ self.n_fft = n_fft
100
+ self.num_mels = num_mels
101
+ self.hop_size = hop_size
102
+ self.win_size = win_size
103
+ self.fmin = fmin
104
+ self.fmax = fmax
105
+ self.fmax_loss = fmax_loss
106
+ self.cached_wav = None
107
+ self.n_cache_reuse = n_cache_reuse
108
+ self._cache_ref_count = 0
109
+ self.device = device
110
+ self.fine_tuning = fine_tuning
111
+ self.base_mels_path = base_mels_path
112
+
113
+ def __getitem__(self, index):
114
+ filename = self.audio_files[index]
115
+ if self._cache_ref_count == 0:
116
+ audio, sampling_rate = load_wav(filename)
117
+ audio = audio / MAX_WAV_VALUE
118
+ if not self.fine_tuning:
119
+ audio = normalize(audio) * 0.95
120
+ self.cached_wav = audio
121
+ if sampling_rate != self.sampling_rate:
122
+ raise ValueError("{} SR doesn't match target {} SR".format(
123
+ sampling_rate, self.sampling_rate))
124
+ self._cache_ref_count = self.n_cache_reuse
125
+ else:
126
+ audio = self.cached_wav
127
+ self._cache_ref_count -= 1
128
+
129
+ audio = torch.FloatTensor(audio)
130
+ audio = audio.unsqueeze(0)
131
+
132
+ if not self.fine_tuning:
133
+ if self.split:
134
+ if audio.size(1) >= self.segment_size:
135
+ max_audio_start = audio.size(1) - self.segment_size
136
+ audio_start = random.randint(0, max_audio_start)
137
+ audio = audio[:, audio_start:audio_start+self.segment_size]
138
+ else:
139
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
140
+
141
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
142
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
143
+ center=False)
144
+ else:
145
+ mel = np.load(
146
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
147
+ mel = torch.from_numpy(mel)
148
+
149
+ if len(mel.shape) < 3:
150
+ mel = mel.unsqueeze(0)
151
+
152
+ if self.split:
153
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
154
+
155
+ if audio.size(1) >= self.segment_size:
156
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
157
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
158
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
159
+ else:
160
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
161
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
162
+
163
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
164
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
165
+ center=False)
166
+
167
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
168
+
169
+ def __len__(self):
170
+ return len(self.audio_files)
Grad-TTS/hifi-gan/models.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.nn as nn
6
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
8
+ from xutils import init_weights, get_padding
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+
13
+ class ResBlock1(torch.nn.Module):
14
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
15
+ super(ResBlock1, self).__init__()
16
+ self.h = h
17
+ self.convs1 = nn.ModuleList([
18
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
19
+ padding=get_padding(kernel_size, dilation[0]))),
20
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
21
+ padding=get_padding(kernel_size, dilation[1]))),
22
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
23
+ padding=get_padding(kernel_size, dilation[2])))
24
+ ])
25
+ self.convs1.apply(init_weights)
26
+
27
+ self.convs2 = nn.ModuleList([
28
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
29
+ padding=get_padding(kernel_size, 1))),
30
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
31
+ padding=get_padding(kernel_size, 1))),
32
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
33
+ padding=get_padding(kernel_size, 1)))
34
+ ])
35
+ self.convs2.apply(init_weights)
36
+
37
+ def forward(self, x):
38
+ for c1, c2 in zip(self.convs1, self.convs2):
39
+ xt = F.leaky_relu(x, LRELU_SLOPE)
40
+ xt = c1(xt)
41
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
42
+ xt = c2(xt)
43
+ x = xt + x
44
+ return x
45
+
46
+ def remove_weight_norm(self):
47
+ for l in self.convs1:
48
+ remove_weight_norm(l)
49
+ for l in self.convs2:
50
+ remove_weight_norm(l)
51
+
52
+
53
+ class ResBlock2(torch.nn.Module):
54
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
55
+ super(ResBlock2, self).__init__()
56
+ self.h = h
57
+ self.convs = nn.ModuleList([
58
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
59
+ padding=get_padding(kernel_size, dilation[0]))),
60
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
61
+ padding=get_padding(kernel_size, dilation[1])))
62
+ ])
63
+ self.convs.apply(init_weights)
64
+
65
+ def forward(self, x):
66
+ for c in self.convs:
67
+ xt = F.leaky_relu(x, LRELU_SLOPE)
68
+ xt = c(xt)
69
+ x = xt + x
70
+ return x
71
+
72
+ def remove_weight_norm(self):
73
+ for l in self.convs:
74
+ remove_weight_norm(l)
75
+
76
+
77
+ class Generator(torch.nn.Module):
78
+ def __init__(self, h):
79
+ super(Generator, self).__init__()
80
+ self.h = h
81
+ self.num_kernels = len(h.resblock_kernel_sizes)
82
+ self.num_upsamples = len(h.upsample_rates)
83
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
84
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
85
+
86
+ self.ups = nn.ModuleList()
87
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
88
+ self.ups.append(weight_norm(
89
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
90
+ k, u, padding=(k-u)//2)))
91
+
92
+ self.resblocks = nn.ModuleList()
93
+ for i in range(len(self.ups)):
94
+ ch = h.upsample_initial_channel//(2**(i+1))
95
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
96
+ self.resblocks.append(resblock(h, ch, k, d))
97
+
98
+ self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
99
+ self.ups.apply(init_weights)
100
+ self.conv_post.apply(init_weights)
101
+
102
+ def forward(self, x):
103
+ x = self.conv_pre(x)
104
+ for i in range(self.num_upsamples):
105
+ x = F.leaky_relu(x, LRELU_SLOPE)
106
+ x = self.ups[i](x)
107
+ xs = None
108
+ for j in range(self.num_kernels):
109
+ if xs is None:
110
+ xs = self.resblocks[i*self.num_kernels+j](x)
111
+ else:
112
+ xs += self.resblocks[i*self.num_kernels+j](x)
113
+ x = xs / self.num_kernels
114
+ x = F.leaky_relu(x)
115
+ x = self.conv_post(x)
116
+ x = torch.tanh(x)
117
+
118
+ return x
119
+
120
+ def remove_weight_norm(self):
121
+ print('Removing weight norm...')
122
+ for l in self.ups:
123
+ remove_weight_norm(l)
124
+ for l in self.resblocks:
125
+ l.remove_weight_norm()
126
+ remove_weight_norm(self.conv_pre)
127
+ remove_weight_norm(self.conv_post)
128
+
129
+
130
+ class DiscriminatorP(torch.nn.Module):
131
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
132
+ super(DiscriminatorP, self).__init__()
133
+ self.period = period
134
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
135
+ self.convs = nn.ModuleList([
136
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
137
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
138
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
139
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
140
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
141
+ ])
142
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
143
+
144
+ def forward(self, x):
145
+ fmap = []
146
+
147
+ # 1d to 2d
148
+ b, c, t = x.shape
149
+ if t % self.period != 0: # pad first
150
+ n_pad = self.period - (t % self.period)
151
+ x = F.pad(x, (0, n_pad), "reflect")
152
+ t = t + n_pad
153
+ x = x.view(b, c, t // self.period, self.period)
154
+
155
+ for l in self.convs:
156
+ x = l(x)
157
+ x = F.leaky_relu(x, LRELU_SLOPE)
158
+ fmap.append(x)
159
+ x = self.conv_post(x)
160
+ fmap.append(x)
161
+ x = torch.flatten(x, 1, -1)
162
+
163
+ return x, fmap
164
+
165
+
166
+ class MultiPeriodDiscriminator(torch.nn.Module):
167
+ def __init__(self):
168
+ super(MultiPeriodDiscriminator, self).__init__()
169
+ self.discriminators = nn.ModuleList([
170
+ DiscriminatorP(2),
171
+ DiscriminatorP(3),
172
+ DiscriminatorP(5),
173
+ DiscriminatorP(7),
174
+ DiscriminatorP(11),
175
+ ])
176
+
177
+ def forward(self, y, y_hat):
178
+ y_d_rs = []
179
+ y_d_gs = []
180
+ fmap_rs = []
181
+ fmap_gs = []
182
+ for i, d in enumerate(self.discriminators):
183
+ y_d_r, fmap_r = d(y)
184
+ y_d_g, fmap_g = d(y_hat)
185
+ y_d_rs.append(y_d_r)
186
+ fmap_rs.append(fmap_r)
187
+ y_d_gs.append(y_d_g)
188
+ fmap_gs.append(fmap_g)
189
+
190
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
191
+
192
+
193
+ class DiscriminatorS(torch.nn.Module):
194
+ def __init__(self, use_spectral_norm=False):
195
+ super(DiscriminatorS, self).__init__()
196
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
197
+ self.convs = nn.ModuleList([
198
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
199
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
200
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
201
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
202
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
203
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
204
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
205
+ ])
206
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
207
+
208
+ def forward(self, x):
209
+ fmap = []
210
+ for l in self.convs:
211
+ x = l(x)
212
+ x = F.leaky_relu(x, LRELU_SLOPE)
213
+ fmap.append(x)
214
+ x = self.conv_post(x)
215
+ fmap.append(x)
216
+ x = torch.flatten(x, 1, -1)
217
+
218
+ return x, fmap
219
+
220
+
221
+ class MultiScaleDiscriminator(torch.nn.Module):
222
+ def __init__(self):
223
+ super(MultiScaleDiscriminator, self).__init__()
224
+ self.discriminators = nn.ModuleList([
225
+ DiscriminatorS(use_spectral_norm=True),
226
+ DiscriminatorS(),
227
+ DiscriminatorS(),
228
+ ])
229
+ self.meanpools = nn.ModuleList([
230
+ AvgPool1d(4, 2, padding=2),
231
+ AvgPool1d(4, 2, padding=2)
232
+ ])
233
+
234
+ def forward(self, y, y_hat):
235
+ y_d_rs = []
236
+ y_d_gs = []
237
+ fmap_rs = []
238
+ fmap_gs = []
239
+ for i, d in enumerate(self.discriminators):
240
+ if i != 0:
241
+ y = self.meanpools[i-1](y)
242
+ y_hat = self.meanpools[i-1](y_hat)
243
+ y_d_r, fmap_r = d(y)
244
+ y_d_g, fmap_g = d(y_hat)
245
+ y_d_rs.append(y_d_r)
246
+ fmap_rs.append(fmap_r)
247
+ y_d_gs.append(y_d_g)
248
+ fmap_gs.append(fmap_g)
249
+
250
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
251
+
252
+
253
+ def feature_loss(fmap_r, fmap_g):
254
+ loss = 0
255
+ for dr, dg in zip(fmap_r, fmap_g):
256
+ for rl, gl in zip(dr, dg):
257
+ loss += torch.mean(torch.abs(rl - gl))
258
+
259
+ return loss*2
260
+
261
+
262
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
263
+ loss = 0
264
+ r_losses = []
265
+ g_losses = []
266
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
267
+ r_loss = torch.mean((1-dr)**2)
268
+ g_loss = torch.mean(dg**2)
269
+ loss += (r_loss + g_loss)
270
+ r_losses.append(r_loss.item())
271
+ g_losses.append(g_loss.item())
272
+
273
+ return loss, r_losses, g_losses
274
+
275
+
276
+ def generator_loss(disc_outputs):
277
+ loss = 0
278
+ gen_losses = []
279
+ for dg in disc_outputs:
280
+ l = torch.mean((1-dg)**2)
281
+ gen_losses.append(l)
282
+ loss += l
283
+
284
+ return loss, gen_losses
285
+
Grad-TTS/hifi-gan/xutils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/jik876/hifi-gan """
2
+
3
+ import glob
4
+ import os
5
+ import matplotlib
6
+ import torch
7
+ from torch.nn.utils import weight_norm
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pylab as plt
10
+
11
+
12
+ def plot_spectrogram(spectrogram):
13
+ fig, ax = plt.subplots(figsize=(10, 2))
14
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
15
+ interpolation='none')
16
+ plt.colorbar(im, ax=ax)
17
+
18
+ fig.canvas.draw()
19
+ plt.close()
20
+
21
+ return fig
22
+
23
+
24
+ def init_weights(m, mean=0.0, std=0.01):
25
+ classname = m.__class__.__name__
26
+ if classname.find("Conv") != -1:
27
+ m.weight.data.normal_(mean, std)
28
+
29
+
30
+ def apply_weight_norm(m):
31
+ classname = m.__class__.__name__
32
+ if classname.find("Conv") != -1:
33
+ weight_norm(m)
34
+
35
+
36
+ def get_padding(kernel_size, dilation=1):
37
+ return int((kernel_size*dilation - dilation)/2)
38
+
39
+
40
+ def load_checkpoint(filepath, device):
41
+ assert os.path.isfile(filepath)
42
+ print("Loading '{}'".format(filepath))
43
+ checkpoint_dict = torch.load(filepath, map_location=device)
44
+ print("Complete.")
45
+ return checkpoint_dict
46
+
47
+
48
+ def save_checkpoint(filepath, obj):
49
+ print("Saving checkpoint to {}".format(filepath))
50
+ torch.save(obj, filepath)
51
+ print("Complete.")
52
+
53
+
54
+ def scan_checkpoint(cp_dir, prefix):
55
+ pattern = os.path.join(cp_dir, prefix + '????????')
56
+ cp_list = glob.glob(pattern)
57
+ if len(cp_list) == 0:
58
+ return None
59
+ return sorted(cp_list)[-1]
60
+
Grad-TTS/inference.ipynb ADDED
@@ -0,0 +1,199 @@