akhaliq3 commited on
Commit
24829a1
1 Parent(s): f01b5b9

spaces demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. 8230_00000.mp3 +0 -0
  2. demo_cli.py +225 -0
  3. demo_toolbox.py +43 -0
  4. encoder/__init__.py +0 -0
  5. encoder/audio.py +117 -0
  6. encoder/config.py +45 -0
  7. encoder/data_objects/__init__.py +2 -0
  8. encoder/data_objects/random_cycler.py +37 -0
  9. encoder/data_objects/speaker.py +40 -0
  10. encoder/data_objects/speaker_batch.py +12 -0
  11. encoder/data_objects/speaker_verification_dataset.py +56 -0
  12. encoder/data_objects/utterance.py +26 -0
  13. encoder/inference.py +178 -0
  14. encoder/model.py +135 -0
  15. encoder/params_data.py +29 -0
  16. encoder/params_model.py +11 -0
  17. encoder/preprocess.py +175 -0
  18. encoder/train.py +123 -0
  19. encoder/visualizations.py +178 -0
  20. encoder_preprocess.py +70 -0
  21. encoder_train.py +47 -0
  22. requirements.txt +16 -0
  23. synthesizer/LICENSE.txt +24 -0
  24. synthesizer/__init__.py +1 -0
  25. synthesizer/audio.py +206 -0
  26. synthesizer/hparams.py +92 -0
  27. synthesizer/inference.py +171 -0
  28. synthesizer/models/tacotron.py +519 -0
  29. synthesizer/preprocess.py +259 -0
  30. synthesizer/synthesize.py +97 -0
  31. synthesizer/synthesizer_dataset.py +92 -0
  32. synthesizer/train.py +269 -0
  33. synthesizer/utils/__init__.py +45 -0
  34. synthesizer/utils/_cmudict.py +62 -0
  35. synthesizer/utils/cleaners.py +88 -0
  36. synthesizer/utils/numbers.py +68 -0
  37. synthesizer/utils/plot.py +76 -0
  38. synthesizer/utils/symbols.py +17 -0
  39. synthesizer/utils/text.py +74 -0
  40. synthesizer_preprocess_audio.py +59 -0
  41. synthesizer_preprocess_embeds.py +25 -0
  42. synthesizer_train.py +35 -0
  43. toolbox/__init__.py +357 -0
  44. toolbox/ui.py +611 -0
  45. toolbox/utterance.py +5 -0
  46. utils/__init__.py +0 -0
  47. utils/argutils.py +40 -0
  48. utils/logmmse.py +247 -0
  49. utils/modelutils.py +17 -0
  50. utils/profiler.py +45 -0
8230_00000.mp3 ADDED
Binary file (16.1 kB). View file
demo_cli.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import model_embedding_size as speaker_embedding_size
2
+ from utils.argutils import print_args
3
+ from utils.modelutils import check_model_paths
4
+ from synthesizer.inference import Synthesizer
5
+ from encoder import inference as encoder
6
+ from vocoder import inference as vocoder
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import librosa
11
+ import argparse
12
+ import torch
13
+ import sys
14
+ import os
15
+ from audioread.exceptions import NoBackendError
16
+
17
+ if __name__ == '__main__':
18
+ ## Info & args
19
+ parser = argparse.ArgumentParser(
20
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
21
+ )
22
+ parser.add_argument("-e", "--enc_model_fpath", type=Path,
23
+ default="encoder/saved_models/pretrained.pt",
24
+ help="Path to a saved encoder")
25
+ parser.add_argument("-s", "--syn_model_fpath", type=Path,
26
+ default="synthesizer/saved_models/pretrained/pretrained.pt",
27
+ help="Path to a saved synthesizer")
28
+ parser.add_argument("-v", "--voc_model_fpath", type=Path,
29
+ default="vocoder/saved_models/pretrained/pretrained.pt",
30
+ help="Path to a saved vocoder")
31
+ parser.add_argument("--cpu", action="store_true", help=\
32
+ "If True, processing is done on CPU, even when a GPU is available.")
33
+ parser.add_argument("--no_sound", action="store_true", help=\
34
+ "If True, audio won't be played.")
35
+ parser.add_argument("--seed", type=int, default=None, help=\
36
+ "Optional random number seed value to make toolbox deterministic.")
37
+ parser.add_argument("--no_mp3_support", action="store_true", help=\
38
+ "If True, disallows loading mp3 files to prevent audioread errors when ffmpeg is not installed.")
39
+ args = parser.parse_args()
40
+ print_args(args, parser)
41
+ if not args.no_sound:
42
+ import sounddevice as sd
43
+
44
+ if args.cpu:
45
+ # Hide GPUs from Pytorch to force CPU processing
46
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
47
+
48
+ if not args.no_mp3_support:
49
+ try:
50
+ librosa.load("samples/1320_00000.mp3")
51
+ except NoBackendError:
52
+ print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
53
+ "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
54
+ exit(-1)
55
+
56
+ print("Running a test of your configuration...\n")
57
+
58
+ if torch.cuda.is_available():
59
+ device_id = torch.cuda.current_device()
60
+ gpu_properties = torch.cuda.get_device_properties(device_id)
61
+ ## Print some environment information (for debugging purposes)
62
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
63
+ "%.1fGb total memory.\n" %
64
+ (torch.cuda.device_count(),
65
+ device_id,
66
+ gpu_properties.name,
67
+ gpu_properties.major,
68
+ gpu_properties.minor,
69
+ gpu_properties.total_memory / 1e9))
70
+ else:
71
+ print("Using CPU for inference.\n")
72
+
73
+ ## Remind the user to download pretrained models if needed
74
+ check_model_paths(encoder_path=args.enc_model_fpath,
75
+ synthesizer_path=args.syn_model_fpath,
76
+ vocoder_path=args.voc_model_fpath)
77
+
78
+ ## Load the models one by one.
79
+ print("Preparing the encoder, the synthesizer and the vocoder...")
80
+ encoder.load_model(args.enc_model_fpath)
81
+ synthesizer = Synthesizer(args.syn_model_fpath)
82
+ vocoder.load_model(args.voc_model_fpath)
83
+
84
+
85
+ ## Run a test
86
+ print("Testing your configuration with small inputs.")
87
+ # Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's
88
+ # sampling rate, which may differ.
89
+ # If you're unfamiliar with digital audio, know that it is encoded as an array of floats
90
+ # (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1.
91
+ # The sampling rate is the number of values (samples) recorded per second, it is set to
92
+ # 16000 for the encoder. Creating an array of length <sampling_rate> will always correspond
93
+ # to an audio of 1 second.
94
+ print("\tTesting the encoder...")
95
+ encoder.embed_utterance(np.zeros(encoder.sampling_rate))
96
+
97
+ # Create a dummy embedding. You would normally use the embedding that encoder.embed_utterance
98
+ # returns, but here we're going to make one ourselves just for the sake of showing that it's
99
+ # possible.
100
+ embed = np.random.rand(speaker_embedding_size)
101
+ # Embeddings are L2-normalized (this isn't important here, but if you want to make your own
102
+ # embeddings it will be).
103
+ embed /= np.linalg.norm(embed)
104
+ # The synthesizer can handle multiple inputs with batching. Let's create another embedding to
105
+ # illustrate that
106
+ embeds = [embed, np.zeros(speaker_embedding_size)]
107
+ texts = ["test 1", "test 2"]
108
+ print("\tTesting the synthesizer... (loading the model will output a lot of text)")
109
+ mels = synthesizer.synthesize_spectrograms(texts, embeds)
110
+
111
+ # The vocoder synthesizes one waveform at a time, but it's more efficient for long ones. We
112
+ # can concatenate the mel spectrograms to a single one.
113
+ mel = np.concatenate(mels, axis=1)
114
+ # The vocoder can take a callback function to display the generation. More on that later. For
115
+ # now we'll simply hide it like this:
116
+ no_action = lambda *args: None
117
+ print("\tTesting the vocoder...")
118
+ # For the sake of making this test short, we'll pass a short target length. The target length
119
+ # is the length of the wav segments that are processed in parallel. E.g. for audio sampled
120
+ # at 16000 Hertz, a target length of 8000 means that the target audio will be cut in chunks of
121
+ # 0.5 seconds which will all be generated together. The parameters here are absurdly short, and
122
+ # that has a detrimental effect on the quality of the audio. The default parameters are
123
+ # recommended in general.
124
+ vocoder.infer_waveform(mel, target=200, overlap=50, progress_callback=no_action)
125
+
126
+ print("All test passed! You can now synthesize speech.\n\n")
127
+
128
+
129
+ ## Interactive speech generation
130
+ print("This is a GUI-less example of interface to SV2TTS. The purpose of this script is to "
131
+ "show how you can interface this project easily with your own. See the source code for "
132
+ "an explanation of what is happening.\n")
133
+
134
+ print("Interactive generation loop")
135
+ num_generated = 0
136
+ while True:
137
+ try:
138
+ # Get the reference audio filepath
139
+ message = "Reference voice: enter an audio filepath of a voice to be cloned (mp3, " \
140
+ "wav, m4a, flac, ...):\n"
141
+ in_fpath = Path(input(message).replace("\"", "").replace("\'", ""))
142
+
143
+ if in_fpath.suffix.lower() == ".mp3" and args.no_mp3_support:
144
+ print("Can't Use mp3 files please try again:")
145
+ continue
146
+ ## Computing the embedding
147
+ # First, we load the wav using the function that the speaker encoder provides. This is
148
+ # important: there is preprocessing that must be applied.
149
+
150
+ # The following two methods are equivalent:
151
+ # - Directly load from the filepath:
152
+ preprocessed_wav = encoder.preprocess_wav(in_fpath)
153
+ # - If the wav is already loaded:
154
+ original_wav, sampling_rate = librosa.load(str(in_fpath))
155
+ preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
156
+ print("Loaded file succesfully")
157
+
158
+ # Then we derive the embedding. There are many functions and parameters that the
159
+ # speaker encoder interfaces. These are mostly for in-depth research. You will typically
160
+ # only use this function (with its default parameters):
161
+ embed = encoder.embed_utterance(preprocessed_wav)
162
+ print("Created the embedding")
163
+
164
+
165
+ ## Generating the spectrogram
166
+ text = input("Write a sentence (+-20 words) to be synthesized:\n")
167
+
168
+ # If seed is specified, reset torch seed and force synthesizer reload
169
+ if args.seed is not None:
170
+ torch.manual_seed(args.seed)
171
+ synthesizer = Synthesizer(args.syn_model_fpath)
172
+
173
+ # The synthesizer works in batch, so you need to put your data in a list or numpy array
174
+ texts = [text]
175
+ embeds = [embed]
176
+ # If you know what the attention layer alignments are, you can retrieve them here by
177
+ # passing return_alignments=True
178
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
179
+ spec = specs[0]
180
+ print("Created the mel spectrogram")
181
+
182
+
183
+ ## Generating the waveform
184
+ print("Synthesizing the waveform:")
185
+
186
+ # If seed is specified, reset torch seed and reload vocoder
187
+ if args.seed is not None:
188
+ torch.manual_seed(args.seed)
189
+ vocoder.load_model(args.voc_model_fpath)
190
+
191
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
192
+ # spectrogram, the more time-efficient the vocoder.
193
+ generated_wav = vocoder.infer_waveform(spec)
194
+
195
+
196
+ ## Post-generation
197
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
198
+ # pad it.
199
+ generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
200
+
201
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
202
+ generated_wav = encoder.preprocess_wav(generated_wav)
203
+
204
+ # Play the audio (non-blocking)
205
+ if not args.no_sound:
206
+ try:
207
+ sd.stop()
208
+ sd.play(generated_wav, synthesizer.sample_rate)
209
+ except sd.PortAudioError as e:
210
+ print("\nCaught exception: %s" % repr(e))
211
+ print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
212
+ except:
213
+ raise
214
+
215
+ # Save it on the disk
216
+ filename = "demo_output_%02d.wav" % num_generated
217
+ print(generated_wav.dtype)
218
+ sf.write(filename, generated_wav.astype(np.float32), synthesizer.sample_rate)
219
+ num_generated += 1
220
+ print("\nSaved output as %s\n\n" % filename)
221
+
222
+
223
+ except Exception as e:
224
+ print("Caught exception: %s" % repr(e))
225
+ print("Restarting\n")
demo_toolbox.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from toolbox import Toolbox
3
+ from utils.argutils import print_args
4
+ from utils.modelutils import check_model_paths
5
+ import argparse
6
+ import os
7
+
8
+
9
+ if __name__ == '__main__':
10
+ parser = argparse.ArgumentParser(
11
+ description="Runs the toolbox",
12
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
13
+ )
14
+
15
+ parser.add_argument("-d", "--datasets_root", type=Path, help= \
16
+ "Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
17
+ "supported datasets.", default=None)
18
+ parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
19
+ help="Directory containing saved encoder models")
20
+ parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
21
+ help="Directory containing saved synthesizer models")
22
+ parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
23
+ help="Directory containing saved vocoder models")
24
+ parser.add_argument("--cpu", action="store_true", help=\
25
+ "If True, processing is done on CPU, even when a GPU is available.")
26
+ parser.add_argument("--seed", type=int, default=None, help=\
27
+ "Optional random number seed value to make toolbox deterministic.")
28
+ parser.add_argument("--no_mp3_support", action="store_true", help=\
29
+ "If True, no mp3 files are allowed.")
30
+ args = parser.parse_args()
31
+ print_args(args, parser)
32
+
33
+ if args.cpu:
34
+ # Hide GPUs from Pytorch to force CPU processing
35
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
36
+ del args.cpu
37
+
38
+ ## Remind the user to download pretrained models if needed
39
+ check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir,
40
+ vocoder_path=args.voc_models_dir)
41
+
42
+ # Launch the toolbox
43
+ Toolbox(**vars(args))
encoder/__init__.py ADDED
File without changes
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ wav,
60
+ sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath, _device)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ sm = cm.ScalarMappable(cmap=cmap)
175
+ sm.set_clim(*color_range)
176
+
177
+ ax.set_xticks([]), ax.set_yticks([])
178
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from encoder.params_data import *
3
+ from encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
122
+ for dataset_name in librispeech_datasets["train"]["other"]:
123
+ # Initialize the preprocessing
124
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
125
+ if not dataset_root:
126
+ return
127
+
128
+ # Preprocess all speakers
129
+ speaker_dirs = list(dataset_root.glob("*"))
130
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
131
+ skip_existing, logger)
132
+
133
+
134
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
135
+ # Initialize the preprocessing
136
+ dataset_name = "VoxCeleb1"
137
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
138
+ if not dataset_root:
139
+ return
140
+
141
+ # Get the contents of the meta file
142
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
143
+ metadata = [line.split("\t") for line in metafile][1:]
144
+
145
+ # Select the ID and the nationality, filter out non-anglophone speakers
146
+ nationalities = {line[0]: line[3] for line in metadata}
147
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
148
+ nationality.lower() in anglophone_nationalites]
149
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
150
+ (len(keep_speaker_ids), len(nationalities)))
151
+
152
+ # Get the speaker directories for anglophone speakers only
153
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
154
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
155
+ speaker_dir.name in keep_speaker_ids]
156
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
157
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
158
+
159
+ # Preprocess all speakers
160
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
161
+ skip_existing, logger)
162
+
163
+
164
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
165
+ # Initialize the preprocessing
166
+ dataset_name = "VoxCeleb2"
167
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
168
+ if not dataset_root:
169
+ return
170
+
171
+ # Get the speaker directories
172
+ # Preprocess all speakers
173
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
174
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
175
+ skip_existing, logger)
encoder/train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.visualizations import Visualizations
2
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from encoder.params_model import *
4
+ from encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # For correct profiling (cuda operations are async)
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize(device)
13
+
14
+
15
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
16
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
17
+ no_visdom: bool):
18
+ # Create a dataset and a dataloader
19
+ dataset = SpeakerVerificationDataset(clean_data_root)
20
+ loader = SpeakerVerificationDataLoader(
21
+ dataset,
22
+ speakers_per_batch,
23
+ utterances_per_speaker,
24
+ num_workers=8,
25
+ )
26
+
27
+ # Setup the device on which to run the forward pass and the loss. These can be different,
28
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
29
+ # hyperparameters) faster on the CPU.
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # FIXME: currently, the gradient is None if loss_device is cuda
32
+ loss_device = torch.device("cpu")
33
+
34
+ # Create the model and the optimizer
35
+ model = SpeakerEncoder(device, loss_device)
36
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
37
+ init_step = 1
38
+
39
+ # Configure file path for the model
40
+ state_fpath = models_dir.joinpath(run_id + ".pt")
41
+ backup_dir = models_dir.joinpath(run_id + "_backups")
42
+
43
+ # Load any existing model
44
+ if not force_restart:
45
+ if state_fpath.exists():
46
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
47
+ checkpoint = torch.load(state_fpath)
48
+ init_step = checkpoint["step"]
49
+ model.load_state_dict(checkpoint["model_state"])
50
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
51
+ optimizer.param_groups[0]["lr"] = learning_rate_init
52
+ else:
53
+ print("No model \"%s\" found, starting training from scratch." % run_id)
54
+ else:
55
+ print("Starting the training from scratch.")
56
+ model.train()
57
+
58
+ # Initialize the visualization environment
59
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
60
+ vis.log_dataset(dataset)
61
+ vis.log_params()
62
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
63
+ vis.log_implementation({"Device": device_name})
64
+
65
+ # Training loop
66
+ profiler = Profiler(summarize_every=10, disabled=False)
67
+ for step, speaker_batch in enumerate(loader, init_step):
68
+ profiler.tick("Blocking, waiting for batch (threaded)")
69
+
70
+ # Forward pass
71
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
72
+ sync(device)
73
+ profiler.tick("Data to %s" % device)
74
+ embeds = model(inputs)
75
+ sync(device)
76
+ profiler.tick("Forward pass")
77
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
78
+ loss, eer = model.loss(embeds_loss)
79
+ sync(loss_device)
80
+ profiler.tick("Loss")
81
+
82
+ # Backward pass
83
+ model.zero_grad()
84
+ loss.backward()
85
+ profiler.tick("Backward pass")
86
+ model.do_gradient_ops()
87
+ optimizer.step()
88
+ profiler.tick("Parameter update")
89
+
90
+ # Update visualizations
91
+ # learning_rate = optimizer.param_groups[0]["lr"]
92
+ vis.update(loss.item(), eer, step)
93
+
94
+ # Draw projections and save them to the backup folder
95
+ if umap_every != 0 and step % umap_every == 0:
96
+ print("Drawing and saving projections (step %d)" % step)
97
+ backup_dir.mkdir(exist_ok=True)
98
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
99
+ embeds = embeds.detach().cpu().numpy()
100
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
101
+ vis.save()
102
+
103
+ # Overwrite the latest version of the model
104
+ if save_every != 0 and step % save_every == 0:
105
+ print("Saving the model (step %d)" % step)
106
+ torch.save({
107
+ "step": step + 1,
108
+ "model_state": model.state_dict(),
109
+ "optimizer_state": optimizer.state_dict(),
110
+ }, state_fpath)
111
+
112
+ # Make a backup
113
+ if backup_every != 0 and step % backup_every == 0:
114
+ print("Making a backup (step %d)" % step)
115
+ backup_dir.mkdir(exist_ok=True)
116
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
117
+ torch.save({
118
+ "step": step + 1,
119
+ "model_state": model.state_dict(),
120
+ "optimizer_state": optimizer.state_dict(),
121
+ }, backup_fpath)
122
+
123
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from encoder import params_data
69
+ from encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
encoder_preprocess.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+ if __name__ == "__main__":
7
+ class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
8
+ pass
9
+
10
+ parser = argparse.ArgumentParser(
11
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
12
+ "writes them to the disk. This will allow you to train the encoder. The "
13
+ "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
14
+ "Ideally, you should have all three. You should extract them as they are "
15
+ "after having downloaded them and put them in a same directory, e.g.:\n"
16
+ "-[datasets_root]\n"
17
+ " -LibriSpeech\n"
18
+ " -train-other-500\n"
19
+ " -VoxCeleb1\n"
20
+ " -wav\n"
21
+ " -vox1_meta.csv\n"
22
+ " -VoxCeleb2\n"
23
+ " -dev",
24
+ formatter_class=MyFormatter
25
+ )
26
+ parser.add_argument("datasets_root", type=Path, help=\
27
+ "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
28
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
29
+ "Path to the output directory that will contain the mel spectrograms. If left out, "
30
+ "defaults to <datasets_root>/SV2TTS/encoder/")
31
+ parser.add_argument("-d", "--datasets", type=str,
32
+ default="librispeech_other,voxceleb1,voxceleb2", help=\
33
+ "Comma-separated list of the name of the datasets you want to preprocess. Only the train "
34
+ "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
35
+ "voxceleb2.")
36
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
37
+ "Whether to skip existing output files with the same name. Useful if this script was "
38
+ "interrupted.")
39
+ parser.add_argument("--no_trim", action="store_true", help=\
40
+ "Preprocess audio without trimming silences (not recommended).")
41
+ args = parser.parse_args()
42
+
43
+ # Verify webrtcvad is available
44
+ if not args.no_trim:
45
+ try:
46
+ import webrtcvad
47
+ except:
48
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
49
+ "noise removal and is recommended. Please install and try again. If installation fails, "
50
+ "use --no_trim to disable this error message.")
51
+ del args.no_trim
52
+
53
+ # Process the arguments
54
+ args.datasets = args.datasets.split(",")
55
+ if not hasattr(args, "out_dir"):
56
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
57
+ assert args.datasets_root.exists()
58
+ args.out_dir.mkdir(exist_ok=True, parents=True)
59
+
60
+ # Preprocess the datasets
61
+ print_args(args, parser)
62
+ preprocess_func = {
63
+ "librispeech_other": preprocess_librispeech,
64
+ "voxceleb1": preprocess_voxceleb1,
65
+ "voxceleb2": preprocess_voxceleb2,
66
+ }
67
+ args = vars(args)
68
+ for dataset in args.pop("datasets"):
69
+ print("Preprocessing %s" % dataset)
70
+ preprocess_func[dataset](**args)
encoder_train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.argutils import print_args
2
+ from encoder.train import train
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+
13
+ parser.add_argument("run_id", type=str, help= \
14
+ "Name for this model instance. If a model state from the same run ID was previously "
15
+ "saved, the training will restart from there. Pass -f to overwrite saved states and "
16
+ "restart from scratch.")
17
+ parser.add_argument("clean_data_root", type=Path, help= \
18
+ "Path to the output directory of encoder_preprocess.py. If you left the default "
19
+ "output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
20
+ parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
21
+ "Path to the output directory that will contain the saved model weights, as well as "
22
+ "backups of those weights and plots generated during training.")
23
+ parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
24
+ "Number of steps between updates of the loss and the plots.")
25
+ parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
26
+ "Number of steps between updates of the umap projection. Set to 0 to never update the "
27
+ "projections.")
28
+ parser.add_argument("-s", "--save_every", type=int, default=500, help= \
29
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
30
+ "model.")
31
+ parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
32
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
33
+ "model.")
34
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
35
+ "Do not load any saved model.")
36
+ parser.add_argument("--visdom_server", type=str, default="http://localhost")
37
+ parser.add_argument("--no_visdom", action="store_true", help= \
38
+ "Disable visdom.")
39
+ args = parser.parse_args()
40
+
41
+ # Process the arguments
42
+ args.models_dir.mkdir(exist_ok=True)
43
+
44
+ # Run the training
45
+ print_args(args, parser)
46
+ train(**vars(args))
47
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ umap-learn
2
+ visdom
3
+ librosa>=0.8.0
4
+ matplotlib>=3.3.0
5
+ numpy==1.19.3; platform_system == "Windows"
6
+ numpy==1.19.4; platform_system != "Windows"
7
+ scipy>=1.0.0
8
+ tqdm
9
+ sounddevice
10
+ SoundFile
11
+ Unidecode
12
+ inflect
13
+ PyQt5
14
+ multiprocess
15
+ numba
16
+ webrtcvad; platform_system != "Windows"
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
1
+ #
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/hparams.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+ sample_rate = 16000,
24
+ n_fft = 800,
25
+ num_mels = 80,
26
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28
+ fmin = 55,
29
+ min_level_db = -100,
30
+ ref_level_db = 20,
31
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33
+ preemphasize = True,
34
+
35
+ ### Tacotron Text-to-Speech (TTS)
36
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37
+ tts_encoder_dims = 256,
38
+ tts_decoder_dims = 128,
39
+ tts_postnet_dims = 512,
40
+ tts_encoder_K = 5,
41
+ tts_lstm_dims = 1024,
42
+ tts_postnet_K = 5,
43
+ tts_num_highways = 4,
44
+ tts_dropout = 0.5,
45
+ tts_cleaner_names = ["english_cleaners"],
46
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
47
+ # For example, for a range of [-4, 4], this
48
+ # will terminate the sequence at the first
49
+ # frame that has all values < -3.4
50
+
51
+ ### Tacotron Training
52
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54
+ (2, 2e-4, 80_000, 12), #
55
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
58
+
59
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61
+ # Set to -1 to generate after completing epoch, or 0 to disable
62
+
63
+ tts_eval_num_samples = 1, # Makes this number of samples
64
+
65
+ ### Data Preprocessing
66
+ max_mel_frames = 900,
67
+ rescale = True,
68
+ rescaling_max = 0.9,
69
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70
+
71
+ ### Mel Visualization and Griffin-Lim
72
+ signal_normalization = True,
73
+ power = 1.5,
74
+ griffin_lim_iters = 60,
75
+
76
+ ### Audio processing options
77
+ fmax = 7600, # Should not exceed (sample_rate // 2)
78
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
79
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82
+ # and [0, max_abs_value] if False
83
+ trim_silence = True, # Use with sample_rate of 16000 for best results
84
+
85
+ ### SV2TTS
86
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
87
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89
+ )
90
+
91
+ def hparams_debug_string():
92
+ return str(hparams)
synthesizer/inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Print some info about the model when it is loaded
89
+ tts_k = self._model.get_step() // 1000
90
+
91
+ simple_table([("Tacotron", str(tts_k) + "k"),
92
+ ("r", self._model.r)])
93
+
94
+ # Preprocess text inputs
95
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
96
+ if not isinstance(embeddings, list):
97
+ embeddings = [embeddings]
98
+
99
+ # Batch inputs
100
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
101
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
102
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
103
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
104
+
105
+ specs = []
106
+ for i, batch in enumerate(batched_inputs, 1):
107
+ if self.verbose:
108
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
109
+
110
+ # Pad texts so they are all the same length
111
+ text_lens = [len(text) for text in batch]
112
+ max_text_len = max(text_lens)
113
+ chars = [pad1d(text, max_text_len) for text in batch]
114
+ chars = np.stack(chars)
115
+
116
+ # Stack speaker embeddings into 2D array for batch processing
117
+ speaker_embeds = np.stack(batched_embeds[i-1])
118
+
119
+ # Convert to tensor
120
+ chars = torch.tensor(chars).long().to(self.device)
121
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
122
+
123
+ # Inference
124
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
125
+ mels = mels.detach().cpu().numpy()
126
+ for m in mels:
127
+ # Trim silence from end of each spectrogram
128
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
129
+ m = m[:, :-1]
130
+ specs.append(m)
131
+
132
+ if self.verbose:
133
+ print("\n\nDone.\n")
134
+ return (specs, alignments) if return_alignments else specs
135
+
136
+ @staticmethod
137
+ def load_preprocess_wav(fpath):
138
+ """
139
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
140
+ train the synthesizer.
141
+ """
142
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
143
+ if hparams.rescale:
144
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
145
+ return wav
146
+
147
+ @staticmethod
148
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
149
+ """
150
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
151
+ were fed to the synthesizer when training.
152
+ """
153
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
154
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
155
+ else:
156
+ wav = fpath_or_wav
157
+
158
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
159
+ return mel_spectrogram
160
+
161
+ @staticmethod
162
+ def griffin_lim(mel):
163
+ """
164
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
165
+ with the same parameters present in hparams.py.
166
+ """
167
+ return audio.inv_mel_spectrogram(mel, hparams)
168
+
169
+
170
+ def pad1d(x, max_len, pad_value=0):
171
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
synthesizer/preprocess.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int,
14
+ skip_existing: bool, hparams, no_alignments: bool,
15
+ datasets_name: str, subfolders: str):
16
+ # Gather the input directories
17
+ dataset_root = datasets_root.joinpath(datasets_name)
18
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
19
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
20
+ assert all(input_dir.exists() for input_dir in input_dirs)
21
+
22
+ # Create the output directories for each output file type
23
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
24
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
25
+
26
+ # Create a metadata file
27
+ metadata_fpath = out_dir.joinpath("train.txt")
28
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
29
+
30
+ # Preprocess the dataset
31
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
32
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
33
+ hparams=hparams, no_alignments=no_alignments)
34
+ job = Pool(n_processes).imap(func, speaker_dirs)
35
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
36
+ for metadatum in speaker_metadata:
37
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
38
+ metadata_file.close()
39
+
40
+ # Verify the contents of the metadata file
41
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
42
+ metadata = [line.split("|") for line in metadata_file]
43
+ mel_frames = sum([int(m[4]) for m in metadata])
44
+ timesteps = sum([int(m[3]) for m in metadata])
45
+ sample_rate = hparams.sample_rate
46
+ hours = (timesteps / sample_rate) / 3600
47
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
48
+ (len(metadata), mel_frames, timesteps, hours))
49
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
50
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
51
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
52
+
53
+
54
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
55
+ metadata = []
56
+ for book_dir in speaker_dir.glob("*"):
57
+ if no_alignments:
58
+ # Gather the utterance audios and texts
59
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
60
+ extensions = ["*.wav", "*.flac", "*.mp3"]
61
+ for extension in extensions:
62
+ wav_fpaths = book_dir.glob(extension)
63
+
64
+ for wav_fpath in wav_fpaths:
65
+ # Load the audio waveform
66
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
67
+ if hparams.rescale:
68
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
69
+
70
+ # Get the corresponding text
71
+ # Check for .txt (for compatibility with other datasets)
72
+ text_fpath = wav_fpath.with_suffix(".txt")
73
+ if not text_fpath.exists():
74
+ # Check for .normalized.txt (LibriTTS)
75
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
76
+ assert text_fpath.exists()
77
+ with text_fpath.open("r") as text_file:
78
+ text = "".join([line for line in text_file])
79
+ text = text.replace("\"", "")
80
+ text = text.strip()
81
+
82
+ # Process the utterance
83
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
84
+ skip_existing, hparams))
85
+ else:
86
+ # Process alignment file (LibriSpeech support)
87
+ # Gather the utterance audios and texts
88
+ try:
89
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
90
+ with alignments_fpath.open("r") as alignments_file:
91
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
92
+ except StopIteration:
93
+ # A few alignment files will be missing
94
+ continue
95
+
96
+ # Iterate over each entry in the alignments file
97
+ for wav_fname, words, end_times in alignments:
98
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
99
+ assert wav_fpath.exists()
100
+ words = words.replace("\"", "").split(",")
101
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
102
+
103
+ # Process each sub-utterance
104
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
105
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
106
+ sub_basename = "%s_%02d" % (wav_fname, i)
107
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
108
+ skip_existing, hparams))
109
+
110
+ return [m for m in metadata if m is not None]
111
+
112
+
113
+ def split_on_silences(wav_fpath, words, end_times, hparams):
114
+ # Load the audio waveform
115
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
116
+ if hparams.rescale:
117
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
118
+
119
+ words = np.array(words)
120
+ start_times = np.array([0.0] + end_times[:-1])
121
+ end_times = np.array(end_times)
122
+ assert len(words) == len(end_times) == len(start_times)
123
+ assert words[0] == "" and words[-1] == ""
124
+
125
+ # Find pauses that are too long
126
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
127
+ mask[0] = mask[-1] = True
128
+ breaks = np.where(mask)[0]
129
+
130
+ # Profile the noise from the silences and perform noise reduction on the waveform
131
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
132
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
133
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
134
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
135
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
136
+ wav = logmmse.denoise(wav, profile, eta=0)
137
+
138
+ # Re-attach segments that are too short
139
+ segments = list(zip(breaks[:-1], breaks[1:]))
140
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
141
+ i = 0
142
+ while i < len(segments) and len(segments) > 1:
143
+ if segment_durations[i] < hparams.utterance_min_duration:
144
+ # See if the segment can be re-attached with the right or the left segment
145
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
146
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
147
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
148
+
149
+ # Do not re-attach if it causes the joined utterance to be too long
150
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
151
+ i += 1
152
+ continue
153
+
154
+ # Re-attach the segment with the neighbour of shortest duration
155
+ j = i - 1 if left_duration <= right_duration else i
156
+ segments[j] = (segments[j][0], segments[j + 1][1])
157
+ segment_durations[j] = joined_duration
158
+ del segments[j + 1], segment_durations[j + 1]
159
+ else:
160
+ i += 1
161
+
162
+ # Split the utterance
163
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
164
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
165
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
166
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
167
+
168
+ # # DEBUG: play the audio segments (run with -n=1)
169
+ # import sounddevice as sd
170
+ # if len(wavs) > 1:
171
+ # print("This sentence was split in %d segments:" % len(wavs))
172
+ # else:
173
+ # print("There are no silences long enough for this sentence to be split:")
174
+ # for wav, text in zip(wavs, texts):
175
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
176
+ # # when playing them. You shouldn't need to do that in your parsers.
177
+ # wav = np.concatenate((wav, [0] * 16000))
178
+ # print("\t%s" % text)
179
+ # sd.play(wav, 16000, blocking=True)
180
+ # print("")
181
+
182
+ return wavs, texts
183
+
184
+
185
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
186
+ skip_existing: bool, hparams):
187
+ ## FOR REFERENCE:
188
+ # For you not to lose your head if you ever wish to change things here or implement your own
189
+ # synthesizer.
190
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
191
+ # - There is no processing done to the audios that will be saved to disk beyond volume
192
+ # normalization (in split_on_silences)
193
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
194
+ # is why we re-apply it on the audio on the side of the vocoder.
195
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
196
+ # without extra padding. This means that you won't have an exact relation between the length
197
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
198
+
199
+
200
+ # Skip existing utterances if needed
201
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
202
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
203
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
204
+ return None
205
+
206
+ # Trim silence
207
+ if hparams.trim_silence:
208
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
209
+
210
+ # Skip utterances that are too short
211
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
212
+ return None
213
+
214
+ # Compute the mel spectrogram
215
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
216
+ mel_frames = mel_spectrogram.shape[1]
217
+
218
+ # Skip utterances that are too long
219
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
220
+ return None
221
+
222
+ # Write the spectrogram, embed and audio to disk
223
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
224
+ np.save(wav_fpath, wav, allow_pickle=False)
225
+
226
+ # Return a tuple describing this training example
227
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
228
+
229
+
230
+ def embed_utterance(fpaths, encoder_model_fpath):
231
+ if not encoder.is_loaded():
232
+ encoder.load_model(encoder_model_fpath)
233
+
234
+ # Compute the speaker embedding of the utterance
235
+ wav_fpath, embed_fpath = fpaths
236
+ wav = np.load(wav_fpath)
237
+ wav = encoder.preprocess_wav(wav)
238
+ embed = encoder.embed_utterance(wav)
239
+ np.save(embed_fpath, embed, allow_pickle=False)
240
+
241
+
242
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
243
+ wav_dir = synthesizer_root.joinpath("audio")
244
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
245
+ assert wav_dir.exists() and metadata_fpath.exists()
246
+ embed_dir = synthesizer_root.joinpath("embeds")
247
+ embed_dir.mkdir(exist_ok=True)
248
+
249
+ # Gather the input wave filepath and the target output embed filepath
250
+ with metadata_fpath.open("r") as metadata_file:
251
+ metadata = [line.split("|") for line in metadata_file]
252
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
253
+
254
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
255
+ # Embed the utterances in separate threads
256
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
257
+ job = Pool(n_processes).imap(func, fpaths)
258
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
259
+
synthesizer/synthesize.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from synthesizer.hparams import hparams_debug_string
4
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
5
+ from synthesizer.models.tacotron import Tacotron
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from synthesizer.utils.symbols import symbols
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import platform
12
+
13
+ def run_synthesis(in_dir, out_dir, model_dir, hparams):
14
+ # This generates ground truth-aligned mels for vocoder training
15
+ synth_dir = Path(out_dir).joinpath("mels_gta")
16
+ synth_dir.mkdir(exist_ok=True)
17
+ print(hparams_debug_string())
18
+
19
+ # Check for GPU
20
+ if torch.cuda.is_available():
21
+ device = torch.device("cuda")
22
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
23
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
24
+ else:
25
+ device = torch.device("cpu")
26
+ print("Synthesizer using device:", device)
27
+
28
+ # Instantiate Tacotron model
29
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
30
+ num_chars=len(symbols),
31
+ encoder_dims=hparams.tts_encoder_dims,
32
+ decoder_dims=hparams.tts_decoder_dims,
33
+ n_mels=hparams.num_mels,
34
+ fft_bins=hparams.num_mels,
35
+ postnet_dims=hparams.tts_postnet_dims,
36
+ encoder_K=hparams.tts_encoder_K,
37
+ lstm_dims=hparams.tts_lstm_dims,
38
+ postnet_K=hparams.tts_postnet_K,
39
+ num_highways=hparams.tts_num_highways,
40
+ dropout=0., # Use zero dropout for gta mels
41
+ stop_threshold=hparams.tts_stop_threshold,
42
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
43
+
44
+ # Load the weights
45
+ model_dir = Path(model_dir)
46
+ model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
47
+ print("\nLoading weights at %s" % model_fpath)
48
+ model.load(model_fpath)
49
+ print("Tacotron weights loaded from step %d" % model.step)
50
+
51
+ # Synthesize using same reduction factor as the model is currently trained
52
+ r = np.int32(model.r)
53
+
54
+ # Set model to eval mode (disable gradient and zoneout)
55
+ model.eval()
56
+
57
+ # Initialize the dataset
58
+ in_dir = Path(in_dir)
59
+ metadata_fpath = in_dir.joinpath("train.txt")
60
+ mel_dir = in_dir.joinpath("mels")
61
+ embed_dir = in_dir.joinpath("embeds")
62
+
63
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
64
+ data_loader = DataLoader(dataset,
65
+ collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
66
+ batch_size=hparams.synthesis_batch_size,
67
+ num_workers=2 if platform.system() != "Windows" else 0,
68
+ shuffle=False,
69
+ pin_memory=True)
70
+
71
+ # Generate GTA mels
72
+ meta_out_fpath = Path(out_dir).joinpath("synthesized.txt")
73
+ with open(meta_out_fpath, "w") as file:
74
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
75
+ texts = texts.to(device)
76
+ mels = mels.to(device)
77
+ embeds = embeds.to(device)
78
+
79
+ # Parallelize model onto GPUS using workaround due to python bug
80
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
81
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
82
+ else:
83
+ _, mels_out, _, _ = model(texts, mels, embeds)
84
+
85
+ for j, k in enumerate(idx):
86
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
87
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
88
+ mel_out = mels_out[j].detach().cpu().numpy().T
89
+
90
+ # Use the length of the ground truth mel to remove padding from the generated mels
91
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
92
+
93
+ # Write the spectrogram to disk
94
+ np.save(mel_filename, mel_out, allow_pickle=False)
95
+
96
+ # Write metadata into the synthesized file
97
+ file.write("|".join(dataset.metadata[k]))
synthesizer/synthesizer_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from synthesizer.utils.text import text_to_sequence
6
+
7
+
8
+ class SynthesizerDataset(Dataset):
9
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
10
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
11
+
12
+ with metadata_fpath.open("r") as metadata_file:
13
+ metadata = [line.split("|") for line in metadata_file]
14
+
15
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
16
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
17
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
18
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
19
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
20
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
21
+ self.metadata = metadata
22
+ self.hparams = hparams
23
+
24
+ print("Found %d samples" % len(self.samples_fpaths))
25
+
26
+ def __getitem__(self, index):
27
+ # Sometimes index may be a list of 2 (not sure why this happens)
28
+ # If that is the case, return a single item corresponding to first element in index
29
+ if index is list:
30
+ index = index[0]
31
+
32
+ mel_path, embed_path = self.samples_fpaths[index]
33
+ mel = np.load(mel_path).T.astype(np.float32)
34
+
35
+ # Load the embed
36
+ embed = np.load(embed_path)
37
+
38
+ # Get the text and clean it
39
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
40
+
41
+ # Convert the list returned by text_to_sequence to a numpy array
42
+ text = np.asarray(text).astype(np.int32)
43
+
44
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
45
+
46
+ def __len__(self):
47
+ return len(self.samples_fpaths)
48
+
49
+
50
+ def collate_synthesizer(batch, r, hparams):
51
+ # Text
52
+ x_lens = [len(x[0]) for x in batch]
53
+ max_x_len = max(x_lens)
54
+
55
+ chars = [pad1d(x[0], max_x_len) for x in batch]
56
+ chars = np.stack(chars)
57
+
58
+ # Mel spectrogram
59
+ spec_lens = [x[1].shape[-1] for x in batch]
60
+ max_spec_len = max(spec_lens) + 1
61
+ if max_spec_len % r != 0:
62
+ max_spec_len += r - max_spec_len % r
63
+
64
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
65
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
66
+ if hparams.symmetric_mels:
67
+ mel_pad_value = -1 * hparams.max_abs_value
68
+ else:
69
+ mel_pad_value = 0
70
+
71
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
72
+ mel = np.stack(mel)
73
+
74
+ # Speaker embedding (SV2TTS)
75
+ embeds = [x[2] for x in batch]
76
+
77
+ # Index (for vocoder preprocessing)
78
+ indices = [x[3] for x in batch]
79
+
80
+
81
+ # Convert all to tensor
82
+ chars = torch.tensor(chars).long()
83
+ mel = torch.tensor(mel)
84
+ embeds = torch.tensor(embeds)
85
+
86
+ return chars, mel, embeds, indices
87
+
88
+ def pad1d(x, max_len, pad_value=0):
89
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
90
+
91
+ def pad2d(x, max_len, pad_value=0):
92
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
synthesizer/train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import optim
4
+ from torch.utils.data import DataLoader
5
+ from synthesizer import audio
6
+ from synthesizer.models.tacotron import Tacotron
7
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
8
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
9
+ from synthesizer.utils.plot import plot_spectrogram
10
+ from synthesizer.utils.symbols import symbols
11
+ from synthesizer.utils.text import sequence_to_text
12
+ from vocoder.display import *
13
+ from datetime import datetime
14
+ import numpy as np
15
+ from pathlib import Path
16
+ import sys
17
+ import time
18
+ import platform
19
+
20
+
21
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
27
+ backup_every: int, force_restart:bool, hparams):
28
+
29
+ syn_dir = Path(syn_dir)
30
+ models_dir = Path(models_dir)
31
+ models_dir.mkdir(exist_ok=True)
32
+
33
+ model_dir = models_dir.joinpath(run_id)
34
+ plot_dir = model_dir.joinpath("plots")
35
+ wav_dir = model_dir.joinpath("wavs")
36
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
37
+ meta_folder = model_dir.joinpath("metas")
38
+ model_dir.mkdir(exist_ok=True)
39
+ plot_dir.mkdir(exist_ok=True)
40
+ wav_dir.mkdir(exist_ok=True)
41
+ mel_output_dir.mkdir(exist_ok=True)
42
+ meta_folder.mkdir(exist_ok=True)
43
+
44
+ weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
45
+ metadata_fpath = syn_dir.joinpath("train.txt")
46
+
47
+ print("Checkpoint path: {}".format(weights_fpath))
48
+ print("Loading training data from: {}".format(metadata_fpath))
49
+ print("Using model: Tacotron")
50
+
51
+ # Book keeping
52
+ step = 0
53
+ time_window = ValueWindow(100)
54
+ loss_window = ValueWindow(100)
55
+
56
+
57
+ # From WaveRNN/train_tacotron.py
58
+ if torch.cuda.is_available():
59
+ device = torch.device("cuda")
60
+
61
+ for session in hparams.tts_schedule:
62
+ _, _, _, batch_size = session
63
+ if batch_size % torch.cuda.device_count() != 0:
64
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
65
+ else:
66
+ device = torch.device("cpu")
67
+ print("Using device:", device)
68
+
69
+ # Instantiate Tacotron Model
70
+ print("\nInitialising Tacotron Model...\n")
71
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
72
+ num_chars=len(symbols),
73
+ encoder_dims=hparams.tts_encoder_dims,
74
+ decoder_dims=hparams.tts_decoder_dims,
75
+ n_mels=hparams.num_mels,
76
+ fft_bins=hparams.num_mels,
77
+ postnet_dims=hparams.tts_postnet_dims,
78
+ encoder_K=hparams.tts_encoder_K,
79
+ lstm_dims=hparams.tts_lstm_dims,
80
+ postnet_K=hparams.tts_postnet_K,
81
+ num_highways=hparams.tts_num_highways,
82
+ dropout=hparams.tts_dropout,
83
+ stop_threshold=hparams.tts_stop_threshold,
84
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
85
+
86
+ # Initialize the optimizer
87
+ optimizer = optim.Adam(model.parameters())
88
+
89
+ # Load the weights
90
+ if force_restart or not weights_fpath.exists():
91
+ print("\nStarting the training of Tacotron from scratch\n")
92
+ model.save(weights_fpath)
93
+
94
+ # Embeddings metadata
95
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
96
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
97
+ for symbol in symbols:
98
+ if symbol == " ":
99
+ symbol = "\\s" # For visual purposes, swap space with \s
100
+
101
+ f.write("{}\n".format(symbol))
102
+
103
+ else:
104
+ print("\nLoading weights at %s" % weights_fpath)
105
+ model.load(weights_fpath, optimizer)
106
+ print("Tacotron weights loaded from step %d" % model.step)
107
+
108
+ # Initialize the dataset
109
+ metadata_fpath = syn_dir.joinpath("train.txt")
110
+ mel_dir = syn_dir.joinpath("mels")
111
+ embed_dir = syn_dir.joinpath("embeds")
112
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
113
+ test_loader = DataLoader(dataset,
114
+ batch_size=1,
115
+ shuffle=True,
116
+ pin_memory=True)
117
+
118
+ for i, session in enumerate(hparams.tts_schedule):
119
+ current_step = model.get_step()
120
+
121
+ r, lr, max_step, batch_size = session
122
+
123
+ training_steps = max_step - current_step
124
+
125
+ # Do we need to change to the next session?
126
+ if current_step >= max_step:
127
+ # Are there no further sessions than the current one?
128
+ if i == len(hparams.tts_schedule) - 1:
129
+ # We have completed training. Save the model and exit
130
+ model.save(weights_fpath, optimizer)
131
+ break
132
+ else:
133
+ # There is a following session, go to it
134
+ continue
135
+
136
+ model.r = r
137
+
138
+ # Begin the training
139
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
140
+ ("Batch Size", batch_size),
141
+ ("Learning Rate", lr),
142
+ ("Outputs/Step (r)", model.r)])
143
+
144
+ for p in optimizer.param_groups:
145
+ p["lr"] = lr
146
+
147
+ data_loader = DataLoader(dataset,
148
+ collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
149
+ batch_size=batch_size,
150
+ num_workers=2 if platform.system() != "Windows" else 0,
151
+ shuffle=True,
152
+ pin_memory=True)
153
+
154
+ total_iters = len(dataset)
155
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
156
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
157
+
158
+ for epoch in range(1, epochs+1):
159
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
160
+ start_time = time.time()
161
+
162
+ # Generate stop tokens for training
163
+ stop = torch.ones(mels.shape[0], mels.shape[2])
164
+ for j, k in enumerate(idx):
165
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
166
+
167
+ texts = texts.to(device)
168
+ mels = mels.to(device)
169
+ embeds = embeds.to(device)
170
+ stop = stop.to(device)
171
+
172
+ # Forward pass
173
+ # Parallelize model onto GPUS using workaround due to python bug
174
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
175
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
176
+ mels, embeds)
177
+ else:
178
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
179
+
180
+ # Backward pass
181
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
182
+ m2_loss = F.mse_loss(m2_hat, mels)
183
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
184
+
185
+ loss = m1_loss + m2_loss + stop_loss
186
+
187
+ optimizer.zero_grad()
188
+ loss.backward()
189
+
190
+ if hparams.tts_clip_grad_norm is not None:
191
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
192
+ if np.isnan(grad_norm.cpu()):
193
+ print("grad_norm was NaN!")
194
+
195
+ optimizer.step()
196
+
197
+ time_window.append(time.time() - start_time)
198
+ loss_window.append(loss.item())
199
+
200
+ step = model.get_step()
201
+ k = step // 1000
202
+
203
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
204
+ stream(msg)
205
+
206
+ # Backup or save model as appropriate
207
+ if backup_every != 0 and step % backup_every == 0 :
208
+ backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
209
+ model.save(backup_fpath, optimizer)
210
+
211
+ if save_every != 0 and step % save_every == 0 :
212
+ # Must save latest optimizer state to ensure that resuming training
213
+ # doesn't produce artifacts
214
+ model.save(weights_fpath, optimizer)
215
+
216
+ # Evaluate model to generate samples
217
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
218
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
219
+ if epoch_eval or step_eval:
220
+ for sample_idx in range(hparams.tts_eval_num_samples):
221
+ # At most, generate samples equal to number in the batch
222
+ if sample_idx + 1 <= len(texts):
223
+ # Remove padding from mels using frame length in metadata
224
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
225
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
226
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
227
+ attention_len = mel_length // model.r
228
+
229
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
230
+ mel_prediction=mel_prediction,
231
+ target_spectrogram=target_spectrogram,
232
+ input_seq=np_now(texts[sample_idx]),
233
+ step=step,
234
+ plot_dir=plot_dir,
235
+ mel_output_dir=mel_output_dir,
236
+ wav_dir=wav_dir,
237
+ sample_num=sample_idx + 1,
238
+ loss=loss,
239
+ hparams=hparams)
240
+
241
+ # Break out of loop to update training schedule
242
+ if step >= max_step:
243
+ break
244
+
245
+ # Add line break after every epoch
246
+ print("")
247
+
248
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
249
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
250
+ # Save some results for evaluation
251
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
252
+ save_attention(attention, attention_path)
253
+
254
+ # save predicted mel spectrogram to disk (debug)
255
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
256
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
257
+
258
+ # save griffin lim inverted wav for debug (mel -> wav)
259
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
260
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
261
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
262
+
263
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
264
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
265
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
266
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
267
+ target_spectrogram=target_spectrogram,
268
+ max_len=target_spectrogram.size // hparams.num_mels)
269
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+
13
+ import re
14
+ from unidecode import unidecode
15
+ from .numbers import normalize_numbers
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+ _inflect = inflect.engine()
5
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
6
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
7
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
8
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
9
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
10
+ _number_re = re.compile(r"[0-9]+")
11
+
12
+
13
+ def _remove_commas(m):
14
+ return m.group(1).replace(",", "")
15
+
16
+
17
+ def _expand_decimal_point(m):
18
+ return m.group(1).replace(".", " point ")
19
+
20
+
21
+ def _expand_dollars(m):
22
+ match = m.group(1)
23
+ parts = match.split(".")
24
+ if len(parts) > 2:
25
+ return match + " dollars" # Unexpected format
26
+ dollars = int(parts[0]) if parts[0] else 0
27
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
28
+ if dollars and cents:
29
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
30
+ cent_unit = "cent" if cents == 1 else "cents"
31
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
32
+ elif dollars:
33
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
34
+ return "%s %s" % (dollars, dollar_unit)
35
+ elif cents:
36
+ cent_unit = "cent" if cents == 1 else "cents"
37
+ return "%s %s" % (cents, cent_unit)
38
+ else:
39
+ return "zero dollars"
40
+
41
+
42
+ def _expand_ordinal(m):
43
+ return _inflect.number_to_words(m.group(0))
44
+
45
+
46
+ def _expand_number(m):
47
+ num = int(m.group(0))
48
+ if num > 1000 and num < 3000:
49
+ if num == 2000:
50
+ return "two thousand"
51
+ elif num > 2000 and num < 2010:
52
+ return "two thousand " + _inflect.number_to_words(num % 100)
53
+ elif num % 100 == 0:
54
+ return _inflect.number_to_words(num // 100) + " hundred"
55
+ else:
56
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
57
+ else:
58
+ return _inflect.number_to_words(num, andword="")
59
+
60
+
61
+ def normalize_numbers(text):
62
+ text = re.sub(_comma_number_re, _remove_commas, text)
63
+ text = re.sub(_pounds_re, r"\1 pounds", text)
64
+ text = re.sub(_dollars_re, _expand_dollars, text)
65
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
66
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
67
+ text = re.sub(_number_re, _expand_number, text)
68
+ return text
synthesizer/utils/plot.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ matplotlib.use("Agg")
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+
6
+
7
+ def split_title_line(title_text, max_words=5):
8
+ """
9
+ A function that splits any string based on specific character
10
+ (returning it with the string), with maximum number of words on it
11
+ """
12
+ seq = title_text.split()
13
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
14
+
15
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
16
+ if max_len is not None:
17
+ alignment = alignment[:, :max_len]
18
+
19
+ fig = plt.figure(figsize=(8, 6))
20
+ ax = fig.add_subplot(111)
21
+
22
+ im = ax.imshow(
23
+ alignment,
24
+ aspect="auto",
25
+ origin="lower",
26
+ interpolation="none")
27
+ fig.colorbar(im, ax=ax)
28
+ xlabel = "Decoder timestep"
29
+
30
+ if split_title:
31
+ title = split_title_line(title)
32
+
33
+ plt.xlabel(xlabel)
34
+ plt.title(title)
35
+ plt.ylabel("Encoder timestep")
36
+ plt.tight_layout()
37
+ plt.savefig(path, format="png")
38
+ plt.close()
39
+
40
+
41
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
42
+ if max_len is not None:
43
+ target_spectrogram = target_spectrogram[:max_len]
44
+ pred_spectrogram = pred_spectrogram[:max_len]
45
+
46
+ if split_title:
47
+ title = split_title_line(title)
48
+
49
+ fig = plt.figure(figsize=(10, 8))
50
+ # Set common labels
51
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
52
+
53
+ #target spectrogram subplot
54
+ if target_spectrogram is not None:
55
+ ax1 = fig.add_subplot(311)
56
+ ax2 = fig.add_subplot(312)
57
+
58
+ if auto_aspect:
59
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
60
+ else:
61
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
62
+ ax1.set_title("Target Mel-Spectrogram")
63
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
64
+ ax2.set_title("Predicted Mel-Spectrogram")
65
+ else:
66
+ ax2 = fig.add_subplot(211)
67
+
68
+ if auto_aspect:
69
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
70
+ else:
71
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
72
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
73
+
74
+ plt.tight_layout()
75
+ plt.savefig(path, format="png")
76
+ plt.close()
synthesizer/utils/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
12
+
13
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
15
+
16
+ # Export all symbols:
17
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
synthesizer/utils/text.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import symbols
2
+ from . import cleaners
3
+ import re
4
+
5
+ # Mappings from symbol to numeric ID and vice versa:
6
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
7
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
8
+
9
+ # Regular expression matching text enclosed in curly braces:
10
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
11
+
12
+
13
+ def text_to_sequence(text, cleaner_names):
14
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
15
+
16
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
17
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
18
+
19
+ Args:
20
+ text: string to convert to a sequence
21
+ cleaner_names: names of the cleaner functions to run the text through
22
+
23
+ Returns:
24
+ List of integers corresponding to the symbols in the text
25
+ """
26
+ sequence = []
27
+
28
+ # Check for curly braces and treat their contents as ARPAbet:
29
+ while len(text):
30
+ m = _curly_re.match(text)
31
+ if not m:
32
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
33
+ break
34
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
35
+ sequence += _arpabet_to_sequence(m.group(2))
36
+ text = m.group(3)
37
+
38
+ # Append EOS token
39
+ sequence.append(_symbol_to_id["~"])
40
+ return sequence
41
+
42
+
43
+ def sequence_to_text(sequence):
44
+ """Converts a sequence of IDs back to a string"""
45
+ result = ""
46
+ for symbol_id in sequence:
47
+ if symbol_id in _id_to_symbol:
48
+ s = _id_to_symbol[symbol_id]
49
+ # Enclose ARPAbet back in curly braces:
50
+ if len(s) > 1 and s[0] == "@":
51
+ s = "{%s}" % s[1:]
52
+ result += s
53
+ return result.replace("}{", " ")
54
+
55
+
56
+ def _clean_text(text, cleaner_names):
57
+ for name in cleaner_names:
58
+ cleaner = getattr(cleaners, name)
59
+ if not cleaner:
60
+ raise Exception("Unknown cleaner: %s" % name)
61
+ text = cleaner(text)
62
+ return text
63
+
64
+
65
+ def _symbols_to_sequence(symbols):
66
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
67
+
68
+
69
+ def _arpabet_to_sequence(text):
70
+ return _symbols_to_sequence(["@" + s for s in text.split()])
71
+
72
+
73
+ def _should_keep_symbol(s):
74
+ return s in _symbol_to_id and s not in ("_", "~")
synthesizer_preprocess_audio.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.preprocess import preprocess_dataset
2
+ from synthesizer.hparams import hparams
3
+ from utils.argutils import print_args
4
+ from pathlib import Path
5
+ import argparse
6
+
7
+
8
+ if __name__ == "__main__":
9
+ parser = argparse.ArgumentParser(
10
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms "
11
+ "and writes them to the disk. Audio files are also saved, to be used by the "
12
+ "vocoder for training.",
13
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
14
+ )
15
+ parser.add_argument("datasets_root", type=Path, help=\
16
+ "Path to the directory containing your LibriSpeech/TTS datasets.")
17
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
18
+ "Path to the output directory that will contain the mel spectrograms, the audios and the "
19
+ "embeds. Defaults to <datasets_root>/SV2TTS/synthesizer/")
20
+ parser.add_argument("-n", "--n_processes", type=int, default=None, help=\
21
+ "Number of processes in parallel.")
22
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
23
+ "Whether to overwrite existing files with the same name. Useful if the preprocessing was "
24
+ "interrupted.")
25
+ parser.add_argument("--hparams", type=str, default="", help=\
26
+ "Hyperparameter overrides as a comma-separated list of name-value pairs")
27
+ parser.add_argument("--no_trim", action="store_true", help=\
28
+ "Preprocess audio without trimming silences (not recommended).")
29
+ parser.add_argument("--no_alignments", action="store_true", help=\
30
+ "Use this option when dataset does not include alignments\
31
+ (these are used to split long audio files into sub-utterances.)")
32
+ parser.add_argument("--datasets_name", type=str, default="LibriSpeech", help=\
33
+ "Name of the dataset directory to process.")
34
+ parser.add_argument("--subfolders", type=str, default="train-clean-100, train-clean-360", help=\
35
+ "Comma-separated list of subfolders to process inside your dataset directory")
36
+ args = parser.parse_args()
37
+
38
+ # Process the arguments
39
+ if not hasattr(args, "out_dir"):
40
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "synthesizer")
41
+
42
+ # Create directories
43
+ assert args.datasets_root.exists()
44
+ args.out_dir.mkdir(exist_ok=True, parents=True)
45
+
46
+ # Verify webrtcvad is available
47
+ if not args.no_trim:
48
+ try:
49
+ import webrtcvad
50
+ except:
51
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
52
+ "noise removal and is recommended. Please install and try again. If installation fails, "
53
+ "use --no_trim to disable this error message.")
54
+ del args.no_trim
55
+
56
+ # Preprocess the dataset
57
+ print_args(args, parser)
58
+ args.hparams = hparams.parse(args.hparams)
59
+ preprocess_dataset(**vars(args))
synthesizer_preprocess_embeds.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.preprocess import create_embeddings
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Creates embeddings for the synthesizer from the LibriSpeech utterances.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+ parser.add_argument("synthesizer_root", type=Path, help=\
13
+ "Path to the synthesizer training data that contains the audios and the train.txt file. "
14
+ "If you let everything as default, it should be <datasets_root>/SV2TTS/synthesizer/.")
15
+ parser.add_argument("-e", "--encoder_model_fpath", type=Path,
16
+ default="encoder/saved_models/pretrained.pt", help=\
17
+ "Path your trained encoder model.")
18
+ parser.add_argument("-n", "--n_processes", type=int, default=4, help= \
19
+ "Number of parallel processes. An encoder is created for each, so you may need to lower "
20
+ "this value on GPUs with low memory. Set it to 1 if CUDA is unhappy.")
21
+ args = parser.parse_args()
22
+
23
+ # Preprocess the dataset
24
+ print_args(args, parser)
25
+ create_embeddings(**vars(args))
synthesizer_train.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.hparams import hparams
2
+ from synthesizer.train import train
3
+ from utils.argutils import print_args
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("run_id", type=str, help= \
10
+ "Name for this model instance. If a model state from the same run ID was previously "
11
+ "saved, the training will restart from there. Pass -f to overwrite saved states and "
12
+ "restart from scratch.")
13
+ parser.add_argument("syn_dir", type=str, default=argparse.SUPPRESS, help= \
14
+ "Path to the synthesizer directory that contains the ground truth mel spectrograms, "
15
+ "the wavs and the embeds.")
16
+ parser.add_argument("-m", "--models_dir", type=str, default="synthesizer/saved_models/", help=\
17
+ "Path to the output directory that will contain the saved model weights and the logs.")
18
+ parser.add_argument("-s", "--save_every", type=int, default=1000, help= \
19
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
20
+ "model.")
21
+ parser.add_argument("-b", "--backup_every", type=int, default=25000, help= \
22
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
23
+ "model.")
24
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
25
+ "Do not load any saved model and restart from scratch.")
26
+ parser.add_argument("--hparams", default="",
27
+ help="Hyperparameter overrides as a comma-separated list of name=value "
28
+ "pairs")
29
+ args = parser.parse_args()
30
+ print_args(args, parser)
31
+
32
+ args.hparams = hparams.parse(args.hparams)
33
+
34
+ # Run the training
35
+ train(**vars(args))
toolbox/__init__.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from toolbox.ui import UI
2
+ from encoder import inference as encoder
3
+ from synthesizer.inference import Synthesizer
4
+ from vocoder import inference as vocoder
5
+ from pathlib import Path
6
+ from time import perf_counter as timer
7
+ from toolbox.utterance import Utterance
8
+ import numpy as np
9
+ import traceback
10
+ import sys
11
+ import torch
12
+ import librosa
13
+ from audioread.exceptions import NoBackendError
14
+
15
+ # Use this directory structure for your datasets, or modify it to fit your needs
16
+ recognized_datasets = [
17
+ "LibriSpeech/dev-clean",
18
+ "LibriSpeech/dev-other",
19
+ "LibriSpeech/test-clean",
20
+ "LibriSpeech/test-other",
21
+ "LibriSpeech/train-clean-100",
22
+ "LibriSpeech/train-clean-360",
23
+ "LibriSpeech/train-other-500",
24
+ "LibriTTS/dev-clean",
25
+ "LibriTTS/dev-other",
26
+ "LibriTTS/test-clean",
27
+ "LibriTTS/test-other",
28
+ "LibriTTS/train-clean-100",
29
+ "LibriTTS/train-clean-360",
30
+ "LibriTTS/train-other-500",
31
+ "LJSpeech-1.1",
32
+ "VoxCeleb1/wav",
33
+ "VoxCeleb1/test_wav",
34
+ "VoxCeleb2/dev/aac",
35
+ "VoxCeleb2/test/aac",
36
+ "VCTK-Corpus/wav48",
37
+ ]
38
+
39
+ #Maximum of generated wavs to keep on memory
40
+ MAX_WAVES = 15
41
+
42
+ class Toolbox:
43
+ def __init__(self, datasets_root, enc_models_dir, syn_models_dir, voc_models_dir, seed, no_mp3_support):
44
+ if not no_mp3_support:
45
+ try:
46
+ librosa.load("samples/6829_00000.mp3")
47
+ except NoBackendError:
48
+ print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
49
+ "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
50
+ exit(-1)
51
+ self.no_mp3_support = no_mp3_support
52
+ sys.excepthook = self.excepthook
53
+ self.datasets_root = datasets_root
54
+ self.utterances = set()
55
+ self.current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
56
+
57
+ self.synthesizer = None # type: Synthesizer
58
+ self.current_wav = None
59
+ self.waves_list = []
60
+ self.waves_count = 0
61
+ self.waves_namelist = []
62
+
63
+ # Check for webrtcvad (enables removal of silences in vocoder output)
64
+ try:
65
+ import webrtcvad
66
+ self.trim_silences = True
67
+ except:
68
+ self.trim_silences = False
69
+
70
+ # Initialize the events and the interface
71
+ self.ui = UI()
72
+ self.reset_ui(enc_models_dir, syn_models_dir, voc_models_dir, seed)
73
+ self.setup_events()
74
+ self.ui.start()
75
+
76
+ def excepthook(self, exc_type, exc_value, exc_tb):
77
+ traceback.print_exception(exc_type, exc_value, exc_tb)
78
+ self.ui.log("Exception: %s" % exc_value)
79
+
80
+ def setup_events(self):
81
+ # Dataset, speaker and utterance selection
82
+ self.ui.browser_load_button.clicked.connect(lambda: self.load_from_browser())
83
+ random_func = lambda level: lambda: self.ui.populate_browser(self.datasets_root,
84
+ recognized_datasets,
85
+ level)
86
+ self.ui.random_dataset_button.clicked.connect(random_func(0))
87
+ self.ui.random_speaker_button.clicked.connect(random_func(1))
88
+ self.ui.random_utterance_button.clicked.connect(random_func(2))
89
+ self.ui.dataset_box.currentIndexChanged.connect(random_func(1))
90
+ self.ui.speaker_box.currentIndexChanged.connect(random_func(2))
91
+
92
+ # Model selection
93
+ self.ui.encoder_box.currentIndexChanged.connect(self.init_encoder)
94
+ def func():
95
+ self.synthesizer = None
96
+ self.ui.synthesizer_box.currentIndexChanged.connect(func)
97
+ self.ui.vocoder_box.currentIndexChanged.connect(self.init_vocoder)
98
+
99
+ # Utterance selection
100
+ func = lambda: self.load_from_browser(self.ui.browse_file())
101
+ self.ui.browser_browse_button.clicked.connect(func)
102
+ func = lambda: self.ui.draw_utterance(self.ui.selected_utterance, "current")
103
+ self.ui.utterance_history.currentIndexChanged.connect(func)
104
+ func = lambda: self.ui.play(self.ui.selected_utterance.wav, Synthesizer.sample_rate)
105
+ self.ui.play_button.clicked.connect(func)
106
+ self.ui.stop_button.clicked.connect(self.ui.stop)
107
+ self.ui.record_button.clicked.connect(self.record)
108
+
109
+ #Audio
110
+ self.ui.setup_audio_devices(Synthesizer.sample_rate)
111
+
112
+ #Wav playback & save
113
+ func = lambda: self.replay_last_wav()
114
+ self.ui.replay_wav_button.clicked.connect(func)
115
+ func = lambda: self.export_current_wave()
116
+ self.ui.export_wav_button.clicked.connect(func)
117
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
118
+
119
+ # Generation
120
+ func = lambda: self.synthesize() or self.vocode()
121
+ self.ui.generate_button.clicked.connect(func)
122
+ self.ui.synthesize_button.clicked.connect(self.synthesize)
123
+ self.ui.vocode_button.clicked.connect(self.vocode)
124
+ self.ui.random_seed_checkbox.clicked.connect(self.update_seed_textbox)
125
+
126
+ # UMAP legend
127
+ self.ui.clear_button.clicked.connect(self.clear_utterances)
128
+
129
+ def set_current_wav(self, index):
130
+ self.current_wav = self.waves_list[index]
131
+
132
+ def export_current_wave(self):
133
+ self.ui.save_audio_file(self.current_wav, Synthesizer.sample_rate)
134
+
135
+ def replay_last_wav(self):
136
+ self.ui.play(self.current_wav, Synthesizer.sample_rate)
137
+
138
+ def reset_ui(self, encoder_models_dir, synthesizer_models_dir, vocoder_models_dir, seed):
139
+ self.ui.populate_browser(self.datasets_root, recognized_datasets, 0, True)
140
+ self.ui.populate_models(encoder_models_dir, synthesizer_models_dir, vocoder_models_dir)
141
+ self.ui.populate_gen_options(seed, self.trim_silences)
142
+
143
+ def load_from_browser(self, fpath=None):
144
+ if fpath is None:
145
+ fpath = Path(self.datasets_root,
146
+ self.ui.current_dataset_name,
147
+ self.ui.current_speaker_name,
148
+ self.ui.current_utterance_name)
149
+ name = str(fpath.relative_to(self.datasets_root))
150
+ speaker_name = self.ui.current_dataset_name + '_' + self.ui.current_speaker_name
151
+
152
+ # Select the next utterance
153
+ if self.ui.auto_next_checkbox.isChecked():
154
+ self.ui.browser_select_next()
155
+ elif fpath == "":
156
+ return
157
+ else:
158
+ name = fpath.name
159
+ speaker_name = fpath.parent.name
160
+
161
+ if fpath.suffix.lower() == ".mp3" and self.no_mp3_support:
162
+ self.ui.log("Error: No mp3 file argument was passed but an mp3 file was used")
163
+ return
164
+
165
+ # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
166
+ # playback, so as to have a fair comparison with the generated audio
167
+ wav = Synthesizer.load_preprocess_wav(fpath)
168
+ self.ui.log("Loaded %s" % name)
169
+
170
+ self.add_real_utterance(wav, name, speaker_name)
171
+
172
+ def record(self):
173
+ wav = self.ui.record_one(encoder.sampling_rate, 5)
174
+ if wav is None:
175
+ return
176
+ self.ui.play(wav, encoder.sampling_rate)
177
+
178
+ speaker_name = "user01"
179
+ name = speaker_name + "_rec_%05d" % np.random.randint(100000)
180
+ self.add_real_utterance(wav, name, speaker_name)
181
+
182
+ def add_real_utterance(self, wav, name, speaker_name):
183
+ # Compute the mel spectrogram
184
+ spec = Synthesizer.make_spectrogram(wav)
185
+ self.ui.draw_spec(spec, "current")
186
+
187
+ # Compute the embedding
188
+ if not encoder.is_loaded():
189
+ self.init_encoder()
190
+ encoder_wav = encoder.preprocess_wav(wav)
191
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
192
+
193
+ # Add the utterance
194
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, False)
195
+ self.utterances.add(utterance)
196
+ self.ui.register_utterance(utterance)
197
+
198
+ # Plot it
199
+ self.ui.draw_embed(embed, name, "current")
200
+ self.ui.draw_umap_projections(self.utterances)
201
+
202
+ def clear_utterances(self):
203
+ self.utterances.clear()
204
+ self.ui.draw_umap_projections(self.utterances)
205
+
206
+ def synthesize(self):
207
+ self.ui.log("Generating the mel spectrogram...")
208
+ self.ui.set_loading(1)
209
+
210
+ # Update the synthesizer random seed
211
+ if self.ui.random_seed_checkbox.isChecked():
212
+ seed = int(self.ui.seed_textbox.text())
213
+ self.ui.populate_gen_options(seed, self.trim_silences)
214
+ else:
215
+ seed = None
216
+
217
+ if seed is not None:
218
+ torch.manual_seed(seed)
219
+
220
+ # Synthesize the spectrogram
221
+ if self.synthesizer is None or seed is not None:
222
+ self.init_synthesizer()
223
+
224
+ texts = self.ui.text_prompt.toPlainText().split("\n")
225
+ embed = self.ui.selected_utterance.embed
226
+ embeds = [embed] * len(texts)
227
+ specs = self.synthesizer.synthesize_spectrograms(texts, embeds)
228
+ breaks = [spec.shape[1] for spec in specs]
229
+ spec = np.concatenate(specs, axis=1)
230
+
231
+ self.ui.draw_spec(spec, "generated")
232
+ self.current_generated = (self.ui.selected_utterance.speaker_name, spec, breaks, None)
233
+ self.ui.set_loading(0)
234
+
235
+ def vocode(self):
236
+ speaker_name, spec, breaks, _ = self.current_generated
237
+ assert spec is not None
238
+
239
+ # Initialize the vocoder model and make it determinstic, if user provides a seed
240
+ if self.ui.random_seed_checkbox.isChecked():
241
+ seed = int(self.ui.seed_textbox.text())
242
+ self.ui.populate_gen_options(seed, self.trim_silences)
243
+ else:
244
+ seed = None
245
+
246
+ if seed is not None:
247
+ torch.manual_seed(seed)
248
+
249
+ # Synthesize the waveform
250
+ if not vocoder.is_loaded() or seed is not None:
251
+ self.init_vocoder()
252
+
253
+ def vocoder_progress(i, seq_len, b_size, gen_rate):
254
+ real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
255
+ line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
256
+ % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
257
+ self.ui.log(line, "overwrite")
258
+ self.ui.set_loading(i, seq_len)
259
+ if self.ui.current_vocoder_fpath is not None:
260
+ self.ui.log("")
261
+ wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
262
+ else:
263
+ self.ui.log("Waveform generation with Griffin-Lim... ")
264
+ wav = Synthesizer.griffin_lim(spec)
265
+ self.ui.set_loading(0)
266
+ self.ui.log(" Done!", "append")
267
+
268
+ # Add breaks
269
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
270
+ b_starts = np.concatenate(([0], b_ends[:-1]))
271
+ wavs = [wav[start:end] for start, end, in zip(b_starts, b_ends)]
272
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
273
+ wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
274
+
275
+ # Trim excessive silences
276
+ if self.ui.trim_silences_checkbox.isChecked():
277
+ wav = encoder.preprocess_wav(wav)
278
+
279
+ # Play it
280
+ wav = wav / np.abs(wav).max() * 0.97
281
+ self.ui.play(wav, Synthesizer.sample_rate)
282
+
283
+ # Name it (history displayed in combobox)
284
+ # TODO better naming for the combobox items?
285
+ wav_name = str(self.waves_count + 1)
286
+
287
+ #Update waves combobox
288
+ self.waves_count += 1
289
+ if self.waves_count > MAX_WAVES:
290
+ self.waves_list.pop()
291
+ self.waves_namelist.pop()
292
+ self.waves_list.insert(0, wav)
293
+ self.waves_namelist.insert(0, wav_name)
294
+
295
+ self.ui.waves_cb.disconnect()
296
+ self.ui.waves_cb_model.setStringList(self.waves_namelist)
297
+ self.ui.waves_cb.setCurrentIndex(0)
298
+ self.ui.waves_cb.currentIndexChanged.connect(self.set_current_wav)
299
+
300
+ # Update current wav
301
+ self.set_current_wav(0)
302
+
303
+ #Enable replay and save buttons:
304
+ self.ui.replay_wav_button.setDisabled(False)
305
+ self.ui.export_wav_button.setDisabled(False)
306
+
307
+ # Compute the embedding
308
+ # TODO: this is problematic with different sampling rates, gotta fix it
309
+ if not encoder.is_loaded():
310
+ self.init_encoder()
311
+ encoder_wav = encoder.preprocess_wav(wav)
312
+ embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
313
+
314
+ # Add the utterance
315
+ name = speaker_name + "_gen_%05d" % np.random.randint(100000)
316
+ utterance = Utterance(name, speaker_name, wav, spec, embed, partial_embeds, True)
317
+ self.utterances.add(utterance)
318
+
319
+ # Plot it
320
+ self.ui.draw_embed(embed, name, "generated")
321
+ self.ui.draw_umap_projections(self.utterances)
322
+
323
+ def init_encoder(self):
324
+ model_fpath = self.ui.current_encoder_fpath
325
+
326
+ self.ui.log("Loading the encoder %s... " % model_fpath)
327
+ self.ui.set_loading(1)
328
+ start = timer()
329
+ encoder.load_model(model_fpath)
330
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
331
+ self.ui.set_loading(0)
332
+
333
+ def init_synthesizer(self):
334
+ model_fpath = self.ui.current_synthesizer_fpath
335
+
336
+ self.ui.log("Loading the synthesizer %s... " % model_fpath)
337
+ self.ui.set_loading(1)
338
+ start = timer()
339
+ self.synthesizer = Synthesizer(model_fpath)
340
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
341
+ self.ui.set_loading(0)
342
+
343
+ def init_vocoder(self):
344
+ model_fpath = self.ui.current_vocoder_fpath
345
+ # Case of Griffin-lim
346
+ if model_fpath is None:
347
+ return
348
+
349
+ self.ui.log("Loading the vocoder %s... " % model_fpath)
350
+ self.ui.set_loading(1)
351
+ start = timer()
352
+ vocoder.load_model(model_fpath)
353
+ self.ui.log("Done (%dms)." % int(1000 * (timer() - start)), "append")
354
+ self.ui.set_loading(0)
355
+
356
+ def update_seed_textbox(self):
357
+ self.ui.update_seed_textbox()
toolbox/ui.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
3
+ from matplotlib.figure import Figure
4
+ from PyQt5.QtCore import Qt, QStringListModel
5
+ from PyQt5.QtWidgets import *
6
+ from encoder.inference import plot_embedding_as_heatmap
7
+ from toolbox.utterance import Utterance
8
+ from pathlib import Path
9
+ from typing import List, Set
10
+ import sounddevice as sd
11
+ import soundfile as sf
12
+ import numpy as np
13
+ # from sklearn.manifold import TSNE # You can try with TSNE if you like, I prefer UMAP
14
+ from time import sleep
15
+ import umap
16
+ import sys
17
+ from warnings import filterwarnings, warn
18
+ filterwarnings("ignore")
19
+
20
+
21
+ colormap = np.array([
22
+ [0, 127, 70],
23
+ [255, 0, 0],
24
+ [255, 217, 38],
25
+ [0, 135, 255],
26
+ [165, 0, 165],
27
+ [255, 167, 255],
28
+ [97, 142, 151],
29
+ [0, 255, 255],
30
+ [255, 96, 38],
31
+ [142, 76, 0],
32
+ [33, 0, 127],
33
+ [0, 0, 0],
34
+ [183, 183, 183],
35
+ [76, 255, 0],
36
+ ], dtype=np.float) / 255
37
+
38
+ default_text = \
39
+ "Welcome to the toolbox! To begin, load an utterance from your datasets or record one " \
40
+ "yourself.\nOnce its embedding has been created, you can synthesize any text written here.\n" \
41
+ "The synthesizer expects to generate " \
42
+ "outputs that are somewhere between 5 and 12 seconds.\nTo mark breaks, write a new line. " \
43
+ "Each line will be treated separately.\nThen, they are joined together to make the final " \
44
+ "spectrogram. Use the vocoder to generate audio.\nThe vocoder generates almost in constant " \
45
+ "time, so it will be more time efficient for longer inputs like this one.\nOn the left you " \
46
+ "have the embedding projections. Load or record more utterances to see them.\nIf you have " \
47
+ "at least 2 or 3 utterances from a same speaker, a cluster should form.\nSynthesized " \
48
+ "utterances are of the same color as the speaker whose voice was used, but they're " \
49
+ "represented with a cross."
50
+
51
+
52
+ class UI(QDialog):
53
+ min_umap_points = 4
54
+ max_log_lines = 5
55
+ max_saved_utterances = 20
56
+
57
+ def draw_utterance(self, utterance: Utterance, which):
58
+ self.draw_spec(utterance.spec, which)
59
+ self.draw_embed(utterance.embed, utterance.name, which)
60
+
61
+ def draw_embed(self, embed, name, which):
62
+ embed_ax, _ = self.current_ax if which == "current" else self.gen_ax
63
+ embed_ax.figure.suptitle("" if embed is None else name)
64
+
65
+ ## Embedding
66
+ # Clear the plot
67
+ if len(embed_ax.images) > 0:
68
+ embed_ax.images[0].colorbar.remove()
69
+ embed_ax.clear()
70
+
71
+ # Draw the embed
72
+ if embed is not None:
73
+ plot_embedding_as_heatmap(embed, embed_ax)
74
+ embed_ax.set_title("embedding")
75
+ embed_ax.set_aspect("equal", "datalim")
76
+ embed_ax.set_xticks([])
77
+ embed_ax.set_yticks([])
78
+ embed_ax.figure.canvas.draw()
79
+
80
+ def draw_spec(self, spec, which):
81
+ _, spec_ax = self.current_ax if which == "current" else self.gen_ax
82
+
83
+ ## Spectrogram
84
+ # Draw the spectrogram
85
+ spec_ax.clear()
86
+ if spec is not None:
87
+ im = spec_ax.imshow(spec, aspect="auto", interpolation="none")
88
+ # spec_ax.figure.colorbar(mappable=im, shrink=0.65, orientation="horizontal",
89
+ # spec_ax=spec_ax)
90
+ spec_ax.set_title("mel spectrogram")
91
+
92
+ spec_ax.set_xticks([])
93
+ spec_ax.set_yticks([])
94
+ spec_ax.figure.canvas.draw()
95
+ if which != "current":
96
+ self.vocode_button.setDisabled(spec is None)
97
+
98
+ def draw_umap_projections(self, utterances: Set[Utterance]):
99
+ self.umap_ax.clear()
100
+
101
+ speakers = np.unique([u.speaker_name for u in utterances])
102
+ colors = {speaker_name: colormap[i] for i, speaker_name in enumerate(speakers)}
103
+ embeds = [u.embed for u in utterances]
104
+
105
+ # Display a message if there aren't enough points
106
+ if len(utterances) < self.min_umap_points:
107
+ self.umap_ax.text(.5, .5, "Add %d more points to\ngenerate the projections" %
108
+ (self.min_umap_points - len(utterances)),
109
+ horizontalalignment='center', fontsize=15)
110
+ self.umap_ax.set_title("")
111
+
112
+ # Compute the projections
113
+ else:
114
+ if not self.umap_hot:
115
+ self.log(
116
+ "Drawing UMAP projections for the first time, this will take a few seconds.")
117
+ self.umap_hot = True
118
+
119
+ reducer = umap.UMAP(int(np.ceil(np.sqrt(len(embeds)))), metric="cosine")
120
+ # reducer = TSNE()
121
+ projections = reducer.fit_transform(embeds)
122
+
123
+ speakers_done = set()
124
+ for projection, utterance in zip(projections, utterances):
125
+ color = colors[utterance.speaker_name]
126
+ mark = "x" if "_gen_" in utterance.name else "o"
127
+ label = None if utterance.speaker_name in speakers_done else utterance.speaker_name
128
+ speakers_done.add(utterance.speaker_name)
129
+ self.umap_ax.scatter(projection[0], projection[1], c=[color], marker=mark,
130
+ label=label)
131
+ # self.umap_ax.set_title("UMAP projections")
132
+ self.umap_ax.legend(prop={'size': 10})
133
+
134
+ # Draw the plot
135
+ self.umap_ax.set_aspect("equal", "datalim")
136
+ self.umap_ax.set_xticks([])
137
+ self.umap_ax.set_yticks([])
138
+ self.umap_ax.figure.canvas.draw()
139
+
140
+ def save_audio_file(self, wav, sample_rate):
141
+ dialog = QFileDialog()
142
+ dialog.setDefaultSuffix(".wav")
143
+ fpath, _ = dialog.getSaveFileName(
144
+ parent=self,
145
+ caption="Select a path to save the audio file",
146
+ filter="Audio Files (*.flac *.wav)"
147
+ )
148
+ if fpath:
149
+ #Default format is wav
150
+ if Path(fpath).suffix == "":
151
+ fpath += ".wav"
152
+ sf.write(fpath, wav, sample_rate)
153
+
154
+ def setup_audio_devices(self, sample_rate):
155
+ input_devices = []
156
+ output_devices = []
157
+ for device in sd.query_devices():
158
+ # Check if valid input
159
+ try:
160
+ sd.check_input_settings(device=device["name"], samplerate=sample_rate)
161
+ input_devices.append(device["name"])
162
+ except:
163
+ pass
164
+
165
+ # Check if valid output
166
+ try:
167
+ sd.check_output_settings(device=device["name"], samplerate=sample_rate)
168
+ output_devices.append(device["name"])
169
+ except Exception as e:
170
+ # Log a warning only if the device is not an input
171
+ if not device["name"] in input_devices:
172
+ warn("Unsupported output device %s for the sample rate: %d \nError: %s" % (device["name"], sample_rate, str(e)))
173
+
174
+ if len(input_devices) == 0:
175
+ self.log("No audio input device detected. Recording may not work.")
176
+ self.audio_in_device = None
177
+ else:
178
+ self.audio_in_device = input_devices[0]
179
+
180
+ if len(output_devices) == 0:
181
+ self.log("No supported output audio devices were found! Audio output may not work.")
182
+ self.audio_out_devices_cb.addItems(["None"])
183
+ self.audio_out_devices_cb.setDisabled(True)
184
+ else:
185
+ self.audio_out_devices_cb.clear()
186
+ self.audio_out_devices_cb.addItems(output_devices)
187
+ self.audio_out_devices_cb.currentTextChanged.connect(self.set_audio_device)
188
+
189
+ self.set_audio_device()
190
+
191
+ def set_audio_device(self):
192
+
193
+ output_device = self.audio_out_devices_cb.currentText()
194
+ if output_device == "None":
195
+ output_device = None
196
+
197
+ # If None, sounddevice queries portaudio
198
+ sd.default.device = (self.audio_in_device, output_device)
199
+
200
+ def play(self, wav, sample_rate):
201
+ try:
202
+ sd.stop()
203
+ sd.play(wav, sample_rate)
204
+ except Exception as e:
205
+ print(e)
206
+ self.log("Error in audio playback. Try selecting a different audio output device.")
207
+ self.log("Your device must be connected before you start the toolbox.")
208
+
209
+ def stop(self):
210
+ sd.stop()
211
+
212
+ def record_one(self, sample_rate, duration):
213
+ self.record_button.setText("Recording...")
214
+ self.record_button.setDisabled(True)
215
+
216
+ self.log("Recording %d seconds of audio" % duration)
217
+ sd.stop()
218
+ try:
219
+ wav = sd.rec(duration * sample_rate, sample_rate, 1)
220
+ except Exception as e:
221
+ print(e)
222
+ self.log("Could not record anything. Is your recording device enabled?")
223
+ self.log("Your device must be connected before you start the toolbox.")
224
+ return None
225
+
226
+ for i in np.arange(0, duration, 0.1):
227
+ self.set_loading(i, duration)
228
+ sleep(0.1)
229
+ self.set_loading(duration, duration)
230
+ sd.wait()
231
+
232
+ self.log("Done recording.")
233
+ self.record_button.setText("Record")
234
+ self.record_button.setDisabled(False)
235
+
236
+ return wav.squeeze()
237
+
238
+ @property
239
+ def current_dataset_name(self):
240
+ return self.dataset_box.currentText()
241
+
242
+ @property
243
+ def current_speaker_name(self):
244
+ return self.speaker_box.currentText()
245
+
246
+ @property
247
+ def current_utterance_name(self):
248
+ return self.utterance_box.currentText()
249
+
250
+ def browse_file(self):
251
+ fpath = QFileDialog().getOpenFileName(
252
+ parent=self,
253
+ caption="Select an audio file",
254
+ filter="Audio Files (*.mp3 *.flac *.wav *.m4a)"
255
+ )
256
+ return Path(fpath[0]) if fpath[0] != "" else ""
257
+
258
+ @staticmethod
259
+ def repopulate_box(box, items, random=False):
260
+ """
261
+ Resets a box and adds a list of items. Pass a list of (item, data) pairs instead to join
262
+ data to the items
263
+ """
264
+ box.blockSignals(True)
265
+ box.clear()
266
+ for item in items:
267
+ item = list(item) if isinstance(item, tuple) else [item]
268
+ box.addItem(str(item[0]), *item[1:])
269
+ if len(items) > 0:
270
+ box.setCurrentIndex(np.random.randint(len(items)) if random else 0)
271
+ box.setDisabled(len(items) == 0)
272
+ box.blockSignals(False)
273
+
274
+ def populate_browser(self, datasets_root: Path, recognized_datasets: List, level: int,
275
+ random=True):
276
+ # Select a random dataset
277
+ if level <= 0:
278
+ if datasets_root is not None:
279
+ datasets = [datasets_root.joinpath(d) for d in recognized_datasets]
280
+ datasets = [d.relative_to(datasets_root) for d in datasets if d.exists()]
281
+ self.browser_load_button.setDisabled(len(datasets) == 0)
282
+ if datasets_root is None or len(datasets) == 0:
283
+ msg = "Warning: you d" + ("id not pass a root directory for datasets as argument" \
284
+ if datasets_root is None else "o not have any of the recognized datasets" \
285
+ " in %s" % datasets_root)
286
+ self.log(msg)
287
+ msg += ".\nThe recognized datasets are:\n\t%s\nFeel free to add your own. You " \
288
+ "can still use the toolbox by recording samples yourself." % \
289
+ ("\n\t".join(recognized_datasets))
290
+ print(msg, file=sys.stderr)
291
+
292
+ self.random_utterance_button.setDisabled(True)
293
+ self.random_speaker_button.setDisabled(True)
294
+ self.random_dataset_button.setDisabled(True)
295
+ self.utterance_box.setDisabled(True)
296
+ self.speaker_box.setDisabled(True)
297
+ self.dataset_box.setDisabled(True)
298
+ self.browser_load_button.setDisabled(True)
299
+ self.auto_next_checkbox.setDisabled(True)
300
+ return
301
+ self.repopulate_box(self.dataset_box, datasets, random)
302
+
303
+ # Select a random speaker
304
+ if level <= 1:
305
+ speakers_root = datasets_root.joinpath(self.current_dataset_name)
306
+ speaker_names = [d.stem for d in speakers_root.glob("*") if d.is_dir()]
307
+ self.repopulate_box(self.speaker_box, speaker_names, random)
308
+
309
+ # Select a random utterance
310
+ if level <= 2:
311
+ utterances_root = datasets_root.joinpath(
312
+ self.current_dataset_name,
313
+ self.current_speaker_name
314
+ )
315
+ utterances = []
316
+ for extension in ['mp3', 'flac', 'wav', 'm4a']:
317
+ utterances.extend(Path(utterances_root).glob("**/*.%s" % extension))
318
+ utterances = [fpath.relative_to(utterances_root) for fpath in utterances]
319
+ self.repopulate_box(self.utterance_box, utterances, random)
320
+
321
+ def browser_select_next(self):
322
+ index = (self.utterance_box.currentIndex() + 1) % len(self.utterance_box)
323
+ self.utterance_box.setCurrentIndex(index)
324
+
325
+ @property
326
+ def current_encoder_fpath(self):
327
+ return self.encoder_box.itemData(self.encoder_box.currentIndex())
328
+
329
+ @property
330
+ def current_synthesizer_fpath(self):
331
+ return self.synthesizer_box.itemData(self.synthesizer_box.currentIndex())
332
+
333
+ @property
334
+ def current_vocoder_fpath(self):
335
+ return self.vocoder_box.itemData(self.vocoder_box.currentIndex())
336
+
337
+ def populate_models(self, encoder_models_dir: Path, synthesizer_models_dir: Path,
338
+ vocoder_models_dir: Path):
339
+ # Encoder
340
+ encoder_fpaths = list(encoder_models_dir.glob("*.pt"))
341
+ if len(encoder_fpaths) == 0:
342
+ raise Exception("No encoder models found in %s" % encoder_models_dir)
343
+ self.repopulate_box(self.encoder_box, [(f.stem, f) for f in encoder_fpaths])
344
+
345
+ # Synthesizer
346
+ synthesizer_fpaths = list(synthesizer_models_dir.glob("**/*.pt"))
347
+ if len(synthesizer_fpaths) == 0:
348
+ raise Exception("No synthesizer models found in %s" % synthesizer_models_dir)
349
+ self.repopulate_box(self.synthesizer_box, [(f.stem, f) for f in synthesizer_fpaths])
350
+
351
+ # Vocoder
352
+ vocoder_fpaths = list(vocoder_models_dir.glob("**/*.pt"))
353
+ vocoder_items = [(f.stem, f) for f in vocoder_fpaths] + [("Griffin-Lim", None)]
354
+ self.repopulate_box(self.vocoder_box, vocoder_items)
355
+
356
+ @property
357
+ def selected_utterance(self):
358
+ return self.utterance_history.itemData(self.utterance_history.currentIndex())
359
+
360
+ def register_utterance(self, utterance: Utterance):
361
+ self.utterance_history.blockSignals(True)
362
+ self.utterance_history.insertItem(0, utterance.name, utterance)
363
+ self.utterance_history.setCurrentIndex(0)
364
+ self.utterance_history.blockSignals(False)
365
+
366
+ if len(self.utterance_history) > self.max_saved_utterances:
367
+ self.utterance_history.removeItem(self.max_saved_utterances)
368
+
369
+ self.play_button.setDisabled(False)
370
+ self.generate_button.setDisabled(False)
371
+ self.synthesize_button.setDisabled(False)
372
+
373
+ def log(self, line, mode="newline"):
374
+ if mode == "newline":
375
+ self.logs.append(line)
376
+ if len(self.logs) > self.max_log_lines:
377
+ del self.logs[0]
378
+ elif mode == "append":
379
+ self.logs[-1] += line
380
+ elif mode == "overwrite":
381
+ self.logs[-1] = line
382
+ log_text = '\n'.join(self.logs)
383
+
384
+ self.log_window.setText(log_text)
385
+ self.app.processEvents()
386
+
387
+ def set_loading(self, value, maximum=1):
388
+ self.loading_bar.setValue(value * 100)
389
+ self.loading_bar.setMaximum(maximum * 100)
390
+ self.loading_bar.setTextVisible(value != 0)
391
+ self.app.processEvents()
392
+
393
+ def populate_gen_options(self, seed, trim_silences):
394
+ if seed is not None:
395
+ self.random_seed_checkbox.setChecked(True)
396
+ self.seed_textbox.setText(str(seed))
397
+ self.seed_textbox.setEnabled(True)
398
+ else:
399
+ self.random_seed_checkbox.setChecked(False)
400
+ self.seed_textbox.setText(str(0))
401
+ self.seed_textbox.setEnabled(False)
402
+
403
+ if not trim_silences:
404
+ self.trim_silences_checkbox.setChecked(False)
405
+ self.trim_silences_checkbox.setDisabled(True)
406
+
407
+ def update_seed_textbox(self):
408
+ if self.random_seed_checkbox.isChecked():
409
+ self.seed_textbox.setEnabled(True)
410
+ else:
411
+ self.seed_textbox.setEnabled(False)
412
+
413
+ def reset_interface(self):
414
+ self.draw_embed(None, None, "current")
415
+ self.draw_embed(None, None, "generated")
416
+ self.draw_spec(None, "current")
417
+ self.draw_spec(None, "generated")
418
+ self.draw_umap_projections(set())
419
+ self.set_loading(0)
420
+ self.play_button.setDisabled(True)
421
+ self.generate_button.setDisabled(True)
422
+ self.synthesize_button.setDisabled(True)
423
+ self.vocode_button.setDisabled(True)
424
+ self.replay_wav_button.setDisabled(True)
425
+ self.export_wav_button.setDisabled(True)
426
+ [self.log("") for _ in range(self.max_log_lines)]
427
+
428
+ def __init__(self):
429
+ ## Initialize the application
430
+ self.app = QApplication(sys.argv)
431
+ super().__init__(None)
432
+ self.setWindowTitle("SV2TTS toolbox")
433
+
434
+
435
+ ## Main layouts
436
+ # Root
437
+ root_layout = QGridLayout()
438
+ self.setLayout(root_layout)
439
+
440
+ # Browser
441
+ browser_layout = QGridLayout()
442
+ root_layout.addLayout(browser_layout, 0, 0, 1, 2)
443
+
444
+ # Generation
445
+ gen_layout = QVBoxLayout()
446
+ root_layout.addLayout(gen_layout, 0, 2, 1, 2)
447
+
448
+ # Projections
449
+ self.projections_layout = QVBoxLayout()
450
+ root_layout.addLayout(self.projections_layout, 1, 0, 1, 1)
451
+
452
+ # Visualizations
453
+ vis_layout = QVBoxLayout()
454
+ root_layout.addLayout(vis_layout, 1, 1, 1, 3)
455
+
456
+
457
+ ## Projections
458
+ # UMap
459
+ fig, self.umap_ax = plt.subplots(figsize=(3, 3), facecolor="#F0F0F0")
460
+ fig.subplots_adjust(left=0.02, bottom=0.02, right=0.98, top=0.98)
461
+ self.projections_layout.addWidget(FigureCanvas(fig))
462
+ self.umap_hot = False
463
+ self.clear_button = QPushButton("Clear")
464
+ self.projections_layout.addWidget(self.clear_button)
465
+
466
+
467
+ ## Browser
468
+ # Dataset, speaker and utterance selection
469
+ i = 0
470
+ self.dataset_box = QComboBox()
471
+ browser_layout.addWidget(QLabel("<b>Dataset</b>"), i, 0)
472
+ browser_layout.addWidget(self.dataset_box, i + 1, 0)
473
+ self.speaker_box = QComboBox()
474
+ browser_layout.addWidget(QLabel("<b>Speaker</b>"), i, 1)
475
+ browser_layout.addWidget(self.speaker_box, i + 1, 1)
476
+ self.utterance_box = QComboBox()
477
+ browser_layout.addWidget(QLabel("<b>Utterance</b>"), i, 2)
478
+ browser_layout.addWidget(self.utterance_box, i + 1, 2)
479
+ self.browser_load_button = QPushButton("Load")
480
+ browser_layout.addWidget(self.browser_load_button, i + 1, 3)
481
+ i += 2
482
+
483
+ # Random buttons
484
+ self.random_dataset_button = QPushButton("Random")
485
+ browser_layout.addWidget(self.random_dataset_button, i, 0)
486
+ self.random_speaker_button = QPushButton("Random")
487
+ browser_layout.addWidget(self.random_speaker_button, i, 1)
488
+ self.random_utterance_button = QPushButton("Random")
489
+ browser_layout.addWidget(self.random_utterance_button, i, 2)
490
+ self.auto_next_checkbox = QCheckBox("Auto select next")
491
+ self.auto_next_checkbox.setChecked(True)
492
+ browser_layout.addWidget(self.auto_next_checkbox, i, 3)
493
+ i += 1
494
+
495
+ # Utterance box
496
+ browser_layout.addWidget(QLabel("<b>Use embedding from:</b>"), i, 0)
497
+ self.utterance_history = QComboBox()
498
+ browser_layout.addWidget(self.utterance_history, i, 1, 1, 3)
499
+ i += 1
500
+
501
+ # Random & next utterance buttons
502
+ self.browser_browse_button = QPushButton("Browse")
503
+ browser_layout.addWidget(self.browser_browse_button, i, 0)
504
+ self.record_button = QPushButton("Record")
505
+ browser_layout.addWidget(self.record_button, i, 1)
506
+ self.play_button = QPushButton("Play")
507
+ browser_layout.addWidget(self.play_button, i, 2)
508
+ self.stop_button = QPushButton("Stop")
509
+ browser_layout.addWidget(self.stop_button, i, 3)
510
+ i += 1
511
+
512
+
513
+ # Model and audio output selection
514
+ self.encoder_box = QComboBox()
515
+ browser_layout.addWidget(QLabel("<b>Encoder</b>"), i, 0)
516
+ browser_layout.addWidget(self.encoder_box, i + 1, 0)
517
+ self.synthesizer_box = QComboBox()
518
+ browser_layout.addWidget(QLabel("<b>Synthesizer</b>"), i, 1)
519
+ browser_layout.addWidget(self.synthesizer_box, i + 1, 1)
520
+ self.vocoder_box = QComboBox()
521
+ browser_layout.addWidget(QLabel("<b>Vocoder</b>"), i, 2)
522
+ browser_layout.addWidget(self.vocoder_box, i + 1, 2)
523
+
524
+ self.audio_out_devices_cb=QComboBox()
525
+ browser_layout.addWidget(QLabel("<b>Audio Output</b>"), i, 3)
526
+ browser_layout.addWidget(self.audio_out_devices_cb, i + 1, 3)
527
+ i += 2
528
+
529
+ #Replay & Save Audio
530
+ browser_layout.addWidget(QLabel("<b>Toolbox Output:</b>"), i, 0)
531
+ self.waves_cb = QComboBox()
532
+ self.waves_cb_model = QStringListModel()
533
+ self.waves_cb.setModel(self.waves_cb_model)
534
+ self.waves_cb.setToolTip("Select one of the last generated waves in this section for replaying or exporting")
535
+ browser_layout.addWidget(self.waves_cb, i, 1)
536
+ self.replay_wav_button = QPushButton("Replay")
537
+ self.replay_wav_button.setToolTip("Replay last generated vocoder")
538
+ browser_layout.addWidget(self.replay_wav_button, i, 2)
539
+ self.export_wav_button = QPushButton("Export")
540
+ self.export_wav_button.setToolTip("Save last generated vocoder audio in filesystem as a wav file")
541
+ browser_layout.addWidget(self.export_wav_button, i, 3)
542
+ i += 1
543
+
544
+
545
+ ## Embed & spectrograms
546
+ vis_layout.addStretch()
547
+
548
+ gridspec_kw = {"width_ratios": [1, 4]}
549
+ fig, self.current_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
550
+ gridspec_kw=gridspec_kw)
551
+ fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
552
+ vis_layout.addWidget(FigureCanvas(fig))
553
+
554
+ fig, self.gen_ax = plt.subplots(1, 2, figsize=(10, 2.25), facecolor="#F0F0F0",
555
+ gridspec_kw=gridspec_kw)
556
+ fig.subplots_adjust(left=0, bottom=0.1, right=1, top=0.8)
557
+ vis_layout.addWidget(FigureCanvas(fig))
558
+
559
+ for ax in self.current_ax.tolist() + self.gen_ax.tolist():
560
+ ax.set_facecolor("#F0F0F0")
561
+ for side in ["top", "right", "bottom", "left"]:
562
+ ax.spines[side].set_visible(False)
563
+
564
+
565
+ ## Generation
566
+ self.text_prompt = QPlainTextEdit(default_text)
567
+ gen_layout.addWidget(self.text_prompt, stretch=1)
568
+
569
+ self.generate_button = QPushButton("Synthesize and vocode")
570
+ gen_layout.addWidget(self.generate_button)
571
+
572
+ layout = QHBoxLayout()
573
+ self.synthesize_button = QPushButton("Synthesize only")
574
+ layout.addWidget(self.synthesize_button)
575
+ self.vocode_button = QPushButton("Vocode only")
576
+ layout.addWidget(self.vocode_button)
577
+ gen_layout.addLayout(layout)
578
+
579
+ layout_seed = QGridLayout()
580
+ self.random_seed_checkbox = QCheckBox("Random seed:")
581
+ self.random_seed_checkbox.setToolTip("When checked, makes the synthesizer and vocoder deterministic.")
582
+ layout_seed.addWidget(self.random_seed_checkbox, 0, 0)
583
+ self.seed_textbox = QLineEdit()
584
+ self.seed_textbox.setMaximumWidth(80)
585
+ layout_seed.addWidget(self.seed_textbox, 0, 1)
586
+ self.trim_silences_checkbox = QCheckBox("Enhance vocoder output")
587
+ self.trim_silences_checkbox.setToolTip("When checked, trims excess silence in vocoder output."
588
+ " This feature requires `webrtcvad` to be installed.")
589
+ layout_seed.addWidget(self.trim_silences_checkbox, 0, 2, 1, 2)
590
+ gen_layout.addLayout(layout_seed)
591
+
592
+ self.loading_bar = QProgressBar()
593
+ gen_layout.addWidget(self.loading_bar)
594
+
595
+ self.log_window = QLabel()
596
+ self.log_window.setAlignment(Qt.AlignBottom | Qt.AlignLeft)
597
+ gen_layout.addWidget(self.log_window)
598
+ self.logs = []
599
+ gen_layout.addStretch()
600
+
601
+
602
+ ## Set the size of the window and of the elements
603
+ max_size = QDesktopWidget().availableGeometry(self).size() * 0.8
604
+ self.resize(max_size)
605
+
606
+ ## Finalize the display
607
+ self.reset_interface()
608
+ self.show()
609
+
610
+ def start(self):
611
+ self.app.exec_()
toolbox/utterance.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ Utterance = namedtuple("Utterance", "name speaker_name wav spec embed partial_embeds synth")
4
+ Utterance.__eq__ = lambda x, y: x.name == y.name
5
+ Utterance.__hash__ = lambda x: hash(x.name)
utils/__init__.py ADDED
File without changes
utils/argutils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import argparse
4
+
5
+ _type_priorities = [ # In decreasing order
6
+ Path,
7
+ str,
8
+ int,
9
+ float,
10
+ bool,
11
+ ]
12
+
13
+ def _priority(o):
14
+ p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
15
+ if p is not None:
16
+ return p
17
+ p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
18
+ if p is not None:
19
+ return p
20
+ return len(_type_priorities)
21
+
22
+ def print_args(args: argparse.Namespace, parser=None):
23
+ args = vars(args)
24
+ if parser is None:
25
+ priorities = list(map(_priority, args.values()))
26
+ else:
27
+ all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
28
+ priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
29
+ priorities = list(map(priority, args.keys()))
30
+
31
+ pad = max(map(len, args.keys())) + 3
32
+ indices = np.lexsort((list(args.keys()), priorities))
33
+ items = list(args.items())
34
+
35
+ print("Arguments:")
36
+ for i in indices:
37
+ param, value = items[i]
38
+ print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
39
+ print("")
40
+
utils/logmmse.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2015 braindead
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #
24
+ # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
25
+ # simply modified the interface to meet my needs.
26
+
27
+
28
+ import numpy as np
29
+ import math
30
+ from scipy.special import expn
31
+ from collections import namedtuple
32
+
33
+ NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
34
+
35
+
36
+ def profile_noise(noise, sampling_rate, window_size=0):
37
+ """
38
+ Creates a profile of the noise in a given waveform.
39
+
40
+ :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
41
+ :param sampling_rate: the sampling rate of the audio
42
+ :param window_size: the size of the window the logmmse algorithm operates on. A default value
43
+ will be picked if left as 0.
44
+ :return: a NoiseProfile object
45
+ """
46
+ noise, dtype = to_float(noise)
47
+ noise += np.finfo(np.float64).eps
48
+
49
+ if window_size == 0:
50
+ window_size = int(math.floor(0.02 * sampling_rate))
51
+
52
+ if window_size % 2 == 1:
53
+ window_size = window_size + 1
54
+
55
+ perc = 50
56
+ len1 = int(math.floor(window_size * perc / 100))
57
+ len2 = int(window_size - len1)
58
+
59
+ win = np.hanning(window_size)
60
+ win = win * len2 / np.sum(win)
61
+ n_fft = 2 * window_size
62
+
63
+ noise_mean = np.zeros(n_fft)
64
+ n_frames = len(noise) // window_size
65
+ for j in range(0, window_size * n_frames, window_size):
66
+ noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
67
+ noise_mu2 = (noise_mean / n_frames) ** 2
68
+
69
+ return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
70
+
71
+
72
+ def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
73
+ """
74
+ Cleans the noise from a speech waveform given a noise profile. The waveform must have the
75
+ same sampling rate as the one used to create the noise profile.
76
+
77
+ :param wav: a speech waveform as a numpy array of floats or ints.
78
+ :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
79
+ the same) waveform.
80
+ :param eta: voice threshold for noise update. While the voice activation detection value is
81
+ below this threshold, the noise profile will be continuously updated throughout the audio.
82
+ Set to 0 to disable updating the noise profile.
83
+ :return: the clean wav as a numpy array of floats or ints of the same length.
84
+ """
85
+ wav, dtype = to_float(wav)
86
+ wav += np.finfo(np.float64).eps
87
+ p = noise_profile
88
+
89
+ nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
90
+ x_final = np.zeros(nframes * p.len2)
91
+
92
+ aa = 0.98
93
+ mu = 0.98
94
+ ksi_min = 10 ** (-25 / 10)
95
+
96
+ x_old = np.zeros(p.len1)
97
+ xk_prev = np.zeros(p.len1)
98
+ noise_mu2 = p.noise_mu2
99
+ for k in range(0, nframes * p.len2, p.len2):
100
+ insign = p.win * wav[k:k + p.window_size]
101
+
102
+ spec = np.fft.fft(insign, p.n_fft, axis=0)
103
+ sig = np.absolute(spec)
104
+ sig2 = sig ** 2
105
+
106
+ gammak = np.minimum(sig2 / noise_mu2, 40)
107
+
108
+ if xk_prev.all() == 0:
109
+ ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
110
+ else:
111
+ ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
112
+ ksi = np.maximum(ksi_min, ksi)
113
+
114
+ log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
115
+ vad_decision = np.sum(log_sigma_k) / p.window_size
116
+ if vad_decision < eta:
117
+ noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
118
+
119
+ a = ksi / (1 + ksi)
120
+ vk = a * gammak
121
+ ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
122
+ hw = a * np.exp(ei_vk)
123
+ sig = sig * hw
124
+ xk_prev = sig ** 2
125
+ xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
126
+ xi_w = np.real(xi_w)
127
+
128
+ x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
129
+ x_old = xi_w[p.len1:p.window_size]
130
+
131
+ output = from_float(x_final, dtype)
132
+ output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
133
+ return output
134
+
135
+
136
+ ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
137
+ ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
138
+ ## webrctvad
139
+ # def vad(wav, sampling_rate, eta=0.15, window_size=0):
140
+ # """
141
+ # TODO: fix doc
142
+ # Creates a profile of the noise in a given waveform.
143
+ #
144
+ # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
145
+ # :param sampling_rate: the sampling rate of the audio
146
+ # :param window_size: the size of the window the logmmse algorithm operates on. A default value
147
+ # will be picked if left as 0.
148
+ # :param eta: voice threshold for noise update. While the voice activation detection value is
149
+ # below this threshold, the noise profile will be continuously updated throughout the audio.
150
+ # Set to 0 to disable updating the noise profile.
151
+ # """
152
+ # wav, dtype = to_float(wav)
153
+ # wav += np.finfo(np.float64).eps
154
+ #
155
+ # if window_size == 0:
156
+ # window_size = int(math.floor(0.02 * sampling_rate))
157
+ #
158
+ # if window_size % 2 == 1:
159
+ # window_size = window_size + 1
160
+ #
161
+ # perc = 50
162
+ # len1 = int(math.floor(window_size * perc / 100))
163
+ # len2 = int(window_size - len1)
164
+ #
165
+ # win = np.hanning(window_size)
166
+ # win = win * len2 / np.sum(win)
167
+ # n_fft = 2 * window_size
168
+ #
169
+ # wav_mean = np.zeros(n_fft)
170
+ # n_frames = len(wav) // window_size
171
+ # for j in range(0, window_size * n_frames, window_size):
172
+ # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
173
+ # noise_mu2 = (wav_mean / n_frames) ** 2
174
+ #
175
+ # wav, dtype = to_float(wav)
176
+ # wav += np.finfo(np.float64).eps
177
+ #
178
+ # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
179
+ # vad = np.zeros(nframes * len2, dtype=np.bool)
180
+ #
181
+ # aa = 0.98
182
+ # mu = 0.98
183
+ # ksi_min = 10 ** (-25 / 10)
184
+ #
185
+ # xk_prev = np.zeros(len1)
186
+ # noise_mu2 = noise_mu2
187
+ # for k in range(0, nframes * len2, len2):
188
+ # insign = win * wav[k:k + window_size]
189
+ #
190
+ # spec = np.fft.fft(insign, n_fft, axis=0)
191
+ # sig = np.absolute(spec)
192
+ # sig2 = sig ** 2
193
+ #
194
+ # gammak = np.minimum(sig2 / noise_mu2, 40)
195
+ #
196
+ # if xk_prev.all() == 0:
197
+ # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
198
+ # else:
199
+ # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
200
+ # ksi = np.maximum(ksi_min, ksi)
201
+ #
202
+ # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
203
+ # vad_decision = np.sum(log_sigma_k) / window_size
204
+ # if vad_decision < eta:
205
+ # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
206
+ # print(vad_decision)
207
+ #
208
+ # a = ksi / (1 + ksi)
209
+ # vk = a * gammak
210
+ # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
211
+ # hw = a * np.exp(ei_vk)
212
+ # sig = sig * hw
213
+ # xk_prev = sig ** 2
214
+ #
215
+ # vad[k:k + len2] = vad_decision >= eta
216
+ #
217
+ # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
218
+ # return vad
219
+
220
+
221
+ def to_float(_input):
222
+ if _input.dtype == np.float64:
223
+ return _input, _input.dtype
224
+ elif _input.dtype == np.float32:
225
+ return _input.astype(np.float64), _input.dtype
226
+ elif _input.dtype == np.uint8:
227
+ return (_input - 128) / 128., _input.dtype
228
+ elif _input.dtype == np.int16:
229
+ return _input / 32768., _input.dtype
230
+ elif _input.dtype == np.int32:
231
+ return _input / 2147483648., _input.dtype
232
+ raise ValueError('Unsupported wave file format')
233
+
234
+
235
+ def from_float(_input, dtype):
236
+ if dtype == np.float64:
237
+ return _input, np.float64
238
+ elif dtype == np.float32:
239
+ return _input.astype(np.float32)
240
+ elif dtype == np.uint8:
241
+ return ((_input * 128) + 128).astype(np.uint8)
242
+ elif dtype == np.int16:
243
+ return (_input * 32768).astype(np.int16)
244
+ elif dtype == np.int32:
245
+ print(_input)
246
+ return (_input * 2147483648).astype(np.int32)
247
+ raise ValueError('Unsupported wave file format')
utils/modelutils.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def check_model_paths(encoder_path: Path, synthesizer_path: Path, vocoder_path: Path):
4
+ # This function tests the model paths and makes sure at least one is valid.
5
+ if encoder_path.is_file() or encoder_path.is_dir():
6
+ return
7
+ if synthesizer_path.is_file() or synthesizer_path.is_dir():
8
+ return
9
+ if vocoder_path.is_file() or vocoder_path.is_dir():
10
+ return
11
+
12
+ # If none of the paths exist, remind the user to download models if needed
13
+ print("********************************************************************************")
14
+ print("Error: Model files not found. Follow these instructions to get and install the models:")
15
+ print("https://github.com/CorentinJ/Real-Time-Voice-Cloning/wiki/Pretrained-models")
16
+ print("********************************************************************************\n")
17
+ quit(-1)
utils/profiler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter as timer
2
+ from collections import OrderedDict
3
+ import numpy as np
4
+
5
+
6
+ class Profiler:
7
+ def __init__(self, summarize_every=5, disabled=False):
8
+ self.last_tick = timer()
9
+ self.logs = OrderedDict()
10
+ self.summarize_every = summarize_every
11
+ self.disabled = disabled
12
+
13
+ def tick(self, name):
14
+ if self.disabled:
15
+ return
16
+
17
+ # Log the time needed to execute that function
18
+ if not name in self.logs:
19
+ self.logs[name] = []
20
+ if len(self.logs[name]) >= self.summarize_every:
21
+ self.summarize()
22
+ self.purge_logs()
23
+ self.logs[name].append(timer() - self.last_tick)
24
+
25
+ self.reset_timer()
26
+
27
+ def purge_logs(self):
28
+ for name in self.logs:
29
+ self.logs[name].clear()
30
+
31
+ def reset_timer(self):
32
+ self.last_tick = timer()
33
+
34
+ def summarize(self):
35
+ n = max(map(len, self.logs.values()))
36
+ assert n == self.summarize_every
37
+ print("\nAverage execution time over %d steps:" % n)
38
+
39
+ name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()]
40
+ pad = max(map(len, name_msgs))
41
+ for name_msg, deltas in zip(name_msgs, self.logs.values()):
42
+ print(" %s mean: %4.0fms std: %4.0fms" %
43
+ (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000))
44
+ print("", flush=True)
45
+