Yeserumo commited on
Commit
c653355
1 Parent(s): 63e9d93
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. output/audio.wav +0 -0
  2. samples/1320_00000.mp3 +0 -0
  3. samples/3575_00000.mp3 +0 -0
  4. samples/6829_00000.mp3 +0 -0
  5. samples/8230_00000.mp3 +0 -0
  6. samples/README.md +22 -0
  7. samples/VCTK.txt +94 -0
  8. samples/p240_00000.mp3 +0 -0
  9. samples/p260_00000.mp3 +0 -0
  10. saved_models/default/encoder.pt +3 -0
  11. saved_models/default/synthesizer.pt +3 -0
  12. saved_models/default/vocoder.pt +3 -0
  13. saved_models/default/zh_synthesizer.pt +3 -0
  14. synthesizer/LICENSE.txt +24 -0
  15. synthesizer/__init__.py +1 -0
  16. synthesizer/__pycache__/__init__.cpython-37.pyc +0 -0
  17. synthesizer/__pycache__/audio.cpython-37.pyc +0 -0
  18. synthesizer/__pycache__/hparams.cpython-37.pyc +0 -0
  19. synthesizer/__pycache__/inference.cpython-37.pyc +0 -0
  20. synthesizer/audio.py +206 -0
  21. synthesizer/hparams.py +92 -0
  22. synthesizer/inference.py +165 -0
  23. synthesizer/models/__pycache__/tacotron.cpython-37.pyc +0 -0
  24. synthesizer/models/tacotron.py +519 -0
  25. synthesizer/preprocess.py +258 -0
  26. synthesizer/synthesize.py +92 -0
  27. synthesizer/synthesizer_dataset.py +92 -0
  28. synthesizer/train.py +258 -0
  29. synthesizer/utils/__init__.py +45 -0
  30. synthesizer/utils/__pycache__/__init__.cpython-37.pyc +0 -0
  31. synthesizer/utils/__pycache__/cleaners.cpython-37.pyc +0 -0
  32. synthesizer/utils/__pycache__/numbers.cpython-37.pyc +0 -0
  33. synthesizer/utils/__pycache__/symbols.cpython-37.pyc +0 -0
  34. synthesizer/utils/__pycache__/text.cpython-37.pyc +0 -0
  35. synthesizer/utils/_cmudict.py +62 -0
  36. synthesizer/utils/cleaners.py +88 -0
  37. synthesizer/utils/numbers.py +69 -0
  38. synthesizer/utils/plot.py +82 -0
  39. synthesizer/utils/symbols.py +21 -0
  40. synthesizer/utils/text.py +75 -0
  41. toolbox/__init__.py +347 -0
  42. toolbox/ui.py +607 -0
  43. toolbox/utterance.py +5 -0
  44. utils/__init__.py +0 -0
  45. utils/__pycache__/__init__.cpython-37.pyc +0 -0
  46. utils/__pycache__/argutils.cpython-37.pyc +0 -0
  47. utils/__pycache__/default_models.cpython-37.pyc +0 -0
  48. utils/argutils.py +40 -0
  49. utils/default_models.py +56 -0
  50. utils/logmmse.py +247 -0
output/audio.wav ADDED
Binary file (189 kB). View file
 
samples/1320_00000.mp3 ADDED
Binary file (15.5 kB). View file
 
samples/3575_00000.mp3 ADDED
Binary file (15.5 kB). View file
 
samples/6829_00000.mp3 ADDED
Binary file (15.6 kB). View file
 
samples/8230_00000.mp3 ADDED
Binary file (16.1 kB). View file
 
samples/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The audio files in this folder are provided for toolbox testing and
2
+ benchmarking purposes. These are the same reference utterances
3
+ used by the SV2TTS authors to generate the audio samples located at:
4
+ https://google.github.io/tacotron/publications/speaker_adaptation/index.html
5
+
6
+ The `p240_00000.mp3` and `p260_00000.mp3` files are compressed
7
+ versions of audios from the VCTK corpus available at:
8
+ https://datashare.is.ed.ac.uk/handle/10283/3443
9
+ VCTK.txt contains the copyright notices and licensing information.
10
+
11
+ The `1320_00000.mp3`, `3575_00000.mp3`, `6829_00000.mp3`
12
+ and `8230_00000.mp3` files are compressed versions of audios
13
+ from the LibriSpeech dataset available at: https://openslr.org/12
14
+ For these files, the following notice applies:
15
+ ```
16
+ LibriSpeech (c) 2014 by Vassil Panayotov
17
+
18
+ LibriSpeech ASR corpus is licensed under a
19
+ Creative Commons Attribution 4.0 International License.
20
+
21
+ See <http://creativecommons.org/licenses/by/4.0/>.
22
+ ```
samples/VCTK.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---------------------------------------------------------------------
2
+ CSTR VCTK Corpus
3
+ English Multi-speaker Corpus for CSTR Voice Cloning Toolkit
4
+
5
+ (Version 0.92)
6
+ RELEASE September 2019
7
+ The Centre for Speech Technology Research
8
+ University of Edinburgh
9
+ Copyright (c) 2019
10
+
11
+ Junichi Yamagishi
12
+ jyamagis@inf.ed.ac.uk
13
+ ---------------------------------------------------------------------
14
+
15
+ Overview
16
+
17
+ This CSTR VCTK Corpus includes speech data uttered by 110 English
18
+ speakers with various accents. Each speaker reads out about 400
19
+ sentences, which were selected from a newspaper, the rainbow passage
20
+ and an elicitation paragraph used for the speech accent archive.
21
+
22
+ The newspaper texts were taken from Herald Glasgow, with permission
23
+ from Herald & Times Group. Each speaker has a different set of the
24
+ newspaper texts selected based a greedy algorithm that increases the
25
+ contextual and phonetic coverage. The details of the text selection
26
+ algorithms are described in the following paper:
27
+
28
+ C. Veaux, J. Yamagishi and S. King,
29
+ "The voice bank corpus: Design, collection and data analysis of
30
+ a large regional accent speech database,"
31
+ https://doi.org/10.1109/ICSDA.2013.6709856
32
+
33
+ The rainbow passage and elicitation paragraph are the same for all
34
+ speakers. The rainbow passage can be found at International Dialects
35
+ of English Archive:
36
+ (http://web.ku.edu/~idea/readings/rainbow.htm). The elicitation
37
+ paragraph is identical to the one used for the speech accent archive
38
+ (http://accent.gmu.edu). The details of the the speech accent archive
39
+ can be found at
40
+ http://www.ualberta.ca/~aacl2009/PDFs/WeinbergerKunath2009AACL.pdf
41
+
42
+ All speech data was recorded using an identical recording setup: an
43
+ omni-directional microphone (DPA 4035) and a small diaphragm condenser
44
+ microphone with very wide bandwidth (Sennheiser MKH 800), 96kHz
45
+ sampling frequency at 24 bits and in a hemi-anechoic chamber of
46
+ the University of Edinburgh. (However, two speakers, p280 and p315
47
+ had technical issues of the audio recordings using MKH 800).
48
+ All recordings were converted into 16 bits, were downsampled to
49
+ 48 kHz, and were manually end-pointed.
50
+
51
+ This corpus was originally aimed for HMM-based text-to-speech synthesis
52
+ systems, especially for speaker-adaptive HMM-based speech synthesis
53
+ that uses average voice models trained on multiple speakers and speaker
54
+ adaptation technologies. This corpus is also suitable for DNN-based
55
+ multi-speaker text-to-speech synthesis systems and waveform modeling.
56
+
57
+ COPYING
58
+
59
+ This corpus is licensed under the Creative Commons License: Attribution 4.0 International
60
+ http://creativecommons.org/licenses/by/4.0/legalcode
61
+
62
+ VCTK VARIANTS
63
+ There are several variants of the VCTK corpus:
64
+ Speech enhancement
65
+ - Noisy speech database for training speech enhancement algorithms and TTS models where we added various types of noises to VCTK artificially: http://dx.doi.org/10.7488/ds/2117
66
+ - Reverberant speech database for training speech dereverberation algorithms and TTS models where we added various types of reverberantion to VCTK artificially http://dx.doi.org/10.7488/ds/1425
67
+ - Noisy reverberant speech database for training speech enhancement algorithms and TTS models http://dx.doi.org/10.7488/ds/2139
68
+ - Device Recorded VCTK where speech signals of the VCTK corpus were played back and re-recorded in office environments using relatively inexpensive consumer devices http://dx.doi.org/10.7488/ds/2316
69
+ - The Microsoft Scalable Noisy Speech Dataset (MS-SNSD) https://github.com/microsoft/MS-SNSD
70
+
71
+ ASV and anti-spoofing
72
+ - Spoofing and Anti-Spoofing (SAS) corpus, which is a collection of synthetic speech signals produced by nine techniques, two of which are speech synthesis, and seven are voice conversion. All of them were built using the VCTK corpus. http://dx.doi.org/10.7488/ds/252
73
+ - Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) Database. This database consists of synthetic speech signals produced by ten techniques and this has been used in the first Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) http://dx.doi.org/10.7488/ds/298
74
+ - ASVspoof 2019: The 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge database. This database has been used in the 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2019) https://doi.org/10.7488/ds/2555
75
+
76
+
77
+ ACKNOWLEDGEMENTS
78
+
79
+ The CSTR VCTK Corpus was constructed by:
80
+
81
+ Christophe Veaux (University of Edinburgh)
82
+ Junichi Yamagishi (University of Edinburgh)
83
+ Kirsten MacDonald
84
+
85
+ The research leading to these results was partly funded from EPSRC
86
+ grants EP/I031022/1 (NST) and EP/J002526/1 (CAF), from the RSE-NSFC
87
+ grant (61111130120), and from the JST CREST (uDialogue).
88
+
89
+ Please cite this corpus as follows:
90
+ Christophe Veaux, Junichi Yamagishi, Kirsten MacDonald,
91
+ "CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit",
92
+ The Centre for Speech Technology Research (CSTR),
93
+ University of Edinburgh
94
+
samples/p240_00000.mp3 ADDED
Binary file (20.2 kB). View file
 
samples/p260_00000.mp3 ADDED
Binary file (20.5 kB). View file
 
saved_models/default/encoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39373b86598fa3da9fcddee6142382efe09777e8d37dc9c0561f41f0070f134e
3
+ size 17090379
saved_models/default/synthesizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c05e07428f95d0ed8755e1ef54cc8ae251300413d94ce5867a56afe39c499d94
3
+ size 370554559
saved_models/default/vocoder.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1d7a6861589e927e0fbdaa5849ca022258fe2b58a20cc7bfb8fb598ccf936169
3
+ size 53845290
saved_models/default/zh_synthesizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27de1bfd98fe7f99f99399c0349b35e213673d8181412deb914bc5593460dfb2
3
+ size 370667477
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
synthesizer/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (163 Bytes). View file
 
synthesizer/__pycache__/audio.cpython-37.pyc ADDED
Binary file (6.74 kB). View file
 
synthesizer/__pycache__/hparams.cpython-37.pyc ADDED
Binary file (2.77 kB). View file
 
synthesizer/__pycache__/inference.cpython-37.pyc ADDED
Binary file (6.3 kB). View file
 
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/hparams.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+ sample_rate = 16000,
24
+ n_fft = 800,
25
+ num_mels = 80,
26
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28
+ fmin = 55,
29
+ min_level_db = -100,
30
+ ref_level_db = 20,
31
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33
+ preemphasize = True,
34
+
35
+ ### Tacotron Text-to-Speech (TTS)
36
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37
+ tts_encoder_dims = 256,
38
+ tts_decoder_dims = 128,
39
+ tts_postnet_dims = 512,
40
+ tts_encoder_K = 5,
41
+ tts_lstm_dims = 1024,
42
+ tts_postnet_K = 5,
43
+ tts_num_highways = 4,
44
+ tts_dropout = 0.5,
45
+ tts_cleaner_names = ["english_cleaners"],
46
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
47
+ # For example, for a range of [-4, 4], this
48
+ # will terminate the sequence at the first
49
+ # frame that has all values < -3.4
50
+
51
+ ### Tacotron Training
52
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54
+ (2, 2e-4, 80_000, 12), #
55
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
58
+
59
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61
+ # Set to -1 to generate after completing epoch, or 0 to disable
62
+
63
+ tts_eval_num_samples = 1, # Makes this number of samples
64
+
65
+ ### Data Preprocessing
66
+ max_mel_frames = 900,
67
+ rescale = True,
68
+ rescaling_max = 0.9,
69
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70
+
71
+ ### Mel Visualization and Griffin-Lim
72
+ signal_normalization = True,
73
+ power = 1.5,
74
+ griffin_lim_iters = 60,
75
+
76
+ ### Audio processing options
77
+ fmax = 7600, # Should not exceed (sample_rate // 2)
78
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
79
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82
+ # and [0, max_abs_value] if False
83
+ trim_silence = True, # Use with sample_rate of 16000 for best results
84
+
85
+ ### SV2TTS
86
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
87
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89
+ )
90
+
91
+ def hparams_debug_string():
92
+ return str(hparams)
synthesizer/inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Preprocess text inputs
89
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
90
+ if not isinstance(embeddings, list):
91
+ embeddings = [embeddings]
92
+
93
+ # Batch inputs
94
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
95
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
96
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
97
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
98
+
99
+ specs = []
100
+ for i, batch in enumerate(batched_inputs, 1):
101
+ if self.verbose:
102
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
103
+
104
+ # Pad texts so they are all the same length
105
+ text_lens = [len(text) for text in batch]
106
+ max_text_len = max(text_lens)
107
+ chars = [pad1d(text, max_text_len) for text in batch]
108
+ chars = np.stack(chars)
109
+
110
+ # Stack speaker embeddings into 2D array for batch processing
111
+ speaker_embeds = np.stack(batched_embeds[i-1])
112
+
113
+ # Convert to tensor
114
+ chars = torch.tensor(chars).long().to(self.device)
115
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
116
+
117
+ # Inference
118
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
119
+ mels = mels.detach().cpu().numpy()
120
+ for m in mels:
121
+ # Trim silence from end of each spectrogram
122
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
123
+ m = m[:, :-1]
124
+ specs.append(m)
125
+
126
+ if self.verbose:
127
+ print("\n\nDone.\n")
128
+ return (specs, alignments) if return_alignments else specs
129
+
130
+ @staticmethod
131
+ def load_preprocess_wav(fpath):
132
+ """
133
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
134
+ train the synthesizer.
135
+ """
136
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
137
+ if hparams.rescale:
138
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
139
+ return wav
140
+
141
+ @staticmethod
142
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
143
+ """
144
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
145
+ were fed to the synthesizer when training.
146
+ """
147
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
148
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
149
+ else:
150
+ wav = fpath_or_wav
151
+
152
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
153
+ return mel_spectrogram
154
+
155
+ @staticmethod
156
+ def griffin_lim(mel):
157
+ """
158
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
159
+ with the same parameters present in hparams.py.
160
+ """
161
+ return audio.inv_mel_spectrogram(mel, hparams)
162
+
163
+
164
+ def pad1d(x, max_len, pad_value=0):
165
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
synthesizer/models/__pycache__/tacotron.cpython-37.pyc ADDED
Binary file (14.2 kB). View file
 
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
synthesizer/preprocess.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
14
+ no_alignments: bool, datasets_name: str, subfolders: str):
15
+ # Gather the input directories
16
+ dataset_root = datasets_root.joinpath(datasets_name)
17
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
18
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
19
+ assert all(input_dir.exists() for input_dir in input_dirs)
20
+
21
+ # Create the output directories for each output file type
22
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
23
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
24
+
25
+ # Create a metadata file
26
+ metadata_fpath = out_dir.joinpath("train.txt")
27
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
28
+
29
+ # Preprocess the dataset
30
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
31
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
32
+ hparams=hparams, no_alignments=no_alignments)
33
+ job = Pool(n_processes).imap(func, speaker_dirs)
34
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
35
+ for metadatum in speaker_metadata:
36
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
37
+ metadata_file.close()
38
+
39
+ # Verify the contents of the metadata file
40
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
41
+ metadata = [line.split("|") for line in metadata_file]
42
+ mel_frames = sum([int(m[4]) for m in metadata])
43
+ timesteps = sum([int(m[3]) for m in metadata])
44
+ sample_rate = hparams.sample_rate
45
+ hours = (timesteps / sample_rate) / 3600
46
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
47
+ (len(metadata), mel_frames, timesteps, hours))
48
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
49
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
50
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
51
+
52
+
53
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
54
+ metadata = []
55
+ for book_dir in speaker_dir.glob("*"):
56
+ if no_alignments:
57
+ # Gather the utterance audios and texts
58
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
59
+ extensions = ["*.wav", "*.flac", "*.mp3"]
60
+ for extension in extensions:
61
+ wav_fpaths = book_dir.glob(extension)
62
+
63
+ for wav_fpath in wav_fpaths:
64
+ # Load the audio waveform
65
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
66
+ if hparams.rescale:
67
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
68
+
69
+ # Get the corresponding text
70
+ # Check for .txt (for compatibility with other datasets)
71
+ text_fpath = wav_fpath.with_suffix(".txt")
72
+ if not text_fpath.exists():
73
+ # Check for .normalized.txt (LibriTTS)
74
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
75
+ assert text_fpath.exists()
76
+ with text_fpath.open("r") as text_file:
77
+ text = "".join([line for line in text_file])
78
+ text = text.replace("\"", "")
79
+ text = text.strip()
80
+
81
+ # Process the utterance
82
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
83
+ skip_existing, hparams))
84
+ else:
85
+ # Process alignment file (LibriSpeech support)
86
+ # Gather the utterance audios and texts
87
+ try:
88
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
89
+ with alignments_fpath.open("r") as alignments_file:
90
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
91
+ except StopIteration:
92
+ # A few alignment files will be missing
93
+ continue
94
+
95
+ # Iterate over each entry in the alignments file
96
+ for wav_fname, words, end_times in alignments:
97
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
98
+ assert wav_fpath.exists()
99
+ words = words.replace("\"", "").split(",")
100
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
101
+
102
+ # Process each sub-utterance
103
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
104
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
105
+ sub_basename = "%s_%02d" % (wav_fname, i)
106
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
107
+ skip_existing, hparams))
108
+
109
+ return [m for m in metadata if m is not None]
110
+
111
+
112
+ def split_on_silences(wav_fpath, words, end_times, hparams):
113
+ # Load the audio waveform
114
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
115
+ if hparams.rescale:
116
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
117
+
118
+ words = np.array(words)
119
+ start_times = np.array([0.0] + end_times[:-1])
120
+ end_times = np.array(end_times)
121
+ assert len(words) == len(end_times) == len(start_times)
122
+ assert words[0] == "" and words[-1] == ""
123
+
124
+ # Find pauses that are too long
125
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
126
+ mask[0] = mask[-1] = True
127
+ breaks = np.where(mask)[0]
128
+
129
+ # Profile the noise from the silences and perform noise reduction on the waveform
130
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
131
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
132
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
133
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
134
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
135
+ wav = logmmse.denoise(wav, profile, eta=0)
136
+
137
+ # Re-attach segments that are too short
138
+ segments = list(zip(breaks[:-1], breaks[1:]))
139
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
140
+ i = 0
141
+ while i < len(segments) and len(segments) > 1:
142
+ if segment_durations[i] < hparams.utterance_min_duration:
143
+ # See if the segment can be re-attached with the right or the left segment
144
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
145
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
146
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
147
+
148
+ # Do not re-attach if it causes the joined utterance to be too long
149
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
150
+ i += 1
151
+ continue
152
+
153
+ # Re-attach the segment with the neighbour of shortest duration
154
+ j = i - 1 if left_duration <= right_duration else i
155
+ segments[j] = (segments[j][0], segments[j + 1][1])
156
+ segment_durations[j] = joined_duration
157
+ del segments[j + 1], segment_durations[j + 1]
158
+ else:
159
+ i += 1
160
+
161
+ # Split the utterance
162
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
163
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
164
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
165
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
166
+
167
+ # # DEBUG: play the audio segments (run with -n=1)
168
+ # import sounddevice as sd
169
+ # if len(wavs) > 1:
170
+ # print("This sentence was split in %d segments:" % len(wavs))
171
+ # else:
172
+ # print("There are no silences long enough for this sentence to be split:")
173
+ # for wav, text in zip(wavs, texts):
174
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
175
+ # # when playing them. You shouldn't need to do that in your parsers.
176
+ # wav = np.concatenate((wav, [0] * 16000))
177
+ # print("\t%s" % text)
178
+ # sd.play(wav, 16000, blocking=True)
179
+ # print("")
180
+
181
+ return wavs, texts
182
+
183
+
184
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
185
+ skip_existing: bool, hparams):
186
+ ## FOR REFERENCE:
187
+ # For you not to lose your head if you ever wish to change things here or implement your own
188
+ # synthesizer.
189
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
190
+ # - There is no processing done to the audios that will be saved to disk beyond volume
191
+ # normalization (in split_on_silences)
192
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
193
+ # is why we re-apply it on the audio on the side of the vocoder.
194
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
195
+ # without extra padding. This means that you won't have an exact relation between the length
196
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
197
+
198
+
199
+ # Skip existing utterances if needed
200
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
201
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
202
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
203
+ return None
204
+
205
+ # Trim silence
206
+ if hparams.trim_silence:
207
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
208
+
209
+ # Skip utterances that are too short
210
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
211
+ return None
212
+
213
+ # Compute the mel spectrogram
214
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
215
+ mel_frames = mel_spectrogram.shape[1]
216
+
217
+ # Skip utterances that are too long
218
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
219
+ return None
220
+
221
+ # Write the spectrogram, embed and audio to disk
222
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
223
+ np.save(wav_fpath, wav, allow_pickle=False)
224
+
225
+ # Return a tuple describing this training example
226
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
227
+
228
+
229
+ def embed_utterance(fpaths, encoder_model_fpath):
230
+ if not encoder.is_loaded():
231
+ encoder.load_model(encoder_model_fpath)
232
+
233
+ # Compute the speaker embedding of the utterance
234
+ wav_fpath, embed_fpath = fpaths
235
+ wav = np.load(wav_fpath)
236
+ wav = encoder.preprocess_wav(wav)
237
+ embed = encoder.embed_utterance(wav)
238
+ np.save(embed_fpath, embed, allow_pickle=False)
239
+
240
+
241
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
242
+ wav_dir = synthesizer_root.joinpath("audio")
243
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
244
+ assert wav_dir.exists() and metadata_fpath.exists()
245
+ embed_dir = synthesizer_root.joinpath("embeds")
246
+ embed_dir.mkdir(exist_ok=True)
247
+
248
+ # Gather the input wave filepath and the target output embed filepath
249
+ with metadata_fpath.open("r") as metadata_file:
250
+ metadata = [line.split("|") for line in metadata_file]
251
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
252
+
253
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
254
+ # Embed the utterances in separate threads
255
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
256
+ job = Pool(n_processes).imap(func, fpaths)
257
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
258
+
synthesizer/synthesize.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from synthesizer.hparams import hparams_debug_string
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import data_parallel_workaround
14
+ from synthesizer.utils.symbols import symbols
15
+
16
+
17
+ def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
18
+ # This generates ground truth-aligned mels for vocoder training
19
+ synth_dir = out_dir / "mels_gta"
20
+ synth_dir.mkdir(exist_ok=True, parents=True)
21
+ print(hparams_debug_string())
22
+
23
+ # Check for GPU
24
+ if torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
27
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
28
+ else:
29
+ device = torch.device("cpu")
30
+ print("Synthesizer using device:", device)
31
+
32
+ # Instantiate Tacotron model
33
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
34
+ num_chars=len(symbols),
35
+ encoder_dims=hparams.tts_encoder_dims,
36
+ decoder_dims=hparams.tts_decoder_dims,
37
+ n_mels=hparams.num_mels,
38
+ fft_bins=hparams.num_mels,
39
+ postnet_dims=hparams.tts_postnet_dims,
40
+ encoder_K=hparams.tts_encoder_K,
41
+ lstm_dims=hparams.tts_lstm_dims,
42
+ postnet_K=hparams.tts_postnet_K,
43
+ num_highways=hparams.tts_num_highways,
44
+ dropout=0., # Use zero dropout for gta mels
45
+ stop_threshold=hparams.tts_stop_threshold,
46
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
47
+
48
+ # Load the weights
49
+ print("\nLoading weights at %s" % syn_model_fpath)
50
+ model.load(syn_model_fpath)
51
+ print("Tacotron weights loaded from step %d" % model.step)
52
+
53
+ # Synthesize using same reduction factor as the model is currently trained
54
+ r = np.int32(model.r)
55
+
56
+ # Set model to eval mode (disable gradient and zoneout)
57
+ model.eval()
58
+
59
+ # Initialize the dataset
60
+ metadata_fpath = in_dir.joinpath("train.txt")
61
+ mel_dir = in_dir.joinpath("mels")
62
+ embed_dir = in_dir.joinpath("embeds")
63
+
64
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
65
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
66
+ data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
67
+
68
+ # Generate GTA mels
69
+ meta_out_fpath = out_dir / "synthesized.txt"
70
+ with meta_out_fpath.open("w") as file:
71
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
72
+ texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
73
+
74
+ # Parallelize model onto GPUS using workaround due to python bug
75
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
76
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
77
+ else:
78
+ _, mels_out, _, _ = model(texts, mels, embeds)
79
+
80
+ for j, k in enumerate(idx):
81
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
82
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
83
+ mel_out = mels_out[j].detach().cpu().numpy().T
84
+
85
+ # Use the length of the ground truth mel to remove padding from the generated mels
86
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
87
+
88
+ # Write the spectrogram to disk
89
+ np.save(mel_filename, mel_out, allow_pickle=False)
90
+
91
+ # Write metadata into the synthesized file
92
+ file.write("|".join(dataset.metadata[k]))
synthesizer/synthesizer_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from synthesizer.utils.text import text_to_sequence
6
+
7
+
8
+ class SynthesizerDataset(Dataset):
9
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
10
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
11
+
12
+ with metadata_fpath.open("r") as metadata_file:
13
+ metadata = [line.split("|") for line in metadata_file]
14
+
15
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
16
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
17
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
18
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
19
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
20
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
21
+ self.metadata = metadata
22
+ self.hparams = hparams
23
+
24
+ print("Found %d samples" % len(self.samples_fpaths))
25
+
26
+ def __getitem__(self, index):
27
+ # Sometimes index may be a list of 2 (not sure why this happens)
28
+ # If that is the case, return a single item corresponding to first element in index
29
+ if index is list:
30
+ index = index[0]
31
+
32
+ mel_path, embed_path = self.samples_fpaths[index]
33
+ mel = np.load(mel_path).T.astype(np.float32)
34
+
35
+ # Load the embed
36
+ embed = np.load(embed_path)
37
+
38
+ # Get the text and clean it
39
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
40
+
41
+ # Convert the list returned by text_to_sequence to a numpy array
42
+ text = np.asarray(text).astype(np.int32)
43
+
44
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
45
+
46
+ def __len__(self):
47
+ return len(self.samples_fpaths)
48
+
49
+
50
+ def collate_synthesizer(batch, r, hparams):
51
+ # Text
52
+ x_lens = [len(x[0]) for x in batch]
53
+ max_x_len = max(x_lens)
54
+
55
+ chars = [pad1d(x[0], max_x_len) for x in batch]
56
+ chars = np.stack(chars)
57
+
58
+ # Mel spectrogram
59
+ spec_lens = [x[1].shape[-1] for x in batch]
60
+ max_spec_len = max(spec_lens) + 1
61
+ if max_spec_len % r != 0:
62
+ max_spec_len += r - max_spec_len % r
63
+
64
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
65
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
66
+ if hparams.symmetric_mels:
67
+ mel_pad_value = -1 * hparams.max_abs_value
68
+ else:
69
+ mel_pad_value = 0
70
+
71
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
72
+ mel = np.stack(mel)
73
+
74
+ # Speaker embedding (SV2TTS)
75
+ embeds = np.array([x[2] for x in batch])
76
+
77
+ # Index (for vocoder preprocessing)
78
+ indices = [x[3] for x in batch]
79
+
80
+
81
+ # Convert all to tensor
82
+ chars = torch.tensor(chars).long()
83
+ mel = torch.tensor(mel)
84
+ embeds = torch.tensor(embeds)
85
+
86
+ return chars, mel, embeds, indices
87
+
88
+ def pad1d(x, max_len, pad_value=0):
89
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
90
+
91
+ def pad2d(x, max_len, pad_value=0):
92
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
synthesizer/train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import optim
8
+ from torch.utils.data import DataLoader
9
+
10
+ from synthesizer import audio
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
14
+ from synthesizer.utils.plot import plot_spectrogram
15
+ from synthesizer.utils.symbols import symbols
16
+ from synthesizer.utils.text import sequence_to_text
17
+ from vocoder.display import *
18
+
19
+
20
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
21
+
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+
27
+ def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
28
+ hparams):
29
+ models_dir.mkdir(exist_ok=True)
30
+
31
+ model_dir = models_dir.joinpath(run_id)
32
+ plot_dir = model_dir.joinpath("plots")
33
+ wav_dir = model_dir.joinpath("wavs")
34
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
35
+ meta_folder = model_dir.joinpath("metas")
36
+ model_dir.mkdir(exist_ok=True)
37
+ plot_dir.mkdir(exist_ok=True)
38
+ wav_dir.mkdir(exist_ok=True)
39
+ mel_output_dir.mkdir(exist_ok=True)
40
+ meta_folder.mkdir(exist_ok=True)
41
+
42
+ weights_fpath = model_dir / f"synthesizer.pt"
43
+ metadata_fpath = syn_dir.joinpath("train.txt")
44
+
45
+ print("Checkpoint path: {}".format(weights_fpath))
46
+ print("Loading training data from: {}".format(metadata_fpath))
47
+ print("Using model: Tacotron")
48
+
49
+ # Bookkeeping
50
+ time_window = ValueWindow(100)
51
+ loss_window = ValueWindow(100)
52
+
53
+ # From WaveRNN/train_tacotron.py
54
+ if torch.cuda.is_available():
55
+ device = torch.device("cuda")
56
+
57
+ for session in hparams.tts_schedule:
58
+ _, _, _, batch_size = session
59
+ if batch_size % torch.cuda.device_count() != 0:
60
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
61
+ else:
62
+ device = torch.device("cpu")
63
+ print("Using device:", device)
64
+
65
+ # Instantiate Tacotron Model
66
+ print("\nInitialising Tacotron Model...\n")
67
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
68
+ num_chars=len(symbols),
69
+ encoder_dims=hparams.tts_encoder_dims,
70
+ decoder_dims=hparams.tts_decoder_dims,
71
+ n_mels=hparams.num_mels,
72
+ fft_bins=hparams.num_mels,
73
+ postnet_dims=hparams.tts_postnet_dims,
74
+ encoder_K=hparams.tts_encoder_K,
75
+ lstm_dims=hparams.tts_lstm_dims,
76
+ postnet_K=hparams.tts_postnet_K,
77
+ num_highways=hparams.tts_num_highways,
78
+ dropout=hparams.tts_dropout,
79
+ stop_threshold=hparams.tts_stop_threshold,
80
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
81
+
82
+ # Initialize the optimizer
83
+ optimizer = optim.Adam(model.parameters())
84
+
85
+ # Load the weights
86
+ if force_restart or not weights_fpath.exists():
87
+ print("\nStarting the training of Tacotron from scratch\n")
88
+ model.save(weights_fpath)
89
+
90
+ # Embeddings metadata
91
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
92
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
93
+ for symbol in symbols:
94
+ if symbol == " ":
95
+ symbol = "\\s" # For visual purposes, swap space with \s
96
+
97
+ f.write("{}\n".format(symbol))
98
+
99
+ else:
100
+ print("\nLoading weights at %s" % weights_fpath)
101
+ model.load(weights_fpath, optimizer)
102
+ print("Tacotron weights loaded from step %d" % model.step)
103
+
104
+ # Initialize the dataset
105
+ metadata_fpath = syn_dir.joinpath("train.txt")
106
+ mel_dir = syn_dir.joinpath("mels")
107
+ embed_dir = syn_dir.joinpath("embeds")
108
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
109
+
110
+ for i, session in enumerate(hparams.tts_schedule):
111
+ current_step = model.get_step()
112
+
113
+ r, lr, max_step, batch_size = session
114
+
115
+ training_steps = max_step - current_step
116
+
117
+ # Do we need to change to the next session?
118
+ if current_step >= max_step:
119
+ # Are there no further sessions than the current one?
120
+ if i == len(hparams.tts_schedule) - 1:
121
+ # We have completed training. Save the model and exit
122
+ model.save(weights_fpath, optimizer)
123
+ break
124
+ else:
125
+ # There is a following session, go to it
126
+ continue
127
+
128
+ model.r = r
129
+
130
+ # Begin the training
131
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
132
+ ("Batch Size", batch_size),
133
+ ("Learning Rate", lr),
134
+ ("Outputs/Step (r)", model.r)])
135
+
136
+ for p in optimizer.param_groups:
137
+ p["lr"] = lr
138
+
139
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
140
+ data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
141
+
142
+ total_iters = len(dataset)
143
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
144
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
145
+
146
+ for epoch in range(1, epochs+1):
147
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
148
+ start_time = time.time()
149
+
150
+ # Generate stop tokens for training
151
+ stop = torch.ones(mels.shape[0], mels.shape[2])
152
+ for j, k in enumerate(idx):
153
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
154
+
155
+ texts = texts.to(device)
156
+ mels = mels.to(device)
157
+ embeds = embeds.to(device)
158
+ stop = stop.to(device)
159
+
160
+ # Forward pass
161
+ # Parallelize model onto GPUS using workaround due to python bug
162
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
163
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
164
+ else:
165
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
166
+
167
+ # Backward pass
168
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
169
+ m2_loss = F.mse_loss(m2_hat, mels)
170
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
171
+
172
+ loss = m1_loss + m2_loss + stop_loss
173
+
174
+ optimizer.zero_grad()
175
+ loss.backward()
176
+
177
+ if hparams.tts_clip_grad_norm is not None:
178
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
179
+ if np.isnan(grad_norm.cpu()):
180
+ print("grad_norm was NaN!")
181
+
182
+ optimizer.step()
183
+
184
+ time_window.append(time.time() - start_time)
185
+ loss_window.append(loss.item())
186
+
187
+ step = model.get_step()
188
+ k = step // 1000
189
+
190
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
191
+ f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
192
+ stream(msg)
193
+
194
+ # Backup or save model as appropriate
195
+ if backup_every != 0 and step % backup_every == 0 :
196
+ backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
197
+ model.save(backup_fpath, optimizer)
198
+
199
+ if save_every != 0 and step % save_every == 0 :
200
+ # Must save latest optimizer state to ensure that resuming training
201
+ # doesn't produce artifacts
202
+ model.save(weights_fpath, optimizer)
203
+
204
+ # Evaluate model to generate samples
205
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
206
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
207
+ if epoch_eval or step_eval:
208
+ for sample_idx in range(hparams.tts_eval_num_samples):
209
+ # At most, generate samples equal to number in the batch
210
+ if sample_idx + 1 <= len(texts):
211
+ # Remove padding from mels using frame length in metadata
212
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
213
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
214
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
215
+ attention_len = mel_length // model.r
216
+
217
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
218
+ mel_prediction=mel_prediction,
219
+ target_spectrogram=target_spectrogram,
220
+ input_seq=np_now(texts[sample_idx]),
221
+ step=step,
222
+ plot_dir=plot_dir,
223
+ mel_output_dir=mel_output_dir,
224
+ wav_dir=wav_dir,
225
+ sample_num=sample_idx + 1,
226
+ loss=loss,
227
+ hparams=hparams)
228
+
229
+ # Break out of loop to update training schedule
230
+ if step >= max_step:
231
+ break
232
+
233
+ # Add line break after every epoch
234
+ print("")
235
+
236
+
237
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
238
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
239
+ # Save some results for evaluation
240
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
241
+ save_attention(attention, attention_path)
242
+
243
+ # save predicted mel spectrogram to disk (debug)
244
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
245
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
246
+
247
+ # save griffin lim inverted wav for debug (mel -> wav)
248
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
249
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
250
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
251
+
252
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
253
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
254
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
255
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
256
+ target_spectrogram=target_spectrogram,
257
+ max_len=target_spectrogram.size // hparams.num_mels)
258
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (1.68 kB). View file
 
synthesizer/utils/__pycache__/cleaners.cpython-37.pyc ADDED
Binary file (2.81 kB). View file
 
synthesizer/utils/__pycache__/numbers.cpython-37.pyc ADDED
Binary file (2.18 kB). View file
 
synthesizer/utils/__pycache__/symbols.cpython-37.pyc ADDED
Binary file (582 Bytes). View file
 
synthesizer/utils/__pycache__/text.cpython-37.pyc ADDED
Binary file (2.71 kB). View file
 
synthesizer/utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+ import re
13
+ from unidecode import unidecode
14
+ from synthesizer.utils.numbers import normalize_numbers
15
+
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+
5
+ _inflect = inflect.engine()
6
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11
+ _number_re = re.compile(r"[0-9]+")
12
+
13
+
14
+ def _remove_commas(m):
15
+ return m.group(1).replace(",", "")
16
+
17
+
18
+ def _expand_decimal_point(m):
19
+ return m.group(1).replace(".", " point ")
20
+
21
+
22
+ def _expand_dollars(m):
23
+ match = m.group(1)
24
+ parts = match.split(".")
25
+ if len(parts) > 2:
26
+ return match + " dollars" # Unexpected format
27
+ dollars = int(parts[0]) if parts[0] else 0
28
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29
+ if dollars and cents:
30
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
31
+ cent_unit = "cent" if cents == 1 else "cents"
32
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33
+ elif dollars:
34
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
35
+ return "%s %s" % (dollars, dollar_unit)
36
+ elif cents:
37
+ cent_unit = "cent" if cents == 1 else "cents"
38
+ return "%s %s" % (cents, cent_unit)
39
+ else:
40
+ return "zero dollars"
41
+
42
+
43
+ def _expand_ordinal(m):
44
+ return _inflect.number_to_words(m.group(0))
45
+
46
+
47
+ def _expand_number(m):
48
+ num = int(m.group(0))
49
+ if num > 1000 and num < 3000:
50
+ if num == 2000:
51
+ return "two thousand"
52
+ elif num > 2000 and num < 2010:
53
+ return "two thousand " + _inflect.number_to_words(num % 100)
54
+ elif num % 100 == 0:
55
+ return _inflect.number_to_words(num // 100) + " hundred"
56
+ else:
57
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58
+ else:
59
+ return _inflect.number_to_words(num, andword="")
60
+
61
+
62
+ def normalize_numbers(text):
63
+ text = re.sub(_comma_number_re, _remove_commas, text)
64
+ text = re.sub(_pounds_re, r"\1 pounds", text)
65
+ text = re.sub(_dollars_re, _expand_dollars, text)
66
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
68
+ text = re.sub(_number_re, _expand_number, text)
69
+ return text
synthesizer/utils/plot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def split_title_line(title_text, max_words=5):
5
+ """
6
+ A function that splits any string based on specific character
7
+ (returning it with the string), with maximum number of words on it
8
+ """
9
+ seq = title_text.split()
10
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
11
+
12
+
13
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ if max_len is not None:
19
+ alignment = alignment[:, :max_len]
20
+
21
+ fig = plt.figure(figsize=(8, 6))
22
+ ax = fig.add_subplot(111)
23
+
24
+ im = ax.imshow(
25
+ alignment,
26
+ aspect="auto",
27
+ origin="lower",
28
+ interpolation="none")
29
+ fig.colorbar(im, ax=ax)
30
+ xlabel = "Decoder timestep"
31
+
32
+ if split_title:
33
+ title = split_title_line(title)
34
+
35
+ plt.xlabel(xlabel)
36
+ plt.title(title)
37
+ plt.ylabel("Encoder timestep")
38
+ plt.tight_layout()
39
+ plt.savefig(path, format="png")
40
+ plt.close()
41
+
42
+
43
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+
48
+ if max_len is not None:
49
+ target_spectrogram = target_spectrogram[:max_len]
50
+ pred_spectrogram = pred_spectrogram[:max_len]
51
+
52
+ if split_title:
53
+ title = split_title_line(title)
54
+
55
+ fig = plt.figure(figsize=(10, 8))
56
+ # Set common labels
57
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
58
+
59
+ #target spectrogram subplot
60
+ if target_spectrogram is not None:
61
+ ax1 = fig.add_subplot(311)
62
+ ax2 = fig.add_subplot(312)
63
+
64
+ if auto_aspect:
65
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
66
+ else:
67
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
68
+ ax1.set_title("Target Mel-Spectrogram")
69
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
70
+ ax2.set_title("Predicted Mel-Spectrogram")
71
+ else:
72
+ ax2 = fig.add_subplot(211)
73
+
74
+ if auto_aspect:
75
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
76
+ else:
77
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
78
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
79
+
80
+ plt.tight_layout()
81
+ plt.savefig(path, format="png")
82
+ plt.close()
synthesizer/utils/symbols.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+
12
+ # for zh
13
+ # _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890!\'(),-.:;? "
14
+
15
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
16
+
17
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
18
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
19
+
20
+ # Export all symbols:
21
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
synthesizer/utils/text.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.utils.symbols import symbols
2
+ from synthesizer.utils import cleaners
3
+ import re
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+ # Regular expression matching text enclosed in curly braces:
11
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
12
+
13
+
14
+ def text_to_sequence(text, cleaner_names):
15
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16
+
17
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
18
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
19
+
20
+ Args:
21
+ text: string to convert to a sequence
22
+ cleaner_names: names of the cleaner functions to run the text through
23
+
24
+ Returns:
25
+ List of integers corresponding to the symbols in the text
26
+ """
27
+ sequence = []
28
+
29
+ # Check for curly braces and treat their contents as ARPAbet:
30
+ while len(text):
31
+ m = _curly_re.match(text)
32
+ if not m:
33
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
34
+ break
35
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
36
+ sequence += _arpabet_to_sequence(m.group(2))
37
+ text = m.group(3)
38
+
39
+ # Append EOS token
40
+ sequence.append(_symbol_to_id["~"])
41
+ return sequence
42
+
43
+
44
+ def sequence_to_text(sequence):
45
+ """Converts a sequence of IDs back to a string"""
46
+ result = ""
47
+ for symbol_id in sequence:
48
+ if symbol_id in _id_to_symbol:
49
+ s = _id_to_symbol[symbol_id]
50
+ # Enclose ARPAbet back in curly braces:
51
+ if len(s) > 1 and s[0] == "@":
52
+ s = "{%s}" % s[1:]
53
+ result += s
54
+ return result.replace("}{", " ")
55
+
56
+
57
+ def _clean_text(text, cleaner_names):
58
+ for name in cleaner_names:
59
+ cleaner = getattr(cleaners, name)
60
+ if not cleaner:
61
+ raise Exception("Unknown cleaner: %s" % name)
62
+ text = cleaner(text)
63
+ return text
64
+
65
+
66
+ def _symbols_to_sequence(symbols):
67
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68
+
69
+
70
+ def _arpabet_to_sequence(text):
71
+ return _symbols_to_sequence(["@" + s for s in text.split()])
72
+
73
+
74
+ def _should_keep_symbol(s):
75
+ return s in _symbol_to_id and s not in ("_", "~")
toolbox/__init__.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+ from pathlib import Path
4
+ from time import perf_counter as timer
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ from encoder import inference as encoder
10
+ from synthesizer.inference import Synthesizer
11
+ from toolbox.ui import UI
12
+ from toolbox.utterance import Utterance
13
+ from vocoder import inference as vocoder
14
+
15
+
16
+ # Use this directory structure for your datasets, or modify it to fit your needs
17
+ recognized_datasets = [
18
+ "LibriSpeech/dev-clean",
19
+ "LibriSpeech/dev-other",
20
+ "LibriSpeech/test-clean",
21
+ "LibriSpeech/test-other",
22
+ "LibriSpeech/train-clean-100",
23
+ "LibriSpeech/train-clean-360",
24
+ "LibriSpeech/train-other-500",
25
+ "LibriTTS/dev-clean",
26
+ "LibriTTS/dev-other",
27
+ "LibriTTS/test-clean",
28
+ "LibriTTS/test-other",
29
+ "LibriTTS/train-clean-100",
30
+ "LibriTTS/train-clean-360",
31
+ "LibriTTS/train-other-500",
32
+ "LJSpeech-1.1",
33
+ "VoxCeleb1/wav",
34
+ "VoxCeleb1/test_wav",
35
+ "VoxCeleb2/dev/aac",
36
+ "VoxCeleb2/test/aac",
37
+ "VCTK-Corpus/wav48",
38
+ ]
39
+
40
+ # Maximum of generated wavs to keep on memory
41
+ MAX_WAVS = 15
42
+
43
+
44
+ class Toolbox:
45
+ def __init__(self, datasets_root: Path, models_dir: Path, seed: int=None):
46
+ sys.excepthook = self.excepthook
47
+ self.datasets_root = datasets_root
48
+ self.utterances = set()
49
+ self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
50
+
51
+ self.synthesizer = None # type: Synthesizer
52
+ self.current_wav = None
53
+ self.waves_list = []
54
+ self.waves_count = 0
55
+ self.waves_namelist = []
56
+
57
+ # Check for webrtcvad (enables removal of silences in vocoder output)
58
+ try:
59
+ import webrtcvad
60
+ self.trim_silences = True
61
+ except:
62
+ self.trim_silences = False
63
+
64
+ # Initialize the events and the interface
65
+ self.ui = UI()
66
+ self.reset_ui(models_dir, seed)
67
+ self.setup_events()
68
+ self.ui.start()
69
+
70
+ def excepthook(self, exc_type, exc_value, exc_tb):
71
+ traceback.print_exception(exc_type, exc_value, exc_tb)
72
+ self.ui.log("Exception: %s" % exc_value)
73
+
74
+ def setup_events(self):
75
+ # Dataset, speaker and utterance selection
76
+ self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
77
+ random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
78
+ recognized_datasets,
79
+ level)
80
+ self.ui.random_dataset_button.clicked.connect(random_func(0))
81
+ self.ui.random_speaker_button.clicked.connect(random_func(1))
82
+ self.ui.random_utterance_button.clicked.connect(random_func(2))
83
+ self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
84
+ self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
85
+
86
+ # Model selection
87
+ self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
88
+ def func():
89
+ self.synthesizer = None
90
+ self.ui.synthesizer_box.currentIndexChanged.connect(func)
91
+ self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
92
+
93
+ # Utterance selection
94
+ func = lambda: self.load_from_browser(self.ui.browse_file())
95
+ self.ui.browser_browse_button.clicked.connect(func)
96
+ func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
97
+ self.ui.utterance_history.currentIndexChanged.connect(func)
98
+ func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
99
+ self.ui.play_button.clicked.connect(func)
100
+ self.ui.stop_button.clicked.connect(self.ui.stop)
101
+ self.ui.record_button.clicked.connect(self.record)
102
+
103
+ #Audio
104
+ self.ui.setup_audio_devices(Synthesizer.sample_rate)
105
+
106
+ #Wav playback & save
107
+ func = lambda: self.replay_last_wav()
108
+ self.ui.replay_wav_button.clicked.connect(func)
109
+ func = lambda: self.export_current_wave()
110
+ self.ui.export_wav_button.clicked.connect(func)
111
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
112
+
113
+ # Generation
114
+ func = lambda: self.synthesize() or self.vocode()
115
+ self.ui.generate_button.clicked.connect(func)
116
+ self.ui.synthesize_button.clicked.connect(self.synthesize)
117
+ self.ui.vocode_button.clicked.connect(self.vocode)
118
+ self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
119
+
120
+ # UMAP legend
121
+ self.ui.clear_button.clicked.connect(self.clear_utterances)
122
+
123
+ def set_current_wav(self, index):
124
+ self.current_wav = self.waves_list[index]
125
+
126
+ def export_current_wave(self):
127
+ self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
128
+
129
+ def replay_last_wav(self):
130
+ self.ui.play(self.current_wav, Synthesizer.sample_rate)
131
+
132
+ def reset_ui(self, models_dir: Path, seed: int=None):
133
+ self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
134
+ self.ui.populate_models(models_dir)
135
+ self.ui.populate_gen_options(seed, self.trim_silences)
136
+
137
+ def load_from_browser(self, fpath=None):
138
+ if fpath is None:
139
+ fpath = Path(self.datasets_root,
140
+ self.ui.current_dataset_name,
141
+ self.ui.current_speaker_name,
142
+ self.ui.current_utterance_name)
143
+ name = str(fpath.relative_to(self.datasets_root))
144
+ speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
145
+
146
+ # Select the next utterance
147
+ if self.ui.auto_next_checkbox.isChecked():
148
+ self.ui.browser_select_next()
149
+ elif fpath == "":
150
+ return
151
+ else:
152
+ name = fpath.name
153
+ speaker_name = fpath.parent.name
154
+
155
+ # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
156
+ # playback, so as to have a fair comparison with the generated audio
157
+ wav = Synthesizer.load_preprocess_wav(fpath)
158
+ self.ui.log("Loaded %s" % name)
159
+
160
+ self.add_real_utterance(wav, name, speaker_name)
161
+
162
+ def record(self):
163
+ wav = self.ui.record_one(encoder.sampling_rate, 5)
164
+ if wav is None:
165
+ return
166
+ self.ui.play(wav, encoder.sampling_rate)
167
+
168
+ speaker_name = "user01"
169
+ name = speaker_name + "_rec_%05d" % np.random.randint(100000)
170
+ self.add_real_utterance(wav, name, speaker_name)
171
+
172
+ def add_real_utterance(self, wav, name, speaker_name):
173
+ # Compute the mel spectrogram
174
+ spec = Synthesizer.make_spectrogram(wav)
175
+ self.ui.draw_spec(spec, "current")
176
+
177
+ # Compute the embedding
178
+ if not encoder.is_loaded():
179
+ self.init_encoder()
180
+ encoder_wav = encoder.preprocess_wav(wav)
181
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
182
+
183
+ # Add the utterance
184
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
185
+ self.utterances.add(utterance)
186
+ self.ui.register_utterance(utterance)
187
+
188
+ # Plot it
189
+ self.ui.draw_embed(embed, name, "current")
190
+ self.ui.draw_umap_projections(self.utterances)
191
+
192
+ def clear_utterances(self):
193
+ self.utterances.clear()
194
+ self.ui.draw_umap_projections(self.utterances)
195
+
196
+ def synthesize(self):
197
+ self.ui.log("Generating the mel spectrogram...")
198
+ self.ui.set_loading(1)
199
+
200
+ # Update the synthesizer random seed
201
+ if self.ui.random_seed_checkbox.isChecked():
202
+ seed = int(self.ui.seed_textbox.text())
203
+ self.ui.populate_gen_options(seed, self.trim_silences)
204
+ else:
205
+ seed = None
206
+
207
+ if seed is not None:
208
+ torch.manual_seed(seed)
209
+
210
+ # Synthesize the spectrogram
211
+ if self.synthesizer is None or seed is not None:
212
+ self.init_synthesizer()
213
+
214
+ texts = self.ui.text_prompt.toPlainText().split("\n")
215
+ embed = self.ui.selected_utterance.embed
216
+ embeds = [embed] * len(texts)
217
+ specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
218
+ breaks = [spec.shape[1] for spec in specs]
219
+ spec = np.concatenate(specs, axis=1)
220
+
221
+ self.ui.draw_spec(spec, "generated")
222
+ self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
223
+ self.ui.set_loading(0)
224
+
225
+ def vocode(self):
226
+ speaker_name, spec, breaks, _ = self.current_generated
227
+ assert spec is not None
228
+
229
+ # Initialize the vocoder model and make it determinstic, if user provides a seed
230
+ if self.ui.random_seed_checkbox.isChecked():
231
+ seed = int(self.ui.seed_textbox.text())
232
+ self.ui.populate_gen_options(seed, self.trim_silences)
233
+ else:
234
+ seed = None
235
+
236
+ if seed is not None:
237
+ torch.manual_seed(seed)
238
+
239
+ # Synthesize the waveform
240
+ if not vocoder.is_loaded() or seed is not None:
241
+ self.init_vocoder()
242
+
243
+ def vocoder_progress(i, seq_len, b_size, gen_rate):
244
+ real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
245
+ line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
246
+ % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
247
+ self.ui.log(line, "overwrite")
248
+ self.ui.set_loading(i, seq_len)
249
+ if self.ui.current_vocoder_fpath is not None:
250
+ self.ui.log("")
251
+ wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
252
+ else:
253
+ self.ui.log("Waveform generation with Griffin-Lim... ")
254
+ wav = Synthesizer.griffin_lim(spec)
255
+ self.ui.set_loading(0)
256
+ self.ui.log(" Done!", "append")
257
+
258
+ # Add breaks
259
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
260
+ b_starts = np.concatenate(([0], b_ends[:-1]))
261
+ wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
262
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
263
+ wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
264
+
265
+ # Trim excessive silences
266
+ if self.ui.trim_silences_checkbox.isChecked():
267
+ wav = encoder.preprocess_wav(wav)
268
+
269
+ # Play it
270
+ wav = wav / np.abs(wav).max() * 0.97
271
+ self.ui.play(wav, Synthesizer.sample_rate)
272
+
273
+ # Name it (history displayed in combobox)
274
+ # TODO better naming for the combobox items?
275
+ wav_name = str(self.waves_count + 1)
276
+
277
+ #Update waves combobox
278
+ self.waves_count += 1
279
+ if self.waves_count > MAX_WAVS:
280
+ self.waves_list.pop()
281
+ self.waves_namelist.pop()
282
+ self.waves_list.insert(0, wav)
283
+ self.waves_namelist.insert(0, wav_name)
284
+
285
+ self.ui.waves_cb.disconnect()
286
+ self.ui.waves_cb_model.setStringList(self.waves_namelist)
287
+ self.ui.waves_cb.setCurrentIndex(0)
288
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
289
+
290
+ # Update current wav
291
+ self.set_current_wav(0)
292
+
293
+ #Enable replay and save buttons:
294
+ self.ui.replay_wav_button.setDisabled(False)
295
+ self.ui.export_wav_button.setDisabled(False)
296
+
297
+ # Compute the embedding
298
+ # TODO: this is problematic with different sampling rates, gotta fix it
299
+ if not encoder.is_loaded():
300
+ self.init_encoder()
301
+ encoder_wav = encoder.preprocess_wav(wav)
302
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
303
+
304
+ # Add the utterance
305
+ name = speaker_name + "_gen_%05d" % np.random.randint(100000)
306
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
307
+ self.utterances.add(utterance)
308
+
309
+ # Plot it
310
+ self.ui.draw_embed(embed, name, "generated")
311
+ self.ui.draw_umap_projections(self.utterances)
312
+
313
+ def init_encoder(self):
314
+ model_fpath = self.ui.current_encoder_fpath
315
+
316
+ self.ui.log("Loading the encoder %s... " % model_fpath)
317
+ self.ui.set_loading(1)
318
+ start = timer()
319
+ encoder.load_model(model_fpath)
320
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
321
+ self.ui.set_loading(0)
322
+
323
+ def init_synthesizer(self):
324
+ model_fpath = self.ui.current_synthesizer_fpath
325
+
326
+ self.ui.log("Loading the synthesizer %s... " % model_fpath)
327
+ self.ui.set_loading(1)
328
+ start = timer()
329
+ self.synthesizer = Synthesizer(model_fpath)
330
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
331
+ self.ui.set_loading(0)
332
+
333
+ def init_vocoder(self):
334
+ model_fpath = self.ui.current_vocoder_fpath
335
+ # Case of Griffin-lim
336
+ if model_fpath is None:
337
+ return
338
+
339
+ self.ui.log("Loading the vocoder %s... " % model_fpath)
340
+ self.ui.set_loading(1)
341
+ start = timer()
342
+ vocoder.load_model(model_fpath)
343
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
344
+ self.ui.set_loading(0)
345
+
346
+ def update_seed_textbox(self):
347
+ self.ui.update_seed_textbox()
toolbox/ui.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from time import sleep
4
+ from typing import List, Set
5
+ from warnings import filterwarnings, warn
6
+
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import sounddevice as sd
10
+ import soundfile as sf
11
+ import umap
12
+ from PyQt5.QtCore import Qt, QStringListModel
13
+ from PyQt5.QtWidgets import *
14
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
15
+
16
+ from encoder.inference import plot_embedding_as_heatmap
17
+ from toolbox.utterance import Utterance
18
+
19
+ filterwarnings("ignore")
20
+
21
+
22
+ colormap = np.array([
23
+ [0, 127, 70],
24
+ [255, 0, 0],
25
+ [255, 217, 38],
26
+ [0, 135, 255],
27
+ [165, 0, 165],
28
+ [255, 167, 255],
29
+ [97, 142, 151],
30
+ [0, 255, 255],
31
+ [255, 96, 38],
32
+ [142, 76, 0],
33
+ [33, 0, 127],
34
+ [0, 0, 0],
35
+ [183, 183, 183],
36
+ [76, 255, 0],
37
+ ], dtype=np.float) / 255
38
+
39
+ default_text = \
40
+ "Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \
41
+ "yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \
42
+ "The synthesizer expects to generate " \
43
+ "outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \
44
+ "Each line will be treated separately.\nThen, they are joined together to make the final " \
45
+ "spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \
46
+ "time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \
47
+ "have the embedding projections. Load or record more utterances to see them.\nIf you have " \
48
+ "at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \
49
+ "utterances are of the same color as the speaker whose voice was used, but they're " \
50
+ "represented with a cross."
51
+
52
+
53
+ class UI(QDialog):
54
+ min_umap_points = 4
55
+ max_log_lines = 5
56
+ max_saved_utterances = 20
57
+
58
+ def draw_utterance(self, utterance: Utterance, which):
59
+ self.draw_spec(utterance.spec, which)
60
+ self.draw_embed(utterance.embed, utterance.name, which)
61
+
62
+ def draw_embed(self, embed, name, which):
63
+ embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
64
+ embed_ax.figure.suptitle("" if embed is None else name)
65
+
66
+ ## Embedding
67
+ # Clear the plot
68
+ if len(embed_ax.images) > 0:
69
+ embed_ax.images[0].colorbar.remove()
70
+ embed_ax.clear()
71
+
72
+ # Draw the embed
73
+ if embed is not None:
74
+ plot_embedding_as_heatmap(embed, embed_ax)
75
+ embed_ax.set_title("embedding")
76
+ embed_ax.set_aspect("equal", "datalim")
77
+ embed_ax.set_xticks([])
78
+ embed_ax.set_yticks([])
79
+ embed_ax.figure.canvas.draw()
80
+
81
+ def draw_spec(self, spec, which):
82
+ _, spec_ax = self.current_ax if which == "current" else self.gen_ax
83
+
84
+ ## Spectrogram
85
+ # Draw the spectrogram
86
+ spec_ax.clear()
87
+ if spec is not None:
88
+ spec_ax.imshow(spec, aspect="auto", interpolation="none")
89
+ spec_ax.set_title("mel spectrogram")
90
+
91
+ spec_ax.set_xticks([])
92
+ spec_ax.set_yticks([])
93
+ spec_ax.figure.canvas.draw()
94
+ if which != "current":
95
+ self.vocode_button.setDisabled(spec is None)
96
+
97
+ def draw_umap_projections(self, utterances: Set[Utterance]):
98
+ self.umap_ax.clear()
99
+
100
+ speakers = np.unique([u.speaker_name for u in utterances])
101
+ colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
102
+ embeds = [u.embed for u in utterances]
103
+
104
+ # Display a message if there aren't enough points
105
+ if len(utterances) < self.min_umap_points:
106
+ self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
107
+ (self.min_umap_points - len(utterances)),
108
+ horizontalalignment='center', fontsize=15)
109
+ self.umap_ax.set_title("")
110
+
111
+ # Compute the projections
112
+ else:
113
+ if not self.umap_hot:
114
+ self.log(
115
+ "Drawing UMAP projections for the first time, this will take a few seconds.")
116
+ self.umap_hot = True
117
+
118
+ reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
119
+ projections = reducer.fit_transform(embeds)
120
+
121
+ speakers_done = set()
122
+ for projection, utterance in zip(projections, utterances):
123
+ color = colors[utterance.speaker_name]
124
+ mark = "x" if "_gen_" in utterance.name else "o"
125
+ label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
126
+ speakers_done.add(utterance.speaker_name)
127
+ self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
128
+ label=label)
129
+ self.umap_ax.legend(prop={'size': 10})
130
+
131
+ # Draw the plot
132
+ self.umap_ax.set_aspect("equal", "datalim")
133
+ self.umap_ax.set_xticks([])
134
+ self.umap_ax.set_yticks([])
135
+ self.umap_ax.figure.canvas.draw()
136
+
137
+ def save_audio_file(self, wav, sample_rate):
138
+ dialog = QFileDialog()
139
+ dialog.setDefaultSuffix(".wav")
140
+ fpath, _ = dialog.getSaveFileName(
141
+ parent=self,
142
+ caption="Select a path to save the audio file",
143
+ filter="Audio Files (*.flac *.wav)"
144
+ )
145
+ if fpath:
146
+ #Default format is wav
147
+ if Path(fpath).suffix == "":
148
+ fpath += ".wav"
149
+ sf.write(fpath, wav, sample_rate)
150
+
151
+ def setup_audio_devices(self, sample_rate):
152
+ input_devices = []
153
+ output_devices = []
154
+ for device in sd.query_devices():
155
+ # Check if valid input
156
+ try:
157
+ sd.check_input_settings(device=device["name"], samplerate=sample_rate)
158
+ input_devices.append(device["name"])
159
+ except:
160
+ pass
161
+
162
+ # Check if valid output
163
+ try:
164
+ sd.check_output_settings(device=device["name"], samplerate=sample_rate)
165
+ output_devices.append(device["name"])
166
+ except Exception as e:
167
+ # Log a warning only if the device is not an input
168
+ if not device["name"] in input_devices:
169
+ warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
170
+
171
+ if len(input_devices) == 0:
172
+ self.log("No audio input device detected. Recording may not work.")
173
+ self.audio_in_device = None
174
+ else:
175
+ self.audio_in_device = input_devices[0]
176
+
177
+ if len(output_devices) == 0:
178
+ self.log("No supported output audio devices were found! Audio output may not work.")
179
+ self.audio_out_devices_cb.addItems(["None"])
180
+ self.audio_out_devices_cb.setDisabled(True)
181
+ else:
182
+ self.audio_out_devices_cb.clear()
183
+ self.audio_out_devices_cb.addItems(output_devices)
184
+ self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
185
+
186
+ self.set_audio_device()
187
+
188
+ def set_audio_device(self):
189
+
190
+ output_device = self.audio_out_devices_cb.currentText()
191
+ if output_device == "None":
192
+ output_device = None
193
+
194
+ # If None, sounddevice queries portaudio
195
+ sd.default.device = (self.audio_in_device, output_device)
196
+
197
+ def play(self, wav, sample_rate):
198
+ try:
199
+ sd.stop()
200
+ sd.play(wav, sample_rate)
201
+ except Exception as e:
202
+ print(e)
203
+ self.log("Error in audio playback. Try selecting a different audio output device.")
204
+ self.log("Your device must be connected before you start the toolbox.")
205
+
206
+ def stop(self):
207
+ sd.stop()
208
+
209
+ def record_one(self, sample_rate, duration):
210
+ self.record_button.setText("Recording...")
211
+ self.record_button.setDisabled(True)
212
+
213
+ self.log("Recording %d seconds of audio" % duration)
214
+ sd.stop()
215
+ try:
216
+ wav = sd.rec(duration * sample_rate, sample_rate, 1)
217
+ except Exception as e:
218
+ print(e)
219
+ self.log("Could not record anything. Is your recording device enabled?")
220
+ self.log("Your device must be connected before you start the toolbox.")
221
+ return None
222
+
223
+ for i in np.arange(0, duration, 0.1):
224
+ self.set_loading(i, duration)
225
+ sleep(0.1)
226
+ self.set_loading(duration, duration)
227
+ sd.wait()
228
+
229
+ self.log("Done recording.")
230
+ self.record_button.setText("Record")
231
+ self.record_button.setDisabled(False)
232
+
233
+ return wav.squeeze()
234
+
235
+ @property
236
+ def current_dataset_name(self):
237
+ return self.dataset_box.currentText()
238
+
239
+ @property
240
+ def current_speaker_name(self):
241
+ return self.speaker_box.currentText()
242
+
243
+ @property
244
+ def current_utterance_name(self):
245
+ return self.utterance_box.currentText()
246
+
247
+ def browse_file(self):
248
+ fpath = QFileDialog().getOpenFileName(
249
+ parent=self,
250
+ caption="Select an audio file",
251
+ filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
252
+ )
253
+ return Path(fpath[0]) if fpath[0] != "" else ""
254
+
255
+ @staticmethod
256
+ def repopulate_box(box, items, random=False):
257
+ """
258
+ Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
259
+ data to the items
260
+ """
261
+ box.blockSignals(True)
262
+ box.clear()
263
+ for item in items:
264
+ item = list(item) if isinstance(item, tuple) else [item]
265
+ box.addItem(str(item[0]), *item[1:])
266
+ if len(items) > 0:
267
+ box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
268
+ box.setDisabled(len(items) == 0)
269
+ box.blockSignals(False)
270
+
271
+ def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
272
+ random=True):
273
+ # Select a random dataset
274
+ if level <= 0:
275
+ if datasets_root is not None:
276
+ datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
277
+ datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
278
+ self.browser_load_button.setDisabled(len(datasets) == 0)
279
+ if datasets_root is None or len(datasets) == 0:
280
+ msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
281
+ if datasets_root is None else "o not have any of the recognized datasets" \
282
+ " in %s" % datasets_root)
283
+ self.log(msg)
284
+ msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
285
+ "can still use the toolbox by recording samples yourself." % \
286
+ ("\n\t".join(recognized_datasets))
287
+ print(msg, file=sys.stderr)
288
+
289
+ self.random_utterance_button.setDisabled(True)
290
+ self.random_speaker_button.setDisabled(True)
291
+ self.random_dataset_button.setDisabled(True)
292
+ self.utterance_box.setDisabled(True)
293
+ self.speaker_box.setDisabled(True)
294
+ self.dataset_box.setDisabled(True)
295
+ self.browser_load_button.setDisabled(True)
296
+ self.auto_next_checkbox.setDisabled(True)
297
+ return
298
+ self.repopulate_box(self.dataset_box, datasets, random)
299
+
300
+ # Select a random speaker
301
+ if level <= 1:
302
+ speakers_root = datasets_root.joinpath(self.current_dataset_name)
303
+ speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
304
+ self.repopulate_box(self.speaker_box, speaker_names, random)
305
+
306
+ # Select a random utterance
307
+ if level <= 2:
308
+ utterances_root = datasets_root.joinpath(
309
+ self.current_dataset_name,
310
+ self.current_speaker_name
311
+ )
312
+ utterances = []
313
+ for extension in ['mp3', 'flac', 'wav', 'm4a']:
314
+ utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
315
+ utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
316
+ self.repopulate_box(self.utterance_box, utterances, random)
317
+
318
+ def browser_select_next(self):
319
+ index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
320
+ self.utterance_box.setCurrentIndex(index)
321
+
322
+ @property
323
+ def current_encoder_fpath(self):
324
+ return self.encoder_box.itemData(self.encoder_box.currentIndex())
325
+
326
+ @property
327
+ def current_synthesizer_fpath(self):
328
+ return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
329
+
330
+ @property
331
+ def current_vocoder_fpath(self):
332
+ return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
333
+
334
+ def populate_models(self, models_dir: Path):
335
+ # Encoder
336
+ encoder_fpaths = list(models_dir.glob("*/encoder.pt"))
337
+ if len(encoder_fpaths) == 0:
338
+ raise Exception("No encoder models found in %s" % models_dir)
339
+ self.repopulate_box(self.encoder_box, [(f.parent.name, f) for f in encoder_fpaths])
340
+
341
+ # Synthesizer
342
+ synthesizer_fpaths = list(models_dir.glob("*/synthesizer.pt"))
343
+ if len(synthesizer_fpaths) == 0:
344
+ raise Exception("No synthesizer models found in %s" % models_dir)
345
+ self.repopulate_box(self.synthesizer_box, [(f.parent.name, f) for f in synthesizer_fpaths])
346
+
347
+ # Vocoder
348
+ vocoder_fpaths = list(models_dir.glob("*/vocoder.pt"))
349
+ vocoder_items = [(f.parent.name, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
350
+ self.repopulate_box(self.vocoder_box, vocoder_items)
351
+
352
+ @property
353
+ def selected_utterance(self):
354
+ return self.utterance_history.itemData(self.utterance_history.currentIndex())
355
+
356
+ def register_utterance(self, utterance: Utterance):
357
+ self.utterance_history.blockSignals(True)
358
+ self.utterance_history.insertItem(0, utterance.name, utterance)
359
+ self.utterance_history.setCurrentIndex(0)
360
+ self.utterance_history.blockSignals(False)
361
+
362
+ if len(self.utterance_history) > self.max_saved_utterances:
363
+ self.utterance_history.removeItem(self.max_saved_utterances)
364
+
365
+ self.play_button.setDisabled(False)
366
+ self.generate_button.setDisabled(False)
367
+ self.synthesize_button.setDisabled(False)
368
+
369
+ def log(self, line, mode="newline"):
370
+ if mode == "newline":
371
+ self.logs.append(line)
372
+ if len(self.logs) > self.max_log_lines:
373
+ del self.logs[0]
374
+ elif mode == "append":
375
+ self.logs[-1] += line
376
+ elif mode == "overwrite":
377
+ self.logs[-1] = line
378
+ log_text = '\n'.join(self.logs)
379
+
380
+ self.log_window.setText(log_text)
381
+ self.app.processEvents()
382
+
383
+ def set_loading(self, value, maximum=1):
384
+ self.loading_bar.setValue(value * 100)
385
+ self.loading_bar.setMaximum(maximum * 100)
386
+ self.loading_bar.setTextVisible(value != 0)
387
+ self.app.processEvents()
388
+
389
+ def populate_gen_options(self, seed, trim_silences):
390
+ if seed is not None:
391
+ self.random_seed_checkbox.setChecked(True)
392
+ self.seed_textbox.setText(str(seed))
393
+ self.seed_textbox.setEnabled(True)
394
+ else:
395
+ self.random_seed_checkbox.setChecked(False)
396
+ self.seed_textbox.setText(str(0))
397
+ self.seed_textbox.setEnabled(False)
398
+
399
+ if not trim_silences:
400
+ self.trim_silences_checkbox.setChecked(False)
401
+ self.trim_silences_checkbox.setDisabled(True)
402
+
403
+ def update_seed_textbox(self):
404
+ if self.random_seed_checkbox.isChecked():
405
+ self.seed_textbox.setEnabled(True)
406
+ else:
407
+ self.seed_textbox.setEnabled(False)
408
+
409
+ def reset_interface(self):
410
+ self.draw_embed(None, None, "current")
411
+ self.draw_embed(None, None, "generated")
412
+ self.draw_spec(None, "current")
413
+ self.draw_spec(None, "generated")
414
+ self.draw_umap_projections(set())
415
+ self.set_loading(0)
416
+ self.play_button.setDisabled(True)
417
+ self.generate_button.setDisabled(True)
418
+ self.synthesize_button.setDisabled(True)
419
+ self.vocode_button.setDisabled(True)
420
+ self.replay_wav_button.setDisabled(True)
421
+ self.export_wav_button.setDisabled(True)
422
+ [self.log("") for _ in range(self.max_log_lines)]
423
+
424
+ def __init__(self):
425
+ ## Initialize the application
426
+ self.app = QApplication(sys.argv)
427
+ super().__init__(None)
428
+ self.setWindowTitle("SV2TTS toolbox")
429
+
430
+
431
+ ## Main layouts
432
+ # Root
433
+ root_layout = QGridLayout()
434
+ self.setLayout(root_layout)
435
+
436
+ # Browser
437
+ browser_layout = QGridLayout()
438
+ root_layout.addLayout(browser_layout, 0, 0, 1, 2)
439
+
440
+ # Generation
441
+ gen_layout = QVBoxLayout()
442
+ root_layout.addLayout(gen_layout, 0, 2, 1, 2)
443
+
444
+ # Projections
445
+ self.projections_layout = QVBoxLayout()
446
+ root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
447
+
448
+ # Visualizations
449
+ vis_layout = QVBoxLayout()
450
+ root_layout.addLayout(vis_layout, 1, 1, 1, 3)
451
+
452
+
453
+ ## Projections
454
+ # UMap
455
+ fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
456
+ fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
457
+ self.projections_layout.addWidget(FigureCanvas(fig))
458
+ self.umap_hot = False
459
+ self.clear_button = QPushButton("Clear")
460
+ self.projections_layout.addWidget(self.clear_button)
461
+
462
+
463
+ ## Browser
464
+ # Dataset, speaker and utterance selection
465
+ i = 0
466
+ self.dataset_box = QComboBox()
467
+ browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
468
+ browser_layout.addWidget(self.dataset_box, i + 1, 0)
469
+ self.speaker_box = QComboBox()
470
+ browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
471
+ browser_layout.addWidget(self.speaker_box, i + 1, 1)
472
+ self.utterance_box = QComboBox()
473
+ browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
474
+ browser_layout.addWidget(self.utterance_box, i + 1, 2)
475
+ self.browser_load_button = QPushButton("Load")
476
+ browser_layout.addWidget(self.browser_load_button, i + 1, 3)
477
+ i += 2
478
+
479
+ # Random buttons
480
+ self.random_dataset_button = QPushButton("Random")
481
+ browser_layout.addWidget(self.random_dataset_button, i, 0)
482
+ self.random_speaker_button = QPushButton("Random")
483
+ browser_layout.addWidget(self.random_speaker_button, i, 1)
484
+ self.random_utterance_button = QPushButton("Random")
485
+ browser_layout.addWidget(self.random_utterance_button, i, 2)
486
+ self.auto_next_checkbox = QCheckBox("Auto select next")
487
+ self.auto_next_checkbox.setChecked(True)
488
+ browser_layout.addWidget(self.auto_next_checkbox, i, 3)
489
+ i += 1
490
+
491
+ # Utterance box
492
+ browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
493
+ self.utterance_history = QComboBox()
494
+ browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
495
+ i += 1
496
+
497
+ # Random & next utterance buttons
498
+ self.browser_browse_button = QPushButton("Browse")
499
+ browser_layout.addWidget(self.browser_browse_button, i, 0)
500
+ self.record_button = QPushButton("Record")
501
+ browser_layout.addWidget(self.record_button, i, 1)
502
+ self.play_button = QPushButton("Play")
503
+ browser_layout.addWidget(self.play_button, i, 2)
504
+ self.stop_button = QPushButton("Stop")
505
+ browser_layout.addWidget(self.stop_button, i, 3)
506
+ i += 1
507
+
508
+
509
+ # Model and audio output selection
510
+ self.encoder_box = QComboBox()
511
+ browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
512
+ browser_layout.addWidget(self.encoder_box, i + 1, 0)
513
+ self.synthesizer_box = QComboBox()
514
+ browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
515
+ browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
516
+ self.vocoder_box = QComboBox()
517
+ browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
518
+ browser_layout.addWidget(self.vocoder_box, i + 1, 2)
519
+
520
+ self.audio_out_devices_cb=QComboBox()
521
+ browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
522
+ browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
523
+ i += 2
524
+
525
+ #Replay & Save Audio
526
+ browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
527
+ self.waves_cb = QComboBox()
528
+ self.waves_cb_model = QStringListModel()
529
+ self.waves_cb.setModel(self.waves_cb_model)
530
+ self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
531
+ browser_layout.addWidget(self.waves_cb, i, 1)
532
+ self.replay_wav_button = QPushButton("Replay")
533
+ self.replay_wav_button.setToolTip("Replay last generated vocoder")
534
+ browser_layout.addWidget(self.replay_wav_button, i, 2)
535
+ self.export_wav_button = QPushButton("Export")
536
+ self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
537
+ browser_layout.addWidget(self.export_wav_button, i, 3)
538
+ i += 1
539
+
540
+
541
+ ## Embed & spectrograms
542
+ vis_layout.addStretch()
543
+
544
+ gridspec_kw = {"width_ratios": [1, 4]}
545
+ fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
546
+ gridspec_kw=gridspec_kw)
547
+ fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
548
+ vis_layout.addWidget(FigureCanvas(fig))
549
+
550
+ fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
551
+ gridspec_kw=gridspec_kw)
552
+ fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
553
+ vis_layout.addWidget(FigureCanvas(fig))
554
+
555
+ for ax in self.current_ax.tolist() + self.gen_ax.tolist():
556
+ ax.set_facecolor("#F0F0F0")
557
+ for side in ["top", "right", "bottom", "left"]:
558
+ ax.spines[side].set_visible(False)
559
+
560
+
561
+ ## Generation
562
+ self.text_prompt = QPlainTextEdit(default_text)
563
+ gen_layout.addWidget(self.text_prompt, stretch=1)
564
+
565
+ self.generate_button = QPushButton("Synthesize and vocode")
566
+ gen_layout.addWidget(self.generate_button)
567
+
568
+ layout = QHBoxLayout()
569
+ self.synthesize_button = QPushButton("Synthesize only")
570
+ layout.addWidget(self.synthesize_button)
571
+ self.vocode_button = QPushButton("Vocode only")
572
+ layout.addWidget(self.vocode_button)
573
+ gen_layout.addLayout(layout)
574
+
575
+ layout_seed = QGridLayout()
576
+ self.random_seed_checkbox = QCheckBox("Random seed:")
577
+ self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
578
+ layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
579
+ self.seed_textbox = QLineEdit()
580
+ self.seed_textbox.setMaximumWidth(80)
581
+ layout_seed.addWidget(self.seed_textbox, 0, 1)
582
+ self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
583
+ self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
584
+ " This feature requires `webrtcvad` to be installed.")
585
+ layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
586
+ gen_layout.addLayout(layout_seed)
587
+
588
+ self.loading_bar = QProgressBar()
589
+ gen_layout.addWidget(self.loading_bar)
590
+
591
+ self.log_window = QLabel()
592
+ self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
593
+ gen_layout.addWidget(self.log_window)
594
+ self.logs = []
595
+ gen_layout.addStretch()
596
+
597
+
598
+ ## Set the size of the window and of the elements
599
+ max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
600
+ self.resize(max_size)
601
+
602
+ ## Finalize the display
603
+ self.reset_interface()
604
+ self.show()
605
+
606
+ def start(self):
607
+ self.app.exec_()
toolbox/utterance.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
4
+ Utterance.__eq__ = lambda x, y: x.name == y.name
5
+ Utterance.__hash__ = lambda x: hash(x.name)
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (157 Bytes). View file
 
utils/__pycache__/argutils.cpython-37.pyc ADDED
Binary file (1.69 kB). View file
 
utils/__pycache__/default_models.cpython-37.pyc ADDED
Binary file (2.26 kB). View file
 
utils/argutils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import argparse
4
+
5
+ _type_priorities = [ # In decreasing order
6
+ Path,
7
+ str,
8
+ int,
9
+ float,
10
+ bool,
11
+ ]
12
+
13
+ def _priority(o):
14
+ p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
15
+ if p is not None:
16
+ return p
17
+ p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
18
+ if p is not None:
19
+ return p
20
+ return len(_type_priorities)
21
+
22
+ def print_args(args: argparse.Namespace, parser=None):
23
+ args = vars(args)
24
+ if parser is None:
25
+ priorities = list(map(_priority, args.values()))
26
+ else:
27
+ all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
28
+ priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
29
+ priorities = list(map(priority, args.keys()))
30
+
31
+ pad = max(map(len, args.keys())) + 3
32
+ indices = np.lexsort((list(args.keys()), priorities))
33
+ items = list(args.items())
34
+
35
+ print("Arguments:")
36
+ for i in indices:
37
+ param, value = items[i]
38
+ print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
39
+ print("")
40
+
utils/default_models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from urllib.error import HTTPError
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ default_models = {
10
+ "encoder": ("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1", 17090379),
11
+ "synthesizer": ("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t", 370554559),
12
+ "vocoder": ("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu", 53845290),
13
+ }
14
+
15
+
16
+ class DownloadProgressBar(tqdm):
17
+ def update_to(self, b=1, bsize=1, tsize=None):
18
+ if tsize is not None:
19
+ self.total = tsize
20
+ self.update(b * bsize - self.n)
21
+
22
+
23
+ def download(url: str, target: Path, bar_pos=0):
24
+ # Ensure the directory exists
25
+ target.parent.mkdir(exist_ok=True, parents=True)
26
+
27
+ desc = f"Downloading {target.name}"
28
+ with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=desc, position=bar_pos, leave=False) as t:
29
+ try:
30
+ urllib.request.urlretrieve(url, filename=target, reporthook=t.update_to)
31
+ except HTTPError:
32
+ return
33
+
34
+
35
+ def ensure_default_models(models_dir: Path):
36
+ # Define download tasks
37
+ jobs = []
38
+ for model_name, (url, size) in default_models.items():
39
+ target_path = models_dir / "default" / f"{model_name}.pt"
40
+ if target_path.exists():
41
+ if target_path.stat().st_size != size:
42
+ print(f"File {target_path} is not of expected size, redownloading...")
43
+ else:
44
+ continue
45
+
46
+ thread = Thread(target=download, args=(url, target_path, len(jobs)))
47
+ thread.start()
48
+ jobs.append((thread, target_path, size))
49
+
50
+ # Run and join threads
51
+ for thread, target_path, size in jobs:
52
+ thread.join()
53
+
54
+ assert target_path.exists() and target_path.stat().st_size == size, \
55
+ f"Download for {target_path.name} failed. You may download models manually instead.\n" \
56
+ f"https://drive.google.com/drive/folders/1fU6umc5uQAVR2udZdHX-lDgXYzTyqG_j"
utils/logmmse.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2015 braindead
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #
24
+ # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
25
+ # simply modified the interface to meet my needs.
26
+
27
+
28
+ import numpy as np
29
+ import math
30
+ from scipy.special import expn
31
+ from collections import namedtuple
32
+
33
+ NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
34
+
35
+
36
+ def profile_noise(noise, sampling_rate, window_size=0):
37
+ """
38
+ Creates a profile of the noise in a given waveform.
39
+
40
+ :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
41
+ :param sampling_rate: the sampling rate of the audio
42
+ :param window_size: the size of the window the logmmse algorithm operates on. A default value
43
+ will be picked if left as 0.
44
+ :return: a NoiseProfile object
45
+ """
46
+ noise, dtype = to_float(noise)
47
+ noise += np.finfo(np.float64).eps
48
+
49
+ if window_size == 0:
50
+ window_size = int(math.floor(0.02 * sampling_rate))
51
+
52
+ if window_size % 2 == 1:
53
+ window_size = window_size + 1
54
+
55
+ perc = 50
56
+ len1 = int(math.floor(window_size * perc / 100))
57
+ len2 = int(window_size - len1)
58
+
59
+ win = np.hanning(window_size)
60
+ win = win * len2 / np.sum(win)
61
+ n_fft = 2 * window_size
62
+
63
+ noise_mean = np.zeros(n_fft)
64
+ n_frames = len(noise) // window_size
65
+ for j in range(0, window_size * n_frames, window_size):
66
+ noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
67
+ noise_mu2 = (noise_mean / n_frames) ** 2
68
+
69
+ return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
70
+
71
+
72
+ def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
73
+ """
74
+ Cleans the noise from a speech waveform given a noise profile. The waveform must have the
75
+ same sampling rate as the one used to create the noise profile.
76
+
77
+ :param wav: a speech waveform as a numpy array of floats or ints.
78
+ :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
79
+ the same) waveform.
80
+ :param eta: voice threshold for noise update. While the voice activation detection value is
81
+ below this threshold, the noise profile will be continuously updated throughout the audio.
82
+ Set to 0 to disable updating the noise profile.
83
+ :return: the clean wav as a numpy array of floats or ints of the same length.
84
+ """
85
+ wav, dtype = to_float(wav)
86
+ wav += np.finfo(np.float64).eps
87
+ p = noise_profile
88
+
89
+ nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
90
+ x_final = np.zeros(nframes * p.len2)
91
+
92
+ aa = 0.98
93
+ mu = 0.98
94
+ ksi_min = 10 ** (-25 / 10)
95
+
96
+ x_old = np.zeros(p.len1)
97
+ xk_prev = np.zeros(p.len1)
98
+ noise_mu2 = p.noise_mu2
99
+ for k in range(0, nframes * p.len2, p.len2):
100
+ insign = p.win * wav[k:k + p.window_size]
101
+
102
+ spec = np.fft.fft(insign, p.n_fft, axis=0)
103
+ sig = np.absolute(spec)
104
+ sig2 = sig ** 2
105
+
106
+ gammak = np.minimum(sig2 / noise_mu2, 40)
107
+
108
+ if xk_prev.all() == 0:
109
+ ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
110
+ else:
111
+ ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
112
+ ksi = np.maximum(ksi_min, ksi)
113
+
114
+ log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
115
+ vad_decision = np.sum(log_sigma_k) / p.window_size
116
+ if vad_decision < eta:
117
+ noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
118
+
119
+ a = ksi / (1 + ksi)
120
+ vk = a * gammak
121
+ ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
122
+ hw = a * np.exp(ei_vk)
123
+ sig = sig * hw
124
+ xk_prev = sig ** 2
125
+ xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
126
+ xi_w = np.real(xi_w)
127
+
128
+ x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
129
+ x_old = xi_w[p.len1:p.window_size]
130
+
131
+ output = from_float(x_final, dtype)
132
+ output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
133
+ return output
134
+
135
+
136
+ ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
137
+ ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
138
+ ## webrctvad
139
+ # def vad(wav, sampling_rate, eta=0.15, window_size=0):
140
+ # """
141
+ # TODO: fix doc
142
+ # Creates a profile of the noise in a given waveform.
143
+ #
144
+ # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
145
+ # :param sampling_rate: the sampling rate of the audio
146
+ # :param window_size: the size of the window the logmmse algorithm operates on. A default value
147
+ # will be picked if left as 0.
148
+ # :param eta: voice threshold for noise update. While the voice activation detection value is
149
+ # below this threshold, the noise profile will be continuously updated throughout the audio.
150
+ # Set to 0 to disable updating the noise profile.
151
+ # """
152
+ # wav, dtype = to_float(wav)
153
+ # wav += np.finfo(np.float64).eps
154
+ #
155
+ # if window_size == 0:
156
+ # window_size = int(math.floor(0.02 * sampling_rate))
157
+ #
158
+ # if window_size % 2 == 1:
159
+ # window_size = window_size + 1
160
+ #
161
+ # perc = 50
162
+ # len1 = int(math.floor(window_size * perc / 100))
163
+ # len2 = int(window_size - len1)
164
+ #
165
+ # win = np.hanning(window_size)
166
+ # win = win * len2 / np.sum(win)
167
+ # n_fft = 2 * window_size
168
+ #
169
+ # wav_mean = np.zeros(n_fft)
170
+ # n_frames = len(wav) // window_size
171
+ # for j in range(0, window_size * n_frames, window_size):
172
+ # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
173
+ # noise_mu2 = (wav_mean / n_frames) ** 2
174
+ #
175
+ # wav, dtype = to_float(wav)
176
+ # wav += np.finfo(np.float64).eps
177
+ #
178
+ # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
179
+ # vad = np.zeros(nframes * len2, dtype=np.bool)
180
+ #
181
+ # aa = 0.98
182
+ # mu = 0.98
183
+ # ksi_min = 10 ** (-25 / 10)
184
+ #
185
+ # xk_prev = np.zeros(len1)
186
+ # noise_mu2 = noise_mu2
187
+ # for k in range(0, nframes * len2, len2):
188
+ # insign = win * wav[k:k + window_size]
189
+ #
190
+ # spec = np.fft.fft(insign, n_fft, axis=0)
191
+ # sig = np.absolute(spec)
192
+ # sig2 = sig ** 2
193
+ #
194
+ # gammak = np.minimum(sig2 / noise_mu2, 40)
195
+ #
196
+ # if xk_prev.all() == 0:
197
+ # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
198
+ # else:
199
+ # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
200
+ # ksi = np.maximum(ksi_min, ksi)
201
+ #
202
+ # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
203
+ # vad_decision = np.sum(log_sigma_k) / window_size
204
+ # if vad_decision < eta:
205
+ # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
206
+ # print(vad_decision)
207
+ #
208
+ # a = ksi / (1 + ksi)
209
+ # vk = a * gammak
210
+ # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
211
+ # hw = a * np.exp(ei_vk)
212
+ # sig = sig * hw
213
+ # xk_prev = sig ** 2
214
+ #
215
+ # vad[k:k + len2] = vad_decision >= eta
216
+ #
217
+ # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
218
+ # return vad
219
+
220
+
221
+ def to_float(_input):
222
+ if _input.dtype == np.float64:
223
+ return _input, _input.dtype
224
+ elif _input.dtype == np.float32:
225
+ return _input.astype(np.float64), _input.dtype
226
+ elif _input.dtype == np.uint8:
227
+ return (_input - 128) / 128., _input.dtype
228
+ elif _input.dtype == np.int16:
229
+ return _input / 32768., _input.dtype
230
+ elif _input.dtype == np.int32:
231
+ return _input / 2147483648., _input.dtype
232
+ raise ValueError('Unsupported wave file format')
233
+
234
+
235
+ def from_float(_input, dtype):
236
+ if dtype == np.float64:
237
+ return _input, np.float64
238
+ elif dtype == np.float32:
239
+ return _input.astype(np.float32)
240
+ elif dtype == np.uint8:
241
+ return ((_input * 128) + 128).astype(np.uint8)
242
+ elif dtype == np.int16:
243
+ return (_input * 32768).astype(np.int16)
244
+ elif dtype == np.int32:
245
+ print(_input)
246
+ return (_input * 2147483648).astype(np.int32)
247
+ raise ValueError('Unsupported wave file format')