Kevin676 ruslanmv commited on
Commit
b6b5ece
0 Parent(s):

Duplicate from ruslanmv/Clone-Your-Voice

Browse files

Co-authored-by: Ruslan Magana Vsevolodovna <ruslanmv@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +50 -0
  2. .gitignore +5 -0
  3. README.md +14 -0
  4. app.py +377 -0
  5. encoder/__init__.py +0 -0
  6. encoder/audio.py +117 -0
  7. encoder/config.py +45 -0
  8. encoder/data_objects/__init__.py +2 -0
  9. encoder/data_objects/random_cycler.py +37 -0
  10. encoder/data_objects/speaker.py +40 -0
  11. encoder/data_objects/speaker_batch.py +13 -0
  12. encoder/data_objects/speaker_verification_dataset.py +56 -0
  13. encoder/data_objects/utterance.py +26 -0
  14. encoder/inference.py +178 -0
  15. encoder/model.py +135 -0
  16. encoder/params_data.py +29 -0
  17. encoder/params_model.py +11 -0
  18. encoder/preprocess.py +184 -0
  19. encoder/train.py +125 -0
  20. encoder/visualizations.py +179 -0
  21. musk.mp3 +0 -0
  22. queen.mp3 +0 -0
  23. requirements.txt +0 -0
  24. synthesizer/LICENSE.txt +24 -0
  25. synthesizer/__init__.py +1 -0
  26. synthesizer/audio.py +206 -0
  27. synthesizer/hparams.py +92 -0
  28. synthesizer/inference.py +165 -0
  29. synthesizer/models/tacotron.py +519 -0
  30. synthesizer/preprocess.py +258 -0
  31. synthesizer/synthesize.py +92 -0
  32. synthesizer/synthesizer_dataset.py +92 -0
  33. synthesizer/train.py +258 -0
  34. synthesizer/utils/__init__.py +45 -0
  35. synthesizer/utils/_cmudict.py +62 -0
  36. synthesizer/utils/cleaners.py +88 -0
  37. synthesizer/utils/numbers.py +69 -0
  38. synthesizer/utils/plot.py +82 -0
  39. synthesizer/utils/symbols.py +17 -0
  40. synthesizer/utils/text.py +75 -0
  41. trump.mp3 +0 -0
  42. utils/__init__.py +0 -0
  43. utils/argutils.py +40 -0
  44. utils/default_models.py +56 -0
  45. utils/logmmse.py +247 -0
  46. utils/profiler.py +45 -0
  47. vocoder/LICENSE.txt +22 -0
  48. vocoder/audio.py +108 -0
  49. vocoder/display.py +127 -0
  50. vocoder/distribution.py +132 -0
.gitattributes ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.model filter=lfs diff=lfs merge=lfs -text
11
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
12
+ *.npy filter=lfs diff=lfs merge=lfs -text
13
+ *.npz filter=lfs diff=lfs merge=lfs -text
14
+ *.onnx filter=lfs diff=lfs merge=lfs -text
15
+ *.ot filter=lfs diff=lfs merge=lfs -text
16
+ *.parquet filter=lfs diff=lfs merge=lfs -text
17
+ *.pickle filter=lfs diff=lfs merge=lfs -text
18
+ *.pkl filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pt filter=lfs diff=lfs merge=lfs -text
21
+ *.pth filter=lfs diff=lfs merge=lfs -text
22
+ *.rar filter=lfs diff=lfs merge=lfs -text
23
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
24
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
25
+ *.tflite filter=lfs diff=lfs merge=lfs -text
26
+ *.tgz filter=lfs diff=lfs merge=lfs -text
27
+ *.wasm filter=lfs diff=lfs merge=lfs -text
28
+ *.xz filter=lfs diff=lfs merge=lfs -text
29
+ *.zip filter=lfs diff=lfs merge=lfs -text
30
+ *.zst filter=lfs diff=lfs merge=lfs -text
31
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
32
+ *.pyc
33
+ *.aux
34
+ *.log
35
+ *.out
36
+ *.synctex.gz
37
+ *.suo
38
+ *__pycache__
39
+ *.idea
40
+ *.ipynb_checkpoints
41
+ *.pickle
42
+ *.npy
43
+ *.blg
44
+ *.bbl
45
+ *.bcf
46
+ *.toc
47
+ *.sh
48
+ encoder/saved_models/*
49
+ synthesizer/saved_models/*
50
+ vocoder/saved_models/*
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ *.pyc
3
+ *.pt
4
+ *.ipynb
5
+ *.wav
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Clone Your Voice
3
+ emoji: 📚
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ python_version: 3.8.4
7
+ sdk: gradio
8
+ sdk_version: 3.0.4
9
+ app_file: app.py
10
+ pinned: false
11
+ duplicated_from: ruslanmv/Clone-Your-Voice
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ from utils.default_models import ensure_default_models
4
+ import sys
5
+ import traceback
6
+ from pathlib import Path
7
+ from time import perf_counter as timer
8
+ import numpy as np
9
+ import torch
10
+ from encoder import inference as encoder
11
+ from synthesizer.inference import Synthesizer
12
+ #from toolbox.utterance import Utterance
13
+ from vocoder import inference as vocoder
14
+ import time
15
+ import librosa
16
+ import numpy as np
17
+ #import sounddevice as sd
18
+ import soundfile as sf
19
+ import argparse
20
+ from utils.argutils import print_args
21
+
22
+ parser = argparse.ArgumentParser(
23
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
24
+ )
25
+ parser.add_argument("-e", "--enc_model_fpath", type=Path,
26
+ default="saved_models/default/encoder.pt",
27
+ help="Path to a saved encoder")
28
+ parser.add_argument("-s", "--syn_model_fpath", type=Path,
29
+ default="saved_models/default/synthesizer.pt",
30
+ help="Path to a saved synthesizer")
31
+ parser.add_argument("-v", "--voc_model_fpath", type=Path,
32
+ default="saved_models/default/vocoder.pt",
33
+ help="Path to a saved vocoder")
34
+ parser.add_argument("--cpu", action="store_true", help=\
35
+ "If True, processing is done on CPU, even when a GPU is available.")
36
+ parser.add_argument("--no_sound", action="store_true", help=\
37
+ "If True, audio won't be played.")
38
+ parser.add_argument("--seed", type=int, default=None, help=\
39
+ "Optional random number seed value to make toolbox deterministic.")
40
+ args = parser.parse_args()
41
+ arg_dict = vars(args)
42
+ print_args(args, parser)
43
+
44
+ # Maximum of generated wavs to keep on memory
45
+ MAX_WAVS = 15
46
+ utterances = set()
47
+ current_generated = (None, None, None, None) # speaker_name, spec, breaks, wav
48
+ synthesizer = None # type: Synthesizer
49
+ current_wav = None
50
+ waves_list = []
51
+ waves_count = 0
52
+ waves_namelist = []
53
+
54
+ # Hide GPUs from Pytorch to force CPU processing
55
+ if arg_dict.pop("cpu"):
56
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
57
+
58
+ print("Running a test of your configuration...\n")
59
+
60
+ if torch.cuda.is_available():
61
+ device_id = torch.cuda.current_device()
62
+ gpu_properties = torch.cuda.get_device_properties(device_id)
63
+ ## Print some environment information (for debugging purposes)
64
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
65
+ "%.1fGb total memory.\n" %
66
+ (torch.cuda.device_count(),
67
+ device_id,
68
+ gpu_properties.name,
69
+ gpu_properties.major,
70
+ gpu_properties.minor,
71
+ gpu_properties.total_memory / 1e9))
72
+ else:
73
+ print("Using CPU for inference.\n")
74
+
75
+ ## Load the models one by one.
76
+ print("Preparing the encoder, the synthesizer and the vocoder...")
77
+ ensure_default_models(Path("saved_models"))
78
+ #encoder.load_model(args.enc_model_fpath)
79
+ #synthesizer = Synthesizer(args.syn_model_fpath)
80
+ #vocoder.load_model(args.voc_model_fpath)
81
+
82
+ def compute_embedding(in_fpath):
83
+
84
+ if not encoder.is_loaded():
85
+ model_fpath = args.enc_model_fpath
86
+ print("Loading the encoder %s... " % model_fpath)
87
+ start = time.time()
88
+ encoder.load_model(model_fpath)
89
+ print("Done (%dms)." % int(1000 * (time.time() - start)), "append")
90
+
91
+
92
+ ## Computing the embedding
93
+ # First, we load the wav using the function that the speaker encoder provides. This is
94
+
95
+ # Get the wav from the disk. We take the wav with the vocoder/synthesizer format for
96
+ # playback, so as to have a fair comparison with the generated audio
97
+ wav = Synthesizer.load_preprocess_wav(in_fpath)
98
+
99
+ # important: there is preprocessing that must be applied.
100
+
101
+ # The following two methods are equivalent:
102
+ # - Directly load from the filepath:
103
+ preprocessed_wav = encoder.preprocess_wav(wav)
104
+
105
+ # - If the wav is already loaded:
106
+ #original_wav, sampling_rate = librosa.load(str(in_fpath))
107
+ #preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
108
+
109
+ # Compute the embedding
110
+ embed, partial_embeds, _ = encoder.embed_utterance(preprocessed_wav, return_partials=True)
111
+
112
+
113
+ print("Loaded file succesfully")
114
+
115
+ # Then we derive the embedding. There are many functions and parameters that the
116
+ # speaker encoder interfaces. These are mostly for in-depth research. You will typically
117
+ # only use this function (with its default parameters):
118
+ #embed = encoder.embed_utterance(preprocessed_wav)
119
+
120
+ return embed
121
+ def create_spectrogram(text,embed):
122
+ # If seed is specified, reset torch seed and force synthesizer reload
123
+ if args.seed is not None:
124
+ torch.manual_seed(args.seed)
125
+ synthesizer = Synthesizer(args.syn_model_fpath)
126
+
127
+
128
+ # Synthesize the spectrogram
129
+ model_fpath = args.syn_model_fpath
130
+ print("Loading the synthesizer %s... " % model_fpath)
131
+ start = time.time()
132
+ synthesizer = Synthesizer(model_fpath)
133
+ print("Done (%dms)." % int(1000 * (time.time()- start)), "append")
134
+
135
+
136
+ # The synthesizer works in batch, so you need to put your data in a list or numpy array
137
+ texts = [text]
138
+ embeds = [embed]
139
+ # If you know what the attention layer alignments are, you can retrieve them here by
140
+ # passing return_alignments=True
141
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
142
+ breaks = [spec.shape[1] for spec in specs]
143
+ spec = np.concatenate(specs, axis=1)
144
+ sample_rate=synthesizer.sample_rate
145
+ return spec, breaks , sample_rate
146
+
147
+
148
+ def generate_waveform(current_generated):
149
+
150
+ speaker_name, spec, breaks = current_generated
151
+ assert spec is not None
152
+
153
+ ## Generating the waveform
154
+ print("Synthesizing the waveform:")
155
+ # If seed is specified, reset torch seed and reload vocoder
156
+ if args.seed is not None:
157
+ torch.manual_seed(args.seed)
158
+ vocoder.load_model(args.voc_model_fpath)
159
+
160
+ model_fpath = args.voc_model_fpath
161
+ # Synthesize the waveform
162
+ if not vocoder.is_loaded():
163
+ print("Loading the vocoder %s... " % model_fpath)
164
+ start = time.time()
165
+ vocoder.load_model(model_fpath)
166
+ print("Done (%dms)." % int(1000 * (time.time()- start)), "append")
167
+
168
+ current_vocoder_fpath= model_fpath
169
+ def vocoder_progress(i, seq_len, b_size, gen_rate):
170
+ real_time_factor = (gen_rate / Synthesizer.sample_rate) * 1000
171
+ line = "Waveform generation: %d/%d (batch size: %d, rate: %.1fkHz - %.2fx real time)" \
172
+ % (i * b_size, seq_len * b_size, b_size, gen_rate, real_time_factor)
173
+ print(line, "overwrite")
174
+
175
+
176
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
177
+ # spectrogram, the more time-efficient the vocoder.
178
+ if current_vocoder_fpath is not None:
179
+ print("")
180
+ generated_wav = vocoder.infer_waveform(spec, progress_callback=vocoder_progress)
181
+ else:
182
+ print("Waveform generation with Griffin-Lim... ")
183
+ generated_wav = Synthesizer.griffin_lim(spec)
184
+
185
+ print(" Done!", "append")
186
+
187
+
188
+ ## Post-generation
189
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
190
+ # pad it.
191
+ generated_wav = np.pad(generated_wav, (0, Synthesizer.sample_rate), mode="constant")
192
+
193
+ # Add breaks
194
+ b_ends = np.cumsum(np.array(breaks) * Synthesizer.hparams.hop_size)
195
+ b_starts = np.concatenate(([0], b_ends[:-1]))
196
+ wavs = [generated_wav[start:end] for start, end, in zip(b_starts, b_ends)]
197
+ breaks = [np.zeros(int(0.15 * Synthesizer.sample_rate))] * len(breaks)
198
+ generated_wav = np.concatenate([i for w, b in zip(wavs, breaks) for i in (w, b)])
199
+
200
+
201
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
202
+ generated_wav = encoder.preprocess_wav(generated_wav)
203
+
204
+
205
+ return generated_wav
206
+
207
+
208
+ def save_on_disk(generated_wav,sample_rate):
209
+ # Save it on the disk
210
+ filename = "cloned_voice.wav"
211
+ print(generated_wav.dtype)
212
+ #OUT=os.environ['OUT_PATH']
213
+ # Returns `None` if key doesn't exist
214
+ #OUT=os.environ.get('OUT_PATH')
215
+ #result = os.path.join(OUT, filename)
216
+ result = filename
217
+ print(" > Saving output to {}".format(result))
218
+ sf.write(result, generated_wav.astype(np.float32), sample_rate)
219
+ print("\nSaved output as %s\n\n" % result)
220
+
221
+ return result
222
+ def play_audio(generated_wav,sample_rate):
223
+ # Play the audio (non-blocking)
224
+ if not args.no_sound:
225
+
226
+ try:
227
+ sd.stop()
228
+ sd.play(generated_wav, sample_rate)
229
+ except sd.PortAudioError as e:
230
+ print("\nCaught exception: %s" % repr(e))
231
+ print("Continuing without audio playback. Suppress this message with the \"--no_sound\" flag.\n")
232
+ except:
233
+ raise
234
+
235
+
236
+ def clean_memory():
237
+ import gc
238
+ #import GPUtil
239
+ # To see memory usage
240
+ print('Before clean ')
241
+ #GPUtil.showUtilization()
242
+ #cleaning memory 1
243
+ gc.collect()
244
+ torch.cuda.empty_cache()
245
+ time.sleep(2)
246
+ print('After Clean GPU')
247
+ #GPUtil.showUtilization()
248
+
249
+ def clone_voice(in_fpath, text):
250
+ try:
251
+ speaker_name = "output"
252
+ # Compute embedding
253
+ embed=compute_embedding(in_fpath)
254
+ print("Created the embedding")
255
+ # Generating the spectrogram
256
+ spec, breaks, sample_rate = create_spectrogram(text,embed)
257
+ current_generated = (speaker_name, spec, breaks)
258
+ print("Created the mel spectrogram")
259
+
260
+ # Create waveform
261
+ generated_wav=generate_waveform(current_generated)
262
+ print("Created the the waveform ")
263
+
264
+ # Save it on the disk
265
+ save_on_disk(generated_wav,sample_rate)
266
+
267
+ #Play the audio
268
+ #play_audio(generated_wav,sample_rate)
269
+
270
+ return
271
+ except Exception as e:
272
+ print("Caught exception: %s" % repr(e))
273
+ print("Restarting\n")
274
+
275
+ # Set environment variables
276
+ home_dir = os.getcwd()
277
+ OUT_PATH=os.path.join(home_dir, "out/")
278
+ os.environ['OUT_PATH'] = OUT_PATH
279
+
280
+ # create output path
281
+ os.makedirs(OUT_PATH, exist_ok=True)
282
+
283
+ USE_CUDA = torch.cuda.is_available()
284
+
285
+ os.system('pip install -q pydub ffmpeg-normalize')
286
+ CONFIG_SE_PATH = "config_se.json"
287
+ CHECKPOINT_SE_PATH = "SE_checkpoint.pth.tar"
288
+ def greet(Text,Voicetoclone ,input_mic=None):
289
+ text= "%s" % (Text)
290
+ #reference_files= "%s" % (Voicetoclone)
291
+
292
+ clean_memory()
293
+ print(text,len(text),type(text))
294
+ print(Voicetoclone,type(Voicetoclone))
295
+
296
+ if len(text) == 0 :
297
+ print("Please add text to the program")
298
+ Text="Please add text to the program, thank you."
299
+ is_no_text=True
300
+ else:
301
+ is_no_text=False
302
+
303
+
304
+ if Voicetoclone==None and input_mic==None:
305
+ print("There is no input audio")
306
+ Text="Please add audio input, to the program, thank you."
307
+ Voicetoclone='trump.mp3'
308
+ if is_no_text:
309
+ Text="Please add text and audio, to the program, thank you."
310
+
311
+ if input_mic != "" and input_mic != None :
312
+ # Get the wav file from the microphone
313
+ print('The value of MIC IS :',input_mic,type(input_mic))
314
+ Voicetoclone= input_mic
315
+
316
+ text= "%s" % (Text)
317
+ reference_files= Voicetoclone
318
+ print("path url")
319
+ print(Voicetoclone)
320
+ sample= str(Voicetoclone)
321
+ os.environ['sample'] = sample
322
+ size= len(reference_files)*sys.getsizeof(reference_files)
323
+ size2= size / 1000000
324
+ if (size2 > 0.012) or len(text)>2000:
325
+ message="File is greater than 30mb or Text inserted is longer than 2000 characters. Please re-try with smaller sizes."
326
+ print(message)
327
+ raise SystemExit("File is greater than 30mb. Please re-try or Text inserted is longer than 2000 characters. Please re-try with smaller sizes.")
328
+ else:
329
+
330
+ env_var = 'sample'
331
+ if env_var in os.environ:
332
+ print(f'{env_var} value is {os.environ[env_var]}')
333
+ else:
334
+ print(f'{env_var} does not exist')
335
+ #os.system(f'ffmpeg-normalize {os.environ[env_var]} -nt rms -t=-27 -o {os.environ[env_var]} -ar 16000 -f')
336
+ in_fpath = Path(Voicetoclone)
337
+ #in_fpath= in_fpath.replace("\"", "").replace("\'", "")
338
+
339
+ out_path=clone_voice(in_fpath, text)
340
+
341
+ print(" > text: {}".format(text))
342
+
343
+ print("Generated Audio")
344
+ return "cloned_voice.wav"
345
+
346
+ demo = gr.Interface(
347
+ fn=greet,
348
+ inputs=[gr.inputs.Textbox(label='What would you like the voice to say? (max. 2000 characters per request)'),
349
+ gr.Audio(
350
+ type="filepath",
351
+ source="upload",
352
+ label='Please upload a voice to clone (max. 30mb)'),
353
+ gr.inputs.Audio(
354
+ source="microphone",
355
+ label='or record',
356
+ type="filepath",
357
+ optional=True)
358
+ ],
359
+ outputs="audio",
360
+
361
+ title = 'Clone Your Voice',
362
+ description = 'A simple application that Clone Your Voice. Wait one minute to process.',
363
+ article =
364
+ '''<div>
365
+ <p style="text-align: center"> All you need to do is record your voice, type what you want be say
366
+ ,then wait for compiling. After that click on Play/Pause for listen the audio. The audio is saved in an wav format.
367
+ For more information visit <a href="https://ruslanmv.com/">ruslanmv.com</a>
368
+ </p>
369
+ </div>''',
370
+
371
+ examples = [["I am the cloned version of Donald Trump. Well. I think what's happening to this country is unbelievably bad. We're no longer a respected country","trump.mp3","trump.mp3"],
372
+ ["I am the cloned version of Elon Musk. Persistence is very important. You should not give up unless you are forced to give up.","musk.mp3","musk.mp3"] ,
373
+ ["I am the cloned version of Elizabeth. It has always been easy to hate and destroy. To build and to cherish is much more difficult." ,"queen.mp3","queen.mp3"]
374
+ ]
375
+
376
+ )
377
+ demo.launch()
encoder/__init__.py ADDED
File without changes
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ wav,
60
+ sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+
6
+ class SpeakerBatch:
7
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
8
+ self.speakers = speakers
9
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
10
+
11
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
12
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
13
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import numpy as np
8
+ import torch
9
+
10
+ _model = None # type: SpeakerEncoder
11
+ _device = None # type: torch.device
12
+
13
+
14
+ def load_model(weights_fpath: Path, device=None):
15
+ """
16
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
17
+ first call to embed_frames() with the default weights file.
18
+
19
+ :param weights_fpath: the path to saved model weights.
20
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
21
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
22
+ If None, will default to your GPU if it"s available, otherwise your CPU.
23
+ """
24
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
25
+ # was saved on. Worth investigating.
26
+ global _model, _device
27
+ if device is None:
28
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ elif isinstance(device, str):
30
+ _device = torch.device(device)
31
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
32
+ checkpoint = torch.load(weights_fpath, _device)
33
+ _model.load_state_dict(checkpoint["model_state"])
34
+ _model.eval()
35
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
36
+
37
+
38
+ def is_loaded():
39
+ return _model is not None
40
+
41
+
42
+ def embed_frames_batch(frames_batch):
43
+ """
44
+ Computes embeddings for a batch of mel spectrogram.
45
+
46
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
47
+ (batch_size, n_frames, n_channels)
48
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
49
+ """
50
+ if _model is None:
51
+ raise Exception("Model was not loaded. Call load_model() before inference.")
52
+
53
+ frames = torch.from_numpy(frames_batch).to(_device)
54
+ embed = _model.forward(frames).detach().cpu().numpy()
55
+ return embed
56
+
57
+
58
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
59
+ min_pad_coverage=0.75, overlap=0.5):
60
+ """
61
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
62
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
63
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
64
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
65
+ defined in params_data.py.
66
+
67
+ The returned ranges may be indexing further than the length of the waveform. It is
68
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
69
+
70
+ :param n_samples: the number of samples in the waveform
71
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
72
+ utterance
73
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
74
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
75
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
76
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
77
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
78
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
79
+ utterances are entirely disjoint.
80
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
81
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
82
+ utterances.
83
+ """
84
+ assert 0 <= overlap < 1
85
+ assert 0 < min_pad_coverage <= 1
86
+
87
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
88
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
89
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
90
+
91
+ # Compute the slices
92
+ wav_slices, mel_slices = [], []
93
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
94
+ for i in range(0, steps, frame_step):
95
+ mel_range = np.array([i, i + partial_utterance_n_frames])
96
+ wav_range = mel_range * samples_per_frame
97
+ mel_slices.append(slice(*mel_range))
98
+ wav_slices.append(slice(*wav_range))
99
+
100
+ # Evaluate whether extra padding is warranted or not
101
+ last_wav_range = wav_slices[-1]
102
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
103
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
104
+ mel_slices = mel_slices[:-1]
105
+ wav_slices = wav_slices[:-1]
106
+
107
+ return wav_slices, mel_slices
108
+
109
+
110
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
111
+ """
112
+ Computes an embedding for a single utterance.
113
+
114
+ # TODO: handle multiple wavs to benefit from batching on GPU
115
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
116
+ :param using_partials: if True, then the utterance is split in partial utterances of
117
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
118
+ normalized average. If False, the utterance is instead computed from feeding the entire
119
+ spectogram to the network.
120
+ :param return_partials: if True, the partial embeddings will also be returned along with the
121
+ wav slices that correspond to the partial embeddings.
122
+ :param kwargs: additional arguments to compute_partial_splits()
123
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
124
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
125
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
126
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
127
+ instead.
128
+ """
129
+ # Process the entire utterance if not using partials
130
+ if not using_partials:
131
+ frames = audio.wav_to_mel_spectrogram(wav)
132
+ embed = embed_frames_batch(frames[None, ...])[0]
133
+ if return_partials:
134
+ return embed, None, None
135
+ return embed
136
+
137
+ # Compute where to split the utterance into partials and pad if necessary
138
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
139
+ max_wave_length = wave_slices[-1].stop
140
+ if max_wave_length >= len(wav):
141
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
142
+
143
+ # Split the utterance into partials
144
+ frames = audio.wav_to_mel_spectrogram(wav)
145
+ frames_batch = np.array([frames[s] for s in mel_slices])
146
+ partial_embeds = embed_frames_batch(frames_batch)
147
+
148
+ # Compute the utterance embedding from the partial embeddings
149
+ raw_embed = np.mean(partial_embeds, axis=0)
150
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
151
+
152
+ if return_partials:
153
+ return embed, partial_embeds, wave_slices
154
+ return embed
155
+
156
+
157
+ def embed_speaker(wavs, **kwargs):
158
+ raise NotImplemented()
159
+
160
+
161
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
162
+ import matplotlib.pyplot as plt
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ sm = cm.ScalarMappable(cmap=cmap)
175
+ sm.set_clim(*color_range)
176
+
177
+ ax.set_xticks([]), ax.set_yticks([])
178
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from multiprocessing import Pool
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+
9
+ from encoder import audio
10
+ from encoder.config import librispeech_datasets, anglophone_nationalites
11
+ from encoder.params_data import *
12
+
13
+
14
+ _AUDIO_EXTENSIONS = ("wav", "flac", "m4a", "mp3")
15
+
16
+ class DatasetLog:
17
+ """
18
+ Registers metadata about the dataset in a text file.
19
+ """
20
+ def __init__(self, root, name):
21
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
22
+ self.sample_data = dict()
23
+
24
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
25
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
26
+ self.write_line("-----")
27
+ self._log_params()
28
+
29
+ def _log_params(self):
30
+ from encoder import params_data
31
+ self.write_line("Parameter values:")
32
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
33
+ value = getattr(params_data, param_name)
34
+ self.write_line("\t%s: %s" % (param_name, value))
35
+ self.write_line("-----")
36
+
37
+ def write_line(self, line):
38
+ self.text_file.write("%s\n" % line)
39
+
40
+ def add_sample(self, **kwargs):
41
+ for param_name, value in kwargs.items():
42
+ if not param_name in self.sample_data:
43
+ self.sample_data[param_name] = []
44
+ self.sample_data[param_name].append(value)
45
+
46
+ def finalize(self):
47
+ self.write_line("Statistics:")
48
+ for param_name, values in self.sample_data.items():
49
+ self.write_line("\t%s:" % param_name)
50
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
51
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
52
+ self.write_line("-----")
53
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
54
+ self.write_line("Finished on %s" % end_time)
55
+ self.text_file.close()
56
+
57
+
58
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
59
+ dataset_root = datasets_root.joinpath(dataset_name)
60
+ if not dataset_root.exists():
61
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
62
+ return None, None
63
+ return dataset_root, DatasetLog(out_dir, dataset_name)
64
+
65
+
66
+ def _preprocess_speaker(speaker_dir: Path, datasets_root: Path, out_dir: Path, skip_existing: bool):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ audio_durs = []
90
+ for extension in _AUDIO_EXTENSIONS:
91
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
92
+ # Check if the target output file already exists
93
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
94
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
95
+ if skip_existing and out_fname in existing_fnames:
96
+ continue
97
+
98
+ # Load and preprocess the waveform
99
+ wav = audio.preprocess_wav(in_fpath)
100
+ if len(wav) == 0:
101
+ continue
102
+
103
+ # Create the mel spectrogram, discard those that are too short
104
+ frames = audio.wav_to_mel_spectrogram(wav)
105
+ if len(frames) < partials_n_frames:
106
+ continue
107
+
108
+ out_fpath = speaker_out_dir.joinpath(out_fname)
109
+ np.save(out_fpath, frames)
110
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
111
+ audio_durs.append(len(wav) / sampling_rate)
112
+
113
+ sources_file.close()
114
+
115
+ return audio_durs
116
+
117
+
118
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger):
119
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
120
+
121
+ # Process the utterances for each speaker
122
+ work_fn = partial(_preprocess_speaker, datasets_root=datasets_root, out_dir=out_dir, skip_existing=skip_existing)
123
+ with Pool(4) as pool:
124
+ tasks = pool.imap(work_fn, speaker_dirs)
125
+ for sample_durs in tqdm(tasks, dataset_name, len(speaker_dirs), unit="speakers"):
126
+ for sample_dur in sample_durs:
127
+ logger.add_sample(duration=sample_dur)
128
+
129
+ logger.finalize()
130
+ print("Done preprocessing %s.\n" % dataset_name)
131
+
132
+
133
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
134
+ for dataset_name in librispeech_datasets["train"]["other"]:
135
+ # Initialize the preprocessing
136
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
137
+ if not dataset_root:
138
+ return
139
+
140
+ # Preprocess all speakers
141
+ speaker_dirs = list(dataset_root.glob("*"))
142
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
143
+
144
+
145
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
146
+ # Initialize the preprocessing
147
+ dataset_name = "VoxCeleb1"
148
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
149
+ if not dataset_root:
150
+ return
151
+
152
+ # Get the contents of the meta file
153
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
154
+ metadata = [line.split("\t") for line in metafile][1:]
155
+
156
+ # Select the ID and the nationality, filter out non-anglophone speakers
157
+ nationalities = {line[0]: line[3] for line in metadata}
158
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
159
+ nationality.lower() in anglophone_nationalites]
160
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
161
+ (len(keep_speaker_ids), len(nationalities)))
162
+
163
+ # Get the speaker directories for anglophone speakers only
164
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
165
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
166
+ speaker_dir.name in keep_speaker_ids]
167
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
168
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
169
+
170
+ # Preprocess all speakers
171
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
172
+
173
+
174
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
175
+ # Initialize the preprocessing
176
+ dataset_name = "VoxCeleb2"
177
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
178
+ if not dataset_root:
179
+ return
180
+
181
+ # Get the speaker directories
182
+ # Preprocess all speakers
183
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
184
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, skip_existing, logger)
encoder/train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
6
+ from encoder.model import SpeakerEncoder
7
+ from encoder.params_model import *
8
+ from encoder.visualizations import Visualizations
9
+ from utils.profiler import Profiler
10
+
11
+
12
+ def sync(device: torch.device):
13
+ # For correct profiling (cuda operations are async)
14
+ if device.type == "cuda":
15
+ torch.cuda.synchronize(device)
16
+
17
+
18
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
19
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
20
+ no_visdom: bool):
21
+ # Create a dataset and a dataloader
22
+ dataset = SpeakerVerificationDataset(clean_data_root)
23
+ loader = SpeakerVerificationDataLoader(
24
+ dataset,
25
+ speakers_per_batch,
26
+ utterances_per_speaker,
27
+ num_workers=4,
28
+ )
29
+
30
+ # Setup the device on which to run the forward pass and the loss. These can be different,
31
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
32
+ # hyperparameters) faster on the CPU.
33
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ # FIXME: currently, the gradient is None if loss_device is cuda
35
+ loss_device = torch.device("cpu")
36
+
37
+ # Create the model and the optimizer
38
+ model = SpeakerEncoder(device, loss_device)
39
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
40
+ init_step = 1
41
+
42
+ # Configure file path for the model
43
+ model_dir = models_dir / run_id
44
+ model_dir.mkdir(exist_ok=True, parents=True)
45
+ state_fpath = model_dir / "encoder.pt"
46
+
47
+ # Load any existing model
48
+ if not force_restart:
49
+ if state_fpath.exists():
50
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
51
+ checkpoint = torch.load(state_fpath)
52
+ init_step = checkpoint["step"]
53
+ model.load_state_dict(checkpoint["model_state"])
54
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
55
+ optimizer.param_groups[0]["lr"] = learning_rate_init
56
+ else:
57
+ print("No model \"%s\" found, starting training from scratch." % run_id)
58
+ else:
59
+ print("Starting the training from scratch.")
60
+ model.train()
61
+
62
+ # Initialize the visualization environment
63
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
64
+ vis.log_dataset(dataset)
65
+ vis.log_params()
66
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
67
+ vis.log_implementation({"Device": device_name})
68
+
69
+ # Training loop
70
+ profiler = Profiler(summarize_every=10, disabled=False)
71
+ for step, speaker_batch in enumerate(loader, init_step):
72
+ profiler.tick("Blocking, waiting for batch (threaded)")
73
+
74
+ # Forward pass
75
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
76
+ sync(device)
77
+ profiler.tick("Data to %s" % device)
78
+ embeds = model(inputs)
79
+ sync(device)
80
+ profiler.tick("Forward pass")
81
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
82
+ loss, eer = model.loss(embeds_loss)
83
+ sync(loss_device)
84
+ profiler.tick("Loss")
85
+
86
+ # Backward pass
87
+ model.zero_grad()
88
+ loss.backward()
89
+ profiler.tick("Backward pass")
90
+ model.do_gradient_ops()
91
+ optimizer.step()
92
+ profiler.tick("Parameter update")
93
+
94
+ # Update visualizations
95
+ # learning_rate = optimizer.param_groups[0]["lr"]
96
+ vis.update(loss.item(), eer, step)
97
+
98
+ # Draw projections and save them to the backup folder
99
+ if umap_every != 0 and step % umap_every == 0:
100
+ print("Drawing and saving projections (step %d)" % step)
101
+ projection_fpath = model_dir / f"umap_{step:06d}.png"
102
+ embeds = embeds.detach().cpu().numpy()
103
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
104
+ vis.save()
105
+
106
+ # Overwrite the latest version of the model
107
+ if save_every != 0 and step % save_every == 0:
108
+ print("Saving the model (step %d)" % step)
109
+ torch.save({
110
+ "step": step + 1,
111
+ "model_state": model.state_dict(),
112
+ "optimizer_state": optimizer.state_dict(),
113
+ }, state_fpath)
114
+
115
+ # Make a backup
116
+ if backup_every != 0 and step % backup_every == 0:
117
+ print("Making a backup (step %d)" % step)
118
+ backup_fpath = model_dir / f"encoder_{step:06d}.bak"
119
+ torch.save({
120
+ "step": step + 1,
121
+ "model_state": model.state_dict(),
122
+ "optimizer_state": optimizer.state_dict(),
123
+ }, backup_fpath)
124
+
125
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from time import perf_counter as timer
3
+
4
+ import numpy as np
5
+ import umap
6
+ import visdom
7
+
8
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
9
+
10
+
11
+ colormap = np.array([
12
+ [76, 255, 0],
13
+ [0, 127, 70],
14
+ [255, 0, 0],
15
+ [255, 217, 38],
16
+ [0, 135, 255],
17
+ [165, 0, 165],
18
+ [255, 167, 255],
19
+ [0, 255, 255],
20
+ [255, 96, 38],
21
+ [142, 76, 0],
22
+ [33, 0, 127],
23
+ [0, 0, 0],
24
+ [183, 183, 183],
25
+ ], dtype=np.float) / 255
26
+
27
+
28
+ class Visualizations:
29
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
30
+ # Tracking data
31
+ self.last_update_timestamp = timer()
32
+ self.update_every = update_every
33
+ self.step_times = []
34
+ self.losses = []
35
+ self.eers = []
36
+ print("Updating the visualizations every %d steps." % update_every)
37
+
38
+ # If visdom is disabled TODO: use a better paradigm for that
39
+ self.disabled = disabled
40
+ if self.disabled:
41
+ return
42
+
43
+ # Set the environment name
44
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
45
+ if env_name is None:
46
+ self.env_name = now
47
+ else:
48
+ self.env_name = "%s (%s)" % (env_name, now)
49
+
50
+ # Connect to visdom and open the corresponding window in the browser
51
+ try:
52
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
53
+ except ConnectionError:
54
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
55
+ "start it.")
56
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
57
+
58
+ # Create the windows
59
+ self.loss_win = None
60
+ self.eer_win = None
61
+ # self.lr_win = None
62
+ self.implementation_win = None
63
+ self.projection_win = None
64
+ self.implementation_string = ""
65
+
66
+ def log_params(self):
67
+ if self.disabled:
68
+ return
69
+ from encoder import params_data
70
+ from encoder import params_model
71
+ param_string = "<b>Model parameters</b>:<br>"
72
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
73
+ value = getattr(params_model, param_name)
74
+ param_string += "\t%s: %s<br>" % (param_name, value)
75
+ param_string += "<b>Data parameters</b>:<br>"
76
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
77
+ value = getattr(params_data, param_name)
78
+ param_string += "\t%s: %s<br>" % (param_name, value)
79
+ self.vis.text(param_string, opts={"title": "Parameters"})
80
+
81
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
82
+ if self.disabled:
83
+ return
84
+ dataset_string = ""
85
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
86
+ dataset_string += "\n" + dataset.get_logs()
87
+ dataset_string = dataset_string.replace("\n", "<br>")
88
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
89
+
90
+ def log_implementation(self, params):
91
+ if self.disabled:
92
+ return
93
+ implementation_string = ""
94
+ for param, value in params.items():
95
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
96
+ implementation_string = implementation_string.replace("\n", "<br>")
97
+ self.implementation_string = implementation_string
98
+ self.implementation_win = self.vis.text(
99
+ implementation_string,
100
+ opts={"title": "Training implementation"}
101
+ )
102
+
103
+ def update(self, loss, eer, step):
104
+ # Update the tracking data
105
+ now = timer()
106
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
107
+ self.last_update_timestamp = now
108
+ self.losses.append(loss)
109
+ self.eers.append(eer)
110
+ print(".", end="")
111
+
112
+ # Update the plots every <update_every> steps
113
+ if step % self.update_every != 0:
114
+ return
115
+ time_string = "Step time: mean: %5dms std: %5dms" % \
116
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
117
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
118
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
119
+ if not self.disabled:
120
+ self.loss_win = self.vis.line(
121
+ [np.mean(self.losses)],
122
+ [step],
123
+ win=self.loss_win,
124
+ update="append" if self.loss_win else None,
125
+ opts=dict(
126
+ legend=["Avg. loss"],
127
+ xlabel="Step",
128
+ ylabel="Loss",
129
+ title="Loss",
130
+ )
131
+ )
132
+ self.eer_win = self.vis.line(
133
+ [np.mean(self.eers)],
134
+ [step],
135
+ win=self.eer_win,
136
+ update="append" if self.eer_win else None,
137
+ opts=dict(
138
+ legend=["Avg. EER"],
139
+ xlabel="Step",
140
+ ylabel="EER",
141
+ title="Equal error rate"
142
+ )
143
+ )
144
+ if self.implementation_win is not None:
145
+ self.vis.text(
146
+ self.implementation_string + ("<b>%s</b>" % time_string),
147
+ win=self.implementation_win,
148
+ opts={"title": "Training implementation"},
149
+ )
150
+
151
+ # Reset the tracking
152
+ self.losses.clear()
153
+ self.eers.clear()
154
+ self.step_times.clear()
155
+
156
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None, max_speakers=10):
157
+ import matplotlib.pyplot as plt
158
+
159
+ max_speakers = min(max_speakers, len(colormap))
160
+ embeds = embeds[:max_speakers * utterances_per_speaker]
161
+
162
+ n_speakers = len(embeds) // utterances_per_speaker
163
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
164
+ colors = [colormap[i] for i in ground_truth]
165
+
166
+ reducer = umap.UMAP()
167
+ projected = reducer.fit_transform(embeds)
168
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
169
+ plt.gca().set_aspect("equal", "datalim")
170
+ plt.title("UMAP projection (step %d)" % step)
171
+ if not self.disabled:
172
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
173
+ if out_fpath is not None:
174
+ plt.savefig(out_fpath)
175
+ plt.clf()
176
+
177
+ def save(self):
178
+ if not self.disabled:
179
+ self.vis.save([self.env_name])
musk.mp3 ADDED
Binary file (761 kB). View file
 
queen.mp3 ADDED
Binary file (551 kB). View file
 
requirements.txt ADDED
Binary file (546 Bytes). View file
 
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/hparams.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+ sample_rate = 16000,
24
+ n_fft = 800,
25
+ num_mels = 80,
26
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28
+ fmin = 55,
29
+ min_level_db = -100,
30
+ ref_level_db = 20,
31
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33
+ preemphasize = True,
34
+
35
+ ### Tacotron Text-to-Speech (TTS)
36
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37
+ tts_encoder_dims = 256,
38
+ tts_decoder_dims = 128,
39
+ tts_postnet_dims = 512,
40
+ tts_encoder_K = 5,
41
+ tts_lstm_dims = 1024,
42
+ tts_postnet_K = 5,
43
+ tts_num_highways = 4,
44
+ tts_dropout = 0.5,
45
+ tts_cleaner_names = ["english_cleaners"],
46
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
47
+ # For example, for a range of [-4, 4], this
48
+ # will terminate the sequence at the first
49
+ # frame that has all values < -3.4
50
+
51
+ ### Tacotron Training
52
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54
+ (2, 2e-4, 80_000, 12), #
55
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
58
+
59
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61
+ # Set to -1 to generate after completing epoch, or 0 to disable
62
+
63
+ tts_eval_num_samples = 1, # Makes this number of samples
64
+
65
+ ### Data Preprocessing
66
+ max_mel_frames = 900,
67
+ rescale = True,
68
+ rescaling_max = 0.9,
69
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70
+
71
+ ### Mel Visualization and Griffin-Lim
72
+ signal_normalization = True,
73
+ power = 1.5,
74
+ griffin_lim_iters = 60,
75
+
76
+ ### Audio processing options
77
+ fmax = 7600, # Should not exceed (sample_rate // 2)
78
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
79
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82
+ # and [0, max_abs_value] if False
83
+ trim_silence = True, # Use with sample_rate of 16000 for best results
84
+
85
+ ### SV2TTS
86
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
87
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89
+ )
90
+
91
+ def hparams_debug_string():
92
+ return str(hparams)
synthesizer/inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Preprocess text inputs
89
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
90
+ if not isinstance(embeddings, list):
91
+ embeddings = [embeddings]
92
+
93
+ # Batch inputs
94
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
95
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
96
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
97
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
98
+
99
+ specs = []
100
+ for i, batch in enumerate(batched_inputs, 1):
101
+ if self.verbose:
102
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
103
+
104
+ # Pad texts so they are all the same length
105
+ text_lens = [len(text) for text in batch]
106
+ max_text_len = max(text_lens)
107
+ chars = [pad1d(text, max_text_len) for text in batch]
108
+ chars = np.stack(chars)
109
+
110
+ # Stack speaker embeddings into 2D array for batch processing
111
+ speaker_embeds = np.stack(batched_embeds[i-1])
112
+
113
+ # Convert to tensor
114
+ chars = torch.tensor(chars).long().to(self.device)
115
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
116
+
117
+ # Inference
118
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
119
+ mels = mels.detach().cpu().numpy()
120
+ for m in mels:
121
+ # Trim silence from end of each spectrogram
122
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
123
+ m = m[:, :-1]
124
+ specs.append(m)
125
+
126
+ if self.verbose:
127
+ print("\n\nDone.\n")
128
+ return (specs, alignments) if return_alignments else specs
129
+
130
+ @staticmethod
131
+ def load_preprocess_wav(fpath):
132
+ """
133
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
134
+ train the synthesizer.
135
+ """
136
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
137
+ if hparams.rescale:
138
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
139
+ return wav
140
+
141
+ @staticmethod
142
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
143
+ """
144
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
145
+ were fed to the synthesizer when training.
146
+ """
147
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
148
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
149
+ else:
150
+ wav = fpath_or_wav
151
+
152
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
153
+ return mel_spectrogram
154
+
155
+ @staticmethod
156
+ def griffin_lim(mel):
157
+ """
158
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
159
+ with the same parameters present in hparams.py.
160
+ """
161
+ return audio.inv_mel_spectrogram(mel, hparams)
162
+
163
+
164
+ def pad1d(x, max_len, pad_value=0):
165
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
synthesizer/preprocess.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int, skip_existing: bool, hparams,
14
+ no_alignments: bool, datasets_name: str, subfolders: str):
15
+ # Gather the input directories
16
+ dataset_root = datasets_root.joinpath(datasets_name)
17
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
18
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
19
+ assert all(input_dir.exists() for input_dir in input_dirs)
20
+
21
+ # Create the output directories for each output file type
22
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
23
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
24
+
25
+ # Create a metadata file
26
+ metadata_fpath = out_dir.joinpath("train.txt")
27
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
28
+
29
+ # Preprocess the dataset
30
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
31
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
32
+ hparams=hparams, no_alignments=no_alignments)
33
+ job = Pool(n_processes).imap(func, speaker_dirs)
34
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
35
+ for metadatum in speaker_metadata:
36
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
37
+ metadata_file.close()
38
+
39
+ # Verify the contents of the metadata file
40
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
41
+ metadata = [line.split("|") for line in metadata_file]
42
+ mel_frames = sum([int(m[4]) for m in metadata])
43
+ timesteps = sum([int(m[3]) for m in metadata])
44
+ sample_rate = hparams.sample_rate
45
+ hours = (timesteps / sample_rate) / 3600
46
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
47
+ (len(metadata), mel_frames, timesteps, hours))
48
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
49
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
50
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
51
+
52
+
53
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
54
+ metadata = []
55
+ for book_dir in speaker_dir.glob("*"):
56
+ if no_alignments:
57
+ # Gather the utterance audios and texts
58
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
59
+ extensions = ["*.wav", "*.flac", "*.mp3"]
60
+ for extension in extensions:
61
+ wav_fpaths = book_dir.glob(extension)
62
+
63
+ for wav_fpath in wav_fpaths:
64
+ # Load the audio waveform
65
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
66
+ if hparams.rescale:
67
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
68
+
69
+ # Get the corresponding text
70
+ # Check for .txt (for compatibility with other datasets)
71
+ text_fpath = wav_fpath.with_suffix(".txt")
72
+ if not text_fpath.exists():
73
+ # Check for .normalized.txt (LibriTTS)
74
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
75
+ assert text_fpath.exists()
76
+ with text_fpath.open("r") as text_file:
77
+ text = "".join([line for line in text_file])
78
+ text = text.replace("\"", "")
79
+ text = text.strip()
80
+
81
+ # Process the utterance
82
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
83
+ skip_existing, hparams))
84
+ else:
85
+ # Process alignment file (LibriSpeech support)
86
+ # Gather the utterance audios and texts
87
+ try:
88
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
89
+ with alignments_fpath.open("r") as alignments_file:
90
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
91
+ except StopIteration:
92
+ # A few alignment files will be missing
93
+ continue
94
+
95
+ # Iterate over each entry in the alignments file
96
+ for wav_fname, words, end_times in alignments:
97
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
98
+ assert wav_fpath.exists()
99
+ words = words.replace("\"", "").split(",")
100
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
101
+
102
+ # Process each sub-utterance
103
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
104
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
105
+ sub_basename = "%s_%02d" % (wav_fname, i)
106
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
107
+ skip_existing, hparams))
108
+
109
+ return [m for m in metadata if m is not None]
110
+
111
+
112
+ def split_on_silences(wav_fpath, words, end_times, hparams):
113
+ # Load the audio waveform
114
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
115
+ if hparams.rescale:
116
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
117
+
118
+ words = np.array(words)
119
+ start_times = np.array([0.0] + end_times[:-1])
120
+ end_times = np.array(end_times)
121
+ assert len(words) == len(end_times) == len(start_times)
122
+ assert words[0] == "" and words[-1] == ""
123
+
124
+ # Find pauses that are too long
125
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
126
+ mask[0] = mask[-1] = True
127
+ breaks = np.where(mask)[0]
128
+
129
+ # Profile the noise from the silences and perform noise reduction on the waveform
130
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
131
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
132
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
133
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
134
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
135
+ wav = logmmse.denoise(wav, profile, eta=0)
136
+
137
+ # Re-attach segments that are too short
138
+ segments = list(zip(breaks[:-1], breaks[1:]))
139
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
140
+ i = 0
141
+ while i < len(segments) and len(segments) > 1:
142
+ if segment_durations[i] < hparams.utterance_min_duration:
143
+ # See if the segment can be re-attached with the right or the left segment
144
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
145
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
146
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
147
+
148
+ # Do not re-attach if it causes the joined utterance to be too long
149
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
150
+ i += 1
151
+ continue
152
+
153
+ # Re-attach the segment with the neighbour of shortest duration
154
+ j = i - 1 if left_duration <= right_duration else i
155
+ segments[j] = (segments[j][0], segments[j + 1][1])
156
+ segment_durations[j] = joined_duration
157
+ del segments[j + 1], segment_durations[j + 1]
158
+ else:
159
+ i += 1
160
+
161
+ # Split the utterance
162
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
163
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
164
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
165
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
166
+
167
+ # # DEBUG: play the audio segments (run with -n=1)
168
+ # import sounddevice as sd
169
+ # if len(wavs) > 1:
170
+ # print("This sentence was split in %d segments:" % len(wavs))
171
+ # else:
172
+ # print("There are no silences long enough for this sentence to be split:")
173
+ # for wav, text in zip(wavs, texts):
174
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
175
+ # # when playing them. You shouldn't need to do that in your parsers.
176
+ # wav = np.concatenate((wav, [0] * 16000))
177
+ # print("\t%s" % text)
178
+ # sd.play(wav, 16000, blocking=True)
179
+ # print("")
180
+
181
+ return wavs, texts
182
+
183
+
184
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
185
+ skip_existing: bool, hparams):
186
+ ## FOR REFERENCE:
187
+ # For you not to lose your head if you ever wish to change things here or implement your own
188
+ # synthesizer.
189
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
190
+ # - There is no processing done to the audios that will be saved to disk beyond volume
191
+ # normalization (in split_on_silences)
192
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
193
+ # is why we re-apply it on the audio on the side of the vocoder.
194
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
195
+ # without extra padding. This means that you won't have an exact relation between the length
196
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
197
+
198
+
199
+ # Skip existing utterances if needed
200
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
201
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
202
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
203
+ return None
204
+
205
+ # Trim silence
206
+ if hparams.trim_silence:
207
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
208
+
209
+ # Skip utterances that are too short
210
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
211
+ return None
212
+
213
+ # Compute the mel spectrogram
214
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
215
+ mel_frames = mel_spectrogram.shape[1]
216
+
217
+ # Skip utterances that are too long
218
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
219
+ return None
220
+
221
+ # Write the spectrogram, embed and audio to disk
222
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
223
+ np.save(wav_fpath, wav, allow_pickle=False)
224
+
225
+ # Return a tuple describing this training example
226
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
227
+
228
+
229
+ def embed_utterance(fpaths, encoder_model_fpath):
230
+ if not encoder.is_loaded():
231
+ encoder.load_model(encoder_model_fpath)
232
+
233
+ # Compute the speaker embedding of the utterance
234
+ wav_fpath, embed_fpath = fpaths
235
+ wav = np.load(wav_fpath)
236
+ wav = encoder.preprocess_wav(wav)
237
+ embed = encoder.embed_utterance(wav)
238
+ np.save(embed_fpath, embed, allow_pickle=False)
239
+
240
+
241
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
242
+ wav_dir = synthesizer_root.joinpath("audio")
243
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
244
+ assert wav_dir.exists() and metadata_fpath.exists()
245
+ embed_dir = synthesizer_root.joinpath("embeds")
246
+ embed_dir.mkdir(exist_ok=True)
247
+
248
+ # Gather the input wave filepath and the target output embed filepath
249
+ with metadata_fpath.open("r") as metadata_file:
250
+ metadata = [line.split("|") for line in metadata_file]
251
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
252
+
253
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
254
+ # Embed the utterances in separate threads
255
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
256
+ job = Pool(n_processes).imap(func, fpaths)
257
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
258
+
synthesizer/synthesize.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import platform
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+ from tqdm import tqdm
9
+
10
+ from synthesizer.hparams import hparams_debug_string
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import data_parallel_workaround
14
+ from synthesizer.utils.symbols import symbols
15
+
16
+
17
+ def run_synthesis(in_dir: Path, out_dir: Path, syn_model_fpath: Path, hparams):
18
+ # This generates ground truth-aligned mels for vocoder training
19
+ synth_dir = out_dir / "mels_gta"
20
+ synth_dir.mkdir(exist_ok=True, parents=True)
21
+ print(hparams_debug_string())
22
+
23
+ # Check for GPU
24
+ if torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
27
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
28
+ else:
29
+ device = torch.device("cpu")
30
+ print("Synthesizer using device:", device)
31
+
32
+ # Instantiate Tacotron model
33
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
34
+ num_chars=len(symbols),
35
+ encoder_dims=hparams.tts_encoder_dims,
36
+ decoder_dims=hparams.tts_decoder_dims,
37
+ n_mels=hparams.num_mels,
38
+ fft_bins=hparams.num_mels,
39
+ postnet_dims=hparams.tts_postnet_dims,
40
+ encoder_K=hparams.tts_encoder_K,
41
+ lstm_dims=hparams.tts_lstm_dims,
42
+ postnet_K=hparams.tts_postnet_K,
43
+ num_highways=hparams.tts_num_highways,
44
+ dropout=0., # Use zero dropout for gta mels
45
+ stop_threshold=hparams.tts_stop_threshold,
46
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
47
+
48
+ # Load the weights
49
+ print("\nLoading weights at %s" % syn_model_fpath)
50
+ model.load(syn_model_fpath)
51
+ print("Tacotron weights loaded from step %d" % model.step)
52
+
53
+ # Synthesize using same reduction factor as the model is currently trained
54
+ r = np.int32(model.r)
55
+
56
+ # Set model to eval mode (disable gradient and zoneout)
57
+ model.eval()
58
+
59
+ # Initialize the dataset
60
+ metadata_fpath = in_dir.joinpath("train.txt")
61
+ mel_dir = in_dir.joinpath("mels")
62
+ embed_dir = in_dir.joinpath("embeds")
63
+
64
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
65
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
66
+ data_loader = DataLoader(dataset, hparams.synthesis_batch_size, collate_fn=collate_fn, num_workers=2)
67
+
68
+ # Generate GTA mels
69
+ meta_out_fpath = out_dir / "synthesized.txt"
70
+ with meta_out_fpath.open("w") as file:
71
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
72
+ texts, mels, embeds = texts.to(device), mels.to(device), embeds.to(device)
73
+
74
+ # Parallelize model onto GPUS using workaround due to python bug
75
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
76
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
77
+ else:
78
+ _, mels_out, _, _ = model(texts, mels, embeds)
79
+
80
+ for j, k in enumerate(idx):
81
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
82
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
83
+ mel_out = mels_out[j].detach().cpu().numpy().T
84
+
85
+ # Use the length of the ground truth mel to remove padding from the generated mels
86
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
87
+
88
+ # Write the spectrogram to disk
89
+ np.save(mel_filename, mel_out, allow_pickle=False)
90
+
91
+ # Write metadata into the synthesized file
92
+ file.write("|".join(dataset.metadata[k]))
synthesizer/synthesizer_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from synthesizer.utils.text import text_to_sequence
6
+
7
+
8
+ class SynthesizerDataset(Dataset):
9
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
10
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
11
+
12
+ with metadata_fpath.open("r") as metadata_file:
13
+ metadata = [line.split("|") for line in metadata_file]
14
+
15
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
16
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
17
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
18
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
19
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
20
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
21
+ self.metadata = metadata
22
+ self.hparams = hparams
23
+
24
+ print("Found %d samples" % len(self.samples_fpaths))
25
+
26
+ def __getitem__(self, index):
27
+ # Sometimes index may be a list of 2 (not sure why this happens)
28
+ # If that is the case, return a single item corresponding to first element in index
29
+ if index is list:
30
+ index = index[0]
31
+
32
+ mel_path, embed_path = self.samples_fpaths[index]
33
+ mel = np.load(mel_path).T.astype(np.float32)
34
+
35
+ # Load the embed
36
+ embed = np.load(embed_path)
37
+
38
+ # Get the text and clean it
39
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
40
+
41
+ # Convert the list returned by text_to_sequence to a numpy array
42
+ text = np.asarray(text).astype(np.int32)
43
+
44
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
45
+
46
+ def __len__(self):
47
+ return len(self.samples_fpaths)
48
+
49
+
50
+ def collate_synthesizer(batch, r, hparams):
51
+ # Text
52
+ x_lens = [len(x[0]) for x in batch]
53
+ max_x_len = max(x_lens)
54
+
55
+ chars = [pad1d(x[0], max_x_len) for x in batch]
56
+ chars = np.stack(chars)
57
+
58
+ # Mel spectrogram
59
+ spec_lens = [x[1].shape[-1] for x in batch]
60
+ max_spec_len = max(spec_lens) + 1
61
+ if max_spec_len % r != 0:
62
+ max_spec_len += r - max_spec_len % r
63
+
64
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
65
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
66
+ if hparams.symmetric_mels:
67
+ mel_pad_value = -1 * hparams.max_abs_value
68
+ else:
69
+ mel_pad_value = 0
70
+
71
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
72
+ mel = np.stack(mel)
73
+
74
+ # Speaker embedding (SV2TTS)
75
+ embeds = np.array([x[2] for x in batch])
76
+
77
+ # Index (for vocoder preprocessing)
78
+ indices = [x[3] for x in batch]
79
+
80
+
81
+ # Convert all to tensor
82
+ chars = torch.tensor(chars).long()
83
+ mel = torch.tensor(mel)
84
+ embeds = torch.tensor(embeds)
85
+
86
+ return chars, mel, embeds, indices
87
+
88
+ def pad1d(x, max_len, pad_value=0):
89
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
90
+
91
+ def pad2d(x, max_len, pad_value=0):
92
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
synthesizer/train.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from functools import partial
3
+ from pathlib import Path
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from torch import optim
8
+ from torch.utils.data import DataLoader
9
+
10
+ from synthesizer import audio
11
+ from synthesizer.models.tacotron import Tacotron
12
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
13
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
14
+ from synthesizer.utils.plot import plot_spectrogram
15
+ from synthesizer.utils.symbols import symbols
16
+ from synthesizer.utils.text import sequence_to_text
17
+ from vocoder.display import *
18
+
19
+
20
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
21
+
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+
27
+ def train(run_id: str, syn_dir: Path, models_dir: Path, save_every: int, backup_every: int, force_restart: bool,
28
+ hparams):
29
+ models_dir.mkdir(exist_ok=True)
30
+
31
+ model_dir = models_dir.joinpath(run_id)
32
+ plot_dir = model_dir.joinpath("plots")
33
+ wav_dir = model_dir.joinpath("wavs")
34
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
35
+ meta_folder = model_dir.joinpath("metas")
36
+ model_dir.mkdir(exist_ok=True)
37
+ plot_dir.mkdir(exist_ok=True)
38
+ wav_dir.mkdir(exist_ok=True)
39
+ mel_output_dir.mkdir(exist_ok=True)
40
+ meta_folder.mkdir(exist_ok=True)
41
+
42
+ weights_fpath = model_dir / f"synthesizer.pt"
43
+ metadata_fpath = syn_dir.joinpath("train.txt")
44
+
45
+ print("Checkpoint path: {}".format(weights_fpath))
46
+ print("Loading training data from: {}".format(metadata_fpath))
47
+ print("Using model: Tacotron")
48
+
49
+ # Bookkeeping
50
+ time_window = ValueWindow(100)
51
+ loss_window = ValueWindow(100)
52
+
53
+ # From WaveRNN/train_tacotron.py
54
+ if torch.cuda.is_available():
55
+ device = torch.device("cuda")
56
+
57
+ for session in hparams.tts_schedule:
58
+ _, _, _, batch_size = session
59
+ if batch_size % torch.cuda.device_count() != 0:
60
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
61
+ else:
62
+ device = torch.device("cpu")
63
+ print("Using device:", device)
64
+
65
+ # Instantiate Tacotron Model
66
+ print("\nInitialising Tacotron Model...\n")
67
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
68
+ num_chars=len(symbols),
69
+ encoder_dims=hparams.tts_encoder_dims,
70
+ decoder_dims=hparams.tts_decoder_dims,
71
+ n_mels=hparams.num_mels,
72
+ fft_bins=hparams.num_mels,
73
+ postnet_dims=hparams.tts_postnet_dims,
74
+ encoder_K=hparams.tts_encoder_K,
75
+ lstm_dims=hparams.tts_lstm_dims,
76
+ postnet_K=hparams.tts_postnet_K,
77
+ num_highways=hparams.tts_num_highways,
78
+ dropout=hparams.tts_dropout,
79
+ stop_threshold=hparams.tts_stop_threshold,
80
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
81
+
82
+ # Initialize the optimizer
83
+ optimizer = optim.Adam(model.parameters())
84
+
85
+ # Load the weights
86
+ if force_restart or not weights_fpath.exists():
87
+ print("\nStarting the training of Tacotron from scratch\n")
88
+ model.save(weights_fpath)
89
+
90
+ # Embeddings metadata
91
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
92
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
93
+ for symbol in symbols:
94
+ if symbol == " ":
95
+ symbol = "\\s" # For visual purposes, swap space with \s
96
+
97
+ f.write("{}\n".format(symbol))
98
+
99
+ else:
100
+ print("\nLoading weights at %s" % weights_fpath)
101
+ model.load(weights_fpath, optimizer)
102
+ print("Tacotron weights loaded from step %d" % model.step)
103
+
104
+ # Initialize the dataset
105
+ metadata_fpath = syn_dir.joinpath("train.txt")
106
+ mel_dir = syn_dir.joinpath("mels")
107
+ embed_dir = syn_dir.joinpath("embeds")
108
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
109
+
110
+ for i, session in enumerate(hparams.tts_schedule):
111
+ current_step = model.get_step()
112
+
113
+ r, lr, max_step, batch_size = session
114
+
115
+ training_steps = max_step - current_step
116
+
117
+ # Do we need to change to the next session?
118
+ if current_step >= max_step:
119
+ # Are there no further sessions than the current one?
120
+ if i == len(hparams.tts_schedule) - 1:
121
+ # We have completed training. Save the model and exit
122
+ model.save(weights_fpath, optimizer)
123
+ break
124
+ else:
125
+ # There is a following session, go to it
126
+ continue
127
+
128
+ model.r = r
129
+
130
+ # Begin the training
131
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
132
+ ("Batch Size", batch_size),
133
+ ("Learning Rate", lr),
134
+ ("Outputs/Step (r)", model.r)])
135
+
136
+ for p in optimizer.param_groups:
137
+ p["lr"] = lr
138
+
139
+ collate_fn = partial(collate_synthesizer, r=r, hparams=hparams)
140
+ data_loader = DataLoader(dataset, batch_size, shuffle=True, num_workers=2, collate_fn=collate_fn)
141
+
142
+ total_iters = len(dataset)
143
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
144
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
145
+
146
+ for epoch in range(1, epochs+1):
147
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
148
+ start_time = time.time()
149
+
150
+ # Generate stop tokens for training
151
+ stop = torch.ones(mels.shape[0], mels.shape[2])
152
+ for j, k in enumerate(idx):
153
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
154
+
155
+ texts = texts.to(device)
156
+ mels = mels.to(device)
157
+ embeds = embeds.to(device)
158
+ stop = stop.to(device)
159
+
160
+ # Forward pass
161
+ # Parallelize model onto GPUS using workaround due to python bug
162
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
163
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts, mels, embeds)
164
+ else:
165
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
166
+
167
+ # Backward pass
168
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
169
+ m2_loss = F.mse_loss(m2_hat, mels)
170
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
171
+
172
+ loss = m1_loss + m2_loss + stop_loss
173
+
174
+ optimizer.zero_grad()
175
+ loss.backward()
176
+
177
+ if hparams.tts_clip_grad_norm is not None:
178
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
179
+ if np.isnan(grad_norm.cpu()):
180
+ print("grad_norm was NaN!")
181
+
182
+ optimizer.step()
183
+
184
+ time_window.append(time.time() - start_time)
185
+ loss_window.append(loss.item())
186
+
187
+ step = model.get_step()
188
+ k = step // 1000
189
+
190
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | " \
191
+ f"{1./time_window.average:#.2} steps/s | Step: {k}k | "
192
+ stream(msg)
193
+
194
+ # Backup or save model as appropriate
195
+ if backup_every != 0 and step % backup_every == 0 :
196
+ backup_fpath = weights_fpath.parent / f"synthesizer_{k:06d}.pt"
197
+ model.save(backup_fpath, optimizer)
198
+
199
+ if save_every != 0 and step % save_every == 0 :
200
+ # Must save latest optimizer state to ensure that resuming training
201
+ # doesn't produce artifacts
202
+ model.save(weights_fpath, optimizer)
203
+
204
+ # Evaluate model to generate samples
205
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
206
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
207
+ if epoch_eval or step_eval:
208
+ for sample_idx in range(hparams.tts_eval_num_samples):
209
+ # At most, generate samples equal to number in the batch
210
+ if sample_idx + 1 <= len(texts):
211
+ # Remove padding from mels using frame length in metadata
212
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
213
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
214
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
215
+ attention_len = mel_length // model.r
216
+
217
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
218
+ mel_prediction=mel_prediction,
219
+ target_spectrogram=target_spectrogram,
220
+ input_seq=np_now(texts[sample_idx]),
221
+ step=step,
222
+ plot_dir=plot_dir,
223
+ mel_output_dir=mel_output_dir,
224
+ wav_dir=wav_dir,
225
+ sample_num=sample_idx + 1,
226
+ loss=loss,
227
+ hparams=hparams)
228
+
229
+ # Break out of loop to update training schedule
230
+ if step >= max_step:
231
+ break
232
+
233
+ # Add line break after every epoch
234
+ print("")
235
+
236
+
237
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
238
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
239
+ # Save some results for evaluation
240
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
241
+ save_attention(attention, attention_path)
242
+
243
+ # save predicted mel spectrogram to disk (debug)
244
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
245
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
246
+
247
+ # save griffin lim inverted wav for debug (mel -> wav)
248
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
249
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
250
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
251
+
252
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
253
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
254
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
255
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
256
+ target_spectrogram=target_spectrogram,
257
+ max_len=target_spectrogram.size // hparams.num_mels)
258
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+ import re
13
+ from unidecode import unidecode
14
+ from synthesizer.utils.numbers import normalize_numbers
15
+
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+
5
+ _inflect = inflect.engine()
6
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
7
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
8
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
9
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
10
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
11
+ _number_re = re.compile(r"[0-9]+")
12
+
13
+
14
+ def _remove_commas(m):
15
+ return m.group(1).replace(",", "")
16
+
17
+
18
+ def _expand_decimal_point(m):
19
+ return m.group(1).replace(".", " point ")
20
+
21
+
22
+ def _expand_dollars(m):
23
+ match = m.group(1)
24
+ parts = match.split(".")
25
+ if len(parts) > 2:
26
+ return match + " dollars" # Unexpected format
27
+ dollars = int(parts[0]) if parts[0] else 0
28
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
29
+ if dollars and cents:
30
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
31
+ cent_unit = "cent" if cents == 1 else "cents"
32
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
33
+ elif dollars:
34
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
35
+ return "%s %s" % (dollars, dollar_unit)
36
+ elif cents:
37
+ cent_unit = "cent" if cents == 1 else "cents"
38
+ return "%s %s" % (cents, cent_unit)
39
+ else:
40
+ return "zero dollars"
41
+
42
+
43
+ def _expand_ordinal(m):
44
+ return _inflect.number_to_words(m.group(0))
45
+
46
+
47
+ def _expand_number(m):
48
+ num = int(m.group(0))
49
+ if num > 1000 and num < 3000:
50
+ if num == 2000:
51
+ return "two thousand"
52
+ elif num > 2000 and num < 2010:
53
+ return "two thousand " + _inflect.number_to_words(num % 100)
54
+ elif num % 100 == 0:
55
+ return _inflect.number_to_words(num // 100) + " hundred"
56
+ else:
57
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
58
+ else:
59
+ return _inflect.number_to_words(num, andword="")
60
+
61
+
62
+ def normalize_numbers(text):
63
+ text = re.sub(_comma_number_re, _remove_commas, text)
64
+ text = re.sub(_pounds_re, r"\1 pounds", text)
65
+ text = re.sub(_dollars_re, _expand_dollars, text)
66
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
67
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
68
+ text = re.sub(_number_re, _expand_number, text)
69
+ return text
synthesizer/utils/plot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def split_title_line(title_text, max_words=5):
5
+ """
6
+ A function that splits any string based on specific character
7
+ (returning it with the string), with maximum number of words on it
8
+ """
9
+ seq = title_text.split()
10
+ return "\n".join([" ".join(seq[i:i + max_words]) for i in range(0, len(seq), max_words)])
11
+
12
+
13
+ def plot_alignment(alignment, path, title=None, split_title=False, max_len=None):
14
+ import matplotlib
15
+ matplotlib.use("Agg")
16
+ import matplotlib.pyplot as plt
17
+
18
+ if max_len is not None:
19
+ alignment = alignment[:, :max_len]
20
+
21
+ fig = plt.figure(figsize=(8, 6))
22
+ ax = fig.add_subplot(111)
23
+
24
+ im = ax.imshow(
25
+ alignment,
26
+ aspect="auto",
27
+ origin="lower",
28
+ interpolation="none")
29
+ fig.colorbar(im, ax=ax)
30
+ xlabel = "Decoder timestep"
31
+
32
+ if split_title:
33
+ title = split_title_line(title)
34
+
35
+ plt.xlabel(xlabel)
36
+ plt.title(title)
37
+ plt.ylabel("Encoder timestep")
38
+ plt.tight_layout()
39
+ plt.savefig(path, format="png")
40
+ plt.close()
41
+
42
+
43
+ def plot_spectrogram(pred_spectrogram, path, title=None, split_title=False, target_spectrogram=None, max_len=None, auto_aspect=False):
44
+ import matplotlib
45
+ matplotlib.use("Agg")
46
+ import matplotlib.pyplot as plt
47
+
48
+ if max_len is not None:
49
+ target_spectrogram = target_spectrogram[:max_len]
50
+ pred_spectrogram = pred_spectrogram[:max_len]
51
+
52
+ if split_title:
53
+ title = split_title_line(title)
54
+
55
+ fig = plt.figure(figsize=(10, 8))
56
+ # Set common labels
57
+ fig.text(0.5, 0.18, title, horizontalalignment="center", fontsize=16)
58
+
59
+ #target spectrogram subplot
60
+ if target_spectrogram is not None:
61
+ ax1 = fig.add_subplot(311)
62
+ ax2 = fig.add_subplot(312)
63
+
64
+ if auto_aspect:
65
+ im = ax1.imshow(np.rot90(target_spectrogram), aspect="auto", interpolation="none")
66
+ else:
67
+ im = ax1.imshow(np.rot90(target_spectrogram), interpolation="none")
68
+ ax1.set_title("Target Mel-Spectrogram")
69
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax1)
70
+ ax2.set_title("Predicted Mel-Spectrogram")
71
+ else:
72
+ ax2 = fig.add_subplot(211)
73
+
74
+ if auto_aspect:
75
+ im = ax2.imshow(np.rot90(pred_spectrogram), aspect="auto", interpolation="none")
76
+ else:
77
+ im = ax2.imshow(np.rot90(pred_spectrogram), interpolation="none")
78
+ fig.colorbar(mappable=im, shrink=0.65, orientation="horizontal", ax=ax2)
79
+
80
+ plt.tight_layout()
81
+ plt.savefig(path, format="png")
82
+ plt.close()
synthesizer/utils/symbols.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Defines the set of symbols used in text input to the model.
3
+
4
+ The default is a set of ASCII characters that works well for English or text that has been run
5
+ through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details.
6
+ """
7
+ # from . import cmudict
8
+
9
+ _pad = "_"
10
+ _eos = "~"
11
+ _characters = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? "
12
+
13
+ # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
14
+ #_arpabet = ["@' + s for s in cmudict.valid_symbols]
15
+
16
+ # Export all symbols:
17
+ symbols = [_pad, _eos] + list(_characters) #+ _arpabet
synthesizer/utils/text.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from synthesizer.utils.symbols import symbols
2
+ from synthesizer.utils import cleaners
3
+ import re
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+ # Regular expression matching text enclosed in curly braces:
11
+ _curly_re = re.compile(r"(.*?)\{(.+?)\}(.*)")
12
+
13
+
14
+ def text_to_sequence(text, cleaner_names):
15
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
16
+
17
+ The text can optionally have ARPAbet sequences enclosed in curly braces embedded
18
+ in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
19
+
20
+ Args:
21
+ text: string to convert to a sequence
22
+ cleaner_names: names of the cleaner functions to run the text through
23
+
24
+ Returns:
25
+ List of integers corresponding to the symbols in the text
26
+ """
27
+ sequence = []
28
+
29
+ # Check for curly braces and treat their contents as ARPAbet:
30
+ while len(text):
31
+ m = _curly_re.match(text)
32
+ if not m:
33
+ sequence += _symbols_to_sequence(_clean_text(text, cleaner_names))
34
+ break
35
+ sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
36
+ sequence += _arpabet_to_sequence(m.group(2))
37
+ text = m.group(3)
38
+
39
+ # Append EOS token
40
+ sequence.append(_symbol_to_id["~"])
41
+ return sequence
42
+
43
+
44
+ def sequence_to_text(sequence):
45
+ """Converts a sequence of IDs back to a string"""
46
+ result = ""
47
+ for symbol_id in sequence:
48
+ if symbol_id in _id_to_symbol:
49
+ s = _id_to_symbol[symbol_id]
50
+ # Enclose ARPAbet back in curly braces:
51
+ if len(s) > 1 and s[0] == "@":
52
+ s = "{%s}" % s[1:]
53
+ result += s
54
+ return result.replace("}{", " ")
55
+
56
+
57
+ def _clean_text(text, cleaner_names):
58
+ for name in cleaner_names:
59
+ cleaner = getattr(cleaners, name)
60
+ if not cleaner:
61
+ raise Exception("Unknown cleaner: %s" % name)
62
+ text = cleaner(text)
63
+ return text
64
+
65
+
66
+ def _symbols_to_sequence(symbols):
67
+ return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
68
+
69
+
70
+ def _arpabet_to_sequence(text):
71
+ return _symbols_to_sequence(["@" + s for s in text.split()])
72
+
73
+
74
+ def _should_keep_symbol(s):
75
+ return s in _symbol_to_id and s not in ("_", "~")
trump.mp3 ADDED
Binary file (239 kB). View file
 
utils/__init__.py ADDED
File without changes
utils/argutils.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import argparse
4
+
5
+ _type_priorities = [ # In decreasing order
6
+ Path,
7
+ str,
8
+ int,
9
+ float,
10
+ bool,
11
+ ]
12
+
13
+ def _priority(o):
14
+ p = next((i for i, t in enumerate(_type_priorities) if type(o) is t), None)
15
+ if p is not None:
16
+ return p
17
+ p = next((i for i, t in enumerate(_type_priorities) if isinstance(o, t)), None)
18
+ if p is not None:
19
+ return p
20
+ return len(_type_priorities)
21
+
22
+ def print_args(args: argparse.Namespace, parser=None):
23
+ args = vars(args)
24
+ if parser is None:
25
+ priorities = list(map(_priority, args.values()))
26
+ else:
27
+ all_params = [a.dest for g in parser._action_groups for a in g._group_actions ]
28
+ priority = lambda p: all_params.index(p) if p in all_params else len(all_params)
29
+ priorities = list(map(priority, args.keys()))
30
+
31
+ pad = max(map(len, args.keys())) + 3
32
+ indices = np.lexsort((list(args.keys()), priorities))
33
+ items = list(args.items())
34
+
35
+ print("Arguments:")
36
+ for i in indices:
37
+ param, value = items[i]
38
+ print(" {0}:{1}{2}".format(param, ' ' * (pad - len(param)), value))
39
+ print("")
40
+
utils/default_models.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import urllib.request
2
+ from pathlib import Path
3
+ from threading import Thread
4
+ from urllib.error import HTTPError
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ default_models = {
10
+ "encoder": ("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1", 17090379),
11
+ "synthesizer": ("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t", 370554559),
12
+ "vocoder": ("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu", 53845290),
13
+ }
14
+
15
+
16
+ class DownloadProgressBar(tqdm):
17
+ def update_to(self, b=1, bsize=1, tsize=None):
18
+ if tsize is not None:
19
+ self.total = tsize
20
+ self.update(b * bsize - self.n)
21
+
22
+
23
+ def download(url: str, target: Path, bar_pos=0):
24
+ # Ensure the directory exists
25
+ target.parent.mkdir(exist_ok=True, parents=True)
26
+
27
+ desc = f"Downloading {target.name}"
28
+ with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=desc, position=bar_pos, leave=False) as t:
29
+ try:
30
+ urllib.request.urlretrieve(url, filename=target, reporthook=t.update_to)
31
+ except HTTPError:
32
+ return
33
+
34
+
35
+ def ensure_default_models(models_dir: Path):
36
+ # Define download tasks
37
+ jobs = []
38
+ for model_name, (url, size) in default_models.items():
39
+ target_path = models_dir / "default" / f"{model_name}.pt"
40
+ if target_path.exists():
41
+ if target_path.stat().st_size != size:
42
+ print(f"File {target_path} is not of expected size, redownloading...")
43
+ else:
44
+ continue
45
+
46
+ thread = Thread(target=download, args=(url, target_path, len(jobs)))
47
+ thread.start()
48
+ jobs.append((thread, target_path, size))
49
+
50
+ # Run and join threads
51
+ for thread, target_path, size in jobs:
52
+ thread.join()
53
+
54
+ assert target_path.exists() and target_path.stat().st_size == size, \
55
+ f"Download for {target_path.name} failed. You may download models manually instead.\n" \
56
+ f"https://drive.google.com/drive/folders/1fU6umc5uQAVR2udZdHX-lDgXYzTyqG_j"
utils/logmmse.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ #
3
+ # Copyright (c) 2015 braindead
4
+ #
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+ #
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+ #
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE.
22
+ #
23
+ #
24
+ # This code was extracted from the logmmse package (https://pypi.org/project/logmmse/) and I
25
+ # simply modified the interface to meet my needs.
26
+
27
+
28
+ import numpy as np
29
+ import math
30
+ from scipy.special import expn
31
+ from collections import namedtuple
32
+
33
+ NoiseProfile = namedtuple("NoiseProfile", "sampling_rate window_size len1 len2 win n_fft noise_mu2")
34
+
35
+
36
+ def profile_noise(noise, sampling_rate, window_size=0):
37
+ """
38
+ Creates a profile of the noise in a given waveform.
39
+
40
+ :param noise: a waveform containing noise ONLY, as a numpy array of floats or ints.
41
+ :param sampling_rate: the sampling rate of the audio
42
+ :param window_size: the size of the window the logmmse algorithm operates on. A default value
43
+ will be picked if left as 0.
44
+ :return: a NoiseProfile object
45
+ """
46
+ noise, dtype = to_float(noise)
47
+ noise += np.finfo(np.float64).eps
48
+
49
+ if window_size == 0:
50
+ window_size = int(math.floor(0.02 * sampling_rate))
51
+
52
+ if window_size % 2 == 1:
53
+ window_size = window_size + 1
54
+
55
+ perc = 50
56
+ len1 = int(math.floor(window_size * perc / 100))
57
+ len2 = int(window_size - len1)
58
+
59
+ win = np.hanning(window_size)
60
+ win = win * len2 / np.sum(win)
61
+ n_fft = 2 * window_size
62
+
63
+ noise_mean = np.zeros(n_fft)
64
+ n_frames = len(noise) // window_size
65
+ for j in range(0, window_size * n_frames, window_size):
66
+ noise_mean += np.absolute(np.fft.fft(win * noise[j:j + window_size], n_fft, axis=0))
67
+ noise_mu2 = (noise_mean / n_frames) ** 2
68
+
69
+ return NoiseProfile(sampling_rate, window_size, len1, len2, win, n_fft, noise_mu2)
70
+
71
+
72
+ def denoise(wav, noise_profile: NoiseProfile, eta=0.15):
73
+ """
74
+ Cleans the noise from a speech waveform given a noise profile. The waveform must have the
75
+ same sampling rate as the one used to create the noise profile.
76
+
77
+ :param wav: a speech waveform as a numpy array of floats or ints.
78
+ :param noise_profile: a NoiseProfile object that was created from a similar (or a segment of
79
+ the same) waveform.
80
+ :param eta: voice threshold for noise update. While the voice activation detection value is
81
+ below this threshold, the noise profile will be continuously updated throughout the audio.
82
+ Set to 0 to disable updating the noise profile.
83
+ :return: the clean wav as a numpy array of floats or ints of the same length.
84
+ """
85
+ wav, dtype = to_float(wav)
86
+ wav += np.finfo(np.float64).eps
87
+ p = noise_profile
88
+
89
+ nframes = int(math.floor(len(wav) / p.len2) - math.floor(p.window_size / p.len2))
90
+ x_final = np.zeros(nframes * p.len2)
91
+
92
+ aa = 0.98
93
+ mu = 0.98
94
+ ksi_min = 10 ** (-25 / 10)
95
+
96
+ x_old = np.zeros(p.len1)
97
+ xk_prev = np.zeros(p.len1)
98
+ noise_mu2 = p.noise_mu2
99
+ for k in range(0, nframes * p.len2, p.len2):
100
+ insign = p.win * wav[k:k + p.window_size]
101
+
102
+ spec = np.fft.fft(insign, p.n_fft, axis=0)
103
+ sig = np.absolute(spec)
104
+ sig2 = sig ** 2
105
+
106
+ gammak = np.minimum(sig2 / noise_mu2, 40)
107
+
108
+ if xk_prev.all() == 0:
109
+ ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
110
+ else:
111
+ ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
112
+ ksi = np.maximum(ksi_min, ksi)
113
+
114
+ log_sigma_k = gammak * ksi/(1 + ksi) - np.log(1 + ksi)
115
+ vad_decision = np.sum(log_sigma_k) / p.window_size
116
+ if vad_decision < eta:
117
+ noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
118
+
119
+ a = ksi / (1 + ksi)
120
+ vk = a * gammak
121
+ ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
122
+ hw = a * np.exp(ei_vk)
123
+ sig = sig * hw
124
+ xk_prev = sig ** 2
125
+ xi_w = np.fft.ifft(hw * spec, p.n_fft, axis=0)
126
+ xi_w = np.real(xi_w)
127
+
128
+ x_final[k:k + p.len2] = x_old + xi_w[0:p.len1]
129
+ x_old = xi_w[p.len1:p.window_size]
130
+
131
+ output = from_float(x_final, dtype)
132
+ output = np.pad(output, (0, len(wav) - len(output)), mode="constant")
133
+ return output
134
+
135
+
136
+ ## Alternative VAD algorithm to webrctvad. It has the advantage of not requiring to install that
137
+ ## darn package and it also works for any sampling rate. Maybe I'll eventually use it instead of
138
+ ## webrctvad
139
+ # def vad(wav, sampling_rate, eta=0.15, window_size=0):
140
+ # """
141
+ # TODO: fix doc
142
+ # Creates a profile of the noise in a given waveform.
143
+ #
144
+ # :param wav: a waveform containing noise ONLY, as a numpy array of floats or ints.
145
+ # :param sampling_rate: the sampling rate of the audio
146
+ # :param window_size: the size of the window the logmmse algorithm operates on. A default value
147
+ # will be picked if left as 0.
148
+ # :param eta: voice threshold for noise update. While the voice activation detection value is
149
+ # below this threshold, the noise profile will be continuously updated throughout the audio.
150
+ # Set to 0 to disable updating the noise profile.
151
+ # """
152
+ # wav, dtype = to_float(wav)
153
+ # wav += np.finfo(np.float64).eps
154
+ #
155
+ # if window_size == 0:
156
+ # window_size = int(math.floor(0.02 * sampling_rate))
157
+ #
158
+ # if window_size % 2 == 1:
159
+ # window_size = window_size + 1
160
+ #
161
+ # perc = 50
162
+ # len1 = int(math.floor(window_size * perc / 100))
163
+ # len2 = int(window_size - len1)
164
+ #
165
+ # win = np.hanning(window_size)
166
+ # win = win * len2 / np.sum(win)
167
+ # n_fft = 2 * window_size
168
+ #
169
+ # wav_mean = np.zeros(n_fft)
170
+ # n_frames = len(wav) // window_size
171
+ # for j in range(0, window_size * n_frames, window_size):
172
+ # wav_mean += np.absolute(np.fft.fft(win * wav[j:j + window_size], n_fft, axis=0))
173
+ # noise_mu2 = (wav_mean / n_frames) ** 2
174
+ #
175
+ # wav, dtype = to_float(wav)
176
+ # wav += np.finfo(np.float64).eps
177
+ #
178
+ # nframes = int(math.floor(len(wav) / len2) - math.floor(window_size / len2))
179
+ # vad = np.zeros(nframes * len2, dtype=np.bool)
180
+ #
181
+ # aa = 0.98
182
+ # mu = 0.98
183
+ # ksi_min = 10 ** (-25 / 10)
184
+ #
185
+ # xk_prev = np.zeros(len1)
186
+ # noise_mu2 = noise_mu2
187
+ # for k in range(0, nframes * len2, len2):
188
+ # insign = win * wav[k:k + window_size]
189
+ #
190
+ # spec = np.fft.fft(insign, n_fft, axis=0)
191
+ # sig = np.absolute(spec)
192
+ # sig2 = sig ** 2
193
+ #
194
+ # gammak = np.minimum(sig2 / noise_mu2, 40)
195
+ #
196
+ # if xk_prev.all() == 0:
197
+ # ksi = aa + (1 - aa) * np.maximum(gammak - 1, 0)
198
+ # else:
199
+ # ksi = aa * xk_prev / noise_mu2 + (1 - aa) * np.maximum(gammak - 1, 0)
200
+ # ksi = np.maximum(ksi_min, ksi)
201
+ #
202
+ # log_sigma_k = gammak * ksi / (1 + ksi) - np.log(1 + ksi)
203
+ # vad_decision = np.sum(log_sigma_k) / window_size
204
+ # if vad_decision < eta:
205
+ # noise_mu2 = mu * noise_mu2 + (1 - mu) * sig2
206
+ # print(vad_decision)
207
+ #
208
+ # a = ksi / (1 + ksi)
209
+ # vk = a * gammak
210
+ # ei_vk = 0.5 * expn(1, np.maximum(vk, 1e-8))
211
+ # hw = a * np.exp(ei_vk)
212
+ # sig = sig * hw
213
+ # xk_prev = sig ** 2
214
+ #
215
+ # vad[k:k + len2] = vad_decision >= eta
216
+ #
217
+ # vad = np.pad(vad, (0, len(wav) - len(vad)), mode="constant")
218
+ # return vad
219
+
220
+
221
+ def to_float(_input):
222
+ if _input.dtype == np.float64:
223
+ return _input, _input.dtype
224
+ elif _input.dtype == np.float32:
225
+ return _input.astype(np.float64), _input.dtype
226
+ elif _input.dtype == np.uint8:
227
+ return (_input - 128) / 128., _input.dtype
228
+ elif _input.dtype == np.int16:
229
+ return _input / 32768., _input.dtype
230
+ elif _input.dtype == np.int32:
231
+ return _input / 2147483648., _input.dtype
232
+ raise ValueError('Unsupported wave file format')
233
+
234
+
235
+ def from_float(_input, dtype):
236
+ if dtype == np.float64:
237
+ return _input, np.float64
238
+ elif dtype == np.float32:
239
+ return _input.astype(np.float32)
240
+ elif dtype == np.uint8:
241
+ return ((_input * 128) + 128).astype(np.uint8)
242
+ elif dtype == np.int16:
243
+ return (_input * 32768).astype(np.int16)
244
+ elif dtype == np.int32:
245
+ print(_input)
246
+ return (_input * 2147483648).astype(np.int32)
247
+ raise ValueError('Unsupported wave file format')
utils/profiler.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from time import perf_counter as timer
2
+ from collections import OrderedDict
3
+ import numpy as np
4
+
5
+
6
+ class Profiler:
7
+ def __init__(self, summarize_every=5, disabled=False):
8
+ self.last_tick = timer()
9
+ self.logs = OrderedDict()
10
+ self.summarize_every = summarize_every
11
+ self.disabled = disabled
12
+
13
+ def tick(self, name):
14
+ if self.disabled:
15
+ return
16
+
17
+ # Log the time needed to execute that function
18
+ if not name in self.logs:
19
+ self.logs[name] = []
20
+ if len(self.logs[name]) >= self.summarize_every:
21
+ self.summarize()
22
+ self.purge_logs()
23
+ self.logs[name].append(timer() - self.last_tick)
24
+
25
+ self.reset_timer()
26
+
27
+ def purge_logs(self):
28
+ for name in self.logs:
29
+ self.logs[name].clear()
30
+
31
+ def reset_timer(self):
32
+ self.last_tick = timer()
33
+
34
+ def summarize(self):
35
+ n = max(map(len, self.logs.values()))
36
+ assert n == self.summarize_every
37
+ print("\nAverage execution time over %d steps:" % n)
38
+
39
+ name_msgs = ["%s (%d/%d):" % (name, len(deltas), n) for name, deltas in self.logs.items()]
40
+ pad = max(map(len, name_msgs))
41
+ for name_msg, deltas in zip(name_msgs, self.logs.values()):
42
+ print(" %s mean: %4.0fms std: %4.0fms" %
43
+ (name_msg.ljust(pad), np.mean(deltas) * 1000, np.std(deltas) * 1000))
44
+ print("", flush=True)
45
+
vocoder/LICENSE.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
4
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
vocoder/audio.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import librosa
4
+ import vocoder.hparams as hp
5
+ from scipy.signal import lfilter
6
+ import soundfile as sf
7
+
8
+
9
+ def label_2_float(x, bits) :
10
+ return 2 * x / (2**bits - 1.) - 1.
11
+
12
+
13
+ def float_2_label(x, bits) :
14
+ assert abs(x).max() <= 1.0
15
+ x = (x + 1.) * (2**bits - 1) / 2
16
+ return x.clip(0, 2**bits - 1)
17
+
18
+
19
+ def load_wav(path) :
20
+ return librosa.load(str(path), sr=hp.sample_rate)[0]
21
+
22
+
23
+ def save_wav(x, path) :
24
+ sf.write(path, x.astype(np.float32), hp.sample_rate)
25
+
26
+
27
+ def split_signal(x) :
28
+ unsigned = x + 2**15
29
+ coarse = unsigned // 256
30
+ fine = unsigned % 256
31
+ return coarse, fine
32
+
33
+
34
+ def combine_signal(coarse, fine) :
35
+ return coarse * 256 + fine - 2**15
36
+
37
+
38
+ def encode_16bits(x) :
39
+ return np.clip(x * 2**15, -2**15, 2**15 - 1).astype(np.int16)
40
+
41
+
42
+ mel_basis = None
43
+
44
+
45
+ def linear_to_mel(spectrogram):
46
+ global mel_basis
47
+ if mel_basis is None:
48
+ mel_basis = build_mel_basis()
49
+ return np.dot(mel_basis, spectrogram)
50
+
51
+
52
+ def build_mel_basis():
53
+ return librosa.filters.mel(hp.sample_rate, hp.n_fft, n_mels=hp.num_mels, fmin=hp.fmin)
54
+
55
+
56
+ def normalize(S):
57
+ return np.clip((S - hp.min_level_db) / -hp.min_level_db, 0, 1)
58
+
59
+
60
+ def denormalize(S):
61
+ return (np.clip(S, 0, 1) * -hp.min_level_db) + hp.min_level_db
62
+
63
+
64
+ def amp_to_db(x):
65
+ return 20 * np.log10(np.maximum(1e-5, x))
66
+
67
+
68
+ def db_to_amp(x):
69
+ return np.power(10.0, x * 0.05)
70
+
71
+
72
+ def spectrogram(y):
73
+ D = stft(y)
74
+ S = amp_to_db(np.abs(D)) - hp.ref_level_db
75
+ return normalize(S)
76
+
77
+
78
+ def melspectrogram(y):
79
+ D = stft(y)
80
+ S = amp_to_db(linear_to_mel(np.abs(D)))
81
+ return normalize(S)
82
+
83
+
84
+ def stft(y):
85
+ return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=hp.hop_length, win_length=hp.win_length)
86
+
87
+
88
+ def pre_emphasis(x):
89
+ return lfilter([1, -hp.preemphasis], [1], x)
90
+
91
+
92
+ def de_emphasis(x):
93
+ return lfilter([1], [1, -hp.preemphasis], x)
94
+
95
+
96
+ def encode_mu_law(x, mu) :
97
+ mu = mu - 1
98
+ fx = np.sign(x) * np.log(1 + mu * np.abs(x)) / np.log(1 + mu)
99
+ return np.floor((fx + 1) / 2 * mu + 0.5)
100
+
101
+
102
+ def decode_mu_law(y, mu, from_labels=True) :
103
+ if from_labels:
104
+ y = label_2_float(y, math.log2(mu))
105
+ mu = mu - 1
106
+ x = np.sign(y) / mu * ((1 + mu) ** np.abs(y) - 1)
107
+ return x
108
+
vocoder/display.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ import sys
4
+
5
+
6
+ def progbar(i, n, size=16):
7
+ done = (i * size) // n
8
+ bar = ''
9
+ for i in range(size):
10
+ bar += '█' if i <= done else '░'
11
+ return bar
12
+
13
+
14
+ def stream(message) :
15
+ try:
16
+ sys.stdout.write("\r{%s}" % message)
17
+ except:
18
+ #Remove non-ASCII characters from message
19
+ message = ''.join(i for i in message if ord(i)<128)
20
+ sys.stdout.write("\r{%s}" % message)
21
+
22
+
23
+ def simple_table(item_tuples) :
24
+
25
+ border_pattern = '+---------------------------------------'
26
+ whitespace = ' '
27
+
28
+ headings, cells, = [], []
29
+
30
+ for item in item_tuples :
31
+
32
+ heading, cell = str(item[0]), str(item[1])
33
+
34
+ pad_head = True if len(heading) < len(cell) else False
35
+
36
+ pad = abs(len(heading) - len(cell))
37
+ pad = whitespace[:pad]
38
+
39
+ pad_left = pad[:len(pad)//2]
40
+ pad_right = pad[len(pad)//2:]
41
+
42
+ if pad_head :
43
+ heading = pad_left + heading + pad_right
44
+ else :
45
+ cell = pad_left + cell + pad_right
46
+
47
+ headings += [heading]
48
+ cells += [cell]
49
+
50
+ border, head, body = '', '', ''
51
+
52
+ for i in range(len(item_tuples)) :
53
+
54
+ temp_head = f'| {headings[i]} '
55
+ temp_body = f'| {cells[i]} '
56
+
57
+ border += border_pattern[:len(temp_head)]
58
+ head += temp_head
59
+ body += temp_body
60
+
61
+ if i == len(item_tuples) - 1 :
62
+ head += '|'
63
+ body += '|'
64
+ border += '+'
65
+
66
+ print(border)
67
+ print(head)
68
+ print(border)
69
+ print(body)
70
+ print(border)
71
+ print(' ')
72
+
73
+
74
+ def time_since(started) :
75
+ elapsed = time.time() - started
76
+ m = int(elapsed // 60)
77
+ s = int(elapsed % 60)
78
+ if m >= 60 :
79
+ h = int(m // 60)
80
+ m = m % 60
81
+ return f'{h}h {m}m {s}s'
82
+ else :
83
+ return f'{m}m {s}s'
84
+
85
+
86
+ def save_attention(attn, path):
87
+ import matplotlib.pyplot as plt
88
+
89
+ fig = plt.figure(figsize=(12, 6))
90
+ plt.imshow(attn.T, interpolation='nearest', aspect='auto')
91
+ fig.savefig(f'{path}.png', bbox_inches='tight')
92
+ plt.close(fig)
93
+
94
+
95
+ def save_spectrogram(M, path, length=None):
96
+ import matplotlib.pyplot as plt
97
+
98
+ M = np.flip(M, axis=0)
99
+ if length : M = M[:, :length]
100
+ fig = plt.figure(figsize=(12, 6))
101
+ plt.imshow(M, interpolation='nearest', aspect='auto')
102
+ fig.savefig(f'{path}.png', bbox_inches='tight')
103
+ plt.close(fig)
104
+
105
+
106
+ def plot(array):
107
+ import matplotlib.pyplot as plt
108
+
109
+ fig = plt.figure(figsize=(30, 5))
110
+ ax = fig.add_subplot(111)
111
+ ax.xaxis.label.set_color('grey')
112
+ ax.yaxis.label.set_color('grey')
113
+ ax.xaxis.label.set_fontsize(23)
114
+ ax.yaxis.label.set_fontsize(23)
115
+ ax.tick_params(axis='x', colors='grey', labelsize=23)
116
+ ax.tick_params(axis='y', colors='grey', labelsize=23)
117
+ plt.plot(array)
118
+
119
+
120
+ def plot_spec(M):
121
+ import matplotlib.pyplot as plt
122
+
123
+ M = np.flip(M, axis=0)
124
+ plt.figure(figsize=(18,4))
125
+ plt.imshow(M, interpolation='nearest', aspect='auto')
126
+ plt.show()
127
+
vocoder/distribution.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def log_sum_exp(x):
7
+ """ numerically stable log_sum_exp implementation that prevents overflow """
8
+ # TF ordering
9
+ axis = len(x.size()) - 1
10
+ m, _ = torch.max(x, dim=axis)
11
+ m2, _ = torch.max(x, dim=axis, keepdim=True)
12
+ return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))
13
+
14
+
15
+ # It is adapted from https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py
16
+ def discretized_mix_logistic_loss(y_hat, y, num_classes=65536,
17
+ log_scale_min=None, reduce=True):
18
+ if log_scale_min is None:
19
+ log_scale_min = float(np.log(1e-14))
20
+ y_hat = y_hat.permute(0,2,1)
21
+ assert y_hat.dim() == 3
22
+ assert y_hat.size(1) % 3 == 0
23
+ nr_mix = y_hat.size(1) // 3
24
+
25
+ # (B x T x C)
26
+ y_hat = y_hat.transpose(1, 2)
27
+
28
+ # unpack parameters. (B, T, num_mixtures) x 3
29
+ logit_probs = y_hat[:, :, :nr_mix]
30
+ means = y_hat[:, :, nr_mix:2 * nr_mix]
31
+ log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix:3 * nr_mix], min=log_scale_min)
32
+
33
+ # B x T x 1 -> B x T x num_mixtures
34
+ y = y.expand_as(means)
35
+
36
+ centered_y = y - means
37
+ inv_stdv = torch.exp(-log_scales)
38
+ plus_in = inv_stdv * (centered_y + 1. / (num_classes - 1))
39
+ cdf_plus = torch.sigmoid(plus_in)
40
+ min_in = inv_stdv * (centered_y - 1. / (num_classes - 1))
41
+ cdf_min = torch.sigmoid(min_in)
42
+
43
+ # log probability for edge case of 0 (before scaling)
44
+ # equivalent: torch.log(F.sigmoid(plus_in))
45
+ log_cdf_plus = plus_in - F.softplus(plus_in)
46
+
47
+ # log probability for edge case of 255 (before scaling)
48
+ # equivalent: (1 - F.sigmoid(min_in)).log()
49
+ log_one_minus_cdf_min = -F.softplus(min_in)
50
+
51
+ # probability for all other cases
52
+ cdf_delta = cdf_plus - cdf_min
53
+
54
+ mid_in = inv_stdv * centered_y
55
+ # log probability in the center of the bin, to be used in extreme cases
56
+ # (not actually used in our code)
57
+ log_pdf_mid = mid_in - log_scales - 2. * F.softplus(mid_in)
58
+
59
+ # tf equivalent
60
+ """
61
+ log_probs = tf.where(x < -0.999, log_cdf_plus,
62
+ tf.where(x > 0.999, log_one_minus_cdf_min,
63
+ tf.where(cdf_delta > 1e-5,
64
+ tf.log(tf.maximum(cdf_delta, 1e-12)),
65
+ log_pdf_mid - np.log(127.5))))
66
+ """
67
+ # TODO: cdf_delta <= 1e-5 actually can happen. How can we choose the value
68
+ # for num_classes=65536 case? 1e-7? not sure..
69
+ inner_inner_cond = (cdf_delta > 1e-5).float()
70
+
71
+ inner_inner_out = inner_inner_cond * \
72
+ torch.log(torch.clamp(cdf_delta, min=1e-12)) + \
73
+ (1. - inner_inner_cond) * (log_pdf_mid - np.log((num_classes - 1) / 2))
74
+ inner_cond = (y > 0.999).float()
75
+ inner_out = inner_cond * log_one_minus_cdf_min + (1. - inner_cond) * inner_inner_out
76
+ cond = (y < -0.999).float()
77
+ log_probs = cond * log_cdf_plus + (1. - cond) * inner_out
78
+
79
+ log_probs = log_probs + F.log_softmax(logit_probs, -1)
80
+
81
+ if reduce:
82
+ return -torch.mean(log_sum_exp(log_probs))
83
+ else:
84
+ return -log_sum_exp(log_probs).unsqueeze(-1)
85
+
86
+
87
+ def sample_from_discretized_mix_logistic(y, log_scale_min=None):
88
+ """
89
+ Sample from discretized mixture of logistic distributions
90
+ Args:
91
+ y (Tensor): B x C x T
92
+ log_scale_min (float): Log scale minimum value
93
+ Returns:
94
+ Tensor: sample in range of [-1, 1].
95
+ """
96
+ if log_scale_min is None:
97
+ log_scale_min = float(np.log(1e-14))
98
+ assert y.size(1) % 3 == 0
99
+ nr_mix = y.size(1) // 3
100
+
101
+ # B x T x C
102
+ y = y.transpose(1, 2)
103
+ logit_probs = y[:, :, :nr_mix]
104
+
105
+ # sample mixture indicator from softmax
106
+ temp = logit_probs.data.new(logit_probs.size()).uniform_(1e-5, 1.0 - 1e-5)
107
+ temp = logit_probs.data - torch.log(- torch.log(temp))
108
+ _, argmax = temp.max(dim=-1)
109
+
110
+ # (B, T) -> (B, T, nr_mix)
111
+ one_hot = to_one_hot(argmax, nr_mix)
112
+ # select logistic parameters
113
+ means = torch.sum(y[:, :, nr_mix:2 * nr_mix] * one_hot, dim=-1)
114
+ log_scales = torch.clamp(torch.sum(
115
+ y[:, :, 2 * nr_mix:3 * nr_mix] * one_hot, dim=-1), min=log_scale_min)
116
+ # sample from logistic & clip to interval
117
+ # we don't actually round to the nearest 8bit value when sampling
118
+ u = means.data.new(means.size()).uniform_(1e-5, 1.0 - 1e-5)
119
+ x = means + torch.exp(log_scales) * (torch.log(u) - torch.log(1. - u))
120
+
121
+ x = torch.clamp(torch.clamp(x, min=-1.), max=1.)
122
+
123
+ return x
124
+
125
+
126
+ def to_one_hot(tensor, n, fill_with=1.):
127
+ # we perform one hot encore with respect to the last axis
128
+ one_hot = torch.FloatTensor(tensor.size() + (n,)).zero_()
129
+ if tensor.is_cuda:
130
+ one_hot = one_hot.cuda()
131
+ one_hot.scatter_(len(tensor.size()), tensor.unsqueeze(-1), fill_with)
132
+ return one_hot