GipAdonimus akhaliq HF staff commited on
Commit
65f61dd
0 Parent(s):

Duplicate from akhaliq/Real-Time-Voice-Cloning

Browse files

Co-authored-by: AK <akhaliq@users.noreply.huggingface.co>

This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +27 -0
  2. 8230_00000.mp3 +0 -0
  3. README.md +39 -0
  4. app.py +61 -0
  5. demo_cli.py +206 -0
  6. demo_toolbox.py +43 -0
  7. encoder/__init__.py +0 -0
  8. encoder/audio.py +117 -0
  9. encoder/config.py +45 -0
  10. encoder/data_objects/__init__.py +2 -0
  11. encoder/data_objects/random_cycler.py +37 -0
  12. encoder/data_objects/speaker.py +40 -0
  13. encoder/data_objects/speaker_batch.py +12 -0
  14. encoder/data_objects/speaker_verification_dataset.py +56 -0
  15. encoder/data_objects/utterance.py +26 -0
  16. encoder/inference.py +178 -0
  17. encoder/model.py +135 -0
  18. encoder/params_data.py +29 -0
  19. encoder/params_model.py +11 -0
  20. encoder/preprocess.py +175 -0
  21. encoder/saved_models/text.txt +0 -0
  22. encoder/train.py +123 -0
  23. encoder/visualizations.py +178 -0
  24. encoder_preprocess.py +70 -0
  25. encoder_train.py +47 -0
  26. packages.txt +5 -0
  27. requirements.txt +19 -0
  28. samples/.DS_Store +0 -0
  29. samples/1320_00000.mp3 +0 -0
  30. samples/3575_00000.mp3 +0 -0
  31. samples/6829_00000.mp3 +0 -0
  32. samples/README.md +22 -0
  33. samples/VCTK.txt +94 -0
  34. samples/p240_00000.mp3 +0 -0
  35. samples/p260_00000.mp3 +0 -0
  36. synthesizer/LICENSE.txt +24 -0
  37. synthesizer/__init__.py +1 -0
  38. synthesizer/audio.py +206 -0
  39. synthesizer/hparams.py +92 -0
  40. synthesizer/inference.py +171 -0
  41. synthesizer/models/tacotron.py +519 -0
  42. synthesizer/preprocess.py +259 -0
  43. synthesizer/saved_models/pretrained/text.txt +0 -0
  44. synthesizer/synthesize.py +97 -0
  45. synthesizer/synthesizer_dataset.py +92 -0
  46. synthesizer/train.py +269 -0
  47. synthesizer/utils/__init__.py +45 -0
  48. synthesizer/utils/_cmudict.py +62 -0
  49. synthesizer/utils/cleaners.py +88 -0
  50. synthesizer/utils/numbers.py +68 -0
.gitattributes ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
8230_00000.mp3 ADDED
Binary file (16.1 kB). View file
 
README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Real Time Voice Cloning
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ app_file: app.py
8
+ sdk_version: 3.17.1
9
+ pinned: false
10
+ duplicated_from: akhaliq/Real-Time-Voice-Cloning
11
+ ---
12
+
13
+ # Configuration
14
+
15
+ `title`: _string_
16
+ Display title for the Space
17
+
18
+ `emoji`: _string_
19
+ Space emoji (emoji-only character allowed)
20
+
21
+ `colorFrom`: _string_
22
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
23
+
24
+ `colorTo`: _string_
25
+ Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
26
+
27
+ `sdk`: _string_
28
+ Can be either `gradio` or `streamlit`
29
+
30
+ `sdk_version` : _string_
31
+ Only applicable for `streamlit` SDK.
32
+ See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
33
+
34
+ `app_file`: _string_
35
+ Path to your main application file (which contains either `gradio` or `streamlit` Python code).
36
+ Path is relative to the root of the repository.
37
+
38
+ `pinned`: _boolean_
39
+ Whether the Space stays on top of your list.
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import shlex
4
+ import gdown
5
+ import uuid
6
+ import torch
7
+
8
+ cpu_param = "--cpu" if not torch.cuda.is_available() else ""
9
+
10
+ if (not os.path.exists("synpretrained.pt")):
11
+ gdown.download("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t",
12
+ "synpretrained.pt", quiet=False)
13
+ gdown.download("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1",
14
+ "encpretrained.pt", quiet=False)
15
+ gdown.download("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu",
16
+ "vocpretrained.pt", quiet=False)
17
+
18
+
19
+ def inference(audio_path, text, mic_path=None):
20
+ if mic_path:
21
+ audio_path = mic_path
22
+ output_path = f"/tmp/output_{uuid.uuid4()}.wav"
23
+ os.system(
24
+ f"python demo_cli.py --no_sound {cpu_param} --audio_path {audio_path} --text {shlex.quote(text.strip())} --output_path {output_path}")
25
+ return output_path
26
+
27
+
28
+ title = "Real-Time-Voice-Cloning"
29
+ description = "Gradio demo for Real-Time-Voice-Cloning: Clone a voice in 5 seconds to generate arbitrary speech in real-time. To use it, simply upload your audio, or click one of the examples to load them. Read more at the links below."
30
+ article = "<p style='text-align: center'><a href='https://matheo.uliege.be/handle/2268.2/6801' target='_blank'>Real-Time Voice Cloning</a> | <a href='https://github.com/CorentinJ/Real-Time-Voice-Cloning' target='_blank'>Github Repo</a></p>"
31
+
32
+ examples = [['test.wav', "This is real time voice cloning on huggingface spaces"]]
33
+
34
+
35
+ def toggle(choice):
36
+ if choice == "mic":
37
+ return gr.update(visible=True, value=None), gr.update(visible=False, value=None)
38
+ else:
39
+ return gr.update(visible=False, value=None), gr.update(visible=True, value=None)
40
+
41
+
42
+ with gr.Blocks() as demo:
43
+ with gr.Row():
44
+ with gr.Column():
45
+ radio = gr.Radio(["mic", "file"], value="mic",
46
+ label="How would you like to upload your audio?")
47
+ mic_input = gr.Mic(label="Input", type="filepath", visible=False)
48
+ audio_file = gr.Audio(
49
+ type="filepath", label="Input", visible=True)
50
+ text_input = gr.Textbox(label="Text")
51
+ with gr.Column():
52
+ audio_output = gr.Audio(label="Output")
53
+
54
+ gr.Examples(examples, fn=inference, inputs=[audio_file, text_input],
55
+ outputs=audio_output, cache_examples=True)
56
+ btn = gr.Button("Generate")
57
+ btn.click(inference, inputs=[audio_file,
58
+ text_input, mic_input], outputs=audio_output)
59
+ radio.change(toggle, radio, [mic_input, audio_file])
60
+
61
+ demo.launch(enable_queue=True)
demo_cli.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import model_embedding_size as speaker_embedding_size
2
+ from utils.argutils import print_args
3
+ from utils.modelutils import check_model_paths
4
+ from synthesizer.inference import Synthesizer
5
+ from encoder import inference as encoder
6
+ from vocoder import inference as vocoder
7
+ from pathlib import Path
8
+ import numpy as np
9
+ import soundfile as sf
10
+ import librosa
11
+ import argparse
12
+ import torch
13
+ import sys
14
+ import os
15
+ from audioread.exceptions import NoBackendError
16
+
17
+
18
+ if __name__ == '__main__':
19
+ ## Info & args
20
+ parser = argparse.ArgumentParser(
21
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
22
+ )
23
+ parser.add_argument("-e", "--enc_model_fpath", type=Path,
24
+ default="encpretrained.pt",
25
+ help="Path to a saved encoder")
26
+ parser.add_argument("-s", "--syn_model_fpath", type=Path,
27
+ default="synpretrained.pt",
28
+ help="Path to a saved synthesizer")
29
+ parser.add_argument("-v", "--voc_model_fpath", type=Path,
30
+ default="vocpretrained.pt",
31
+ help="Path to a saved vocoder")
32
+ parser.add_argument("--cpu", action="store_true", help="If True, processing is done on CPU, even when a GPU is available.")
33
+ parser.add_argument("--no_sound", action="store_true", help="If True, audio won't be played.")
34
+ parser.add_argument("--seed", type=int, default=None, help="Optional random number seed value to make toolbox deterministic.")
35
+ parser.add_argument("--no_mp3_support", action="store_true", help="If True, disallows loading mp3 files to prevent audioread errors when ffmpeg is not installed.")
36
+ parser.add_argument("-audio", "--audio_path", type=Path, required = True,
37
+ help="Path to a audio file")
38
+ parser.add_argument("--text", type=str, required = True, help="Text Input")
39
+ parser.add_argument("--output_path", type=str, required = True, help="output file path")
40
+
41
+ args = parser.parse_args()
42
+ print_args(args, parser)
43
+ if not args.no_sound:
44
+ import sounddevice as sd
45
+
46
+ if args.cpu:
47
+ # Hide GPUs from Pytorch to force CPU processing
48
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
49
+
50
+ if not args.no_mp3_support:
51
+ try:
52
+ librosa.load("samples/1320_00000.mp3")
53
+ except NoBackendError:
54
+ print("Librosa will be unable to open mp3 files if additional software is not installed.\n"
55
+ "Please install ffmpeg or add the '--no_mp3_support' option to proceed without support for mp3 files.")
56
+ exit(-1)
57
+
58
+ print("Running a test of your configuration...\n")
59
+
60
+ if torch.cuda.is_available():
61
+ device_id = torch.cuda.current_device()
62
+ gpu_properties = torch.cuda.get_device_properties(device_id)
63
+ ## Print some environment information (for debugging purposes)
64
+ print("Found %d GPUs available. Using GPU %d (%s) of compute capability %d.%d with "
65
+ "%.1fGb total memory.\n" %
66
+ (torch.cuda.device_count(),
67
+ device_id,
68
+ gpu_properties.name,
69
+ gpu_properties.major,
70
+ gpu_properties.minor,
71
+ gpu_properties.total_memory / 1e9))
72
+ else:
73
+ print("Using CPU for inference.\n")
74
+
75
+ ## Remind the user to download pretrained models if needed
76
+ check_model_paths(encoder_path=args.enc_model_fpath,
77
+ synthesizer_path=args.syn_model_fpath,
78
+ vocoder_path=args.voc_model_fpath)
79
+
80
+ ## Load the models one by one.
81
+ print("Preparing the encoder, the synthesizer and the vocoder...")
82
+ encoder.load_model(args.enc_model_fpath)
83
+ synthesizer = Synthesizer(args.syn_model_fpath)
84
+ vocoder.load_model(args.voc_model_fpath)
85
+
86
+
87
+ ## Run a test
88
+ # print("Testing your configuration with small inputs.")
89
+ # # Forward an audio waveform of zeroes that lasts 1 second. Notice how we can get the encoder's
90
+ # # sampling rate, which may differ.
91
+ # # If you're unfamiliar with digital audio, know that it is encoded as an array of floats
92
+ # # (or sometimes integers, but mostly floats in this projects) ranging from -1 to 1.
93
+ # # The sampling rate is the number of values (samples) recorded per second, it is set to
94
+ # # 16000 for the encoder. Creating an array of length <sampling_rate> will always correspond
95
+ # # to an audio of 1 second.
96
+ # print(" Testing the encoder...")
97
+ # encoder.embed_utterance(np.zeros(encoder.sampling_rate))
98
+
99
+ # # Create a dummy embedding. You would normally use the embedding that encoder.embed_utterance
100
+ # # returns, but here we're going to make one ourselves just for the sake of showing that it's
101
+ # # possible.
102
+ # embed = np.random.rand(speaker_embedding_size)
103
+ # # Embeddings are L2-normalized (this isn't important here, but if you want to make your own
104
+ # # embeddings it will be).
105
+ # embed /= np.linalg.norm(embed)
106
+ # # The synthesizer can handle multiple inputs with batching. Let's create another embedding to
107
+ # # illustrate that
108
+ # embeds = [embed, np.zeros(speaker_embedding_size)]
109
+ # texts = ["test 1", "test 2"]
110
+ # print(" Testing the synthesizer... (loading the model will output a lot of text)")
111
+ # mels = synthesizer.synthesize_spectrograms(texts, embeds)
112
+
113
+ # # The vocoder synthesizes one waveform at a time, but it's more efficient for long ones. We
114
+ # # can concatenate the mel spectrograms to a single one.
115
+ # mel = np.concatenate(mels, axis=1)
116
+ # # The vocoder can take a callback function to display the generation. More on that later. For
117
+ # # now we'll simply hide it like this:
118
+ # no_action = lambda *args: None
119
+ # print(" Testing the vocoder...")
120
+ # # For the sake of making this test short, we'll pass a short target length. The target length
121
+ # # is the length of the wav segments that are processed in parallel. E.g. for audio sampled
122
+ # # at 16000 Hertz, a target length of 8000 means that the target audio will be cut in chunks of
123
+ # # 0.5 seconds which will all be generated together. The parameters here are absurdly short, and
124
+ # # that has a detrimental effect on the quality of the audio. The default parameters are
125
+ # # recommended in general.
126
+ # vocoder.infer_waveform(mel, target=200, overlap=50, progress_callback=no_action)
127
+
128
+ print("All test passed! You can now synthesize speech.\n\n")
129
+
130
+
131
+ ## Interactive speech generation
132
+ print("This is a GUI-less example of interface to SV2TTS. The purpose of this script is to "
133
+ "show how you can interface this project easily with your own. See the source code for "
134
+ "an explanation of what is happening.\n")
135
+
136
+ print("Interactive generation loop")
137
+ # while True:
138
+ # Get the reference audio filepath
139
+ message = "Reference voice: enter an audio filepath of a voice to be cloned (mp3, " "wav, m4a, flac, ...):\n"
140
+ in_fpath = args.audio_path
141
+
142
+ if in_fpath.suffix.lower() == ".mp3" and args.no_mp3_support:
143
+ print("Can't Use mp3 files please try again:")
144
+ ## Computing the embedding
145
+ # First, we load the wav using the function that the speaker encoder provides. This is
146
+ # important: there is preprocessing that must be applied.
147
+
148
+ # The following two methods are equivalent:
149
+ # - Directly load from the filepath:
150
+ preprocessed_wav = encoder.preprocess_wav(in_fpath)
151
+ # - If the wav is already loaded:
152
+ original_wav, sampling_rate = librosa.load(str(in_fpath))
153
+ preprocessed_wav = encoder.preprocess_wav(original_wav, sampling_rate)
154
+ print("Loaded file succesfully")
155
+
156
+ # Then we derive the embedding. There are many functions and parameters that the
157
+ # speaker encoder interfaces. These are mostly for in-depth research. You will typically
158
+ # only use this function (with its default parameters):
159
+ embed = encoder.embed_utterance(preprocessed_wav)
160
+ print("Created the embedding")
161
+
162
+
163
+ ## Generating the spectrogram
164
+ text = args.text
165
+
166
+ # If seed is specified, reset torch seed and force synthesizer reload
167
+ if args.seed is not None:
168
+ torch.manual_seed(args.seed)
169
+ synthesizer = Synthesizer(args.syn_model_fpath)
170
+
171
+ # The synthesizer works in batch, so you need to put your data in a list or numpy array
172
+ texts = [text]
173
+ embeds = [embed]
174
+ # If you know what the attention layer alignments are, you can retrieve them here by
175
+ # passing return_alignments=True
176
+ specs = synthesizer.synthesize_spectrograms(texts, embeds)
177
+ spec = specs[0]
178
+ print("Created the mel spectrogram")
179
+
180
+
181
+ ## Generating the waveform
182
+ print("Synthesizing the waveform:")
183
+
184
+ # If seed is specified, reset torch seed and reload vocoder
185
+ if args.seed is not None:
186
+ torch.manual_seed(args.seed)
187
+ vocoder.load_model(args.voc_model_fpath)
188
+
189
+ # Synthesizing the waveform is fairly straightforward. Remember that the longer the
190
+ # spectrogram, the more time-efficient the vocoder.
191
+ generated_wav = vocoder.infer_waveform(spec)
192
+
193
+
194
+ ## Post-generation
195
+ # There's a bug with sounddevice that makes the audio cut one second earlier, so we
196
+ # pad it.
197
+ generated_wav = np.pad(generated_wav, (0, synthesizer.sample_rate), mode="constant")
198
+
199
+ # Trim excess silences to compensate for gaps in spectrograms (issue #53)
200
+ generated_wav = encoder.preprocess_wav(generated_wav)
201
+
202
+ # Save it on the disk
203
+ filename = args.output_path
204
+ print(generated_wav.dtype)
205
+ sf.write(filename, generated_wav.astype(np.float32), synthesizer.sample_rate)
206
+ print("\nSaved output as %s\n\n" % filename)
demo_toolbox.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from toolbox import Toolbox
3
+ from utils.argutils import print_args
4
+ from utils.modelutils import check_model_paths
5
+ import argparse
6
+ import os
7
+
8
+
9
+ if __name__ == '__main__':
10
+ parser = argparse.ArgumentParser(
11
+ description="Runs the toolbox",
12
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
13
+ )
14
+
15
+ parser.add_argument("-d", "--datasets_root", type=Path, help= \
16
+ "Path to the directory containing your datasets. See toolbox/__init__.py for a list of "
17
+ "supported datasets.", default=None)
18
+ parser.add_argument("-e", "--enc_models_dir", type=Path, default="encoder/saved_models",
19
+ help="Directory containing saved encoder models")
20
+ parser.add_argument("-s", "--syn_models_dir", type=Path, default="synthesizer/saved_models",
21
+ help="Directory containing saved synthesizer models")
22
+ parser.add_argument("-v", "--voc_models_dir", type=Path, default="vocoder/saved_models",
23
+ help="Directory containing saved vocoder models")
24
+ parser.add_argument("--cpu", action="store_true", help=\
25
+ "If True, processing is done on CPU, even when a GPU is available.")
26
+ parser.add_argument("--seed", type=int, default=None, help=\
27
+ "Optional random number seed value to make toolbox deterministic.")
28
+ parser.add_argument("--no_mp3_support", action="store_true", help=\
29
+ "If True, no mp3 files are allowed.")
30
+ args = parser.parse_args()
31
+ print_args(args, parser)
32
+
33
+ if args.cpu:
34
+ # Hide GPUs from Pytorch to force CPU processing
35
+ os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
36
+ del args.cpu
37
+
38
+ ## Remind the user to download pretrained models if needed
39
+ check_model_paths(encoder_path=args.enc_models_dir, synthesizer_path=args.syn_models_dir,
40
+ vocoder_path=args.voc_models_dir)
41
+
42
+ # Launch the toolbox
43
+ Toolbox(**vars(args))
encoder/__init__.py ADDED
File without changes
encoder/audio.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.ndimage.morphology import binary_dilation
2
+ from encoder.params_data import *
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+ from warnings import warn
6
+ import numpy as np
7
+ import librosa
8
+ import struct
9
+
10
+ try:
11
+ import webrtcvad
12
+ except:
13
+ warn("Unable to import 'webrtcvad'. This package enables noise removal and is recommended.")
14
+ webrtcvad=None
15
+
16
+ int16_max = (2 ** 15) - 1
17
+
18
+
19
+ def preprocess_wav(fpath_or_wav: Union[str, Path, np.ndarray],
20
+ source_sr: Optional[int] = None,
21
+ normalize: Optional[bool] = True,
22
+ trim_silence: Optional[bool] = True):
23
+ """
24
+ Applies the preprocessing operations used in training the Speaker Encoder to a waveform
25
+ either on disk or in memory. The waveform will be resampled to match the data hyperparameters.
26
+
27
+ :param fpath_or_wav: either a filepath to an audio file (many extensions are supported, not
28
+ just .wav), either the waveform as a numpy array of floats.
29
+ :param source_sr: if passing an audio waveform, the sampling rate of the waveform before
30
+ preprocessing. After preprocessing, the waveform's sampling rate will match the data
31
+ hyperparameters. If passing a filepath, the sampling rate will be automatically detected and
32
+ this argument will be ignored.
33
+ """
34
+ # Load the wav from disk if needed
35
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
36
+ wav, source_sr = librosa.load(str(fpath_or_wav), sr=None)
37
+ else:
38
+ wav = fpath_or_wav
39
+
40
+ # Resample the wav if needed
41
+ if source_sr is not None and source_sr != sampling_rate:
42
+ wav = librosa.resample(wav, source_sr, sampling_rate)
43
+
44
+ # Apply the preprocessing: normalize volume and shorten long silences
45
+ if normalize:
46
+ wav = normalize_volume(wav, audio_norm_target_dBFS, increase_only=True)
47
+ if webrtcvad and trim_silence:
48
+ wav = trim_long_silences(wav)
49
+
50
+ return wav
51
+
52
+
53
+ def wav_to_mel_spectrogram(wav):
54
+ """
55
+ Derives a mel spectrogram ready to be used by the encoder from a preprocessed audio waveform.
56
+ Note: this not a log-mel spectrogram.
57
+ """
58
+ frames = librosa.feature.melspectrogram(
59
+ wav,
60
+ sampling_rate,
61
+ n_fft=int(sampling_rate * mel_window_length / 1000),
62
+ hop_length=int(sampling_rate * mel_window_step / 1000),
63
+ n_mels=mel_n_channels
64
+ )
65
+ return frames.astype(np.float32).T
66
+
67
+
68
+ def trim_long_silences(wav):
69
+ """
70
+ Ensures that segments without voice in the waveform remain no longer than a
71
+ threshold determined by the VAD parameters in params.py.
72
+
73
+ :param wav: the raw waveform as a numpy array of floats
74
+ :return: the same waveform with silences trimmed away (length <= original wav length)
75
+ """
76
+ # Compute the voice detection window size
77
+ samples_per_window = (vad_window_length * sampling_rate) // 1000
78
+
79
+ # Trim the end of the audio to have a multiple of the window size
80
+ wav = wav[:len(wav) - (len(wav) % samples_per_window)]
81
+
82
+ # Convert the float waveform to 16-bit mono PCM
83
+ pcm_wave = struct.pack("%dh" % len(wav), *(np.round(wav * int16_max)).astype(np.int16))
84
+
85
+ # Perform voice activation detection
86
+ voice_flags = []
87
+ vad = webrtcvad.Vad(mode=3)
88
+ for window_start in range(0, len(wav), samples_per_window):
89
+ window_end = window_start + samples_per_window
90
+ voice_flags.append(vad.is_speech(pcm_wave[window_start * 2:window_end * 2],
91
+ sample_rate=sampling_rate))
92
+ voice_flags = np.array(voice_flags)
93
+
94
+ # Smooth the voice detection with a moving average
95
+ def moving_average(array, width):
96
+ array_padded = np.concatenate((np.zeros((width - 1) // 2), array, np.zeros(width // 2)))
97
+ ret = np.cumsum(array_padded, dtype=float)
98
+ ret[width:] = ret[width:] - ret[:-width]
99
+ return ret[width - 1:] / width
100
+
101
+ audio_mask = moving_average(voice_flags, vad_moving_average_width)
102
+ audio_mask = np.round(audio_mask).astype(np.bool)
103
+
104
+ # Dilate the voiced regions
105
+ audio_mask = binary_dilation(audio_mask, np.ones(vad_max_silence_length + 1))
106
+ audio_mask = np.repeat(audio_mask, samples_per_window)
107
+
108
+ return wav[audio_mask == True]
109
+
110
+
111
+ def normalize_volume(wav, target_dBFS, increase_only=False, decrease_only=False):
112
+ if increase_only and decrease_only:
113
+ raise ValueError("Both increase only and decrease only are set")
114
+ dBFS_change = target_dBFS - 10 * np.log10(np.mean(wav ** 2))
115
+ if (dBFS_change < 0 and increase_only) or (dBFS_change > 0 and decrease_only):
116
+ return wav
117
+ return wav * (10 ** (dBFS_change / 20))
encoder/config.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ librispeech_datasets = {
2
+ "train": {
3
+ "clean": ["LibriSpeech/train-clean-100", "LibriSpeech/train-clean-360"],
4
+ "other": ["LibriSpeech/train-other-500"]
5
+ },
6
+ "test": {
7
+ "clean": ["LibriSpeech/test-clean"],
8
+ "other": ["LibriSpeech/test-other"]
9
+ },
10
+ "dev": {
11
+ "clean": ["LibriSpeech/dev-clean"],
12
+ "other": ["LibriSpeech/dev-other"]
13
+ },
14
+ }
15
+ libritts_datasets = {
16
+ "train": {
17
+ "clean": ["LibriTTS/train-clean-100", "LibriTTS/train-clean-360"],
18
+ "other": ["LibriTTS/train-other-500"]
19
+ },
20
+ "test": {
21
+ "clean": ["LibriTTS/test-clean"],
22
+ "other": ["LibriTTS/test-other"]
23
+ },
24
+ "dev": {
25
+ "clean": ["LibriTTS/dev-clean"],
26
+ "other": ["LibriTTS/dev-other"]
27
+ },
28
+ }
29
+ voxceleb_datasets = {
30
+ "voxceleb1" : {
31
+ "train": ["VoxCeleb1/wav"],
32
+ "test": ["VoxCeleb1/test_wav"]
33
+ },
34
+ "voxceleb2" : {
35
+ "train": ["VoxCeleb2/dev/aac"],
36
+ "test": ["VoxCeleb2/test_wav"]
37
+ }
38
+ }
39
+
40
+ other_datasets = [
41
+ "LJSpeech-1.1",
42
+ "VCTK-Corpus/wav48",
43
+ ]
44
+
45
+ anglophone_nationalites = ["australia", "canada", "ireland", "uk", "usa"]
encoder/data_objects/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataLoader
encoder/data_objects/random_cycler.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ class RandomCycler:
4
+ """
5
+ Creates an internal copy of a sequence and allows access to its items in a constrained random
6
+ order. For a source sequence of n items and one or several consecutive queries of a total
7
+ of m items, the following guarantees hold (one implies the other):
8
+ - Each item will be returned between m // n and ((m - 1) // n) + 1 times.
9
+ - Between two appearances of the same item, there may be at most 2 * (n - 1) other items.
10
+ """
11
+
12
+ def __init__(self, source):
13
+ if len(source) == 0:
14
+ raise Exception("Can't create RandomCycler from an empty collection")
15
+ self.all_items = list(source)
16
+ self.next_items = []
17
+
18
+ def sample(self, count: int):
19
+ shuffle = lambda l: random.sample(l, len(l))
20
+
21
+ out = []
22
+ while count > 0:
23
+ if count >= len(self.all_items):
24
+ out.extend(shuffle(list(self.all_items)))
25
+ count -= len(self.all_items)
26
+ continue
27
+ n = min(count, len(self.next_items))
28
+ out.extend(self.next_items[:n])
29
+ count -= n
30
+ self.next_items = self.next_items[n:]
31
+ if len(self.next_items) == 0:
32
+ self.next_items = shuffle(list(self.all_items))
33
+ return out
34
+
35
+ def __next__(self):
36
+ return self.sample(1)[0]
37
+
encoder/data_objects/speaker.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.utterance import Utterance
3
+ from pathlib import Path
4
+
5
+ # Contains the set of utterances of a single speaker
6
+ class Speaker:
7
+ def __init__(self, root: Path):
8
+ self.root = root
9
+ self.name = root.name
10
+ self.utterances = None
11
+ self.utterance_cycler = None
12
+
13
+ def _load_utterances(self):
14
+ with self.root.joinpath("_sources.txt").open("r") as sources_file:
15
+ sources = [l.split(",") for l in sources_file]
16
+ sources = {frames_fname: wave_fpath for frames_fname, wave_fpath in sources}
17
+ self.utterances = [Utterance(self.root.joinpath(f), w) for f, w in sources.items()]
18
+ self.utterance_cycler = RandomCycler(self.utterances)
19
+
20
+ def random_partial(self, count, n_frames):
21
+ """
22
+ Samples a batch of <count> unique partial utterances from the disk in a way that all
23
+ utterances come up at least once every two cycles and in a random order every time.
24
+
25
+ :param count: The number of partial utterances to sample from the set of utterances from
26
+ that speaker. Utterances are guaranteed not to be repeated if <count> is not larger than
27
+ the number of utterances available.
28
+ :param n_frames: The number of frames in the partial utterance.
29
+ :return: A list of tuples (utterance, frames, range) where utterance is an Utterance,
30
+ frames are the frames of the partial utterances and range is the range of the partial
31
+ utterance with regard to the complete utterance.
32
+ """
33
+ if self.utterances is None:
34
+ self._load_utterances()
35
+
36
+ utterances = self.utterance_cycler.sample(count)
37
+
38
+ a = [(u,) + u.random_partial(n_frames) for u in utterances]
39
+
40
+ return a
encoder/data_objects/speaker_batch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List
3
+ from encoder.data_objects.speaker import Speaker
4
+
5
+ class SpeakerBatch:
6
+ def __init__(self, speakers: List[Speaker], utterances_per_speaker: int, n_frames: int):
7
+ self.speakers = speakers
8
+ self.partials = {s: s.random_partial(utterances_per_speaker, n_frames) for s in speakers}
9
+
10
+ # Array of shape (n_speakers * n_utterances, n_frames, mel_n), e.g. for 3 speakers with
11
+ # 4 utterances each of 160 frames of 40 mel coefficients: (12, 160, 40)
12
+ self.data = np.array([frames for s in speakers for _, frames, _ in self.partials[s]])
encoder/data_objects/speaker_verification_dataset.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.random_cycler import RandomCycler
2
+ from encoder.data_objects.speaker_batch import SpeakerBatch
3
+ from encoder.data_objects.speaker import Speaker
4
+ from encoder.params_data import partials_n_frames
5
+ from torch.utils.data import Dataset, DataLoader
6
+ from pathlib import Path
7
+
8
+ # TODO: improve with a pool of speakers for data efficiency
9
+
10
+ class SpeakerVerificationDataset(Dataset):
11
+ def __init__(self, datasets_root: Path):
12
+ self.root = datasets_root
13
+ speaker_dirs = [f for f in self.root.glob("*") if f.is_dir()]
14
+ if len(speaker_dirs) == 0:
15
+ raise Exception("No speakers found. Make sure you are pointing to the directory "
16
+ "containing all preprocessed speaker directories.")
17
+ self.speakers = [Speaker(speaker_dir) for speaker_dir in speaker_dirs]
18
+ self.speaker_cycler = RandomCycler(self.speakers)
19
+
20
+ def __len__(self):
21
+ return int(1e10)
22
+
23
+ def __getitem__(self, index):
24
+ return next(self.speaker_cycler)
25
+
26
+ def get_logs(self):
27
+ log_string = ""
28
+ for log_fpath in self.root.glob("*.txt"):
29
+ with log_fpath.open("r") as log_file:
30
+ log_string += "".join(log_file.readlines())
31
+ return log_string
32
+
33
+
34
+ class SpeakerVerificationDataLoader(DataLoader):
35
+ def __init__(self, dataset, speakers_per_batch, utterances_per_speaker, sampler=None,
36
+ batch_sampler=None, num_workers=0, pin_memory=False, timeout=0,
37
+ worker_init_fn=None):
38
+ self.utterances_per_speaker = utterances_per_speaker
39
+
40
+ super().__init__(
41
+ dataset=dataset,
42
+ batch_size=speakers_per_batch,
43
+ shuffle=False,
44
+ sampler=sampler,
45
+ batch_sampler=batch_sampler,
46
+ num_workers=num_workers,
47
+ collate_fn=self.collate,
48
+ pin_memory=pin_memory,
49
+ drop_last=False,
50
+ timeout=timeout,
51
+ worker_init_fn=worker_init_fn
52
+ )
53
+
54
+ def collate(self, speakers):
55
+ return SpeakerBatch(speakers, self.utterances_per_speaker, partials_n_frames)
56
+
encoder/data_objects/utterance.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class Utterance:
5
+ def __init__(self, frames_fpath, wave_fpath):
6
+ self.frames_fpath = frames_fpath
7
+ self.wave_fpath = wave_fpath
8
+
9
+ def get_frames(self):
10
+ return np.load(self.frames_fpath)
11
+
12
+ def random_partial(self, n_frames):
13
+ """
14
+ Crops the frames into a partial utterance of n_frames
15
+
16
+ :param n_frames: The number of frames of the partial utterance
17
+ :return: the partial utterance frames and a tuple indicating the start and end of the
18
+ partial utterance in the complete utterance.
19
+ """
20
+ frames = self.get_frames()
21
+ if frames.shape[0] == n_frames:
22
+ start = 0
23
+ else:
24
+ start = np.random.randint(0, frames.shape[0] - n_frames)
25
+ end = start + n_frames
26
+ return frames[start:end], (start, end)
encoder/inference.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_data import *
2
+ from encoder.model import SpeakerEncoder
3
+ from encoder.audio import preprocess_wav # We want to expose this function from here
4
+ from matplotlib import cm
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+
11
+ _model = None # type: SpeakerEncoder
12
+ _device = None # type: torch.device
13
+
14
+
15
+ def load_model(weights_fpath: Path, device=None):
16
+ """
17
+ Loads the model in memory. If this function is not explicitely called, it will be run on the
18
+ first call to embed_frames() with the default weights file.
19
+
20
+ :param weights_fpath: the path to saved model weights.
21
+ :param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda"). The
22
+ model will be loaded and will run on this device. Outputs will however always be on the cpu.
23
+ If None, will default to your GPU if it"s available, otherwise your CPU.
24
+ """
25
+ # TODO: I think the slow loading of the encoder might have something to do with the device it
26
+ # was saved on. Worth investigating.
27
+ global _model, _device
28
+ if device is None:
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ elif isinstance(device, str):
31
+ _device = torch.device(device)
32
+ _model = SpeakerEncoder(_device, torch.device("cpu"))
33
+ checkpoint = torch.load(weights_fpath, _device)
34
+ _model.load_state_dict(checkpoint["model_state"])
35
+ _model.eval()
36
+ print("Loaded encoder \"%s\" trained to step %d" % (weights_fpath.name, checkpoint["step"]))
37
+
38
+
39
+ def is_loaded():
40
+ return _model is not None
41
+
42
+
43
+ def embed_frames_batch(frames_batch):
44
+ """
45
+ Computes embeddings for a batch of mel spectrogram.
46
+
47
+ :param frames_batch: a batch mel of spectrogram as a numpy array of float32 of shape
48
+ (batch_size, n_frames, n_channels)
49
+ :return: the embeddings as a numpy array of float32 of shape (batch_size, model_embedding_size)
50
+ """
51
+ if _model is None:
52
+ raise Exception("Model was not loaded. Call load_model() before inference.")
53
+
54
+ frames = torch.from_numpy(frames_batch).to(_device)
55
+ embed = _model.forward(frames).detach().cpu().numpy()
56
+ return embed
57
+
58
+
59
+ def compute_partial_slices(n_samples, partial_utterance_n_frames=partials_n_frames,
60
+ min_pad_coverage=0.75, overlap=0.5):
61
+ """
62
+ Computes where to split an utterance waveform and its corresponding mel spectrogram to obtain
63
+ partial utterances of <partial_utterance_n_frames> each. Both the waveform and the mel
64
+ spectrogram slices are returned, so as to make each partial utterance waveform correspond to
65
+ its spectrogram. This function assumes that the mel spectrogram parameters used are those
66
+ defined in params_data.py.
67
+
68
+ The returned ranges may be indexing further than the length of the waveform. It is
69
+ recommended that you pad the waveform with zeros up to wave_slices[-1].stop.
70
+
71
+ :param n_samples: the number of samples in the waveform
72
+ :param partial_utterance_n_frames: the number of mel spectrogram frames in each partial
73
+ utterance
74
+ :param min_pad_coverage: when reaching the last partial utterance, it may or may not have
75
+ enough frames. If at least <min_pad_coverage> of <partial_utterance_n_frames> are present,
76
+ then the last partial utterance will be considered, as if we padded the audio. Otherwise,
77
+ it will be discarded, as if we trimmed the audio. If there aren't enough frames for 1 partial
78
+ utterance, this parameter is ignored so that the function always returns at least 1 slice.
79
+ :param overlap: by how much the partial utterance should overlap. If set to 0, the partial
80
+ utterances are entirely disjoint.
81
+ :return: the waveform slices and mel spectrogram slices as lists of array slices. Index
82
+ respectively the waveform and the mel spectrogram with these slices to obtain the partial
83
+ utterances.
84
+ """
85
+ assert 0 <= overlap < 1
86
+ assert 0 < min_pad_coverage <= 1
87
+
88
+ samples_per_frame = int((sampling_rate * mel_window_step / 1000))
89
+ n_frames = int(np.ceil((n_samples + 1) / samples_per_frame))
90
+ frame_step = max(int(np.round(partial_utterance_n_frames * (1 - overlap))), 1)
91
+
92
+ # Compute the slices
93
+ wav_slices, mel_slices = [], []
94
+ steps = max(1, n_frames - partial_utterance_n_frames + frame_step + 1)
95
+ for i in range(0, steps, frame_step):
96
+ mel_range = np.array([i, i + partial_utterance_n_frames])
97
+ wav_range = mel_range * samples_per_frame
98
+ mel_slices.append(slice(*mel_range))
99
+ wav_slices.append(slice(*wav_range))
100
+
101
+ # Evaluate whether extra padding is warranted or not
102
+ last_wav_range = wav_slices[-1]
103
+ coverage = (n_samples - last_wav_range.start) / (last_wav_range.stop - last_wav_range.start)
104
+ if coverage < min_pad_coverage and len(mel_slices) > 1:
105
+ mel_slices = mel_slices[:-1]
106
+ wav_slices = wav_slices[:-1]
107
+
108
+ return wav_slices, mel_slices
109
+
110
+
111
+ def embed_utterance(wav, using_partials=True, return_partials=False, **kwargs):
112
+ """
113
+ Computes an embedding for a single utterance.
114
+
115
+ # TODO: handle multiple wavs to benefit from batching on GPU
116
+ :param wav: a preprocessed (see audio.py) utterance waveform as a numpy array of float32
117
+ :param using_partials: if True, then the utterance is split in partial utterances of
118
+ <partial_utterance_n_frames> frames and the utterance embedding is computed from their
119
+ normalized average. If False, the utterance is instead computed from feeding the entire
120
+ spectogram to the network.
121
+ :param return_partials: if True, the partial embeddings will also be returned along with the
122
+ wav slices that correspond to the partial embeddings.
123
+ :param kwargs: additional arguments to compute_partial_splits()
124
+ :return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
125
+ <return_partials> is True, the partial utterances as a numpy array of float32 of shape
126
+ (n_partials, model_embedding_size) and the wav partials as a list of slices will also be
127
+ returned. If <using_partials> is simultaneously set to False, both these values will be None
128
+ instead.
129
+ """
130
+ # Process the entire utterance if not using partials
131
+ if not using_partials:
132
+ frames = audio.wav_to_mel_spectrogram(wav)
133
+ embed = embed_frames_batch(frames[None, ...])[0]
134
+ if return_partials:
135
+ return embed, None, None
136
+ return embed
137
+
138
+ # Compute where to split the utterance into partials and pad if necessary
139
+ wave_slices, mel_slices = compute_partial_slices(len(wav), **kwargs)
140
+ max_wave_length = wave_slices[-1].stop
141
+ if max_wave_length >= len(wav):
142
+ wav = np.pad(wav, (0, max_wave_length - len(wav)), "constant")
143
+
144
+ # Split the utterance into partials
145
+ frames = audio.wav_to_mel_spectrogram(wav)
146
+ frames_batch = np.array([frames[s] for s in mel_slices])
147
+ partial_embeds = embed_frames_batch(frames_batch)
148
+
149
+ # Compute the utterance embedding from the partial embeddings
150
+ raw_embed = np.mean(partial_embeds, axis=0)
151
+ embed = raw_embed / np.linalg.norm(raw_embed, 2)
152
+
153
+ if return_partials:
154
+ return embed, partial_embeds, wave_slices
155
+ return embed
156
+
157
+
158
+ def embed_speaker(wavs, **kwargs):
159
+ raise NotImplemented()
160
+
161
+
162
+ def plot_embedding_as_heatmap(embed, ax=None, title="", shape=None, color_range=(0, 0.30)):
163
+ if ax is None:
164
+ ax = plt.gca()
165
+
166
+ if shape is None:
167
+ height = int(np.sqrt(len(embed)))
168
+ shape = (height, -1)
169
+ embed = embed.reshape(shape)
170
+
171
+ cmap = cm.get_cmap()
172
+ mappable = ax.imshow(embed, cmap=cmap)
173
+ cbar = plt.colorbar(mappable, ax=ax, fraction=0.046, pad=0.04)
174
+ sm = cm.ScalarMappable(cmap=cmap)
175
+ sm.set_clim(*color_range)
176
+
177
+ ax.set_xticks([]), ax.set_yticks([])
178
+ ax.set_title(title)
encoder/model.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.params_model import *
2
+ from encoder.params_data import *
3
+ from scipy.interpolate import interp1d
4
+ from sklearn.metrics import roc_curve
5
+ from torch.nn.utils import clip_grad_norm_
6
+ from scipy.optimize import brentq
7
+ from torch import nn
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class SpeakerEncoder(nn.Module):
13
+ def __init__(self, device, loss_device):
14
+ super().__init__()
15
+ self.loss_device = loss_device
16
+
17
+ # Network defition
18
+ self.lstm = nn.LSTM(input_size=mel_n_channels,
19
+ hidden_size=model_hidden_size,
20
+ num_layers=model_num_layers,
21
+ batch_first=True).to(device)
22
+ self.linear = nn.Linear(in_features=model_hidden_size,
23
+ out_features=model_embedding_size).to(device)
24
+ self.relu = torch.nn.ReLU().to(device)
25
+
26
+ # Cosine similarity scaling (with fixed initial parameter values)
27
+ self.similarity_weight = nn.Parameter(torch.tensor([10.])).to(loss_device)
28
+ self.similarity_bias = nn.Parameter(torch.tensor([-5.])).to(loss_device)
29
+
30
+ # Loss
31
+ self.loss_fn = nn.CrossEntropyLoss().to(loss_device)
32
+
33
+ def do_gradient_ops(self):
34
+ # Gradient scale
35
+ self.similarity_weight.grad *= 0.01
36
+ self.similarity_bias.grad *= 0.01
37
+
38
+ # Gradient clipping
39
+ clip_grad_norm_(self.parameters(), 3, norm_type=2)
40
+
41
+ def forward(self, utterances, hidden_init=None):
42
+ """
43
+ Computes the embeddings of a batch of utterance spectrograms.
44
+
45
+ :param utterances: batch of mel-scale filterbanks of same duration as a tensor of shape
46
+ (batch_size, n_frames, n_channels)
47
+ :param hidden_init: initial hidden state of the LSTM as a tensor of shape (num_layers,
48
+ batch_size, hidden_size). Will default to a tensor of zeros if None.
49
+ :return: the embeddings as a tensor of shape (batch_size, embedding_size)
50
+ """
51
+ # Pass the input through the LSTM layers and retrieve all outputs, the final hidden state
52
+ # and the final cell state.
53
+ out, (hidden, cell) = self.lstm(utterances, hidden_init)
54
+
55
+ # We take only the hidden state of the last layer
56
+ embeds_raw = self.relu(self.linear(hidden[-1]))
57
+
58
+ # L2-normalize it
59
+ embeds = embeds_raw / (torch.norm(embeds_raw, dim=1, keepdim=True) + 1e-5)
60
+
61
+ return embeds
62
+
63
+ def similarity_matrix(self, embeds):
64
+ """
65
+ Computes the similarity matrix according the section 2.1 of GE2E.
66
+
67
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
68
+ utterances_per_speaker, embedding_size)
69
+ :return: the similarity matrix as a tensor of shape (speakers_per_batch,
70
+ utterances_per_speaker, speakers_per_batch)
71
+ """
72
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
73
+
74
+ # Inclusive centroids (1 per speaker). Cloning is needed for reverse differentiation
75
+ centroids_incl = torch.mean(embeds, dim=1, keepdim=True)
76
+ centroids_incl = centroids_incl.clone() / (torch.norm(centroids_incl, dim=2, keepdim=True) + 1e-5)
77
+
78
+ # Exclusive centroids (1 per utterance)
79
+ centroids_excl = (torch.sum(embeds, dim=1, keepdim=True) - embeds)
80
+ centroids_excl /= (utterances_per_speaker - 1)
81
+ centroids_excl = centroids_excl.clone() / (torch.norm(centroids_excl, dim=2, keepdim=True) + 1e-5)
82
+
83
+ # Similarity matrix. The cosine similarity of already 2-normed vectors is simply the dot
84
+ # product of these vectors (which is just an element-wise multiplication reduced by a sum).
85
+ # We vectorize the computation for efficiency.
86
+ sim_matrix = torch.zeros(speakers_per_batch, utterances_per_speaker,
87
+ speakers_per_batch).to(self.loss_device)
88
+ mask_matrix = 1 - np.eye(speakers_per_batch, dtype=np.int)
89
+ for j in range(speakers_per_batch):
90
+ mask = np.where(mask_matrix[j])[0]
91
+ sim_matrix[mask, :, j] = (embeds[mask] * centroids_incl[j]).sum(dim=2)
92
+ sim_matrix[j, :, j] = (embeds[j] * centroids_excl[j]).sum(dim=1)
93
+
94
+ ## Even more vectorized version (slower maybe because of transpose)
95
+ # sim_matrix2 = torch.zeros(speakers_per_batch, speakers_per_batch, utterances_per_speaker
96
+ # ).to(self.loss_device)
97
+ # eye = np.eye(speakers_per_batch, dtype=np.int)
98
+ # mask = np.where(1 - eye)
99
+ # sim_matrix2[mask] = (embeds[mask[0]] * centroids_incl[mask[1]]).sum(dim=2)
100
+ # mask = np.where(eye)
101
+ # sim_matrix2[mask] = (embeds * centroids_excl).sum(dim=2)
102
+ # sim_matrix2 = sim_matrix2.transpose(1, 2)
103
+
104
+ sim_matrix = sim_matrix * self.similarity_weight + self.similarity_bias
105
+ return sim_matrix
106
+
107
+ def loss(self, embeds):
108
+ """
109
+ Computes the softmax loss according the section 2.1 of GE2E.
110
+
111
+ :param embeds: the embeddings as a tensor of shape (speakers_per_batch,
112
+ utterances_per_speaker, embedding_size)
113
+ :return: the loss and the EER for this batch of embeddings.
114
+ """
115
+ speakers_per_batch, utterances_per_speaker = embeds.shape[:2]
116
+
117
+ # Loss
118
+ sim_matrix = self.similarity_matrix(embeds)
119
+ sim_matrix = sim_matrix.reshape((speakers_per_batch * utterances_per_speaker,
120
+ speakers_per_batch))
121
+ ground_truth = np.repeat(np.arange(speakers_per_batch), utterances_per_speaker)
122
+ target = torch.from_numpy(ground_truth).long().to(self.loss_device)
123
+ loss = self.loss_fn(sim_matrix, target)
124
+
125
+ # EER (not backpropagated)
126
+ with torch.no_grad():
127
+ inv_argmax = lambda i: np.eye(1, speakers_per_batch, i, dtype=np.int)[0]
128
+ labels = np.array([inv_argmax(i) for i in ground_truth])
129
+ preds = sim_matrix.detach().cpu().numpy()
130
+
131
+ # Snippet from https://yangcha.github.io/EER-ROC/
132
+ fpr, tpr, thresholds = roc_curve(labels.flatten(), preds.flatten())
133
+ eer = brentq(lambda x: 1. - x - interp1d(fpr, tpr)(x), 0., 1.)
134
+
135
+ return loss, eer
encoder/params_data.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Mel-filterbank
3
+ mel_window_length = 25 # In milliseconds
4
+ mel_window_step = 10 # In milliseconds
5
+ mel_n_channels = 40
6
+
7
+
8
+ ## Audio
9
+ sampling_rate = 16000
10
+ # Number of spectrogram frames in a partial utterance
11
+ partials_n_frames = 160 # 1600 ms
12
+ # Number of spectrogram frames at inference
13
+ inference_n_frames = 80 # 800 ms
14
+
15
+
16
+ ## Voice Activation Detection
17
+ # Window size of the VAD. Must be either 10, 20 or 30 milliseconds.
18
+ # This sets the granularity of the VAD. Should not need to be changed.
19
+ vad_window_length = 30 # In milliseconds
20
+ # Number of frames to average together when performing the moving average smoothing.
21
+ # The larger this value, the larger the VAD variations must be to not get smoothed out.
22
+ vad_moving_average_width = 8
23
+ # Maximum number of consecutive silent frames a segment can have.
24
+ vad_max_silence_length = 6
25
+
26
+
27
+ ## Audio volume normalization
28
+ audio_norm_target_dBFS = -30
29
+
encoder/params_model.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ## Model parameters
3
+ model_hidden_size = 256
4
+ model_embedding_size = 256
5
+ model_num_layers = 3
6
+
7
+
8
+ ## Training parameters
9
+ learning_rate_init = 1e-4
10
+ speakers_per_batch = 64
11
+ utterances_per_speaker = 10
encoder/preprocess.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocess.pool import ThreadPool
2
+ from encoder.params_data import *
3
+ from encoder.config import librispeech_datasets, anglophone_nationalites
4
+ from datetime import datetime
5
+ from encoder import audio
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+
10
+
11
+ class DatasetLog:
12
+ """
13
+ Registers metadata about the dataset in a text file.
14
+ """
15
+ def __init__(self, root, name):
16
+ self.text_file = open(Path(root, "Log_%s.txt" % name.replace("/", "_")), "w")
17
+ self.sample_data = dict()
18
+
19
+ start_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
20
+ self.write_line("Creating dataset %s on %s" % (name, start_time))
21
+ self.write_line("-----")
22
+ self._log_params()
23
+
24
+ def _log_params(self):
25
+ from encoder import params_data
26
+ self.write_line("Parameter values:")
27
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
28
+ value = getattr(params_data, param_name)
29
+ self.write_line("\t%s: %s" % (param_name, value))
30
+ self.write_line("-----")
31
+
32
+ def write_line(self, line):
33
+ self.text_file.write("%s\n" % line)
34
+
35
+ def add_sample(self, **kwargs):
36
+ for param_name, value in kwargs.items():
37
+ if not param_name in self.sample_data:
38
+ self.sample_data[param_name] = []
39
+ self.sample_data[param_name].append(value)
40
+
41
+ def finalize(self):
42
+ self.write_line("Statistics:")
43
+ for param_name, values in self.sample_data.items():
44
+ self.write_line("\t%s:" % param_name)
45
+ self.write_line("\t\tmin %.3f, max %.3f" % (np.min(values), np.max(values)))
46
+ self.write_line("\t\tmean %.3f, median %.3f" % (np.mean(values), np.median(values)))
47
+ self.write_line("-----")
48
+ end_time = str(datetime.now().strftime("%A %d %B %Y at %H:%M"))
49
+ self.write_line("Finished on %s" % end_time)
50
+ self.text_file.close()
51
+
52
+
53
+ def _init_preprocess_dataset(dataset_name, datasets_root, out_dir) -> (Path, DatasetLog):
54
+ dataset_root = datasets_root.joinpath(dataset_name)
55
+ if not dataset_root.exists():
56
+ print("Couldn\'t find %s, skipping this dataset." % dataset_root)
57
+ return None, None
58
+ return dataset_root, DatasetLog(out_dir, dataset_name)
59
+
60
+
61
+ def _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, extension,
62
+ skip_existing, logger):
63
+ print("%s: Preprocessing data for %d speakers." % (dataset_name, len(speaker_dirs)))
64
+
65
+ # Function to preprocess utterances for one speaker
66
+ def preprocess_speaker(speaker_dir: Path):
67
+ # Give a name to the speaker that includes its dataset
68
+ speaker_name = "_".join(speaker_dir.relative_to(datasets_root).parts)
69
+
70
+ # Create an output directory with that name, as well as a txt file containing a
71
+ # reference to each source file.
72
+ speaker_out_dir = out_dir.joinpath(speaker_name)
73
+ speaker_out_dir.mkdir(exist_ok=True)
74
+ sources_fpath = speaker_out_dir.joinpath("_sources.txt")
75
+
76
+ # There's a possibility that the preprocessing was interrupted earlier, check if
77
+ # there already is a sources file.
78
+ if sources_fpath.exists():
79
+ try:
80
+ with sources_fpath.open("r") as sources_file:
81
+ existing_fnames = {line.split(",")[0] for line in sources_file}
82
+ except:
83
+ existing_fnames = {}
84
+ else:
85
+ existing_fnames = {}
86
+
87
+ # Gather all audio files for that speaker recursively
88
+ sources_file = sources_fpath.open("a" if skip_existing else "w")
89
+ for in_fpath in speaker_dir.glob("**/*.%s" % extension):
90
+ # Check if the target output file already exists
91
+ out_fname = "_".join(in_fpath.relative_to(speaker_dir).parts)
92
+ out_fname = out_fname.replace(".%s" % extension, ".npy")
93
+ if skip_existing and out_fname in existing_fnames:
94
+ continue
95
+
96
+ # Load and preprocess the waveform
97
+ wav = audio.preprocess_wav(in_fpath)
98
+ if len(wav) == 0:
99
+ continue
100
+
101
+ # Create the mel spectrogram, discard those that are too short
102
+ frames = audio.wav_to_mel_spectrogram(wav)
103
+ if len(frames) < partials_n_frames:
104
+ continue
105
+
106
+ out_fpath = speaker_out_dir.joinpath(out_fname)
107
+ np.save(out_fpath, frames)
108
+ logger.add_sample(duration=len(wav) / sampling_rate)
109
+ sources_file.write("%s,%s\n" % (out_fname, in_fpath))
110
+
111
+ sources_file.close()
112
+
113
+ # Process the utterances for each speaker
114
+ with ThreadPool(8) as pool:
115
+ list(tqdm(pool.imap(preprocess_speaker, speaker_dirs), dataset_name, len(speaker_dirs),
116
+ unit="speakers"))
117
+ logger.finalize()
118
+ print("Done preprocessing %s.\n" % dataset_name)
119
+
120
+
121
+ def preprocess_librispeech(datasets_root: Path, out_dir: Path, skip_existing=False):
122
+ for dataset_name in librispeech_datasets["train"]["other"]:
123
+ # Initialize the preprocessing
124
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
125
+ if not dataset_root:
126
+ return
127
+
128
+ # Preprocess all speakers
129
+ speaker_dirs = list(dataset_root.glob("*"))
130
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "flac",
131
+ skip_existing, logger)
132
+
133
+
134
+ def preprocess_voxceleb1(datasets_root: Path, out_dir: Path, skip_existing=False):
135
+ # Initialize the preprocessing
136
+ dataset_name = "VoxCeleb1"
137
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
138
+ if not dataset_root:
139
+ return
140
+
141
+ # Get the contents of the meta file
142
+ with dataset_root.joinpath("vox1_meta.csv").open("r") as metafile:
143
+ metadata = [line.split("\t") for line in metafile][1:]
144
+
145
+ # Select the ID and the nationality, filter out non-anglophone speakers
146
+ nationalities = {line[0]: line[3] for line in metadata}
147
+ keep_speaker_ids = [speaker_id for speaker_id, nationality in nationalities.items() if
148
+ nationality.lower() in anglophone_nationalites]
149
+ print("VoxCeleb1: using samples from %d (presumed anglophone) speakers out of %d." %
150
+ (len(keep_speaker_ids), len(nationalities)))
151
+
152
+ # Get the speaker directories for anglophone speakers only
153
+ speaker_dirs = dataset_root.joinpath("wav").glob("*")
154
+ speaker_dirs = [speaker_dir for speaker_dir in speaker_dirs if
155
+ speaker_dir.name in keep_speaker_ids]
156
+ print("VoxCeleb1: found %d anglophone speakers on the disk, %d missing (this is normal)." %
157
+ (len(speaker_dirs), len(keep_speaker_ids) - len(speaker_dirs)))
158
+
159
+ # Preprocess all speakers
160
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "wav",
161
+ skip_existing, logger)
162
+
163
+
164
+ def preprocess_voxceleb2(datasets_root: Path, out_dir: Path, skip_existing=False):
165
+ # Initialize the preprocessing
166
+ dataset_name = "VoxCeleb2"
167
+ dataset_root, logger = _init_preprocess_dataset(dataset_name, datasets_root, out_dir)
168
+ if not dataset_root:
169
+ return
170
+
171
+ # Get the speaker directories
172
+ # Preprocess all speakers
173
+ speaker_dirs = list(dataset_root.joinpath("dev", "aac").glob("*"))
174
+ _preprocess_speaker_dirs(speaker_dirs, dataset_name, datasets_root, out_dir, "m4a",
175
+ skip_existing, logger)
encoder/saved_models/text.txt ADDED
File without changes
encoder/train.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.visualizations import Visualizations
2
+ from encoder.data_objects import SpeakerVerificationDataLoader, SpeakerVerificationDataset
3
+ from encoder.params_model import *
4
+ from encoder.model import SpeakerEncoder
5
+ from utils.profiler import Profiler
6
+ from pathlib import Path
7
+ import torch
8
+
9
+ def sync(device: torch.device):
10
+ # For correct profiling (cuda operations are async)
11
+ if device.type == "cuda":
12
+ torch.cuda.synchronize(device)
13
+
14
+
15
+ def train(run_id: str, clean_data_root: Path, models_dir: Path, umap_every: int, save_every: int,
16
+ backup_every: int, vis_every: int, force_restart: bool, visdom_server: str,
17
+ no_visdom: bool):
18
+ # Create a dataset and a dataloader
19
+ dataset = SpeakerVerificationDataset(clean_data_root)
20
+ loader = SpeakerVerificationDataLoader(
21
+ dataset,
22
+ speakers_per_batch,
23
+ utterances_per_speaker,
24
+ num_workers=8,
25
+ )
26
+
27
+ # Setup the device on which to run the forward pass and the loss. These can be different,
28
+ # because the forward pass is faster on the GPU whereas the loss is often (depending on your
29
+ # hyperparameters) faster on the CPU.
30
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
31
+ # FIXME: currently, the gradient is None if loss_device is cuda
32
+ loss_device = torch.device("cpu")
33
+
34
+ # Create the model and the optimizer
35
+ model = SpeakerEncoder(device, loss_device)
36
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate_init)
37
+ init_step = 1
38
+
39
+ # Configure file path for the model
40
+ state_fpath = models_dir.joinpath(run_id + ".pt")
41
+ backup_dir = models_dir.joinpath(run_id + "_backups")
42
+
43
+ # Load any existing model
44
+ if not force_restart:
45
+ if state_fpath.exists():
46
+ print("Found existing model \"%s\", loading it and resuming training." % run_id)
47
+ checkpoint = torch.load(state_fpath)
48
+ init_step = checkpoint["step"]
49
+ model.load_state_dict(checkpoint["model_state"])
50
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
51
+ optimizer.param_groups[0]["lr"] = learning_rate_init
52
+ else:
53
+ print("No model \"%s\" found, starting training from scratch." % run_id)
54
+ else:
55
+ print("Starting the training from scratch.")
56
+ model.train()
57
+
58
+ # Initialize the visualization environment
59
+ vis = Visualizations(run_id, vis_every, server=visdom_server, disabled=no_visdom)
60
+ vis.log_dataset(dataset)
61
+ vis.log_params()
62
+ device_name = str(torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
63
+ vis.log_implementation({"Device": device_name})
64
+
65
+ # Training loop
66
+ profiler = Profiler(summarize_every=10, disabled=False)
67
+ for step, speaker_batch in enumerate(loader, init_step):
68
+ profiler.tick("Blocking, waiting for batch (threaded)")
69
+
70
+ # Forward pass
71
+ inputs = torch.from_numpy(speaker_batch.data).to(device)
72
+ sync(device)
73
+ profiler.tick("Data to %s" % device)
74
+ embeds = model(inputs)
75
+ sync(device)
76
+ profiler.tick("Forward pass")
77
+ embeds_loss = embeds.view((speakers_per_batch, utterances_per_speaker, -1)).to(loss_device)
78
+ loss, eer = model.loss(embeds_loss)
79
+ sync(loss_device)
80
+ profiler.tick("Loss")
81
+
82
+ # Backward pass
83
+ model.zero_grad()
84
+ loss.backward()
85
+ profiler.tick("Backward pass")
86
+ model.do_gradient_ops()
87
+ optimizer.step()
88
+ profiler.tick("Parameter update")
89
+
90
+ # Update visualizations
91
+ # learning_rate = optimizer.param_groups[0]["lr"]
92
+ vis.update(loss.item(), eer, step)
93
+
94
+ # Draw projections and save them to the backup folder
95
+ if umap_every != 0 and step % umap_every == 0:
96
+ print("Drawing and saving projections (step %d)" % step)
97
+ backup_dir.mkdir(exist_ok=True)
98
+ projection_fpath = backup_dir.joinpath("%s_umap_%06d.png" % (run_id, step))
99
+ embeds = embeds.detach().cpu().numpy()
100
+ vis.draw_projections(embeds, utterances_per_speaker, step, projection_fpath)
101
+ vis.save()
102
+
103
+ # Overwrite the latest version of the model
104
+ if save_every != 0 and step % save_every == 0:
105
+ print("Saving the model (step %d)" % step)
106
+ torch.save({
107
+ "step": step + 1,
108
+ "model_state": model.state_dict(),
109
+ "optimizer_state": optimizer.state_dict(),
110
+ }, state_fpath)
111
+
112
+ # Make a backup
113
+ if backup_every != 0 and step % backup_every == 0:
114
+ print("Making a backup (step %d)" % step)
115
+ backup_dir.mkdir(exist_ok=True)
116
+ backup_fpath = backup_dir.joinpath("%s_bak_%06d.pt" % (run_id, step))
117
+ torch.save({
118
+ "step": step + 1,
119
+ "model_state": model.state_dict(),
120
+ "optimizer_state": optimizer.state_dict(),
121
+ }, backup_fpath)
122
+
123
+ profiler.tick("Extras (visualizations, saving)")
encoder/visualizations.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.data_objects.speaker_verification_dataset import SpeakerVerificationDataset
2
+ from datetime import datetime
3
+ from time import perf_counter as timer
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ # import webbrowser
7
+ import visdom
8
+ import umap
9
+
10
+ colormap = np.array([
11
+ [76, 255, 0],
12
+ [0, 127, 70],
13
+ [255, 0, 0],
14
+ [255, 217, 38],
15
+ [0, 135, 255],
16
+ [165, 0, 165],
17
+ [255, 167, 255],
18
+ [0, 255, 255],
19
+ [255, 96, 38],
20
+ [142, 76, 0],
21
+ [33, 0, 127],
22
+ [0, 0, 0],
23
+ [183, 183, 183],
24
+ ], dtype=np.float) / 255
25
+
26
+
27
+ class Visualizations:
28
+ def __init__(self, env_name=None, update_every=10, server="http://localhost", disabled=False):
29
+ # Tracking data
30
+ self.last_update_timestamp = timer()
31
+ self.update_every = update_every
32
+ self.step_times = []
33
+ self.losses = []
34
+ self.eers = []
35
+ print("Updating the visualizations every %d steps." % update_every)
36
+
37
+ # If visdom is disabled TODO: use a better paradigm for that
38
+ self.disabled = disabled
39
+ if self.disabled:
40
+ return
41
+
42
+ # Set the environment name
43
+ now = str(datetime.now().strftime("%d-%m %Hh%M"))
44
+ if env_name is None:
45
+ self.env_name = now
46
+ else:
47
+ self.env_name = "%s (%s)" % (env_name, now)
48
+
49
+ # Connect to visdom and open the corresponding window in the browser
50
+ try:
51
+ self.vis = visdom.Visdom(server, env=self.env_name, raise_exceptions=True)
52
+ except ConnectionError:
53
+ raise Exception("No visdom server detected. Run the command \"visdom\" in your CLI to "
54
+ "start it.")
55
+ # webbrowser.open("http://localhost:8097/env/" + self.env_name)
56
+
57
+ # Create the windows
58
+ self.loss_win = None
59
+ self.eer_win = None
60
+ # self.lr_win = None
61
+ self.implementation_win = None
62
+ self.projection_win = None
63
+ self.implementation_string = ""
64
+
65
+ def log_params(self):
66
+ if self.disabled:
67
+ return
68
+ from encoder import params_data
69
+ from encoder import params_model
70
+ param_string = "<b>Model parameters</b>:<br>"
71
+ for param_name in (p for p in dir(params_model) if not p.startswith("__")):
72
+ value = getattr(params_model, param_name)
73
+ param_string += "\t%s: %s<br>" % (param_name, value)
74
+ param_string += "<b>Data parameters</b>:<br>"
75
+ for param_name in (p for p in dir(params_data) if not p.startswith("__")):
76
+ value = getattr(params_data, param_name)
77
+ param_string += "\t%s: %s<br>" % (param_name, value)
78
+ self.vis.text(param_string, opts={"title": "Parameters"})
79
+
80
+ def log_dataset(self, dataset: SpeakerVerificationDataset):
81
+ if self.disabled:
82
+ return
83
+ dataset_string = ""
84
+ dataset_string += "<b>Speakers</b>: %s\n" % len(dataset.speakers)
85
+ dataset_string += "\n" + dataset.get_logs()
86
+ dataset_string = dataset_string.replace("\n", "<br>")
87
+ self.vis.text(dataset_string, opts={"title": "Dataset"})
88
+
89
+ def log_implementation(self, params):
90
+ if self.disabled:
91
+ return
92
+ implementation_string = ""
93
+ for param, value in params.items():
94
+ implementation_string += "<b>%s</b>: %s\n" % (param, value)
95
+ implementation_string = implementation_string.replace("\n", "<br>")
96
+ self.implementation_string = implementation_string
97
+ self.implementation_win = self.vis.text(
98
+ implementation_string,
99
+ opts={"title": "Training implementation"}
100
+ )
101
+
102
+ def update(self, loss, eer, step):
103
+ # Update the tracking data
104
+ now = timer()
105
+ self.step_times.append(1000 * (now - self.last_update_timestamp))
106
+ self.last_update_timestamp = now
107
+ self.losses.append(loss)
108
+ self.eers.append(eer)
109
+ print(".", end="")
110
+
111
+ # Update the plots every <update_every> steps
112
+ if step % self.update_every != 0:
113
+ return
114
+ time_string = "Step time: mean: %5dms std: %5dms" % \
115
+ (int(np.mean(self.step_times)), int(np.std(self.step_times)))
116
+ print("\nStep %6d Loss: %.4f EER: %.4f %s" %
117
+ (step, np.mean(self.losses), np.mean(self.eers), time_string))
118
+ if not self.disabled:
119
+ self.loss_win = self.vis.line(
120
+ [np.mean(self.losses)],
121
+ [step],
122
+ win=self.loss_win,
123
+ update="append" if self.loss_win else None,
124
+ opts=dict(
125
+ legend=["Avg. loss"],
126
+ xlabel="Step",
127
+ ylabel="Loss",
128
+ title="Loss",
129
+ )
130
+ )
131
+ self.eer_win = self.vis.line(
132
+ [np.mean(self.eers)],
133
+ [step],
134
+ win=self.eer_win,
135
+ update="append" if self.eer_win else None,
136
+ opts=dict(
137
+ legend=["Avg. EER"],
138
+ xlabel="Step",
139
+ ylabel="EER",
140
+ title="Equal error rate"
141
+ )
142
+ )
143
+ if self.implementation_win is not None:
144
+ self.vis.text(
145
+ self.implementation_string + ("<b>%s</b>" % time_string),
146
+ win=self.implementation_win,
147
+ opts={"title": "Training implementation"},
148
+ )
149
+
150
+ # Reset the tracking
151
+ self.losses.clear()
152
+ self.eers.clear()
153
+ self.step_times.clear()
154
+
155
+ def draw_projections(self, embeds, utterances_per_speaker, step, out_fpath=None,
156
+ max_speakers=10):
157
+ max_speakers = min(max_speakers, len(colormap))
158
+ embeds = embeds[:max_speakers * utterances_per_speaker]
159
+
160
+ n_speakers = len(embeds) // utterances_per_speaker
161
+ ground_truth = np.repeat(np.arange(n_speakers), utterances_per_speaker)
162
+ colors = [colormap[i] for i in ground_truth]
163
+
164
+ reducer = umap.UMAP()
165
+ projected = reducer.fit_transform(embeds)
166
+ plt.scatter(projected[:, 0], projected[:, 1], c=colors)
167
+ plt.gca().set_aspect("equal", "datalim")
168
+ plt.title("UMAP projection (step %d)" % step)
169
+ if not self.disabled:
170
+ self.projection_win = self.vis.matplot(plt, win=self.projection_win)
171
+ if out_fpath is not None:
172
+ plt.savefig(out_fpath)
173
+ plt.clf()
174
+
175
+ def save(self):
176
+ if not self.disabled:
177
+ self.vis.save([self.env_name])
178
+
encoder_preprocess.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from encoder.preprocess import preprocess_librispeech, preprocess_voxceleb1, preprocess_voxceleb2
2
+ from utils.argutils import print_args
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+ if __name__ == "__main__":
7
+ class MyFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter):
8
+ pass
9
+
10
+ parser = argparse.ArgumentParser(
11
+ description="Preprocesses audio files from datasets, encodes them as mel spectrograms and "
12
+ "writes them to the disk. This will allow you to train the encoder. The "
13
+ "datasets required are at least one of VoxCeleb1, VoxCeleb2 and LibriSpeech. "
14
+ "Ideally, you should have all three. You should extract them as they are "
15
+ "after having downloaded them and put them in a same directory, e.g.:\n"
16
+ "-[datasets_root]\n"
17
+ " -LibriSpeech\n"
18
+ " -train-other-500\n"
19
+ " -VoxCeleb1\n"
20
+ " -wav\n"
21
+ " -vox1_meta.csv\n"
22
+ " -VoxCeleb2\n"
23
+ " -dev",
24
+ formatter_class=MyFormatter
25
+ )
26
+ parser.add_argument("datasets_root", type=Path, help=\
27
+ "Path to the directory containing your LibriSpeech/TTS and VoxCeleb datasets.")
28
+ parser.add_argument("-o", "--out_dir", type=Path, default=argparse.SUPPRESS, help=\
29
+ "Path to the output directory that will contain the mel spectrograms. If left out, "
30
+ "defaults to <datasets_root>/SV2TTS/encoder/")
31
+ parser.add_argument("-d", "--datasets", type=str,
32
+ default="librispeech_other,voxceleb1,voxceleb2", help=\
33
+ "Comma-separated list of the name of the datasets you want to preprocess. Only the train "
34
+ "set of these datasets will be used. Possible names: librispeech_other, voxceleb1, "
35
+ "voxceleb2.")
36
+ parser.add_argument("-s", "--skip_existing", action="store_true", help=\
37
+ "Whether to skip existing output files with the same name. Useful if this script was "
38
+ "interrupted.")
39
+ parser.add_argument("--no_trim", action="store_true", help=\
40
+ "Preprocess audio without trimming silences (not recommended).")
41
+ args = parser.parse_args()
42
+
43
+ # Verify webrtcvad is available
44
+ if not args.no_trim:
45
+ try:
46
+ import webrtcvad
47
+ except:
48
+ raise ModuleNotFoundError("Package 'webrtcvad' not found. This package enables "
49
+ "noise removal and is recommended. Please install and try again. If installation fails, "
50
+ "use --no_trim to disable this error message.")
51
+ del args.no_trim
52
+
53
+ # Process the arguments
54
+ args.datasets = args.datasets.split(",")
55
+ if not hasattr(args, "out_dir"):
56
+ args.out_dir = args.datasets_root.joinpath("SV2TTS", "encoder")
57
+ assert args.datasets_root.exists()
58
+ args.out_dir.mkdir(exist_ok=True, parents=True)
59
+
60
+ # Preprocess the datasets
61
+ print_args(args, parser)
62
+ preprocess_func = {
63
+ "librispeech_other": preprocess_librispeech,
64
+ "voxceleb1": preprocess_voxceleb1,
65
+ "voxceleb2": preprocess_voxceleb2,
66
+ }
67
+ args = vars(args)
68
+ for dataset in args.pop("datasets"):
69
+ print("Preprocessing %s" % dataset)
70
+ preprocess_func[dataset](**args)
encoder_train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.argutils import print_args
2
+ from encoder.train import train
3
+ from pathlib import Path
4
+ import argparse
5
+
6
+
7
+ if __name__ == "__main__":
8
+ parser = argparse.ArgumentParser(
9
+ description="Trains the speaker encoder. You must have run encoder_preprocess.py first.",
10
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
11
+ )
12
+
13
+ parser.add_argument("run_id", type=str, help= \
14
+ "Name for this model instance. If a model state from the same run ID was previously "
15
+ "saved, the training will restart from there. Pass -f to overwrite saved states and "
16
+ "restart from scratch.")
17
+ parser.add_argument("clean_data_root", type=Path, help= \
18
+ "Path to the output directory of encoder_preprocess.py. If you left the default "
19
+ "output directory when preprocessing, it should be <datasets_root>/SV2TTS/encoder/.")
20
+ parser.add_argument("-m", "--models_dir", type=Path, default="encoder/saved_models/", help=\
21
+ "Path to the output directory that will contain the saved model weights, as well as "
22
+ "backups of those weights and plots generated during training.")
23
+ parser.add_argument("-v", "--vis_every", type=int, default=10, help= \
24
+ "Number of steps between updates of the loss and the plots.")
25
+ parser.add_argument("-u", "--umap_every", type=int, default=100, help= \
26
+ "Number of steps between updates of the umap projection. Set to 0 to never update the "
27
+ "projections.")
28
+ parser.add_argument("-s", "--save_every", type=int, default=500, help= \
29
+ "Number of steps between updates of the model on the disk. Set to 0 to never save the "
30
+ "model.")
31
+ parser.add_argument("-b", "--backup_every", type=int, default=7500, help= \
32
+ "Number of steps between backups of the model. Set to 0 to never make backups of the "
33
+ "model.")
34
+ parser.add_argument("-f", "--force_restart", action="store_true", help= \
35
+ "Do not load any saved model.")
36
+ parser.add_argument("--visdom_server", type=str, default="http://localhost")
37
+ parser.add_argument("--no_visdom", action="store_true", help= \
38
+ "Disable visdom.")
39
+ args = parser.parse_args()
40
+
41
+ # Process the arguments
42
+ args.models_dir.mkdir(exist_ok=True)
43
+
44
+ # Run the training
45
+ print_args(args, parser)
46
+ train(**vars(args))
47
+
packages.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ libportaudio2
2
+ libsndfile1
3
+ ffmpeg
4
+ wget
5
+ unzip
requirements.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ inflect==5.3.0
2
+ librosa==0.8.1
3
+ matplotlib==3.5.1
4
+ numpy
5
+ Pillow==8.4.0
6
+ PyQt5==5.15.6
7
+ scikit-learn==1.0.2
8
+ scipy==1.7.3
9
+ sounddevice==0.4.3
10
+ SoundFile==0.10.3.post1
11
+ tqdm==4.62.3
12
+ umap-learn==0.5.2
13
+ Unidecode==1.3.2
14
+ urllib3==1.26.7
15
+ visdom==0.1.8.9
16
+ webrtcvad==2.0.10
17
+ gradio==3.17.1
18
+ gdown
19
+ torch
samples/.DS_Store ADDED
Binary file (6.15 kB). View file
 
samples/1320_00000.mp3 ADDED
Binary file (15.5 kB). View file
 
samples/3575_00000.mp3 ADDED
Binary file (15.5 kB). View file
 
samples/6829_00000.mp3 ADDED
Binary file (15.6 kB). View file
 
samples/README.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The audio files in this folder are provided for toolbox testing and
2
+ benchmarking purposes. These are the same reference utterances
3
+ used by the SV2TTS authors to generate the audio samples located at:
4
+ https://google.github.io/tacotron/publications/speaker_adaptation/index.html
5
+
6
+ The `p240_00000.mp3` and `p260_00000.mp3` files are compressed
7
+ versions of audios from the VCTK corpus available at:
8
+ https://datashare.is.ed.ac.uk/handle/10283/3443
9
+ VCTK.txt contains the copyright notices and licensing information.
10
+
11
+ The `1320_00000.mp3`, `3575_00000.mp3`, `6829_00000.mp3`
12
+ and `8230_00000.mp3` files are compressed versions of audios
13
+ from the LibriSpeech dataset available at: https://openslr.org/12
14
+ For these files, the following notice applies:
15
+ ```
16
+ LibriSpeech (c) 2014 by Vassil Panayotov
17
+
18
+ LibriSpeech ASR corpus is licensed under a
19
+ Creative Commons Attribution 4.0 International License.
20
+
21
+ See <http://creativecommons.org/licenses/by/4.0/>.
22
+ ```
samples/VCTK.txt ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---------------------------------------------------------------------
2
+ CSTR VCTK Corpus
3
+ English Multi-speaker Corpus for CSTR Voice Cloning Toolkit
4
+
5
+ (Version 0.92)
6
+ RELEASE September 2019
7
+ The Centre for Speech Technology Research
8
+ University of Edinburgh
9
+ Copyright (c) 2019
10
+
11
+ Junichi Yamagishi
12
+ jyamagis@inf.ed.ac.uk
13
+ ---------------------------------------------------------------------
14
+
15
+ Overview
16
+
17
+ This CSTR VCTK Corpus includes speech data uttered by 110 English
18
+ speakers with various accents. Each speaker reads out about 400
19
+ sentences, which were selected from a newspaper, the rainbow passage
20
+ and an elicitation paragraph used for the speech accent archive.
21
+
22
+ The newspaper texts were taken from Herald Glasgow, with permission
23
+ from Herald & Times Group. Each speaker has a different set of the
24
+ newspaper texts selected based a greedy algorithm that increases the
25
+ contextual and phonetic coverage. The details of the text selection
26
+ algorithms are described in the following paper:
27
+
28
+ C. Veaux, J. Yamagishi and S. King,
29
+ "The voice bank corpus: Design, collection and data analysis of
30
+ a large regional accent speech database,"
31
+ https://doi.org/10.1109/ICSDA.2013.6709856
32
+
33
+ The rainbow passage and elicitation paragraph are the same for all
34
+ speakers. The rainbow passage can be found at International Dialects
35
+ of English Archive:
36
+ (http://web.ku.edu/~idea/readings/rainbow.htm). The elicitation
37
+ paragraph is identical to the one used for the speech accent archive
38
+ (http://accent.gmu.edu). The details of the the speech accent archive
39
+ can be found at
40
+ http://www.ualberta.ca/~aacl2009/PDFs/WeinbergerKunath2009AACL.pdf
41
+
42
+ All speech data was recorded using an identical recording setup: an
43
+ omni-directional microphone (DPA 4035) and a small diaphragm condenser
44
+ microphone with very wide bandwidth (Sennheiser MKH 800), 96kHz
45
+ sampling frequency at 24 bits and in a hemi-anechoic chamber of
46
+ the University of Edinburgh. (However, two speakers, p280 and p315
47
+ had technical issues of the audio recordings using MKH 800).
48
+ All recordings were converted into 16 bits, were downsampled to
49
+ 48 kHz, and were manually end-pointed.
50
+
51
+ This corpus was originally aimed for HMM-based text-to-speech synthesis
52
+ systems, especially for speaker-adaptive HMM-based speech synthesis
53
+ that uses average voice models trained on multiple speakers and speaker
54
+ adaptation technologies. This corpus is also suitable for DNN-based
55
+ multi-speaker text-to-speech synthesis systems and waveform modeling.
56
+
57
+ COPYING
58
+
59
+ This corpus is licensed under the Creative Commons License: Attribution 4.0 International
60
+ http://creativecommons.org/licenses/by/4.0/legalcode
61
+
62
+ VCTK VARIANTS
63
+ There are several variants of the VCTK corpus:
64
+ Speech enhancement
65
+ - Noisy speech database for training speech enhancement algorithms and TTS models where we added various types of noises to VCTK artificially: http://dx.doi.org/10.7488/ds/2117
66
+ - Reverberant speech database for training speech dereverberation algorithms and TTS models where we added various types of reverberantion to VCTK artificially http://dx.doi.org/10.7488/ds/1425
67
+ - Noisy reverberant speech database for training speech enhancement algorithms and TTS models http://dx.doi.org/10.7488/ds/2139
68
+ - Device Recorded VCTK where speech signals of the VCTK corpus were played back and re-recorded in office environments using relatively inexpensive consumer devices http://dx.doi.org/10.7488/ds/2316
69
+ - The Microsoft Scalable Noisy Speech Dataset (MS-SNSD) https://github.com/microsoft/MS-SNSD
70
+
71
+ ASV and anti-spoofing
72
+ - Spoofing and Anti-Spoofing (SAS) corpus, which is a collection of synthetic speech signals produced by nine techniques, two of which are speech synthesis, and seven are voice conversion. All of them were built using the VCTK corpus. http://dx.doi.org/10.7488/ds/252
73
+ - Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) Database. This database consists of synthetic speech signals produced by ten techniques and this has been used in the first Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2015) http://dx.doi.org/10.7488/ds/298
74
+ - ASVspoof 2019: The 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge database. This database has been used in the 3rd Automatic Speaker Verification Spoofing and Countermeasures Challenge (ASVspoof 2019) https://doi.org/10.7488/ds/2555
75
+
76
+
77
+ ACKNOWLEDGEMENTS
78
+
79
+ The CSTR VCTK Corpus was constructed by:
80
+
81
+ Christophe Veaux (University of Edinburgh)
82
+ Junichi Yamagishi (University of Edinburgh)
83
+ Kirsten MacDonald
84
+
85
+ The research leading to these results was partly funded from EPSRC
86
+ grants EP/I031022/1 (NST) and EP/J002526/1 (CAF), from the RSE-NSFC
87
+ grant (61111130120), and from the JST CREST (uDialogue).
88
+
89
+ Please cite this corpus as follows:
90
+ Christophe Veaux, Junichi Yamagishi, Kirsten MacDonald,
91
+ "CSTR VCTK Corpus: English Multi-speaker Corpus for CSTR Voice Cloning Toolkit",
92
+ The Centre for Speech Technology Research (CSTR),
93
+ University of Edinburgh
94
+
samples/p240_00000.mp3 ADDED
Binary file (20.2 kB). View file
 
samples/p260_00000.mp3 ADDED
Binary file (20.5 kB). View file
 
synthesizer/LICENSE.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Original work Copyright (c) 2018 Rayhane Mama (https://github.com/Rayhane-mamah)
4
+ Original work Copyright (c) 2019 fatchord (https://github.com/fatchord)
5
+ Modified work Copyright (c) 2019 Corentin Jemine (https://github.com/CorentinJ)
6
+ Modified work Copyright (c) 2020 blue-fish (https://github.com/blue-fish)
7
+
8
+ Permission is hereby granted, free of charge, to any person obtaining a copy
9
+ of this software and associated documentation files (the "Software"), to deal
10
+ in the Software without restriction, including without limitation the rights
11
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12
+ copies of the Software, and to permit persons to whom the Software is
13
+ furnished to do so, subject to the following conditions:
14
+
15
+ The above copyright notice and this permission notice shall be included in all
16
+ copies or substantial portions of the Software.
17
+
18
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
23
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
24
+ SOFTWARE.
synthesizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #
synthesizer/audio.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import librosa.filters
3
+ import numpy as np
4
+ from scipy import signal
5
+ from scipy.io import wavfile
6
+ import soundfile as sf
7
+
8
+
9
+ def load_wav(path, sr):
10
+ return librosa.core.load(path, sr=sr)[0]
11
+
12
+ def save_wav(wav, path, sr):
13
+ wav *= 32767 / max(0.01, np.max(np.abs(wav)))
14
+ #proposed by @dsmiller
15
+ wavfile.write(path, sr, wav.astype(np.int16))
16
+
17
+ def save_wavenet_wav(wav, path, sr):
18
+ sf.write(path, wav.astype(np.float32), sr)
19
+
20
+ def preemphasis(wav, k, preemphasize=True):
21
+ if preemphasize:
22
+ return signal.lfilter([1, -k], [1], wav)
23
+ return wav
24
+
25
+ def inv_preemphasis(wav, k, inv_preemphasize=True):
26
+ if inv_preemphasize:
27
+ return signal.lfilter([1], [1, -k], wav)
28
+ return wav
29
+
30
+ #From https://github.com/r9y9/wavenet_vocoder/blob/master/audio.py
31
+ def start_and_end_indices(quantized, silence_threshold=2):
32
+ for start in range(quantized.size):
33
+ if abs(quantized[start] - 127) > silence_threshold:
34
+ break
35
+ for end in range(quantized.size - 1, 1, -1):
36
+ if abs(quantized[end] - 127) > silence_threshold:
37
+ break
38
+
39
+ assert abs(quantized[start] - 127) > silence_threshold
40
+ assert abs(quantized[end] - 127) > silence_threshold
41
+
42
+ return start, end
43
+
44
+ def get_hop_size(hparams):
45
+ hop_size = hparams.hop_size
46
+ if hop_size is None:
47
+ assert hparams.frame_shift_ms is not None
48
+ hop_size = int(hparams.frame_shift_ms / 1000 * hparams.sample_rate)
49
+ return hop_size
50
+
51
+ def linearspectrogram(wav, hparams):
52
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
53
+ S = _amp_to_db(np.abs(D), hparams) - hparams.ref_level_db
54
+
55
+ if hparams.signal_normalization:
56
+ return _normalize(S, hparams)
57
+ return S
58
+
59
+ def melspectrogram(wav, hparams):
60
+ D = _stft(preemphasis(wav, hparams.preemphasis, hparams.preemphasize), hparams)
61
+ S = _amp_to_db(_linear_to_mel(np.abs(D), hparams), hparams) - hparams.ref_level_db
62
+
63
+ if hparams.signal_normalization:
64
+ return _normalize(S, hparams)
65
+ return S
66
+
67
+ def inv_linear_spectrogram(linear_spectrogram, hparams):
68
+ """Converts linear spectrogram to waveform using librosa"""
69
+ if hparams.signal_normalization:
70
+ D = _denormalize(linear_spectrogram, hparams)
71
+ else:
72
+ D = linear_spectrogram
73
+
74
+ S = _db_to_amp(D + hparams.ref_level_db) #Convert back to linear
75
+
76
+ if hparams.use_lws:
77
+ processor = _lws_processor(hparams)
78
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
79
+ y = processor.istft(D).astype(np.float32)
80
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
81
+ else:
82
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
83
+
84
+ def inv_mel_spectrogram(mel_spectrogram, hparams):
85
+ """Converts mel spectrogram to waveform using librosa"""
86
+ if hparams.signal_normalization:
87
+ D = _denormalize(mel_spectrogram, hparams)
88
+ else:
89
+ D = mel_spectrogram
90
+
91
+ S = _mel_to_linear(_db_to_amp(D + hparams.ref_level_db), hparams) # Convert back to linear
92
+
93
+ if hparams.use_lws:
94
+ processor = _lws_processor(hparams)
95
+ D = processor.run_lws(S.astype(np.float64).T ** hparams.power)
96
+ y = processor.istft(D).astype(np.float32)
97
+ return inv_preemphasis(y, hparams.preemphasis, hparams.preemphasize)
98
+ else:
99
+ return inv_preemphasis(_griffin_lim(S ** hparams.power, hparams), hparams.preemphasis, hparams.preemphasize)
100
+
101
+ def _lws_processor(hparams):
102
+ import lws
103
+ return lws.lws(hparams.n_fft, get_hop_size(hparams), fftsize=hparams.win_size, mode="speech")
104
+
105
+ def _griffin_lim(S, hparams):
106
+ """librosa implementation of Griffin-Lim
107
+ Based on https://github.com/librosa/librosa/issues/434
108
+ """
109
+ angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
110
+ S_complex = np.abs(S).astype(np.complex)
111
+ y = _istft(S_complex * angles, hparams)
112
+ for i in range(hparams.griffin_lim_iters):
113
+ angles = np.exp(1j * np.angle(_stft(y, hparams)))
114
+ y = _istft(S_complex * angles, hparams)
115
+ return y
116
+
117
+ def _stft(y, hparams):
118
+ if hparams.use_lws:
119
+ return _lws_processor(hparams).stft(y).T
120
+ else:
121
+ return librosa.stft(y=y, n_fft=hparams.n_fft, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
122
+
123
+ def _istft(y, hparams):
124
+ return librosa.istft(y, hop_length=get_hop_size(hparams), win_length=hparams.win_size)
125
+
126
+ ##########################################################
127
+ #Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
128
+ def num_frames(length, fsize, fshift):
129
+ """Compute number of time frames of spectrogram
130
+ """
131
+ pad = (fsize - fshift)
132
+ if length % fshift == 0:
133
+ M = (length + pad * 2 - fsize) // fshift + 1
134
+ else:
135
+ M = (length + pad * 2 - fsize) // fshift + 2
136
+ return M
137
+
138
+
139
+ def pad_lr(x, fsize, fshift):
140
+ """Compute left and right padding
141
+ """
142
+ M = num_frames(len(x), fsize, fshift)
143
+ pad = (fsize - fshift)
144
+ T = len(x) + 2 * pad
145
+ r = (M - 1) * fshift + fsize - T
146
+ return pad, pad + r
147
+ ##########################################################
148
+ #Librosa correct padding
149
+ def librosa_pad_lr(x, fsize, fshift):
150
+ return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
151
+
152
+ # Conversions
153
+ _mel_basis = None
154
+ _inv_mel_basis = None
155
+
156
+ def _linear_to_mel(spectogram, hparams):
157
+ global _mel_basis
158
+ if _mel_basis is None:
159
+ _mel_basis = _build_mel_basis(hparams)
160
+ return np.dot(_mel_basis, spectogram)
161
+
162
+ def _mel_to_linear(mel_spectrogram, hparams):
163
+ global _inv_mel_basis
164
+ if _inv_mel_basis is None:
165
+ _inv_mel_basis = np.linalg.pinv(_build_mel_basis(hparams))
166
+ return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
167
+
168
+ def _build_mel_basis(hparams):
169
+ assert hparams.fmax <= hparams.sample_rate // 2
170
+ return librosa.filters.mel(hparams.sample_rate, hparams.n_fft, n_mels=hparams.num_mels,
171
+ fmin=hparams.fmin, fmax=hparams.fmax)
172
+
173
+ def _amp_to_db(x, hparams):
174
+ min_level = np.exp(hparams.min_level_db / 20 * np.log(10))
175
+ return 20 * np.log10(np.maximum(min_level, x))
176
+
177
+ def _db_to_amp(x):
178
+ return np.power(10.0, (x) * 0.05)
179
+
180
+ def _normalize(S, hparams):
181
+ if hparams.allow_clipping_in_normalization:
182
+ if hparams.symmetric_mels:
183
+ return np.clip((2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value,
184
+ -hparams.max_abs_value, hparams.max_abs_value)
185
+ else:
186
+ return np.clip(hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db)), 0, hparams.max_abs_value)
187
+
188
+ assert S.max() <= 0 and S.min() - hparams.min_level_db >= 0
189
+ if hparams.symmetric_mels:
190
+ return (2 * hparams.max_abs_value) * ((S - hparams.min_level_db) / (-hparams.min_level_db)) - hparams.max_abs_value
191
+ else:
192
+ return hparams.max_abs_value * ((S - hparams.min_level_db) / (-hparams.min_level_db))
193
+
194
+ def _denormalize(D, hparams):
195
+ if hparams.allow_clipping_in_normalization:
196
+ if hparams.symmetric_mels:
197
+ return (((np.clip(D, -hparams.max_abs_value,
198
+ hparams.max_abs_value) + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value))
199
+ + hparams.min_level_db)
200
+ else:
201
+ return ((np.clip(D, 0, hparams.max_abs_value) * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
202
+
203
+ if hparams.symmetric_mels:
204
+ return (((D + hparams.max_abs_value) * -hparams.min_level_db / (2 * hparams.max_abs_value)) + hparams.min_level_db)
205
+ else:
206
+ return ((D * -hparams.min_level_db / hparams.max_abs_value) + hparams.min_level_db)
synthesizer/hparams.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import pprint
3
+
4
+ class HParams(object):
5
+ def __init__(self, **kwargs): self.__dict__.update(kwargs)
6
+ def __setitem__(self, key, value): setattr(self, key, value)
7
+ def __getitem__(self, key): return getattr(self, key)
8
+ def __repr__(self): return pprint.pformat(self.__dict__)
9
+
10
+ def parse(self, string):
11
+ # Overrides hparams from a comma-separated string of name=value pairs
12
+ if len(string) > 0:
13
+ overrides = [s.split("=") for s in string.split(",")]
14
+ keys, values = zip(*overrides)
15
+ keys = list(map(str.strip, keys))
16
+ values = list(map(str.strip, values))
17
+ for k in keys:
18
+ self.__dict__[k] = ast.literal_eval(values[keys.index(k)])
19
+ return self
20
+
21
+ hparams = HParams(
22
+ ### Signal Processing (used in both synthesizer and vocoder)
23
+ sample_rate = 16000,
24
+ n_fft = 800,
25
+ num_mels = 80,
26
+ hop_size = 200, # Tacotron uses 12.5 ms frame shift (set to sample_rate * 0.0125)
27
+ win_size = 800, # Tacotron uses 50 ms frame length (set to sample_rate * 0.050)
28
+ fmin = 55,
29
+ min_level_db = -100,
30
+ ref_level_db = 20,
31
+ max_abs_value = 4., # Gradient explodes if too big, premature convergence if too small.
32
+ preemphasis = 0.97, # Filter coefficient to use if preemphasize is True
33
+ preemphasize = True,
34
+
35
+ ### Tacotron Text-to-Speech (TTS)
36
+ tts_embed_dims = 512, # Embedding dimension for the graphemes/phoneme inputs
37
+ tts_encoder_dims = 256,
38
+ tts_decoder_dims = 128,
39
+ tts_postnet_dims = 512,
40
+ tts_encoder_K = 5,
41
+ tts_lstm_dims = 1024,
42
+ tts_postnet_K = 5,
43
+ tts_num_highways = 4,
44
+ tts_dropout = 0.5,
45
+ tts_cleaner_names = ["english_cleaners"],
46
+ tts_stop_threshold = -3.4, # Value below which audio generation ends.
47
+ # For example, for a range of [-4, 4], this
48
+ # will terminate the sequence at the first
49
+ # frame that has all values < -3.4
50
+
51
+ ### Tacotron Training
52
+ tts_schedule = [(2, 1e-3, 20_000, 12), # Progressive training schedule
53
+ (2, 5e-4, 40_000, 12), # (r, lr, step, batch_size)
54
+ (2, 2e-4, 80_000, 12), #
55
+ (2, 1e-4, 160_000, 12), # r = reduction factor (# of mel frames
56
+ (2, 3e-5, 320_000, 12), # synthesized for each decoder iteration)
57
+ (2, 1e-5, 640_000, 12)], # lr = learning rate
58
+
59
+ tts_clip_grad_norm = 1.0, # clips the gradient norm to prevent explosion - set to None if not needed
60
+ tts_eval_interval = 500, # Number of steps between model evaluation (sample generation)
61
+ # Set to -1 to generate after completing epoch, or 0 to disable
62
+
63
+ tts_eval_num_samples = 1, # Makes this number of samples
64
+
65
+ ### Data Preprocessing
66
+ max_mel_frames = 900,
67
+ rescale = True,
68
+ rescaling_max = 0.9,
69
+ synthesis_batch_size = 16, # For vocoder preprocessing and inference.
70
+
71
+ ### Mel Visualization and Griffin-Lim
72
+ signal_normalization = True,
73
+ power = 1.5,
74
+ griffin_lim_iters = 60,
75
+
76
+ ### Audio processing options
77
+ fmax = 7600, # Should not exceed (sample_rate // 2)
78
+ allow_clipping_in_normalization = True, # Used when signal_normalization = True
79
+ clip_mels_length = True, # If true, discards samples exceeding max_mel_frames
80
+ use_lws = False, # "Fast spectrogram phase recovery using local weighted sums"
81
+ symmetric_mels = True, # Sets mel range to [-max_abs_value, max_abs_value] if True,
82
+ # and [0, max_abs_value] if False
83
+ trim_silence = True, # Use with sample_rate of 16000 for best results
84
+
85
+ ### SV2TTS
86
+ speaker_embedding_size = 256, # Dimension for the speaker embedding
87
+ silence_min_duration_split = 0.4, # Duration in seconds of a silence for an utterance to be split
88
+ utterance_min_duration = 1.6, # Duration in seconds below which utterances are discarded
89
+ )
90
+
91
+ def hparams_debug_string():
92
+ return str(hparams)
synthesizer/inference.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from synthesizer import audio
3
+ from synthesizer.hparams import hparams
4
+ from synthesizer.models.tacotron import Tacotron
5
+ from synthesizer.utils.symbols import symbols
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from vocoder.display import simple_table
8
+ from pathlib import Path
9
+ from typing import Union, List
10
+ import numpy as np
11
+ import librosa
12
+
13
+
14
+ class Synthesizer:
15
+ sample_rate = hparams.sample_rate
16
+ hparams = hparams
17
+
18
+ def __init__(self, model_fpath: Path, verbose=True):
19
+ """
20
+ The model isn't instantiated and loaded in memory until needed or until load() is called.
21
+
22
+ :param model_fpath: path to the trained model file
23
+ :param verbose: if False, prints less information when using the model
24
+ """
25
+ self.model_fpath = model_fpath
26
+ self.verbose = verbose
27
+
28
+ # Check for GPU
29
+ if torch.cuda.is_available():
30
+ self.device = torch.device("cuda")
31
+ else:
32
+ self.device = torch.device("cpu")
33
+ if self.verbose:
34
+ print("Synthesizer using device:", self.device)
35
+
36
+ # Tacotron model will be instantiated later on first use.
37
+ self._model = None
38
+
39
+ def is_loaded(self):
40
+ """
41
+ Whether the model is loaded in memory.
42
+ """
43
+ return self._model is not None
44
+
45
+ def load(self):
46
+ """
47
+ Instantiates and loads the model given the weights file that was passed in the constructor.
48
+ """
49
+ self._model = Tacotron(embed_dims=hparams.tts_embed_dims,
50
+ num_chars=len(symbols),
51
+ encoder_dims=hparams.tts_encoder_dims,
52
+ decoder_dims=hparams.tts_decoder_dims,
53
+ n_mels=hparams.num_mels,
54
+ fft_bins=hparams.num_mels,
55
+ postnet_dims=hparams.tts_postnet_dims,
56
+ encoder_K=hparams.tts_encoder_K,
57
+ lstm_dims=hparams.tts_lstm_dims,
58
+ postnet_K=hparams.tts_postnet_K,
59
+ num_highways=hparams.tts_num_highways,
60
+ dropout=hparams.tts_dropout,
61
+ stop_threshold=hparams.tts_stop_threshold,
62
+ speaker_embedding_size=hparams.speaker_embedding_size).to(self.device)
63
+
64
+ self._model.load(self.model_fpath)
65
+ self._model.eval()
66
+
67
+ if self.verbose:
68
+ print("Loaded synthesizer \"%s\" trained to step %d" % (self.model_fpath.name, self._model.state_dict()["step"]))
69
+
70
+ def synthesize_spectrograms(self, texts: List[str],
71
+ embeddings: Union[np.ndarray, List[np.ndarray]],
72
+ return_alignments=False):
73
+ """
74
+ Synthesizes mel spectrograms from texts and speaker embeddings.
75
+
76
+ :param texts: a list of N text prompts to be synthesized
77
+ :param embeddings: a numpy array or list of speaker embeddings of shape (N, 256)
78
+ :param return_alignments: if True, a matrix representing the alignments between the
79
+ characters
80
+ and each decoder output step will be returned for each spectrogram
81
+ :return: a list of N melspectrograms as numpy arrays of shape (80, Mi), where Mi is the
82
+ sequence length of spectrogram i, and possibly the alignments.
83
+ """
84
+ # Load the model on the first request.
85
+ if not self.is_loaded():
86
+ self.load()
87
+
88
+ # Print some info about the model when it is loaded
89
+ tts_k = self._model.get_step() // 1000
90
+
91
+ simple_table([("Tacotron", str(tts_k) + "k"),
92
+ ("r", self._model.r)])
93
+
94
+ # Preprocess text inputs
95
+ inputs = [text_to_sequence(text.strip(), hparams.tts_cleaner_names) for text in texts]
96
+ if not isinstance(embeddings, list):
97
+ embeddings = [embeddings]
98
+
99
+ # Batch inputs
100
+ batched_inputs = [inputs[i:i+hparams.synthesis_batch_size]
101
+ for i in range(0, len(inputs), hparams.synthesis_batch_size)]
102
+ batched_embeds = [embeddings[i:i+hparams.synthesis_batch_size]
103
+ for i in range(0, len(embeddings), hparams.synthesis_batch_size)]
104
+
105
+ specs = []
106
+ for i, batch in enumerate(batched_inputs, 1):
107
+ if self.verbose:
108
+ print(f"\n| Generating {i}/{len(batched_inputs)}")
109
+
110
+ # Pad texts so they are all the same length
111
+ text_lens = [len(text) for text in batch]
112
+ max_text_len = max(text_lens)
113
+ chars = [pad1d(text, max_text_len) for text in batch]
114
+ chars = np.stack(chars)
115
+
116
+ # Stack speaker embeddings into 2D array for batch processing
117
+ speaker_embeds = np.stack(batched_embeds[i-1])
118
+
119
+ # Convert to tensor
120
+ chars = torch.tensor(chars).long().to(self.device)
121
+ speaker_embeddings = torch.tensor(speaker_embeds).float().to(self.device)
122
+
123
+ # Inference
124
+ _, mels, alignments = self._model.generate(chars, speaker_embeddings)
125
+ mels = mels.detach().cpu().numpy()
126
+ for m in mels:
127
+ # Trim silence from end of each spectrogram
128
+ while np.max(m[:, -1]) < hparams.tts_stop_threshold:
129
+ m = m[:, :-1]
130
+ specs.append(m)
131
+
132
+ if self.verbose:
133
+ print("\n\nDone.\n")
134
+ return (specs, alignments) if return_alignments else specs
135
+
136
+ @staticmethod
137
+ def load_preprocess_wav(fpath):
138
+ """
139
+ Loads and preprocesses an audio file under the same conditions the audio files were used to
140
+ train the synthesizer.
141
+ """
142
+ wav = librosa.load(str(fpath), hparams.sample_rate)[0]
143
+ if hparams.rescale:
144
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
145
+ return wav
146
+
147
+ @staticmethod
148
+ def make_spectrogram(fpath_or_wav: Union[str, Path, np.ndarray]):
149
+ """
150
+ Creates a mel spectrogram from an audio file in the same manner as the mel spectrograms that
151
+ were fed to the synthesizer when training.
152
+ """
153
+ if isinstance(fpath_or_wav, str) or isinstance(fpath_or_wav, Path):
154
+ wav = Synthesizer.load_preprocess_wav(fpath_or_wav)
155
+ else:
156
+ wav = fpath_or_wav
157
+
158
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
159
+ return mel_spectrogram
160
+
161
+ @staticmethod
162
+ def griffin_lim(mel):
163
+ """
164
+ Inverts a mel spectrogram using Griffin-Lim. The mel spectrogram is expected to have been built
165
+ with the same parameters present in hparams.py.
166
+ """
167
+ return audio.inv_mel_spectrogram(mel, hparams)
168
+
169
+
170
+ def pad1d(x, max_len, pad_value=0):
171
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
synthesizer/models/tacotron.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+
10
+ class HighwayNetwork(nn.Module):
11
+ def __init__(self, size):
12
+ super().__init__()
13
+ self.W1 = nn.Linear(size, size)
14
+ self.W2 = nn.Linear(size, size)
15
+ self.W1.bias.data.fill_(0.)
16
+
17
+ def forward(self, x):
18
+ x1 = self.W1(x)
19
+ x2 = self.W2(x)
20
+ g = torch.sigmoid(x2)
21
+ y = g * F.relu(x1) + (1. - g) * x
22
+ return y
23
+
24
+
25
+ class Encoder(nn.Module):
26
+ def __init__(self, embed_dims, num_chars, encoder_dims, K, num_highways, dropout):
27
+ super().__init__()
28
+ prenet_dims = (encoder_dims, encoder_dims)
29
+ cbhg_channels = encoder_dims
30
+ self.embedding = nn.Embedding(num_chars, embed_dims)
31
+ self.pre_net = PreNet(embed_dims, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
32
+ dropout=dropout)
33
+ self.cbhg = CBHG(K=K, in_channels=cbhg_channels, channels=cbhg_channels,
34
+ proj_channels=[cbhg_channels, cbhg_channels],
35
+ num_highways=num_highways)
36
+
37
+ def forward(self, x, speaker_embedding=None):
38
+ x = self.embedding(x)
39
+ x = self.pre_net(x)
40
+ x.transpose_(1, 2)
41
+ x = self.cbhg(x)
42
+ if speaker_embedding is not None:
43
+ x = self.add_speaker_embedding(x, speaker_embedding)
44
+ return x
45
+
46
+ def add_speaker_embedding(self, x, speaker_embedding):
47
+ # SV2TTS
48
+ # The input x is the encoder output and is a 3D tensor with size (batch_size, num_chars, tts_embed_dims)
49
+ # When training, speaker_embedding is also a 2D tensor with size (batch_size, speaker_embedding_size)
50
+ # (for inference, speaker_embedding is a 1D tensor with size (speaker_embedding_size))
51
+ # This concats the speaker embedding for each char in the encoder output
52
+
53
+ # Save the dimensions as human-readable names
54
+ batch_size = x.size()[0]
55
+ num_chars = x.size()[1]
56
+
57
+ if speaker_embedding.dim() == 1:
58
+ idx = 0
59
+ else:
60
+ idx = 1
61
+
62
+ # Start by making a copy of each speaker embedding to match the input text length
63
+ # The output of this has size (batch_size, num_chars * tts_embed_dims)
64
+ speaker_embedding_size = speaker_embedding.size()[idx]
65
+ e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
66
+
67
+ # Reshape it and transpose
68
+ e = e.reshape(batch_size, speaker_embedding_size, num_chars)
69
+ e = e.transpose(1, 2)
70
+
71
+ # Concatenate the tiled speaker embedding with the encoder output
72
+ x = torch.cat((x, e), 2)
73
+ return x
74
+
75
+
76
+ class BatchNormConv(nn.Module):
77
+ def __init__(self, in_channels, out_channels, kernel, relu=True):
78
+ super().__init__()
79
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel, stride=1, padding=kernel // 2, bias=False)
80
+ self.bnorm = nn.BatchNorm1d(out_channels)
81
+ self.relu = relu
82
+
83
+ def forward(self, x):
84
+ x = self.conv(x)
85
+ x = F.relu(x) if self.relu is True else x
86
+ return self.bnorm(x)
87
+
88
+
89
+ class CBHG(nn.Module):
90
+ def __init__(self, K, in_channels, channels, proj_channels, num_highways):
91
+ super().__init__()
92
+
93
+ # List of all rnns to call `flatten_parameters()` on
94
+ self._to_flatten = []
95
+
96
+ self.bank_kernels = [i for i in range(1, K + 1)]
97
+ self.conv1d_bank = nn.ModuleList()
98
+ for k in self.bank_kernels:
99
+ conv = BatchNormConv(in_channels, channels, k)
100
+ self.conv1d_bank.append(conv)
101
+
102
+ self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=1)
103
+
104
+ self.conv_project1 = BatchNormConv(len(self.bank_kernels) * channels, proj_channels[0], 3)
105
+ self.conv_project2 = BatchNormConv(proj_channels[0], proj_channels[1], 3, relu=False)
106
+
107
+ # Fix the highway input if necessary
108
+ if proj_channels[-1] != channels:
109
+ self.highway_mismatch = True
110
+ self.pre_highway = nn.Linear(proj_channels[-1], channels, bias=False)
111
+ else:
112
+ self.highway_mismatch = False
113
+
114
+ self.highways = nn.ModuleList()
115
+ for i in range(num_highways):
116
+ hn = HighwayNetwork(channels)
117
+ self.highways.append(hn)
118
+
119
+ self.rnn = nn.GRU(channels, channels // 2, batch_first=True, bidirectional=True)
120
+ self._to_flatten.append(self.rnn)
121
+
122
+ # Avoid fragmentation of RNN parameters and associated warning
123
+ self._flatten_parameters()
124
+
125
+ def forward(self, x):
126
+ # Although we `_flatten_parameters()` on init, when using DataParallel
127
+ # the model gets replicated, making it no longer guaranteed that the
128
+ # weights are contiguous in GPU memory. Hence, we must call it again
129
+ self._flatten_parameters()
130
+
131
+ # Save these for later
132
+ residual = x
133
+ seq_len = x.size(-1)
134
+ conv_bank = []
135
+
136
+ # Convolution Bank
137
+ for conv in self.conv1d_bank:
138
+ c = conv(x) # Convolution
139
+ conv_bank.append(c[:, :, :seq_len])
140
+
141
+ # Stack along the channel axis
142
+ conv_bank = torch.cat(conv_bank, dim=1)
143
+
144
+ # dump the last padding to fit residual
145
+ x = self.maxpool(conv_bank)[:, :, :seq_len]
146
+
147
+ # Conv1d projections
148
+ x = self.conv_project1(x)
149
+ x = self.conv_project2(x)
150
+
151
+ # Residual Connect
152
+ x = x + residual
153
+
154
+ # Through the highways
155
+ x = x.transpose(1, 2)
156
+ if self.highway_mismatch is True:
157
+ x = self.pre_highway(x)
158
+ for h in self.highways: x = h(x)
159
+
160
+ # And then the RNN
161
+ x, _ = self.rnn(x)
162
+ return x
163
+
164
+ def _flatten_parameters(self):
165
+ """Calls `flatten_parameters` on all the rnns used by the WaveRNN. Used
166
+ to improve efficiency and avoid PyTorch yelling at us."""
167
+ [m.flatten_parameters() for m in self._to_flatten]
168
+
169
+ class PreNet(nn.Module):
170
+ def __init__(self, in_dims, fc1_dims=256, fc2_dims=128, dropout=0.5):
171
+ super().__init__()
172
+ self.fc1 = nn.Linear(in_dims, fc1_dims)
173
+ self.fc2 = nn.Linear(fc1_dims, fc2_dims)
174
+ self.p = dropout
175
+
176
+ def forward(self, x):
177
+ x = self.fc1(x)
178
+ x = F.relu(x)
179
+ x = F.dropout(x, self.p, training=True)
180
+ x = self.fc2(x)
181
+ x = F.relu(x)
182
+ x = F.dropout(x, self.p, training=True)
183
+ return x
184
+
185
+
186
+ class Attention(nn.Module):
187
+ def __init__(self, attn_dims):
188
+ super().__init__()
189
+ self.W = nn.Linear(attn_dims, attn_dims, bias=False)
190
+ self.v = nn.Linear(attn_dims, 1, bias=False)
191
+
192
+ def forward(self, encoder_seq_proj, query, t):
193
+
194
+ # print(encoder_seq_proj.shape)
195
+ # Transform the query vector
196
+ query_proj = self.W(query).unsqueeze(1)
197
+
198
+ # Compute the scores
199
+ u = self.v(torch.tanh(encoder_seq_proj + query_proj))
200
+ scores = F.softmax(u, dim=1)
201
+
202
+ return scores.transpose(1, 2)
203
+
204
+
205
+ class LSA(nn.Module):
206
+ def __init__(self, attn_dim, kernel_size=31, filters=32):
207
+ super().__init__()
208
+ self.conv = nn.Conv1d(1, filters, padding=(kernel_size - 1) // 2, kernel_size=kernel_size, bias=True)
209
+ self.L = nn.Linear(filters, attn_dim, bias=False)
210
+ self.W = nn.Linear(attn_dim, attn_dim, bias=True) # Include the attention bias in this term
211
+ self.v = nn.Linear(attn_dim, 1, bias=False)
212
+ self.cumulative = None
213
+ self.attention = None
214
+
215
+ def init_attention(self, encoder_seq_proj):
216
+ device = next(self.parameters()).device # use same device as parameters
217
+ b, t, c = encoder_seq_proj.size()
218
+ self.cumulative = torch.zeros(b, t, device=device)
219
+ self.attention = torch.zeros(b, t, device=device)
220
+
221
+ def forward(self, encoder_seq_proj, query, t, chars):
222
+
223
+ if t == 0: self.init_attention(encoder_seq_proj)
224
+
225
+ processed_query = self.W(query).unsqueeze(1)
226
+
227
+ location = self.cumulative.unsqueeze(1)
228
+ processed_loc = self.L(self.conv(location).transpose(1, 2))
229
+
230
+ u = self.v(torch.tanh(processed_query + encoder_seq_proj + processed_loc))
231
+ u = u.squeeze(-1)
232
+
233
+ # Mask zero padding chars
234
+ u = u * (chars != 0).float()
235
+
236
+ # Smooth Attention
237
+ # scores = torch.sigmoid(u) / torch.sigmoid(u).sum(dim=1, keepdim=True)
238
+ scores = F.softmax(u, dim=1)
239
+ self.attention = scores
240
+ self.cumulative = self.cumulative + self.attention
241
+
242
+ return scores.unsqueeze(-1).transpose(1, 2)
243
+
244
+
245
+ class Decoder(nn.Module):
246
+ # Class variable because its value doesn't change between classes
247
+ # yet ought to be scoped by class because its a property of a Decoder
248
+ max_r = 20
249
+ def __init__(self, n_mels, encoder_dims, decoder_dims, lstm_dims,
250
+ dropout, speaker_embedding_size):
251
+ super().__init__()
252
+ self.register_buffer("r", torch.tensor(1, dtype=torch.int))
253
+ self.n_mels = n_mels
254
+ prenet_dims = (decoder_dims * 2, decoder_dims * 2)
255
+ self.prenet = PreNet(n_mels, fc1_dims=prenet_dims[0], fc2_dims=prenet_dims[1],
256
+ dropout=dropout)
257
+ self.attn_net = LSA(decoder_dims)
258
+ self.attn_rnn = nn.GRUCell(encoder_dims + prenet_dims[1] + speaker_embedding_size, decoder_dims)
259
+ self.rnn_input = nn.Linear(encoder_dims + decoder_dims + speaker_embedding_size, lstm_dims)
260
+ self.res_rnn1 = nn.LSTMCell(lstm_dims, lstm_dims)
261
+ self.res_rnn2 = nn.LSTMCell(lstm_dims, lstm_dims)
262
+ self.mel_proj = nn.Linear(lstm_dims, n_mels * self.max_r, bias=False)
263
+ self.stop_proj = nn.Linear(encoder_dims + speaker_embedding_size + lstm_dims, 1)
264
+
265
+ def zoneout(self, prev, current, p=0.1):
266
+ device = next(self.parameters()).device # Use same device as parameters
267
+ mask = torch.zeros(prev.size(), device=device).bernoulli_(p)
268
+ return prev * mask + current * (1 - mask)
269
+
270
+ def forward(self, encoder_seq, encoder_seq_proj, prenet_in,
271
+ hidden_states, cell_states, context_vec, t, chars):
272
+
273
+ # Need this for reshaping mels
274
+ batch_size = encoder_seq.size(0)
275
+
276
+ # Unpack the hidden and cell states
277
+ attn_hidden, rnn1_hidden, rnn2_hidden = hidden_states
278
+ rnn1_cell, rnn2_cell = cell_states
279
+
280
+ # PreNet for the Attention RNN
281
+ prenet_out = self.prenet(prenet_in)
282
+
283
+ # Compute the Attention RNN hidden state
284
+ attn_rnn_in = torch.cat([context_vec, prenet_out], dim=-1)
285
+ attn_hidden = self.attn_rnn(attn_rnn_in.squeeze(1), attn_hidden)
286
+
287
+ # Compute the attention scores
288
+ scores = self.attn_net(encoder_seq_proj, attn_hidden, t, chars)
289
+
290
+ # Dot product to create the context vector
291
+ context_vec = scores @ encoder_seq
292
+ context_vec = context_vec.squeeze(1)
293
+
294
+ # Concat Attention RNN output w. Context Vector & project
295
+ x = torch.cat([context_vec, attn_hidden], dim=1)
296
+ x = self.rnn_input(x)
297
+
298
+ # Compute first Residual RNN
299
+ rnn1_hidden_next, rnn1_cell = self.res_rnn1(x, (rnn1_hidden, rnn1_cell))
300
+ if self.training:
301
+ rnn1_hidden = self.zoneout(rnn1_hidden, rnn1_hidden_next)
302
+ else:
303
+ rnn1_hidden = rnn1_hidden_next
304
+ x = x + rnn1_hidden
305
+
306
+ # Compute second Residual RNN
307
+ rnn2_hidden_next, rnn2_cell = self.res_rnn2(x, (rnn2_hidden, rnn2_cell))
308
+ if self.training:
309
+ rnn2_hidden = self.zoneout(rnn2_hidden, rnn2_hidden_next)
310
+ else:
311
+ rnn2_hidden = rnn2_hidden_next
312
+ x = x + rnn2_hidden
313
+
314
+ # Project Mels
315
+ mels = self.mel_proj(x)
316
+ mels = mels.view(batch_size, self.n_mels, self.max_r)[:, :, :self.r]
317
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
318
+ cell_states = (rnn1_cell, rnn2_cell)
319
+
320
+ # Stop token prediction
321
+ s = torch.cat((x, context_vec), dim=1)
322
+ s = self.stop_proj(s)
323
+ stop_tokens = torch.sigmoid(s)
324
+
325
+ return mels, scores, hidden_states, cell_states, context_vec, stop_tokens
326
+
327
+
328
+ class Tacotron(nn.Module):
329
+ def __init__(self, embed_dims, num_chars, encoder_dims, decoder_dims, n_mels,
330
+ fft_bins, postnet_dims, encoder_K, lstm_dims, postnet_K, num_highways,
331
+ dropout, stop_threshold, speaker_embedding_size):
332
+ super().__init__()
333
+ self.n_mels = n_mels
334
+ self.lstm_dims = lstm_dims
335
+ self.encoder_dims = encoder_dims
336
+ self.decoder_dims = decoder_dims
337
+ self.speaker_embedding_size = speaker_embedding_size
338
+ self.encoder = Encoder(embed_dims, num_chars, encoder_dims,
339
+ encoder_K, num_highways, dropout)
340
+ self.encoder_proj = nn.Linear(encoder_dims + speaker_embedding_size, decoder_dims, bias=False)
341
+ self.decoder = Decoder(n_mels, encoder_dims, decoder_dims, lstm_dims,
342
+ dropout, speaker_embedding_size)
343
+ self.postnet = CBHG(postnet_K, n_mels, postnet_dims,
344
+ [postnet_dims, fft_bins], num_highways)
345
+ self.post_proj = nn.Linear(postnet_dims, fft_bins, bias=False)
346
+
347
+ self.init_model()
348
+ self.num_params()
349
+
350
+ self.register_buffer("step", torch.zeros(1, dtype=torch.long))
351
+ self.register_buffer("stop_threshold", torch.tensor(stop_threshold, dtype=torch.float32))
352
+
353
+ @property
354
+ def r(self):
355
+ return self.decoder.r.item()
356
+
357
+ @r.setter
358
+ def r(self, value):
359
+ self.decoder.r = self.decoder.r.new_tensor(value, requires_grad=False)
360
+
361
+ def forward(self, x, m, speaker_embedding):
362
+ device = next(self.parameters()).device # use same device as parameters
363
+
364
+ self.step += 1
365
+ batch_size, _, steps = m.size()
366
+
367
+ # Initialise all hidden states and pack into tuple
368
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
369
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
370
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
371
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
372
+
373
+ # Initialise all lstm cell states and pack into tuple
374
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
375
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
376
+ cell_states = (rnn1_cell, rnn2_cell)
377
+
378
+ # <GO> Frame for start of decoder loop
379
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
380
+
381
+ # Need an initial context vector
382
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
383
+
384
+ # SV2TTS: Run the encoder with the speaker embedding
385
+ # The projection avoids unnecessary matmuls in the decoder loop
386
+ encoder_seq = self.encoder(x, speaker_embedding)
387
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
388
+
389
+ # Need a couple of lists for outputs
390
+ mel_outputs, attn_scores, stop_outputs = [], [], []
391
+
392
+ # Run the decoder loop
393
+ for t in range(0, steps, self.r):
394
+ prenet_in = m[:, :, t - 1] if t > 0 else go_frame
395
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
396
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
397
+ hidden_states, cell_states, context_vec, t, x)
398
+ mel_outputs.append(mel_frames)
399
+ attn_scores.append(scores)
400
+ stop_outputs.extend([stop_tokens] * self.r)
401
+
402
+ # Concat the mel outputs into sequence
403
+ mel_outputs = torch.cat(mel_outputs, dim=2)
404
+
405
+ # Post-Process for Linear Spectrograms
406
+ postnet_out = self.postnet(mel_outputs)
407
+ linear = self.post_proj(postnet_out)
408
+ linear = linear.transpose(1, 2)
409
+
410
+ # For easy visualisation
411
+ attn_scores = torch.cat(attn_scores, 1)
412
+ # attn_scores = attn_scores.cpu().data.numpy()
413
+ stop_outputs = torch.cat(stop_outputs, 1)
414
+
415
+ return mel_outputs, linear, attn_scores, stop_outputs
416
+
417
+ def generate(self, x, speaker_embedding=None, steps=2000):
418
+ self.eval()
419
+ device = next(self.parameters()).device # use same device as parameters
420
+
421
+ batch_size, _ = x.size()
422
+
423
+ # Need to initialise all hidden states and pack into tuple for tidyness
424
+ attn_hidden = torch.zeros(batch_size, self.decoder_dims, device=device)
425
+ rnn1_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
426
+ rnn2_hidden = torch.zeros(batch_size, self.lstm_dims, device=device)
427
+ hidden_states = (attn_hidden, rnn1_hidden, rnn2_hidden)
428
+
429
+ # Need to initialise all lstm cell states and pack into tuple for tidyness
430
+ rnn1_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
431
+ rnn2_cell = torch.zeros(batch_size, self.lstm_dims, device=device)
432
+ cell_states = (rnn1_cell, rnn2_cell)
433
+
434
+ # Need a <GO> Frame for start of decoder loop
435
+ go_frame = torch.zeros(batch_size, self.n_mels, device=device)
436
+
437
+ # Need an initial context vector
438
+ context_vec = torch.zeros(batch_size, self.encoder_dims + self.speaker_embedding_size, device=device)
439
+
440
+ # SV2TTS: Run the encoder with the speaker embedding
441
+ # The projection avoids unnecessary matmuls in the decoder loop
442
+ encoder_seq = self.encoder(x, speaker_embedding)
443
+ encoder_seq_proj = self.encoder_proj(encoder_seq)
444
+
445
+ # Need a couple of lists for outputs
446
+ mel_outputs, attn_scores, stop_outputs = [], [], []
447
+
448
+ # Run the decoder loop
449
+ for t in range(0, steps, self.r):
450
+ prenet_in = mel_outputs[-1][:, :, -1] if t > 0 else go_frame
451
+ mel_frames, scores, hidden_states, cell_states, context_vec, stop_tokens = \
452
+ self.decoder(encoder_seq, encoder_seq_proj, prenet_in,
453
+ hidden_states, cell_states, context_vec, t, x)
454
+ mel_outputs.append(mel_frames)
455
+ attn_scores.append(scores)
456
+ stop_outputs.extend([stop_tokens] * self.r)
457
+ # Stop the loop when all stop tokens in batch exceed threshold
458
+ if (stop_tokens > 0.5).all() and t > 10: break
459
+
460
+ # Concat the mel outputs into sequence
461
+ mel_outputs = torch.cat(mel_outputs, dim=2)
462
+
463
+ # Post-Process for Linear Spectrograms
464
+ postnet_out = self.postnet(mel_outputs)
465
+ linear = self.post_proj(postnet_out)
466
+
467
+
468
+ linear = linear.transpose(1, 2)
469
+
470
+ # For easy visualisation
471
+ attn_scores = torch.cat(attn_scores, 1)
472
+ stop_outputs = torch.cat(stop_outputs, 1)
473
+
474
+ self.train()
475
+
476
+ return mel_outputs, linear, attn_scores
477
+
478
+ def init_model(self):
479
+ for p in self.parameters():
480
+ if p.dim() > 1: nn.init.xavier_uniform_(p)
481
+
482
+ def get_step(self):
483
+ return self.step.data.item()
484
+
485
+ def reset_step(self):
486
+ # assignment to parameters or buffers is overloaded, updates internal dict entry
487
+ self.step = self.step.data.new_tensor(1)
488
+
489
+ def log(self, path, msg):
490
+ with open(path, "a") as f:
491
+ print(msg, file=f)
492
+
493
+ def load(self, path, optimizer=None):
494
+ # Use device of model params as location for loaded state
495
+ device = next(self.parameters()).device
496
+ checkpoint = torch.load(str(path), map_location=device)
497
+ self.load_state_dict(checkpoint["model_state"])
498
+
499
+ if "optimizer_state" in checkpoint and optimizer is not None:
500
+ optimizer.load_state_dict(checkpoint["optimizer_state"])
501
+
502
+ def save(self, path, optimizer=None):
503
+ if optimizer is not None:
504
+ torch.save({
505
+ "model_state": self.state_dict(),
506
+ "optimizer_state": optimizer.state_dict(),
507
+ }, str(path))
508
+ else:
509
+ torch.save({
510
+ "model_state": self.state_dict(),
511
+ }, str(path))
512
+
513
+
514
+ def num_params(self, print_out=True):
515
+ parameters = filter(lambda p: p.requires_grad, self.parameters())
516
+ parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
517
+ if print_out:
518
+ print("Trainable Parameters: %.3fM" % parameters)
519
+ return parameters
synthesizer/preprocess.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from multiprocessing.pool import Pool
2
+ from synthesizer import audio
3
+ from functools import partial
4
+ from itertools import chain
5
+ from encoder import inference as encoder
6
+ from pathlib import Path
7
+ from utils import logmmse
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import librosa
11
+
12
+
13
+ def preprocess_dataset(datasets_root: Path, out_dir: Path, n_processes: int,
14
+ skip_existing: bool, hparams, no_alignments: bool,
15
+ datasets_name: str, subfolders: str):
16
+ # Gather the input directories
17
+ dataset_root = datasets_root.joinpath(datasets_name)
18
+ input_dirs = [dataset_root.joinpath(subfolder.strip()) for subfolder in subfolders.split(",")]
19
+ print("\n ".join(map(str, ["Using data from:"] + input_dirs)))
20
+ assert all(input_dir.exists() for input_dir in input_dirs)
21
+
22
+ # Create the output directories for each output file type
23
+ out_dir.joinpath("mels").mkdir(exist_ok=True)
24
+ out_dir.joinpath("audio").mkdir(exist_ok=True)
25
+
26
+ # Create a metadata file
27
+ metadata_fpath = out_dir.joinpath("train.txt")
28
+ metadata_file = metadata_fpath.open("a" if skip_existing else "w", encoding="utf-8")
29
+
30
+ # Preprocess the dataset
31
+ speaker_dirs = list(chain.from_iterable(input_dir.glob("*") for input_dir in input_dirs))
32
+ func = partial(preprocess_speaker, out_dir=out_dir, skip_existing=skip_existing,
33
+ hparams=hparams, no_alignments=no_alignments)
34
+ job = Pool(n_processes).imap(func, speaker_dirs)
35
+ for speaker_metadata in tqdm(job, datasets_name, len(speaker_dirs), unit="speakers"):
36
+ for metadatum in speaker_metadata:
37
+ metadata_file.write("|".join(str(x) for x in metadatum) + "\n")
38
+ metadata_file.close()
39
+
40
+ # Verify the contents of the metadata file
41
+ with metadata_fpath.open("r", encoding="utf-8") as metadata_file:
42
+ metadata = [line.split("|") for line in metadata_file]
43
+ mel_frames = sum([int(m[4]) for m in metadata])
44
+ timesteps = sum([int(m[3]) for m in metadata])
45
+ sample_rate = hparams.sample_rate
46
+ hours = (timesteps / sample_rate) / 3600
47
+ print("The dataset consists of %d utterances, %d mel frames, %d audio timesteps (%.2f hours)." %
48
+ (len(metadata), mel_frames, timesteps, hours))
49
+ print("Max input length (text chars): %d" % max(len(m[5]) for m in metadata))
50
+ print("Max mel frames length: %d" % max(int(m[4]) for m in metadata))
51
+ print("Max audio timesteps length: %d" % max(int(m[3]) for m in metadata))
52
+
53
+
54
+ def preprocess_speaker(speaker_dir, out_dir: Path, skip_existing: bool, hparams, no_alignments: bool):
55
+ metadata = []
56
+ for book_dir in speaker_dir.glob("*"):
57
+ if no_alignments:
58
+ # Gather the utterance audios and texts
59
+ # LibriTTS uses .wav but we will include extensions for compatibility with other datasets
60
+ extensions = ["*.wav", "*.flac", "*.mp3"]
61
+ for extension in extensions:
62
+ wav_fpaths = book_dir.glob(extension)
63
+
64
+ for wav_fpath in wav_fpaths:
65
+ # Load the audio waveform
66
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
67
+ if hparams.rescale:
68
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
69
+
70
+ # Get the corresponding text
71
+ # Check for .txt (for compatibility with other datasets)
72
+ text_fpath = wav_fpath.with_suffix(".txt")
73
+ if not text_fpath.exists():
74
+ # Check for .normalized.txt (LibriTTS)
75
+ text_fpath = wav_fpath.with_suffix(".normalized.txt")
76
+ assert text_fpath.exists()
77
+ with text_fpath.open("r") as text_file:
78
+ text = "".join([line for line in text_file])
79
+ text = text.replace("\"", "")
80
+ text = text.strip()
81
+
82
+ # Process the utterance
83
+ metadata.append(process_utterance(wav, text, out_dir, str(wav_fpath.with_suffix("").name),
84
+ skip_existing, hparams))
85
+ else:
86
+ # Process alignment file (LibriSpeech support)
87
+ # Gather the utterance audios and texts
88
+ try:
89
+ alignments_fpath = next(book_dir.glob("*.alignment.txt"))
90
+ with alignments_fpath.open("r") as alignments_file:
91
+ alignments = [line.rstrip().split(" ") for line in alignments_file]
92
+ except StopIteration:
93
+ # A few alignment files will be missing
94
+ continue
95
+
96
+ # Iterate over each entry in the alignments file
97
+ for wav_fname, words, end_times in alignments:
98
+ wav_fpath = book_dir.joinpath(wav_fname + ".flac")
99
+ assert wav_fpath.exists()
100
+ words = words.replace("\"", "").split(",")
101
+ end_times = list(map(float, end_times.replace("\"", "").split(",")))
102
+
103
+ # Process each sub-utterance
104
+ wavs, texts = split_on_silences(wav_fpath, words, end_times, hparams)
105
+ for i, (wav, text) in enumerate(zip(wavs, texts)):
106
+ sub_basename = "%s_%02d" % (wav_fname, i)
107
+ metadata.append(process_utterance(wav, text, out_dir, sub_basename,
108
+ skip_existing, hparams))
109
+
110
+ return [m for m in metadata if m is not None]
111
+
112
+
113
+ def split_on_silences(wav_fpath, words, end_times, hparams):
114
+ # Load the audio waveform
115
+ wav, _ = librosa.load(str(wav_fpath), hparams.sample_rate)
116
+ if hparams.rescale:
117
+ wav = wav / np.abs(wav).max() * hparams.rescaling_max
118
+
119
+ words = np.array(words)
120
+ start_times = np.array([0.0] + end_times[:-1])
121
+ end_times = np.array(end_times)
122
+ assert len(words) == len(end_times) == len(start_times)
123
+ assert words[0] == "" and words[-1] == ""
124
+
125
+ # Find pauses that are too long
126
+ mask = (words == "") & (end_times - start_times >= hparams.silence_min_duration_split)
127
+ mask[0] = mask[-1] = True
128
+ breaks = np.where(mask)[0]
129
+
130
+ # Profile the noise from the silences and perform noise reduction on the waveform
131
+ silence_times = [[start_times[i], end_times[i]] for i in breaks]
132
+ silence_times = (np.array(silence_times) * hparams.sample_rate).astype(np.int)
133
+ noisy_wav = np.concatenate([wav[stime[0]:stime[1]] for stime in silence_times])
134
+ if len(noisy_wav) > hparams.sample_rate * 0.02:
135
+ profile = logmmse.profile_noise(noisy_wav, hparams.sample_rate)
136
+ wav = logmmse.denoise(wav, profile, eta=0)
137
+
138
+ # Re-attach segments that are too short
139
+ segments = list(zip(breaks[:-1], breaks[1:]))
140
+ segment_durations = [start_times[end] - end_times[start] for start, end in segments]
141
+ i = 0
142
+ while i < len(segments) and len(segments) > 1:
143
+ if segment_durations[i] < hparams.utterance_min_duration:
144
+ # See if the segment can be re-attached with the right or the left segment
145
+ left_duration = float("inf") if i == 0 else segment_durations[i - 1]
146
+ right_duration = float("inf") if i == len(segments) - 1 else segment_durations[i + 1]
147
+ joined_duration = segment_durations[i] + min(left_duration, right_duration)
148
+
149
+ # Do not re-attach if it causes the joined utterance to be too long
150
+ if joined_duration > hparams.hop_size * hparams.max_mel_frames / hparams.sample_rate:
151
+ i += 1
152
+ continue
153
+
154
+ # Re-attach the segment with the neighbour of shortest duration
155
+ j = i - 1 if left_duration <= right_duration else i
156
+ segments[j] = (segments[j][0], segments[j + 1][1])
157
+ segment_durations[j] = joined_duration
158
+ del segments[j + 1], segment_durations[j + 1]
159
+ else:
160
+ i += 1
161
+
162
+ # Split the utterance
163
+ segment_times = [[end_times[start], start_times[end]] for start, end in segments]
164
+ segment_times = (np.array(segment_times) * hparams.sample_rate).astype(np.int)
165
+ wavs = [wav[segment_time[0]:segment_time[1]] for segment_time in segment_times]
166
+ texts = [" ".join(words[start + 1:end]).replace(" ", " ") for start, end in segments]
167
+
168
+ # # DEBUG: play the audio segments (run with -n=1)
169
+ # import sounddevice as sd
170
+ # if len(wavs) > 1:
171
+ # print("This sentence was split in %d segments:" % len(wavs))
172
+ # else:
173
+ # print("There are no silences long enough for this sentence to be split:")
174
+ # for wav, text in zip(wavs, texts):
175
+ # # Pad the waveform with 1 second of silence because sounddevice tends to cut them early
176
+ # # when playing them. You shouldn't need to do that in your parsers.
177
+ # wav = np.concatenate((wav, [0] * 16000))
178
+ # print("\t%s" % text)
179
+ # sd.play(wav, 16000, blocking=True)
180
+ # print("")
181
+
182
+ return wavs, texts
183
+
184
+
185
+ def process_utterance(wav: np.ndarray, text: str, out_dir: Path, basename: str,
186
+ skip_existing: bool, hparams):
187
+ ## FOR REFERENCE:
188
+ # For you not to lose your head if you ever wish to change things here or implement your own
189
+ # synthesizer.
190
+ # - Both the audios and the mel spectrograms are saved as numpy arrays
191
+ # - There is no processing done to the audios that will be saved to disk beyond volume
192
+ # normalization (in split_on_silences)
193
+ # - However, pre-emphasis is applied to the audios before computing the mel spectrogram. This
194
+ # is why we re-apply it on the audio on the side of the vocoder.
195
+ # - Librosa pads the waveform before computing the mel spectrogram. Here, the waveform is saved
196
+ # without extra padding. This means that you won't have an exact relation between the length
197
+ # of the wav and of the mel spectrogram. See the vocoder data loader.
198
+
199
+
200
+ # Skip existing utterances if needed
201
+ mel_fpath = out_dir.joinpath("mels", "mel-%s.npy" % basename)
202
+ wav_fpath = out_dir.joinpath("audio", "audio-%s.npy" % basename)
203
+ if skip_existing and mel_fpath.exists() and wav_fpath.exists():
204
+ return None
205
+
206
+ # Trim silence
207
+ if hparams.trim_silence:
208
+ wav = encoder.preprocess_wav(wav, normalize=False, trim_silence=True)
209
+
210
+ # Skip utterances that are too short
211
+ if len(wav) < hparams.utterance_min_duration * hparams.sample_rate:
212
+ return None
213
+
214
+ # Compute the mel spectrogram
215
+ mel_spectrogram = audio.melspectrogram(wav, hparams).astype(np.float32)
216
+ mel_frames = mel_spectrogram.shape[1]
217
+
218
+ # Skip utterances that are too long
219
+ if mel_frames > hparams.max_mel_frames and hparams.clip_mels_length:
220
+ return None
221
+
222
+ # Write the spectrogram, embed and audio to disk
223
+ np.save(mel_fpath, mel_spectrogram.T, allow_pickle=False)
224
+ np.save(wav_fpath, wav, allow_pickle=False)
225
+
226
+ # Return a tuple describing this training example
227
+ return wav_fpath.name, mel_fpath.name, "embed-%s.npy" % basename, len(wav), mel_frames, text
228
+
229
+
230
+ def embed_utterance(fpaths, encoder_model_fpath):
231
+ if not encoder.is_loaded():
232
+ encoder.load_model(encoder_model_fpath)
233
+
234
+ # Compute the speaker embedding of the utterance
235
+ wav_fpath, embed_fpath = fpaths
236
+ wav = np.load(wav_fpath)
237
+ wav = encoder.preprocess_wav(wav)
238
+ embed = encoder.embed_utterance(wav)
239
+ np.save(embed_fpath, embed, allow_pickle=False)
240
+
241
+
242
+ def create_embeddings(synthesizer_root: Path, encoder_model_fpath: Path, n_processes: int):
243
+ wav_dir = synthesizer_root.joinpath("audio")
244
+ metadata_fpath = synthesizer_root.joinpath("train.txt")
245
+ assert wav_dir.exists() and metadata_fpath.exists()
246
+ embed_dir = synthesizer_root.joinpath("embeds")
247
+ embed_dir.mkdir(exist_ok=True)
248
+
249
+ # Gather the input wave filepath and the target output embed filepath
250
+ with metadata_fpath.open("r") as metadata_file:
251
+ metadata = [line.split("|") for line in metadata_file]
252
+ fpaths = [(wav_dir.joinpath(m[0]), embed_dir.joinpath(m[2])) for m in metadata]
253
+
254
+ # TODO: improve on the multiprocessing, it's terrible. Disk I/O is the bottleneck here.
255
+ # Embed the utterances in separate threads
256
+ func = partial(embed_utterance, encoder_model_fpath=encoder_model_fpath)
257
+ job = Pool(n_processes).imap(func, fpaths)
258
+ list(tqdm(job, "Embedding", len(fpaths), unit="utterances"))
259
+
synthesizer/saved_models/pretrained/text.txt ADDED
File without changes
synthesizer/synthesize.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader
3
+ from synthesizer.hparams import hparams_debug_string
4
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
5
+ from synthesizer.models.tacotron import Tacotron
6
+ from synthesizer.utils.text import text_to_sequence
7
+ from synthesizer.utils.symbols import symbols
8
+ import numpy as np
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+ import platform
12
+
13
+ def run_synthesis(in_dir, out_dir, model_dir, hparams):
14
+ # This generates ground truth-aligned mels for vocoder training
15
+ synth_dir = Path(out_dir).joinpath("mels_gta")
16
+ synth_dir.mkdir(exist_ok=True)
17
+ print(hparams_debug_string())
18
+
19
+ # Check for GPU
20
+ if torch.cuda.is_available():
21
+ device = torch.device("cuda")
22
+ if hparams.synthesis_batch_size % torch.cuda.device_count() != 0:
23
+ raise ValueError("`hparams.synthesis_batch_size` must be evenly divisible by n_gpus!")
24
+ else:
25
+ device = torch.device("cpu")
26
+ print("Synthesizer using device:", device)
27
+
28
+ # Instantiate Tacotron model
29
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
30
+ num_chars=len(symbols),
31
+ encoder_dims=hparams.tts_encoder_dims,
32
+ decoder_dims=hparams.tts_decoder_dims,
33
+ n_mels=hparams.num_mels,
34
+ fft_bins=hparams.num_mels,
35
+ postnet_dims=hparams.tts_postnet_dims,
36
+ encoder_K=hparams.tts_encoder_K,
37
+ lstm_dims=hparams.tts_lstm_dims,
38
+ postnet_K=hparams.tts_postnet_K,
39
+ num_highways=hparams.tts_num_highways,
40
+ dropout=0., # Use zero dropout for gta mels
41
+ stop_threshold=hparams.tts_stop_threshold,
42
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
43
+
44
+ # Load the weights
45
+ model_dir = Path(model_dir)
46
+ model_fpath = model_dir.joinpath(model_dir.stem).with_suffix(".pt")
47
+ print("\nLoading weights at %s" % model_fpath)
48
+ model.load(model_fpath)
49
+ print("Tacotron weights loaded from step %d" % model.step)
50
+
51
+ # Synthesize using same reduction factor as the model is currently trained
52
+ r = np.int32(model.r)
53
+
54
+ # Set model to eval mode (disable gradient and zoneout)
55
+ model.eval()
56
+
57
+ # Initialize the dataset
58
+ in_dir = Path(in_dir)
59
+ metadata_fpath = in_dir.joinpath("train.txt")
60
+ mel_dir = in_dir.joinpath("mels")
61
+ embed_dir = in_dir.joinpath("embeds")
62
+
63
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
64
+ data_loader = DataLoader(dataset,
65
+ collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
66
+ batch_size=hparams.synthesis_batch_size,
67
+ num_workers=2 if platform.system() != "Windows" else 0,
68
+ shuffle=False,
69
+ pin_memory=True)
70
+
71
+ # Generate GTA mels
72
+ meta_out_fpath = Path(out_dir).joinpath("synthesized.txt")
73
+ with open(meta_out_fpath, "w") as file:
74
+ for i, (texts, mels, embeds, idx) in tqdm(enumerate(data_loader), total=len(data_loader)):
75
+ texts = texts.to(device)
76
+ mels = mels.to(device)
77
+ embeds = embeds.to(device)
78
+
79
+ # Parallelize model onto GPUS using workaround due to python bug
80
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
81
+ _, mels_out, _ = data_parallel_workaround(model, texts, mels, embeds)
82
+ else:
83
+ _, mels_out, _, _ = model(texts, mels, embeds)
84
+
85
+ for j, k in enumerate(idx):
86
+ # Note: outputs mel-spectrogram files and target ones have same names, just different folders
87
+ mel_filename = Path(synth_dir).joinpath(dataset.metadata[k][1])
88
+ mel_out = mels_out[j].detach().cpu().numpy().T
89
+
90
+ # Use the length of the ground truth mel to remove padding from the generated mels
91
+ mel_out = mel_out[:int(dataset.metadata[k][4])]
92
+
93
+ # Write the spectrogram to disk
94
+ np.save(mel_filename, mel_out, allow_pickle=False)
95
+
96
+ # Write metadata into the synthesized file
97
+ file.write("|".join(dataset.metadata[k]))
synthesizer/synthesizer_dataset.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ from pathlib import Path
5
+ from synthesizer.utils.text import text_to_sequence
6
+
7
+
8
+ class SynthesizerDataset(Dataset):
9
+ def __init__(self, metadata_fpath: Path, mel_dir: Path, embed_dir: Path, hparams):
10
+ print("Using inputs from:\n\t%s\n\t%s\n\t%s" % (metadata_fpath, mel_dir, embed_dir))
11
+
12
+ with metadata_fpath.open("r") as metadata_file:
13
+ metadata = [line.split("|") for line in metadata_file]
14
+
15
+ mel_fnames = [x[1] for x in metadata if int(x[4])]
16
+ mel_fpaths = [mel_dir.joinpath(fname) for fname in mel_fnames]
17
+ embed_fnames = [x[2] for x in metadata if int(x[4])]
18
+ embed_fpaths = [embed_dir.joinpath(fname) for fname in embed_fnames]
19
+ self.samples_fpaths = list(zip(mel_fpaths, embed_fpaths))
20
+ self.samples_texts = [x[5].strip() for x in metadata if int(x[4])]
21
+ self.metadata = metadata
22
+ self.hparams = hparams
23
+
24
+ print("Found %d samples" % len(self.samples_fpaths))
25
+
26
+ def __getitem__(self, index):
27
+ # Sometimes index may be a list of 2 (not sure why this happens)
28
+ # If that is the case, return a single item corresponding to first element in index
29
+ if index is list:
30
+ index = index[0]
31
+
32
+ mel_path, embed_path = self.samples_fpaths[index]
33
+ mel = np.load(mel_path).T.astype(np.float32)
34
+
35
+ # Load the embed
36
+ embed = np.load(embed_path)
37
+
38
+ # Get the text and clean it
39
+ text = text_to_sequence(self.samples_texts[index], self.hparams.tts_cleaner_names)
40
+
41
+ # Convert the list returned by text_to_sequence to a numpy array
42
+ text = np.asarray(text).astype(np.int32)
43
+
44
+ return text, mel.astype(np.float32), embed.astype(np.float32), index
45
+
46
+ def __len__(self):
47
+ return len(self.samples_fpaths)
48
+
49
+
50
+ def collate_synthesizer(batch, r, hparams):
51
+ # Text
52
+ x_lens = [len(x[0]) for x in batch]
53
+ max_x_len = max(x_lens)
54
+
55
+ chars = [pad1d(x[0], max_x_len) for x in batch]
56
+ chars = np.stack(chars)
57
+
58
+ # Mel spectrogram
59
+ spec_lens = [x[1].shape[-1] for x in batch]
60
+ max_spec_len = max(spec_lens) + 1
61
+ if max_spec_len % r != 0:
62
+ max_spec_len += r - max_spec_len % r
63
+
64
+ # WaveRNN mel spectrograms are normalized to [0, 1] so zero padding adds silence
65
+ # By default, SV2TTS uses symmetric mels, where -1*max_abs_value is silence.
66
+ if hparams.symmetric_mels:
67
+ mel_pad_value = -1 * hparams.max_abs_value
68
+ else:
69
+ mel_pad_value = 0
70
+
71
+ mel = [pad2d(x[1], max_spec_len, pad_value=mel_pad_value) for x in batch]
72
+ mel = np.stack(mel)
73
+
74
+ # Speaker embedding (SV2TTS)
75
+ embeds = [x[2] for x in batch]
76
+
77
+ # Index (for vocoder preprocessing)
78
+ indices = [x[3] for x in batch]
79
+
80
+
81
+ # Convert all to tensor
82
+ chars = torch.tensor(chars).long()
83
+ mel = torch.tensor(mel)
84
+ embeds = torch.tensor(embeds)
85
+
86
+ return chars, mel, embeds, indices
87
+
88
+ def pad1d(x, max_len, pad_value=0):
89
+ return np.pad(x, (0, max_len - len(x)), mode="constant", constant_values=pad_value)
90
+
91
+ def pad2d(x, max_len, pad_value=0):
92
+ return np.pad(x, ((0, 0), (0, max_len - x.shape[-1])), mode="constant", constant_values=pad_value)
synthesizer/train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import optim
4
+ from torch.utils.data import DataLoader
5
+ from synthesizer import audio
6
+ from synthesizer.models.tacotron import Tacotron
7
+ from synthesizer.synthesizer_dataset import SynthesizerDataset, collate_synthesizer
8
+ from synthesizer.utils import ValueWindow, data_parallel_workaround
9
+ from synthesizer.utils.plot import plot_spectrogram
10
+ from synthesizer.utils.symbols import symbols
11
+ from synthesizer.utils.text import sequence_to_text
12
+ from vocoder.display import *
13
+ from datetime import datetime
14
+ import numpy as np
15
+ from pathlib import Path
16
+ import sys
17
+ import time
18
+ import platform
19
+
20
+
21
+ def np_now(x: torch.Tensor): return x.detach().cpu().numpy()
22
+
23
+ def time_string():
24
+ return datetime.now().strftime("%Y-%m-%d %H:%M")
25
+
26
+ def train(run_id: str, syn_dir: str, models_dir: str, save_every: int,
27
+ backup_every: int, force_restart:bool, hparams):
28
+
29
+ syn_dir = Path(syn_dir)
30
+ models_dir = Path(models_dir)
31
+ models_dir.mkdir(exist_ok=True)
32
+
33
+ model_dir = models_dir.joinpath(run_id)
34
+ plot_dir = model_dir.joinpath("plots")
35
+ wav_dir = model_dir.joinpath("wavs")
36
+ mel_output_dir = model_dir.joinpath("mel-spectrograms")
37
+ meta_folder = model_dir.joinpath("metas")
38
+ model_dir.mkdir(exist_ok=True)
39
+ plot_dir.mkdir(exist_ok=True)
40
+ wav_dir.mkdir(exist_ok=True)
41
+ mel_output_dir.mkdir(exist_ok=True)
42
+ meta_folder.mkdir(exist_ok=True)
43
+
44
+ weights_fpath = model_dir.joinpath(run_id).with_suffix(".pt")
45
+ metadata_fpath = syn_dir.joinpath("train.txt")
46
+
47
+ print("Checkpoint path: {}".format(weights_fpath))
48
+ print("Loading training data from: {}".format(metadata_fpath))
49
+ print("Using model: Tacotron")
50
+
51
+ # Book keeping
52
+ step = 0
53
+ time_window = ValueWindow(100)
54
+ loss_window = ValueWindow(100)
55
+
56
+
57
+ # From WaveRNN/train_tacotron.py
58
+ if torch.cuda.is_available():
59
+ device = torch.device("cuda")
60
+
61
+ for session in hparams.tts_schedule:
62
+ _, _, _, batch_size = session
63
+ if batch_size % torch.cuda.device_count() != 0:
64
+ raise ValueError("`batch_size` must be evenly divisible by n_gpus!")
65
+ else:
66
+ device = torch.device("cpu")
67
+ print("Using device:", device)
68
+
69
+ # Instantiate Tacotron Model
70
+ print("\nInitialising Tacotron Model...\n")
71
+ model = Tacotron(embed_dims=hparams.tts_embed_dims,
72
+ num_chars=len(symbols),
73
+ encoder_dims=hparams.tts_encoder_dims,
74
+ decoder_dims=hparams.tts_decoder_dims,
75
+ n_mels=hparams.num_mels,
76
+ fft_bins=hparams.num_mels,
77
+ postnet_dims=hparams.tts_postnet_dims,
78
+ encoder_K=hparams.tts_encoder_K,
79
+ lstm_dims=hparams.tts_lstm_dims,
80
+ postnet_K=hparams.tts_postnet_K,
81
+ num_highways=hparams.tts_num_highways,
82
+ dropout=hparams.tts_dropout,
83
+ stop_threshold=hparams.tts_stop_threshold,
84
+ speaker_embedding_size=hparams.speaker_embedding_size).to(device)
85
+
86
+ # Initialize the optimizer
87
+ optimizer = optim.Adam(model.parameters())
88
+
89
+ # Load the weights
90
+ if force_restart or not weights_fpath.exists():
91
+ print("\nStarting the training of Tacotron from scratch\n")
92
+ model.save(weights_fpath)
93
+
94
+ # Embeddings metadata
95
+ char_embedding_fpath = meta_folder.joinpath("CharacterEmbeddings.tsv")
96
+ with open(char_embedding_fpath, "w", encoding="utf-8") as f:
97
+ for symbol in symbols:
98
+ if symbol == " ":
99
+ symbol = "\\s" # For visual purposes, swap space with \s
100
+
101
+ f.write("{}\n".format(symbol))
102
+
103
+ else:
104
+ print("\nLoading weights at %s" % weights_fpath)
105
+ model.load(weights_fpath, optimizer)
106
+ print("Tacotron weights loaded from step %d" % model.step)
107
+
108
+ # Initialize the dataset
109
+ metadata_fpath = syn_dir.joinpath("train.txt")
110
+ mel_dir = syn_dir.joinpath("mels")
111
+ embed_dir = syn_dir.joinpath("embeds")
112
+ dataset = SynthesizerDataset(metadata_fpath, mel_dir, embed_dir, hparams)
113
+ test_loader = DataLoader(dataset,
114
+ batch_size=1,
115
+ shuffle=True,
116
+ pin_memory=True)
117
+
118
+ for i, session in enumerate(hparams.tts_schedule):
119
+ current_step = model.get_step()
120
+
121
+ r, lr, max_step, batch_size = session
122
+
123
+ training_steps = max_step - current_step
124
+
125
+ # Do we need to change to the next session?
126
+ if current_step >= max_step:
127
+ # Are there no further sessions than the current one?
128
+ if i == len(hparams.tts_schedule) - 1:
129
+ # We have completed training. Save the model and exit
130
+ model.save(weights_fpath, optimizer)
131
+ break
132
+ else:
133
+ # There is a following session, go to it
134
+ continue
135
+
136
+ model.r = r
137
+
138
+ # Begin the training
139
+ simple_table([(f"Steps with r={r}", str(training_steps // 1000) + "k Steps"),
140
+ ("Batch Size", batch_size),
141
+ ("Learning Rate", lr),
142
+ ("Outputs/Step (r)", model.r)])
143
+
144
+ for p in optimizer.param_groups:
145
+ p["lr"] = lr
146
+
147
+ data_loader = DataLoader(dataset,
148
+ collate_fn=lambda batch: collate_synthesizer(batch, r, hparams),
149
+ batch_size=batch_size,
150
+ num_workers=2 if platform.system() != "Windows" else 0,
151
+ shuffle=True,
152
+ pin_memory=True)
153
+
154
+ total_iters = len(dataset)
155
+ steps_per_epoch = np.ceil(total_iters / batch_size).astype(np.int32)
156
+ epochs = np.ceil(training_steps / steps_per_epoch).astype(np.int32)
157
+
158
+ for epoch in range(1, epochs+1):
159
+ for i, (texts, mels, embeds, idx) in enumerate(data_loader, 1):
160
+ start_time = time.time()
161
+
162
+ # Generate stop tokens for training
163
+ stop = torch.ones(mels.shape[0], mels.shape[2])
164
+ for j, k in enumerate(idx):
165
+ stop[j, :int(dataset.metadata[k][4])-1] = 0
166
+
167
+ texts = texts.to(device)
168
+ mels = mels.to(device)
169
+ embeds = embeds.to(device)
170
+ stop = stop.to(device)
171
+
172
+ # Forward pass
173
+ # Parallelize model onto GPUS using workaround due to python bug
174
+ if device.type == "cuda" and torch.cuda.device_count() > 1:
175
+ m1_hat, m2_hat, attention, stop_pred = data_parallel_workaround(model, texts,
176
+ mels, embeds)
177
+ else:
178
+ m1_hat, m2_hat, attention, stop_pred = model(texts, mels, embeds)
179
+
180
+ # Backward pass
181
+ m1_loss = F.mse_loss(m1_hat, mels) + F.l1_loss(m1_hat, mels)
182
+ m2_loss = F.mse_loss(m2_hat, mels)
183
+ stop_loss = F.binary_cross_entropy(stop_pred, stop)
184
+
185
+ loss = m1_loss + m2_loss + stop_loss
186
+
187
+ optimizer.zero_grad()
188
+ loss.backward()
189
+
190
+ if hparams.tts_clip_grad_norm is not None:
191
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hparams.tts_clip_grad_norm)
192
+ if np.isnan(grad_norm.cpu()):
193
+ print("grad_norm was NaN!")
194
+
195
+ optimizer.step()
196
+
197
+ time_window.append(time.time() - start_time)
198
+ loss_window.append(loss.item())
199
+
200
+ step = model.get_step()
201
+ k = step // 1000
202
+
203
+ msg = f"| Epoch: {epoch}/{epochs} ({i}/{steps_per_epoch}) | Loss: {loss_window.average:#.4} | {1./time_window.average:#.2} steps/s | Step: {k}k | "
204
+ stream(msg)
205
+
206
+ # Backup or save model as appropriate
207
+ if backup_every != 0 and step % backup_every == 0 :
208
+ backup_fpath = Path("{}/{}_{}k.pt".format(str(weights_fpath.parent), run_id, k))
209
+ model.save(backup_fpath, optimizer)
210
+
211
+ if save_every != 0 and step % save_every == 0 :
212
+ # Must save latest optimizer state to ensure that resuming training
213
+ # doesn't produce artifacts
214
+ model.save(weights_fpath, optimizer)
215
+
216
+ # Evaluate model to generate samples
217
+ epoch_eval = hparams.tts_eval_interval == -1 and i == steps_per_epoch # If epoch is done
218
+ step_eval = hparams.tts_eval_interval > 0 and step % hparams.tts_eval_interval == 0 # Every N steps
219
+ if epoch_eval or step_eval:
220
+ for sample_idx in range(hparams.tts_eval_num_samples):
221
+ # At most, generate samples equal to number in the batch
222
+ if sample_idx + 1 <= len(texts):
223
+ # Remove padding from mels using frame length in metadata
224
+ mel_length = int(dataset.metadata[idx[sample_idx]][4])
225
+ mel_prediction = np_now(m2_hat[sample_idx]).T[:mel_length]
226
+ target_spectrogram = np_now(mels[sample_idx]).T[:mel_length]
227
+ attention_len = mel_length // model.r
228
+
229
+ eval_model(attention=np_now(attention[sample_idx][:, :attention_len]),
230
+ mel_prediction=mel_prediction,
231
+ target_spectrogram=target_spectrogram,
232
+ input_seq=np_now(texts[sample_idx]),
233
+ step=step,
234
+ plot_dir=plot_dir,
235
+ mel_output_dir=mel_output_dir,
236
+ wav_dir=wav_dir,
237
+ sample_num=sample_idx + 1,
238
+ loss=loss,
239
+ hparams=hparams)
240
+
241
+ # Break out of loop to update training schedule
242
+ if step >= max_step:
243
+ break
244
+
245
+ # Add line break after every epoch
246
+ print("")
247
+
248
+ def eval_model(attention, mel_prediction, target_spectrogram, input_seq, step,
249
+ plot_dir, mel_output_dir, wav_dir, sample_num, loss, hparams):
250
+ # Save some results for evaluation
251
+ attention_path = str(plot_dir.joinpath("attention_step_{}_sample_{}".format(step, sample_num)))
252
+ save_attention(attention, attention_path)
253
+
254
+ # save predicted mel spectrogram to disk (debug)
255
+ mel_output_fpath = mel_output_dir.joinpath("mel-prediction-step-{}_sample_{}.npy".format(step, sample_num))
256
+ np.save(str(mel_output_fpath), mel_prediction, allow_pickle=False)
257
+
258
+ # save griffin lim inverted wav for debug (mel -> wav)
259
+ wav = audio.inv_mel_spectrogram(mel_prediction.T, hparams)
260
+ wav_fpath = wav_dir.joinpath("step-{}-wave-from-mel_sample_{}.wav".format(step, sample_num))
261
+ audio.save_wav(wav, str(wav_fpath), sr=hparams.sample_rate)
262
+
263
+ # save real and predicted mel-spectrogram plot to disk (control purposes)
264
+ spec_fpath = plot_dir.joinpath("step-{}-mel-spectrogram_sample_{}.png".format(step, sample_num))
265
+ title_str = "{}, {}, step={}, loss={:.5f}".format("Tacotron", time_string(), step, loss)
266
+ plot_spectrogram(mel_prediction, str(spec_fpath), title=title_str,
267
+ target_spectrogram=target_spectrogram,
268
+ max_len=target_spectrogram.size // hparams.num_mels)
269
+ print("Input at step {}: {}".format(step, sequence_to_text(input_seq)))
synthesizer/utils/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ _output_ref = None
5
+ _replicas_ref = None
6
+
7
+ def data_parallel_workaround(model, *input):
8
+ global _output_ref
9
+ global _replicas_ref
10
+ device_ids = list(range(torch.cuda.device_count()))
11
+ output_device = device_ids[0]
12
+ replicas = torch.nn.parallel.replicate(model, device_ids)
13
+ # input.shape = (num_args, batch, ...)
14
+ inputs = torch.nn.parallel.scatter(input, device_ids)
15
+ # inputs.shape = (num_gpus, num_args, batch/num_gpus, ...)
16
+ replicas = replicas[:len(inputs)]
17
+ outputs = torch.nn.parallel.parallel_apply(replicas, inputs)
18
+ y_hat = torch.nn.parallel.gather(outputs, output_device)
19
+ _output_ref = outputs
20
+ _replicas_ref = replicas
21
+ return y_hat
22
+
23
+
24
+ class ValueWindow():
25
+ def __init__(self, window_size=100):
26
+ self._window_size = window_size
27
+ self._values = []
28
+
29
+ def append(self, x):
30
+ self._values = self._values[-(self._window_size - 1):] + [x]
31
+
32
+ @property
33
+ def sum(self):
34
+ return sum(self._values)
35
+
36
+ @property
37
+ def count(self):
38
+ return len(self._values)
39
+
40
+ @property
41
+ def average(self):
42
+ return self.sum / max(1, self.count)
43
+
44
+ def reset(self):
45
+ self._values = []
synthesizer/utils/_cmudict.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ valid_symbols = [
4
+ "AA", "AA0", "AA1", "AA2", "AE", "AE0", "AE1", "AE2", "AH", "AH0", "AH1", "AH2",
5
+ "AO", "AO0", "AO1", "AO2", "AW", "AW0", "AW1", "AW2", "AY", "AY0", "AY1", "AY2",
6
+ "B", "CH", "D", "DH", "EH", "EH0", "EH1", "EH2", "ER", "ER0", "ER1", "ER2", "EY",
7
+ "EY0", "EY1", "EY2", "F", "G", "HH", "IH", "IH0", "IH1", "IH2", "IY", "IY0", "IY1",
8
+ "IY2", "JH", "K", "L", "M", "N", "NG", "OW", "OW0", "OW1", "OW2", "OY", "OY0",
9
+ "OY1", "OY2", "P", "R", "S", "SH", "T", "TH", "UH", "UH0", "UH1", "UH2", "UW",
10
+ "UW0", "UW1", "UW2", "V", "W", "Y", "Z", "ZH"
11
+ ]
12
+
13
+ _valid_symbol_set = set(valid_symbols)
14
+
15
+
16
+ class CMUDict:
17
+ """Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict"""
18
+ def __init__(self, file_or_path, keep_ambiguous=True):
19
+ if isinstance(file_or_path, str):
20
+ with open(file_or_path, encoding="latin-1") as f:
21
+ entries = _parse_cmudict(f)
22
+ else:
23
+ entries = _parse_cmudict(file_or_path)
24
+ if not keep_ambiguous:
25
+ entries = {word: pron for word, pron in entries.items() if len(pron) == 1}
26
+ self._entries = entries
27
+
28
+
29
+ def __len__(self):
30
+ return len(self._entries)
31
+
32
+
33
+ def lookup(self, word):
34
+ """Returns list of ARPAbet pronunciations of the given word."""
35
+ return self._entries.get(word.upper())
36
+
37
+
38
+
39
+ _alt_re = re.compile(r"\([0-9]+\)")
40
+
41
+
42
+ def _parse_cmudict(file):
43
+ cmudict = {}
44
+ for line in file:
45
+ if len(line) and (line[0] >= "A" and line[0] <= "Z" or line[0] == "'"):
46
+ parts = line.split(" ")
47
+ word = re.sub(_alt_re, "", parts[0])
48
+ pronunciation = _get_pronunciation(parts[1])
49
+ if pronunciation:
50
+ if word in cmudict:
51
+ cmudict[word].append(pronunciation)
52
+ else:
53
+ cmudict[word] = [pronunciation]
54
+ return cmudict
55
+
56
+
57
+ def _get_pronunciation(s):
58
+ parts = s.strip().split(" ")
59
+ for part in parts:
60
+ if part not in _valid_symbol_set:
61
+ return None
62
+ return " ".join(parts)
synthesizer/utils/cleaners.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Cleaners are transformations that run over the input text at both training and eval time.
3
+
4
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
5
+ hyperparameter. Some cleaners are English-specific. You"ll typically want to use:
6
+ 1. "english_cleaners" for English text
7
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
8
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
9
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
10
+ the symbols in symbols.py to match your data).
11
+ """
12
+
13
+ import re
14
+ from unidecode import unidecode
15
+ from .numbers import normalize_numbers
16
+
17
+ # Regular expression matching whitespace:
18
+ _whitespace_re = re.compile(r"\s+")
19
+
20
+ # List of (regular expression, replacement) pairs for abbreviations:
21
+ _abbreviations = [(re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [
22
+ ("mrs", "misess"),
23
+ ("mr", "mister"),
24
+ ("dr", "doctor"),
25
+ ("st", "saint"),
26
+ ("co", "company"),
27
+ ("jr", "junior"),
28
+ ("maj", "major"),
29
+ ("gen", "general"),
30
+ ("drs", "doctors"),
31
+ ("rev", "reverend"),
32
+ ("lt", "lieutenant"),
33
+ ("hon", "honorable"),
34
+ ("sgt", "sergeant"),
35
+ ("capt", "captain"),
36
+ ("esq", "esquire"),
37
+ ("ltd", "limited"),
38
+ ("col", "colonel"),
39
+ ("ft", "fort"),
40
+ ]]
41
+
42
+
43
+ def expand_abbreviations(text):
44
+ for regex, replacement in _abbreviations:
45
+ text = re.sub(regex, replacement, text)
46
+ return text
47
+
48
+
49
+ def expand_numbers(text):
50
+ return normalize_numbers(text)
51
+
52
+
53
+ def lowercase(text):
54
+ """lowercase input tokens."""
55
+ return text.lower()
56
+
57
+
58
+ def collapse_whitespace(text):
59
+ return re.sub(_whitespace_re, " ", text)
60
+
61
+
62
+ def convert_to_ascii(text):
63
+ return unidecode(text)
64
+
65
+
66
+ def basic_cleaners(text):
67
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
68
+ text = lowercase(text)
69
+ text = collapse_whitespace(text)
70
+ return text
71
+
72
+
73
+ def transliteration_cleaners(text):
74
+ """Pipeline for non-English text that transliterates to ASCII."""
75
+ text = convert_to_ascii(text)
76
+ text = lowercase(text)
77
+ text = collapse_whitespace(text)
78
+ return text
79
+
80
+
81
+ def english_cleaners(text):
82
+ """Pipeline for English text, including number and abbreviation expansion."""
83
+ text = convert_to_ascii(text)
84
+ text = lowercase(text)
85
+ text = expand_numbers(text)
86
+ text = expand_abbreviations(text)
87
+ text = collapse_whitespace(text)
88
+ return text
synthesizer/utils/numbers.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import inflect
3
+
4
+ _inflect = inflect.engine()
5
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
6
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
7
+ _pounds_re = re.compile(r"£([0-9\,]*[0-9]+)")
8
+ _dollars_re = re.compile(r"\$([0-9\.\,]*[0-9]+)")
9
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
10
+ _number_re = re.compile(r"[0-9]+")
11
+
12
+
13
+ def _remove_commas(m):
14
+ return m.group(1).replace(",", "")
15
+
16
+
17
+ def _expand_decimal_point(m):
18
+ return m.group(1).replace(".", " point ")
19
+
20
+
21
+ def _expand_dollars(m):
22
+ match = m.group(1)
23
+ parts = match.split(".")
24
+ if len(parts) > 2:
25
+ return match + " dollars" # Unexpected format
26
+ dollars = int(parts[0]) if parts[0] else 0
27
+ cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
28
+ if dollars and cents:
29
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
30
+ cent_unit = "cent" if cents == 1 else "cents"
31
+ return "%s %s, %s %s" % (dollars, dollar_unit, cents, cent_unit)
32
+ elif dollars:
33
+ dollar_unit = "dollar" if dollars == 1 else "dollars"
34
+ return "%s %s" % (dollars, dollar_unit)
35
+ elif cents:
36
+ cent_unit = "cent" if cents == 1 else "cents"
37
+ return "%s %s" % (cents, cent_unit)
38
+ else:
39
+ return "zero dollars"
40
+
41
+
42
+ def _expand_ordinal(m):
43
+ return _inflect.number_to_words(m.group(0))
44
+
45
+
46
+ def _expand_number(m):
47
+ num = int(m.group(0))
48
+ if num > 1000 and num < 3000:
49
+ if num == 2000:
50
+ return "two thousand"
51
+ elif num > 2000 and num < 2010:
52
+ return "two thousand " + _inflect.number_to_words(num % 100)
53
+ elif num % 100 == 0:
54
+ return _inflect.number_to_words(num // 100) + " hundred"
55
+ else:
56
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
57
+ else:
58
+ return _inflect.number_to_words(num, andword="")
59
+
60
+
61
+ def normalize_numbers(text):
62
+ text = re.sub(_comma_number_re, _remove_commas, text)
63
+ text = re.sub(_pounds_re, r"\1 pounds", text)
64
+ text = re.sub(_dollars_re, _expand_dollars, text)
65
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
66
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
67
+ text = re.sub(_number_re, _expand_number, text)
68
+ return text